Compare commits

...

109 Commits

Author SHA1 Message Date
Bruce MacDonald
6ee8c80199 restore model load duration on generate response (#1524)
* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
2023-12-14 12:15:50 -05:00
Jeffrey Morgan
31f0551dab Update runner to support mixtral and mixture of experts (MoE) (#1475) 2023-12-13 17:15:10 -05:00
Jeffrey Morgan
4a1abfe4fa fix tests 2023-12-13 14:42:30 -05:00
Jeffrey Morgan
bbd41494bf add multimodal to README.md 2023-12-13 14:38:47 -05:00
Jeffrey Morgan
fedba24a63 Docs for multimodal support (#1485)
* add multimodal docs

* add chat api docs

* consistency between `/api/generate` and `/api/chat`

* simplify docs
2023-12-13 13:59:33 -05:00
pepperoni21
e3b090dbc5 Added message format for chat api (#1488) 2023-12-13 11:21:23 -05:00
Patrick Devine
d9e60f634b add image support to the chat api (#1490) 2023-12-12 13:28:58 -08:00
Michael Yang
4251b342de Merge pull request #1469 from jmorganca/mxyng/model-types
remove per-model types
2023-12-12 12:27:03 -08:00
Jeffrey Morgan
0a9d348023 Fix issues with /set template and /set system (#1486) 2023-12-12 14:43:19 -05:00
Bruce MacDonald
3144e2a439 exponential back-off (#1484) 2023-12-12 12:33:02 -05:00
Bruce MacDonald
c0960e29b5 retry on concurrent request failure (#1483)
- remove parallel
2023-12-12 12:14:35 -05:00
ruecat
5314fc9b63 Fix Readme "Database -> MindsDB" link (#1479) 2023-12-12 10:26:13 -05:00
Jorge Torres
a36b5fef3b Update README.md (#1412) 2023-12-11 18:05:10 -05:00
Patrick Devine
910e9401d0 Multimodal support (#1216)
---------

Co-authored-by: Matt Apperson <mattapperson@Matts-MacBook-Pro.local>
2023-12-11 13:56:22 -08:00
Michael Yang
56ffc3023a remove per-model types
mostly replaced by decoding tensors except ggml models which only
support llama
2023-12-11 09:40:21 -08:00
Bruce MacDonald
7a1b37ac64 os specific ctrl-z (#1420) 2023-12-11 10:48:14 -05:00
Jeffrey Morgan
5d4d2e2c60 update docs with chat completion api 2023-12-10 13:53:36 -05:00
Jeffrey Morgan
7db5bcf73b fix go-staticcheck warning 2023-12-10 11:44:27 -05:00
Jeffrey Morgan
fa2f095bd9 fix model name returned by /api/generate being different than the model name provided 2023-12-10 11:42:15 -05:00
Jeffrey Morgan
045b855db9 fix error on accumulating final chat response 2023-12-10 11:24:39 -05:00
Jeffrey Morgan
32064a0646 fix empty response when receiving runner error 2023-12-10 10:53:38 -05:00
Jeffrey Morgan
d9a250e9b5 seek to end of file when decoding older model formats 2023-12-09 21:14:35 -05:00
Jeffrey Morgan
944519ed16 seek to eof for older model binaries 2023-12-09 20:48:57 -05:00
Jeffrey Morgan
2dd040d04c do not use --parallel 2 for old runners 2023-12-09 20:17:33 -05:00
Bruce MacDonald
bbe41ce41a fix: parallel queueing race condition caused silent failure (#1445)
* fix: queued request failures

- increase parallel requests to 2 to complete queued request, queueing is managed in ollama

* log steam errors
2023-12-09 14:14:02 -05:00
Jeffrey Morgan
9e1406e4ed Don't expose model information in /api/generate 2023-12-09 02:05:43 -08:00
Jeffrey Morgan
b74580c913 Update api.md 2023-12-08 16:02:07 -08:00
Bruce MacDonald
7e9405fd07 fix: encode full previous prompt in context (#1424) 2023-12-08 16:53:51 -05:00
Bruce MacDonald
3b0b8930d4 fix: only flush template in chat when current role encountered (#1426) 2023-12-08 16:44:24 -05:00
Bruce MacDonald
e3f925fc1b fix: restore modelfile system in prompt template (#1425) 2023-12-08 14:20:19 -05:00
Jeffrey Morgan
2a2289fb6b Update api.md 2023-12-08 09:36:45 -08:00
Matt Williams
dd427f499a Merge pull request #1419 from jmorganca/mattw/typescript-simplechat
Simple chat example for typescript
2023-12-07 14:42:24 -08:00
Michael Yang
2ae573c7ed Merge pull request #1421 from jmorganca/mxyng/fix-newline
fix redundant newline
2023-12-07 13:47:23 -08:00
Matt Williams
02fe26c44b update the readme as per bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 13:46:30 -08:00
Michael Yang
16c7548460 fix redundant newline 2023-12-07 13:44:45 -08:00
Matt Williams
fa75998c0d Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:54 -08:00
Matt Williams
5344f886c8 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:37 -08:00
Matt Williams
6cc823c9b5 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:59 -08:00
Matt Williams
b84d34e632 Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:33 -08:00
Matt Williams
30229a913c Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:24 -08:00
Matt Williams
1ade380bd7 Simple chat example for typescript
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 11:48:25 -08:00
Jeffrey Morgan
ba264e9da8 add future version note to chat api docs 2023-12-07 09:42:15 -08:00
Matt Williams
a2405ec831 Merge pull request #1409 from jmorganca/mattw/python-simplechat
Simple chat example
2023-12-06 15:49:45 -08:00
Matt Williams
ce809bb529 Merge branch 'mattw/python-simplechat' of github.com:jmorganca/ollama into mattw/python-simplechat 2023-12-06 15:48:42 -08:00
Matt Williams
76bc4d0458 Cleanup as per Bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 15:44:40 -08:00
Bruce MacDonald
4a02945a15 Update examples/python-simplechat/client.py 2023-12-06 18:36:45 -05:00
Matt Williams
aec742b6d2 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:45 -08:00
Matt Williams
f337642e94 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:35 -08:00
Matt Williams
51131cc6e2 Update examples/python-simplechat/client.py
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:10 -08:00
Matt Williams
43027789dc Simple chat example
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 14:35:58 -08:00
Xe Iaso
f9b7d65e2b docs/tutorials: add bit on how to use Fly GPUs on-demand with Ollama (#1406)
Signed-off-by: Xe Iaso <xe@camellia.finch-kitefin.ts.net>
2023-12-06 14:14:02 -08:00
Michael Yang
1f05d77110 Merge pull request #1244 from jmorganca/brucemacd/no-fail-template
do not fail on unsupported template variables
2023-12-06 13:23:04 -08:00
Michael Yang
c3ff36088b Merge pull request #774 from jmorganca/mxyng/server-version
add version api and show server version in cli
2023-12-06 13:22:55 -08:00
Samuel Calderon
13524b5e72 List "Send chat messages" in table of contents (#1399)
Thank you @calderonsamuel
2023-12-06 12:34:27 -08:00
Michael Yang
f1b049fed8 Merge pull request #1377 from jmorganca/mxyng/qwen
update for qwen
2023-12-06 12:31:51 -08:00
Jeffrey Morgan
97c5696945 fix base urls in chat examples 2023-12-06 12:10:20 -08:00
Bruce MacDonald
47d4e22673 use missingkey in set empty interface when missing 2023-12-05 15:49:05 -08:00
Michael Yang
32f62fbb8e Merge pull request #1334 from jmorganca/mxyng/load-projectors
load projectors
2023-12-05 14:40:53 -08:00
Michael Yang
5d75505ebd return model configuration in generate 2023-12-05 14:39:02 -08:00
Michael Yang
b9495ea162 load projectors 2023-12-05 14:36:12 -08:00
Michael Yang
409bb9674e Merge pull request #1308 from jmorganca/mxyng/split-from
split from into one or more models
2023-12-05 14:33:03 -08:00
Michael Yang
d3479c07a1 Merge pull request #1250 from jmorganca/mxyng/create-layer
refactor layer creation
2023-12-05 14:32:52 -08:00
Michael Yang
b12f1b984f Merge pull request #1393 from jmorganca/mxyng/fix-whitespace
fix: trim space in modelfile fields
2023-12-05 12:18:01 -08:00
Bruce MacDonald
195e3d9dbd chat api endpoint (#1392) 2023-12-05 14:57:33 -05:00
Michael Yang
38fe1a368b fix: trim space in modelfile fields 2023-12-05 11:57:29 -08:00
Michael Yang
4b77fcb2b9 comments 2023-12-05 09:43:50 -08:00
Michael Yang
cde13bcdea cmd: only print server version when different 2023-12-05 09:36:01 -08:00
Michael Yang
0f0cd265a7 cmd: add server version 2023-12-05 09:36:01 -08:00
Michael Yang
0db4706ec2 api: add version api handler 2023-12-05 09:36:01 -08:00
Michael Yang
1ebdbd9694 server: add version handler 2023-12-05 09:36:01 -08:00
Michael Yang
5c59455b59 cmd: use existing cmd context 2023-12-05 09:36:01 -08:00
Jeffrey Morgan
00d06619a1 Revert "chat api (#991)" while context variable is fixed
This reverts commit 7a0899d62d.
2023-12-04 21:16:27 -08:00
Matt Williams
f1ef3f9947 remove mention of gpt-neox in import (#1381)
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-04 20:58:10 -08:00
Michael Yang
5a5dca13b2 comments 2023-12-04 16:59:23 -08:00
Michael Yang
7232f1fa41 go mod tidy 2023-12-04 16:59:23 -08:00
Michael Yang
72e7a49aa9 seek instead of copyn 2023-12-04 16:59:23 -08:00
Michael Yang
a3737cbd33 use NewLayer for CreateBlobHandler 2023-12-04 16:59:23 -08:00
Michael Yang
998f1785b6 add modelfamilies 2023-12-04 16:59:23 -08:00
Michael Yang
70a93057cd refactor layer creation
previous layer creation was not ideal because:

1. it required reading the input file multiple times, once to calculate
   the sha256 checksum, another to write it to disk, and potentially one
   more to decode the underlying gguf
2. used io.ReadSeeker which is prone to user error. if the file isn't
   reset correctly or in the right place, it could end up reading an
   empty file

there are also some brittleness when reading existing layers else
writing the inherited layers will error reading an already closed file

this commit aims to fix these issues by restructuring layer creation.

1. it will now write the layer to a temporary file as well as the hash
   function and move it to the final location on Commit
2. layers are read once once when copied to the destination. exception
   is raw model files which still requires a second read to decode the
   model metadata
2023-12-04 16:59:23 -08:00
Michael Yang
2cb0fa7d40 split from into one or more models 2023-12-04 16:59:23 -08:00
Michael Yang
b2816bca67 unnecessary ReadSeeker for DecodeGGML 2023-12-04 16:59:23 -08:00
Patrick Devine
bf704423c5 revert cli to use /api/generate (#1383) 2023-12-04 16:35:29 -08:00
Bruce MacDonald
7a0899d62d chat api (#991)
- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
2023-12-04 18:01:06 -05:00
Michael Yang
0cca1486dd Merge pull request #1376 from jmorganca/mxyng/rocky-install
install: fix rocky kernel packages
2023-12-04 14:23:43 -08:00
Patrick Devine
2113c9d31a make linewrap still work when the terminal width has changed (#1350) 2023-12-04 14:14:56 -08:00
Michael Yang
6deebf2489 update for qwen 2023-12-04 11:38:05 -08:00
Michael Yang
95cb38ae47 install: fix rocky kernel packages 2023-12-04 11:10:42 -08:00
ruecat
1f126afb2d Ollama Telegram Bot (#1364)
* Add "ollama-telegram" to Extensions & Plugins

* Update README.md
2023-12-03 11:19:55 -08:00
Jeffrey Morgan
f6201a7a6c remove duplicate community integration in README.md 2023-12-02 21:18:13 -08:00
Michael Yang
b3f6c6598f Merge pull request #1349 from jmorganca/mxyng/ctrl-z
handle ctrl+z
2023-12-01 16:21:49 -08:00
Michael Yang
88620e983a handle ctrl+z 2023-12-01 16:15:20 -08:00
Michael Yang
cedae0d17a Merge pull request #1347 from jshph/adapter-hash
Fix adapter loading from SHA hash
2023-12-01 11:08:25 -08:00
Joshua Pham
bb80a597db Fix adapter loading from SHA hash 2023-12-01 13:50:55 -05:00
Patrick Devine
6681d37861 allow setting the system and template for prompts in the repl (#1335) 2023-12-01 09:28:35 -08:00
Michael Yang
0409c1fa59 docker: set PATH, LD_LIBRARY_PATH, and capabilities (#1336)
* docker: set PATH, LD_LIBRARY_PATH, and capabilities

* example: update k8s gpu manifest
2023-11-30 21:16:56 -08:00
Michael Yang
b56e92470a Merge pull request #1229 from jmorganca/mxyng/calculate-as-you-go
revert checksum calculation to calculate-as-you-go
2023-11-30 10:54:38 -08:00
Jeffrey Morgan
5687f1a0cf fix unexpected end of response errors when cancelling in ollama run 2023-11-30 00:30:21 -05:00
James Radtke
7eda3d0c55 Corrected transposed 129 to 192 for OLLAMA_ORIGINS example (#1325) 2023-11-29 22:44:17 -05:00
Bruce MacDonald
7194a07d4d Add chatd to example projects 2023-11-29 21:18:21 -05:00
Michael Yang
13efd5f218 upload: fix PUT retry 2023-11-29 16:38:35 -08:00
Michael Yang
c4bdfffd96 upload: separate progress tracking 2023-11-29 16:38:33 -08:00
Michael Yang
26c63418e0 new hasher 2023-11-29 14:52:41 -08:00
Michael Yang
2799784ac8 revert checksum calculation to calculate-as-you-go 2023-11-29 13:47:58 -08:00
Alec Hammond
91897a606f Add OllamaEmbeddings to python LangChain example (#994)
* Add OllamaEmbeddings to python LangChain example

* typo

---------

Co-authored-by: Alec Hammond <alechammond@fb.com>
2023-11-29 16:25:39 -05:00
Bruce MacDonald
96122b7271 validate model tags on copy (#1323) 2023-11-29 15:54:29 -05:00
jeremiahbuckley
39be7fdb98 fix rhel cuda install (#1321)
Co-authored-by: Cloud User <azureuser@testgpu2.hqzwom21okjenksna4y3c4ymjd.phxx.internal.cloudapp.net>
2023-11-29 14:55:15 -05:00
Timothy Jaeryang Baek
c2e3b89176 fix: disable ':' in tag names (#1280)
Co-authored-by: rootedbox
2023-11-29 13:33:45 -05:00
Patrick Devine
cde31cb220 Allow setting parameters in the REPL (#1294) 2023-11-29 09:56:42 -08:00
ToasterUwU
63097607b2 Correct MacOS Host port example (#1301) 2023-11-29 11:44:03 -05:00
38 changed files with 2281 additions and 825 deletions

View File

@@ -19,5 +19,11 @@ RUN apt-get update && apt-get install -y ca-certificates
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
# set some environment variable for better NVIDIA compatibility
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]

View File

@@ -57,6 +57,7 @@ Here are some example open-source models that can be downloaded:
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
| Vicuna | 7B | 3.8GB | `ollama run vicuna` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
> Note: You should have at least 8 GB of RAM to run the 3B models, 16 GB to run the 7B models, and 32 GB to run the 13B models.
@@ -104,7 +105,7 @@ FROM llama2
# set the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1
# set the system prompt
# set the system message
SYSTEM """
You are Mario from Super Mario Bros. Answer as Mario, the assistant, only.
"""
@@ -158,6 +159,13 @@ For multiline input, you can wrap text with `"""`:
I'm a basic program that prints the famous "Hello, world!" message to the console.
```
### Multimodal models
```
>>> What's in this image? /Users/jmorgan/Desktop/smile.png
The image features a yellow smiley face, which is likely the central focus of the picture.
```
### Pass in prompt as arguments
```
@@ -205,7 +213,8 @@ Finally, in a separate shell, run a model:
## REST API
Ollama has a REST API for running and managing models.
For example, to generate text from a model:
### Generate a response
```
curl http://localhost:11434/api/generate -d '{
@@ -214,14 +223,21 @@ curl http://localhost:11434/api/generate -d '{
}'
```
### Chat with a model
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"messages": [
{ "role": "user", "content": "why is the sky blue?" }
]
}'
```
See the [API documentation](./docs/api.md) for all endpoints.
## Community Integrations
### Mobile
- [Mobile Artificial Intelligence Distribution](https://github.com/MaidFoundation/Maid) (Maid)
### Web & Desktop
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
@@ -233,6 +249,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
### Terminal
@@ -245,6 +262,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [gptel Emacs client](https://github.com/karthink/gptel)
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
### Database
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
### Package managers
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
@@ -276,6 +297,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)

View File

@@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
type ChatResponseFunc func(ChatResponse) error
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
var resp ChatResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
@@ -311,3 +324,15 @@ func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) err
return nil
}
func (c *Client) Version(ctx context.Context) (string, error) {
var version struct {
Version string `json:"version"`
}
if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
return "", err
}
return version.Version, nil
}

View File

@@ -6,6 +6,7 @@ import (
"math"
"os"
"reflect"
"strconv"
"strings"
"time"
)
@@ -30,19 +31,56 @@ func (e StatusError) Error() string {
}
}
type ImageData []byte
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
Images []ImageData `json:"images, omitempty"`
}
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message *Message `json:"message,omitempty"`
Done bool `json:"done"`
Metrics
}
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct {
Runner
@@ -114,11 +152,12 @@ type ShowRequest struct {
}
type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
}
type CopyRequest struct {
@@ -154,10 +193,11 @@ type ListResponse struct {
}
type ModelResponse struct {
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}
type TokenResponse struct {
@@ -172,39 +212,42 @@ type GenerateResponse struct {
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
Metrics
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
type ModelDetails struct {
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
ParameterSize string `json:"parameter_size"`
QuantizationLevel string `json:"quantization_level"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
}
if r.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
if m.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
if m.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
if m.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
if m.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
}
}
@@ -360,3 +403,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {
return nil, fmt.Errorf("unknown parameter '%s'", key)
} else {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}

View File

@@ -17,7 +17,9 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"syscall"
"time"
@@ -36,6 +38,8 @@ import (
"github.com/jmorganca/ollama/version"
)
type ImageData []byte
func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
@@ -133,7 +137,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
if err := client.Create(context.Background(), &request, fn); err != nil {
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -148,7 +152,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// check if the model exists on the server
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
@@ -208,7 +212,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
}
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(context.Background(), &request, fn); err != nil {
if err := client.Push(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -222,7 +226,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
return err
}
models, err := client.List(context.Background())
models, err := client.List(cmd.Context())
if err != nil {
return err
}
@@ -257,7 +261,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, name := range args {
req := api.DeleteRequest{Name: name}
if err := client.Delete(context.Background(), &req); err != nil {
if err := client.Delete(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("deleted '%s'\n", name)
@@ -322,7 +326,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
}
req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(context.Background(), &req)
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
}
@@ -350,7 +354,7 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
req := api.CopyRequest{Source: args[0], Destination: args[1]}
if err := client.Copy(context.Background(), &req); err != nil {
if err := client.Copy(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
@@ -404,7 +408,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(context.Background(), &request, fn); err != nil {
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -412,13 +416,22 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Images: []ImageData{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
@@ -427,34 +440,41 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
// output is being piped
if !term.IsTerminal(int(os.Stdout.Fd())) {
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
wordWrap := os.Getenv("TERM") == "xterm-256color"
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wordWrap = false
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
// prompts are provided via stdin or args so don't enter interactive mode
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
}
return generateInteractive(cmd, args[0], wordWrap, format)
return generateInteractive(cmd, opts)
}
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
type generateOptions struct {
Model string
Prompt string
WordWrap bool
Format string
System string
Template string
Images []ImageData
Options map[string]interface{}
}
func generate(cmd *cobra.Command, opts generateOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -475,34 +495,39 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
wordWrap = false
opts.WordWrap = false
}
cancelCtx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
var abort bool
go func() {
<-sigChan
cancel()
abort = true
}()
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
if wordWrap {
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch)
wordBuffer = ""
currentLineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
@@ -522,28 +547,43 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
}
}
} else {
fmt.Print(response.Response)
fmt.Printf("%s%s", wordBuffer, response.Response)
if len(wordBuffer) > 0 {
wordBuffer = ""
}
}
return nil
}
if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "context canceled") && abort {
images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
Images: images,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
if prompt != "" {
if opts.Prompt != "" {
fmt.Println()
fmt.Println()
}
if !latest.Done {
if abort {
return nil
}
return errors.New("unexpected end of response")
return nil
}
verbose, err := cmd.Flags().GetBool("verbose")
@@ -555,16 +595,48 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
latest.Summary()
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
// get model details
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return false
}
req := api.ShowRequest{Name: name}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return false
}
return slices.Contains(resp.Details.Families, "clip")
}
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)
// load the model
if err := generate(cmd, model, "", false, ""); err != nil {
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
}
if err := generate(cmd, loadOpts); err != nil {
return err
}
@@ -581,14 +653,17 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "")
}
@@ -597,11 +672,27 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, " /show license Show model license")
fmt.Fprintln(os.Stderr, " /show modelfile Show Modelfile for this model")
fmt.Fprintln(os.Stderr, " /show parameters Show parameters for this model")
fmt.Fprintln(os.Stderr, " /show system Show system prompt")
fmt.Fprintln(os.Stderr, " /show system Show system message")
fmt.Fprintln(os.Stderr, " /show template Show prompt template")
fmt.Fprintln(os.Stderr, "")
}
// only list out the most common parameters
usageParameters := func() {
fmt.Fprintln(os.Stderr, "Available Parameters:")
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -615,6 +706,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var multiline MultilineState
var prompt string
for {
@@ -641,16 +733,30 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
// if the prompt so far starts with """ then we're in multiline mode
// and we need to keep reading until we find a line that ends with """
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n"
prompt += cut
if !found {
prompt += "\n"
continue
}
prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false
switch multiline {
case MultilineSystem:
opts.System = prompt
prompt = ""
fmt.Println("Set system message.")
case MultilineTemplate:
opts.Template = prompt
prompt = ""
fmt.Println("Set prompt template.")
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
scanner.Prompt.UseAlt = true
multiline = MultilinePrompt
prompt += line + "\n"
continue
case scanner.Pasting:
@@ -670,10 +776,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
cmd.Flags().Set("verbose", "true")
@@ -685,12 +791,60 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
opts.Format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
usageParameters()
continue
}
var params []string
for _, p := range args[3:] {
params = append(params, p)
}
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut
if found {
if args[1] == "system" {
opts.System = prompt
fmt.Println("Set system message.")
} else {
opts.Template = prompt
fmt.Println("Set prompt template.")
}
prompt = ""
} else {
prompt = `"""` + prompt + "\n"
if args[1] == "system" {
multiline = MultilineSystem
} else {
multiline = MultilineTemplate
}
scanner.Prompt.UseAlt = true
}
} else {
opts.System = line
fmt.Println("Set system message.")
}
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
@@ -705,7 +859,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Println("error: couldn't connect to ollama server")
return err
}
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model})
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
if err != nil {
fmt.Println("error: couldn't get model")
return err
@@ -724,19 +878,33 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if resp.Parameters == "" {
fmt.Print("No parameters were specified for this model.\n\n")
} else {
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
}
fmt.Println("Model defined parameters:")
fmt.Println(resp.Parameters)
}
case "system":
if resp.System == "" {
fmt.Print("No system prompt was specified for this model.\n\n")
} else {
fmt.Println(resp.System)
switch {
case opts.System != "":
fmt.Println(opts.System + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Print("No system message was specified for this model.\n\n")
}
case "template":
if resp.Template == "" {
fmt.Print("No prompt template was specified for this model.\n\n")
} else {
switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template)
default:
fmt.Print("No prompt template was specified for this model.\n\n")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@@ -766,8 +934,30 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
prompt += line
}
if len(prompt) > 0 && prompt[0] != '/' {
if err := generate(cmd, model, prompt, wordWrap, format); err != nil {
if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if multiModal {
newPrompt, images, err := extractFileNames(prompt)
if err != nil {
return err
}
opts.Prompt = newPrompt
// reset the context if we find another image
if len(images) > 0 {
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
}
if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println()
prompt = ""
continue
}
}
if err := generate(cmd, opts); err != nil {
return err
}
@@ -776,6 +966,57 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
}
}
func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements
replacements := map[string]string{
"\\ ": " ", // Escaped space
"\\(": "(", // Escaped left parenthesis
"\\)": ")", // Escaped right parenthesis
"\\[": "[", // Escaped left square bracket
"\\]": "]", // Escaped right square bracket
"\\{": "{", // Escaped left curly brace
"\\}": "}", // Escaped right curly brace
"\\$": "$", // Escaped dollar sign
"\\&": "&", // Escaped ampersand
"\\;": ";", // Escaped semicolon
"\\'": "'", // Escaped single quote
"\\\\": "\\", // Escaped backslash
"\\*": "*", // Escaped asterisk
"\\?": "?", // Escaped question mark
}
for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
}
func extractFileNames(input string) (string, []ImageData, error) {
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
// and followed by more characters and a file extension
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
re := regexp.MustCompile(regexPattern)
filePaths := re.FindAllString(input, -1)
var imgs []ImageData
for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if err != nil {
if os.IsNotExist(err) {
continue
}
fmt.Printf("Couldn't process image: %q\n", err)
return "", imgs, err
}
fmt.Printf("Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return input, imgs, nil
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
if err != nil {
@@ -802,6 +1043,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return server.Serve(ln, origins)
}
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}
info, err := file.Stat()
if err != nil {
return nil, err
}
// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize {
return nil, fmt.Errorf("file size exceeds maximum limit (100MB)")
}
buf = make([]byte, info.Size())
_, err = file.Seek(0, 0)
if err != nil {
return nil, err
}
_, err = io.ReadFull(file, buf)
if err != nil {
return nil, err
}
return buf, nil
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
@@ -851,7 +1136,7 @@ func initializeKeypair() error {
return nil
}
func startMacApp(client *api.Client) error {
func startMacApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
@@ -875,24 +1160,24 @@ func startMacApp(client *api.Client) error {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(context.Background()); err == nil {
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := client.Heartbeat(context.Background()); err != nil {
if err := client.Heartbeat(cmd.Context()); err != nil {
if !strings.Contains(err.Error(), "connection refused") {
return err
}
if runtime.GOOS == "darwin" {
if err := startMacApp(client); err != nil {
if err := startMacApp(cmd.Context(), client); err != nil {
return fmt.Errorf("could not connect to ollama app, is it running?")
}
} else {
@@ -902,8 +1187,29 @@ func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
return nil
}
func versionHandler(cmd *cobra.Command, _ []string) {
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
serverVersion, err := client.Version(cmd.Context())
if err != nil {
fmt.Println("Warning: could not connect to a running Ollama instance")
}
if serverVersion != "" {
fmt.Printf("ollama version is %s\n", serverVersion)
}
if serverVersion != version.Version {
fmt.Printf("Warning: client version is %s\n", version.Version)
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
rootCmd := &cobra.Command{
Use: "ollama",
@@ -913,10 +1219,17 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Version: version.Version,
Run: func(cmd *cobra.Command, args []string) {
if version, _ := cmd.Flags().GetBool("version"); version {
versionHandler(cmd, args)
return
}
cmd.Print(cmd.UsageString())
},
}
cobra.EnableCommandSorting = false
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -940,7 +1253,7 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model")
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system prompt of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",

View File

@@ -3,6 +3,7 @@
## Endpoints
- [Generate a completion](#generate-a-completion)
- [Generate a chat completion](#generate-a-chat-completion)
- [Create a Model](#create-a-model)
- [List Local Models](#list-local-models)
- [Show Model Information](#show-model-information)
@@ -24,7 +25,7 @@ All durations are returned in nanoseconds.
### Streaming responses
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
Certain endpoints stream responses as JSON objects.
## Generate a completion
@@ -32,22 +33,23 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
POST /api/generate
```
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
- `model`: (required) the [model name](#model-names)
- `prompt`: the prompt to generate a response for
- `images`: a list of base64-encoded images (for multimodal models such as `llava`)
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
### JSON mode
@@ -114,6 +116,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
#### Request (No streaming)
A response can be recieved in one reply when streaming is off.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
@@ -144,9 +148,40 @@ If `stream` is set to `false`, the response will be a single JSON object:
}
```
#### Request (Raw mode)
#### Request (with images)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
To submit images to multimodal models such as `llava` or `bakllava`, provide a list of base64-encoded `images`:
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llava",
"prompt":"What is in this picture?",
"stream": false,
"images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"]
}'
```
#### Response
```
{
"model": "llava",
"created_at": "2023-11-03T15:36:02.583064Z",
"response": "A happy cartoon character, which is cute and cheerful.",
"context": [1, 2, 3],
"done": true,
"total_duration": 14648695333,
"load_duration": 3302671417,
"prompt_eval_count": 14,
"prompt_eval_duration": 286243000,
"eval_count": 129,
"eval_duration": 10931424000
}
```
#### Request (Raw Mode)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
```shell
curl http://localhost:11434/api/generate -d '{
@@ -164,6 +199,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "mistral",
"created_at": "2023-11-03T15:36:02.583064Z",
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
"context": [1, 2, 3],
"done": true,
"total_duration": 14648695333,
"load_duration": 3302671417,
@@ -249,7 +285,7 @@ curl http://localhost:11434/api/generate -d '{
"penalize_newline": true,
"stop": ["\n", "user:"],
"numa": false,
"num_ctx": 4,
"num_ctx": 1024,
"num_batch": 2,
"num_gqa": 1,
"num_gpu": 1,
@@ -264,7 +300,7 @@ curl http://localhost:11434/api/generate -d '{
"rope_frequency_base": 1.1,
"rope_frequency_scale": 0.8,
"num_thread": 8
}
}
}'
```
@@ -275,7 +311,6 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
@@ -288,6 +323,159 @@ curl http://localhost:11434/api/generate -d '{
}
```
## Generate a chat completion
```shell
POST /api/chat
```
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user` or `assistant`
- `content`: the content of the message
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples
#### Request
Send a chat message with a streaming response.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
#### Request (With History)
Send a chat message with a conversation history.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
},
{
"role": "assistant",
"content": "due to rayleigh scattering."
},
{
"role": "user",
"content": "how is that different than mie scattering?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
#### Request (with images)
Send a chat message with a conversation history.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "what is in this image?",
"images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"]
},
]
}'
```
## Create a Model
```shell
@@ -415,7 +603,7 @@ A single JSON object will be returned.
POST /api/show
```
Show details about a model including modelfile, template, parameters, license, and system prompt.
Show information about a model including details, modelfile, template, parameters, license, and system prompt.
### Parameters
@@ -435,10 +623,16 @@ curl http://localhost:11434/api/show -d '{
```json
{
"license": "<contents of license block>",
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llama2:latest\n\nFROM /Users/username/.ollama/models/blobs/sha256:8daa9615cce30c259a9555b1cc250d461d1bc69980a274b44d7eda0be78076d8\nTEMPLATE \"\"\"[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] \"\"\"\nSYSTEM \"\"\"\"\"\"\nPARAMETER stop [INST]\nPARAMETER stop [/INST]\nPARAMETER stop <<SYS>>\nPARAMETER stop <</SYS>>\n",
"parameters": "stop [INST]\nstop [/INST]\nstop <<SYS>>\nstop <</SYS>>",
"template": "[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>\n\n{{ end }}{{ .Prompt }} [/INST] "
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM mike/llava:latest\nTEMPLATE \"\"\"\nUSER:{{ .Prompt }}\nASSISTANT:\n\"\"\"\nPARAMETER num_ctx 4096",
"parameters": "num_ctx 4096",
"template": "\nUSER:{{ .Prompt }}\nASSISTANT:\n",
"license:": "<license>",
"details": {
"format": "gguf",
"families": ["llama", "clip"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
```

View File

@@ -23,7 +23,7 @@ Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with th
On macOS:
```bash
OLLAMA_HOST=0.0.0.0:11435 ollama serve
OLLAMA_HOST=0.0.0.0:11434 ollama serve
```
On Linux:
@@ -59,7 +59,7 @@ OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
On Linux:
```bash
echo 'Environment="OLLAMA_ORIGINS=http://129.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf
echo 'Environment="OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf
```
Reload `systemd` and restart Ollama:

View File

@@ -43,7 +43,6 @@ Ollama supports a set of model architectures, with support for more coming soon:
- Llama & Mistral
- Falcon & RW
- GPT-NeoX
- BigCode
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
@@ -184,9 +183,6 @@ python convert.py <path to model directory>
# FalconForCausalLM
python convert-falcon-hf-to-gguf.py <path to model directory>
# GPTNeoXForCausalLM
python convert-gptneox-hf-to-gguf.py <path to model directory>
# GPTBigCodeForCausalLM
python convert-starcoder-hf-to-gguf.py <path to model directory>
```

View File

@@ -30,14 +30,14 @@ The format of the `Modelfile`:
INSTRUCTION arguments
```
| Instruction | Description |
| ----------------------------------- | ------------------------------------------------------------- |
| [`FROM`](#from-required) (required) | Defines the base model to use. |
| [`PARAMETER`](#parameter) | Sets the parameters for how Ollama will run the model. |
| [`TEMPLATE`](#template) | The full prompt template to be sent to the model. |
| [`SYSTEM`](#system) | Specifies the system prompt that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
| Instruction | Description |
| ----------------------------------- | -------------------------------------------------------------- |
| [`FROM`](#from-required) (required) | Defines the base model to use. |
| [`PARAMETER`](#parameter) | Sets the parameters for how Ollama will run the model. |
| [`TEMPLATE`](#template) | The full prompt template to be sent to the model. |
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
## Examples
@@ -52,7 +52,7 @@ PARAMETER temperature 1
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
PARAMETER num_ctx 4096
# sets a custom system prompt to specify the behavior of the chat assistant
# sets a custom system message to specify the behavior of the chat assistant
SYSTEM You are Mario from super mario bros, acting as an assistant.
```
@@ -70,9 +70,9 @@ More examples are available in the [examples directory](../examples).
There are two ways to view `Modelfile`s underlying the models in [ollama.ai/library][1]:
- Option 1: view a details page from a model's tags page:
1. Go to a particular model's tags (e.g. https://ollama.ai/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.ai/library/llama2:13b)
3. Scroll down to "Layers"
1. Go to a particular model's tags (e.g. https://ollama.ai/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.ai/library/llama2:13b)
3. Scroll down to "Layers"
- Note: if the [`FROM` instruction](#from-required) is not present,
it means the model was created from a local file
- Option 2: use `ollama show` to print the `Modelfile` like so:
@@ -152,15 +152,15 @@ PARAMETER <parameter> <parametervalue>
### TEMPLATE
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
#### Template Variables
| Variable | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------ |
| `{{ .System }}` | The system prompt used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
| Variable | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------- |
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
```modelfile
TEMPLATE """
@@ -180,7 +180,7 @@ SYSTEM """<system message>"""
### SYSTEM
The `SYSTEM` instruction specifies the system prompt to be used in the template, if applicable.
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
```modelfile
SYSTEM """<system message>"""

83
docs/tutorials/fly-gpu.md Normal file
View File

@@ -0,0 +1,83 @@
# Running Ollama on Fly.io GPU Instances
Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
Create a new app with `fly apps create`:
```bash
fly apps create
```
Then create a `fly.toml` file in a new folder that looks like this:
```toml
app = "sparkling-violet-709"
primary_region = "ord"
vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
[build]
image = "ollama/ollama"
[http_service]
internal_port = 11434
force_https = false
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
processes = ["app"]
[mounts]
source = "models"
destination = "/root/.ollama"
initial_size = "100gb"
```
Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
```bash
fly ips allocate-v6 --private
```
Then deploy your app:
```bash
fly deploy
```
And finally you can access it interactively with a new Fly.io Machine:
```
fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
```
```bash
$ ollama run openchat:7b-v3.5-fp16
>>> How do I bake chocolate chip cookies?
To bake chocolate chip cookies, follow these steps:
1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
cup packed brown sugar until light and fluffy.
3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
teaspoon of pure vanilla extract.
4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
to cool completely.
Enjoy your homemade chocolate chip cookies!
```
When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
And that's it!

View File

@@ -42,12 +42,13 @@ text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
```
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. For now, we don't have embeddings built in to Ollama, though we will be adding that soon, so for now, we can use the GPT4All library for that. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb`
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb`
```python
from langchain.embeddings import GPT4AllEmbeddings
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings())
oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="llama2")
vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed)
```
Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text.

View File

@@ -25,9 +25,11 @@ spec:
image: ollama/ollama:latest
env:
- name: PATH
value: /usr/local/nvidia/bin:/usr/local/nvidia/lib64:/usr/bin:/usr/sbin:/bin:/sbin
value: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64
value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: NVIDIA_DRIVER_CAPABILITIES
value: compute,utility
ports:
- name: http
containerPort: 11434

View File

@@ -0,0 +1,46 @@
import json
import requests
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
model = "llama2" # TODO: update this for whatever model you wish to use
def chat(messages):
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
r.raise_for_status()
output = ""
for line in r.iter_lines():
body = json.loads(line)
if "error" in body:
raise Exception(body["error"])
if body.get("done") is False:
message = body.get("message", "")
content = message.get("content", "")
output += content
# the response streams one token at a time, print that as we receive it
print(content, end="", flush=True)
if body.get("done", False):
message["content"] = output
return message
def main():
messages = []
while True:
user_input = input("Enter a prompt: ")
print()
messages.append({"role": "user", "content": user_input})
message = chat(messages)
messages.append(message)
print("\n\n")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of objects with a role and content specified. Then with each output and prompt, you add more of those role/content objects, which builds up the history.
## Review the Code
You can see in the **chat** function that actually calling the endpoint is done simply with:
```python
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself.
In the **main** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

View File

@@ -0,0 +1,77 @@
import * as readline from "readline";
const model = "llama2";
type Message = {
role: "assistant" | "user" | "system";
content: string;
}
const messages: Message[] = [{
role: "system",
content: "You are a helpful AI agent."
}]
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
})
async function chat(messages: Message[]): Promise<Message> {
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
const reader = response.body?.getReader()
if (!reader) {
throw new Error("Failed to read response body")
}
let content = ""
while (true) {
const { done, value } = await reader.read()
if (done) {
break;
}
const rawjson = new TextDecoder().decode(value);
const json = JSON.parse(rawjson)
if (json.done === false) {
process.stdout.write(json.message.content);
content += json.message.content
}
}
return { role: "assistant", content: content };
}
async function askQuestion(): Promise<void> {
return new Promise<void>((resolve) => {
rl.question("\n\nAsk a question: (press enter alone to quit)\n\n", async (user_input) => {
if (user_input.trim() === "") {
rl.close();
console.log("Thankyou. Goodbye.\n")
console.log("=======\nHere is the message history that was used in this conversation.\n=======\n")
messages.forEach(message => {
console.log(message)
})
resolve();
} else {
console.log();
messages.push({ role: "user", content: user_input });
messages.push(await chat(messages));
await askQuestion(); // Ask the next question
}
});
});
}
async function main() {
await askQuestion();
}
main();

View File

@@ -0,0 +1 @@
{ "dependencies": { "@types/node": "^20.10.4", "prompt-sync": "^4.2.0", "readline": "^1.3.0" } }

View File

@@ -0,0 +1,39 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of message objects with a role and content specified. Then with each output and prompt, you add more messages, which builds up the history.
## Run the Example
There are a few ways to run this, just like any Typescript code:
1. Compile with `tsc` and then run it with `node client.js`.
2. Install `tsx` and run it with `tsx client.ts`.
3. Install `bun` and run it with `bun client.ts`.
## Review the Code
You can see in the **chat** function that is actually calling the endpoint is simply done with:
```typescript
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself. In this example, **chat** takes the full array of messages and outputs the resulting message from this call of the chat endpoint.
In the **askQuestion** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message to the messages array.
At the end, you will see a printout of all the messages.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

7
go.mod
View File

@@ -5,14 +5,15 @@ go 1.20
require (
github.com/emirpasic/gods v1.18.1
github.com/gin-gonic/gin v1.9.1
github.com/mattn/go-runewidth v0.0.14
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
golang.org/x/sync v0.3.0
)
require github.com/rivo/uniseg v0.2.0 // indirect
require (
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
)
require (
github.com/bytedance/sonic v1.9.1 // indirect

2
go.sum
View File

@@ -63,8 +63,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -1,20 +0,0 @@
package llm
const (
falconModelType7B = 32
falconModelType40B = 60
falconModelType180B = 80
)
func falconModelType(numLayer uint32) string {
switch numLayer {
case 32:
return "7B"
case 60:
return "40B"
case 80:
return "180B"
default:
return "unknown"
}
}

View File

@@ -7,9 +7,10 @@ import (
)
type GGML struct {
magic uint32
container
model
Size int64
}
const (
@@ -82,7 +83,7 @@ type model interface {
type container interface {
Name() string
Decode(io.Reader) (model, error)
Decode(*readSeekOffset) (model, error)
}
type containerGGML struct{}
@@ -91,7 +92,9 @@ func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) (model, error) {
func (c *containerGGML) Decode(ro *readSeekOffset) (model, error) {
// file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -103,9 +106,9 @@ func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) (model, error) {
func (c *containerGGMF) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -114,6 +117,10 @@ func (c *containerGGMF) Decode(r io.Reader) (model, error) {
}
c.version = version
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -125,9 +132,9 @@ func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) (model, error) {
func (c *containerGGJT) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
@@ -139,7 +146,11 @@ func (c *containerGGJT) Decode(r io.Reader) (model, error) {
// different model types may have different layouts for hyperparameters
var llama llamaModel
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
binary.Read(ro, binary.LittleEndian, &llama.hyperparameters)
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return &llama, nil
}
@@ -151,9 +162,9 @@ func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) (model, error) {
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -162,6 +173,10 @@ func (c *containerLORA) Decode(r io.Reader) (model, error) {
}
c.version = version
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -180,33 +195,61 @@ const (
)
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
ro := readSeekOffset{ReadSeeker: r}
switch ggml.magic {
var magic uint32
if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
return nil, err
}
var c container
switch magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
c = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
c = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
c = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
c = &containerLORA{}
case FILE_MAGIC_GGUF_LE:
ggml.container = &containerGGUF{bo: binary.LittleEndian}
c = &containerGGUF{bo: binary.LittleEndian}
case FILE_MAGIC_GGUF_BE:
ggml.container = &containerGGUF{bo: binary.BigEndian}
c = &containerGGUF{bo: binary.BigEndian}
default:
return nil, errors.New("invalid file magic")
}
model, err := ggml.Decode(r)
model, err := c.Decode(&ro)
if err != nil {
return nil, err
}
ggml.model = model
// final model type
return &ggml, nil
return &GGML{
container: c,
model: model,
Size: ro.offset,
}, nil
}
type readSeekOffset struct {
io.ReadSeeker
offset int64
}
func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
offset, err := rso.ReadSeeker.Seek(offset, whence)
if err != nil {
return 0, err
}
rso.offset = offset
return offset, nil
}
func (rso *readSeekOffset) Read(p []byte) (int, error) {
n, err := rso.ReadSeeker.Read(p)
rso.offset += int64(n)
return n, err
}

View File

@@ -23,26 +23,24 @@ type containerGGUF struct {
NumTensor uint64
NumKV uint64
}
parameters uint64
}
func (c *containerGGUF) Name() string {
return "gguf"
}
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
binary.Read(r, c.bo, &c.Version)
func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) {
binary.Read(rso, c.bo, &c.Version)
switch c.Version {
case 1:
binary.Read(r, c.bo, &c.V1)
binary.Read(rso, c.bo, &c.V1)
default:
binary.Read(r, c.bo, &c.V2)
binary.Read(rso, c.bo, &c.V2)
}
model := newGGUFModel(c)
if err := model.Decode(r); err != nil {
if err := model.Decode(rso); err != nil {
return nil, err
}
@@ -67,9 +65,23 @@ const (
type kv map[string]any
type tensor struct {
name string
kind uint32
offset uint64
size uint64
// shape is the number of elements in each dimension
shape [4]uint64
}
type ggufModel struct {
*containerGGUF
kv
tensors []tensor
parameters uint64
}
func newGGUFModel(container *containerGGUF) *ggufModel {
@@ -96,8 +108,7 @@ func (llm *ggufModel) NumKV() uint64 {
}
func (llm *ggufModel) ModelFamily() string {
t, ok := llm.kv["general.architecture"].(string)
if ok {
if t, ok := llm.kv["general.architecture"].(string); ok {
return t
}
@@ -109,82 +120,60 @@ func (llm *ggufModel) ModelType() string {
return format.HumanNumber(llm.parameters)
}
switch llm.ModelFamily() {
case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
heads, headsOK := llm.kv["llama.head_count"].(uint32)
headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32)
if headsOK && headsKVsOK && heads/headKVs == 8 {
return "70B"
}
return llamaModelType(blocks)
}
case "falcon":
if blocks, ok := llm.kv["falcon.block_count"].(uint32); ok {
return falconModelType(blocks)
}
case "starcoder":
if blocks, ok := llm.kv["starcoder.block_count"].(uint32); ok {
return starCoderModelType(blocks)
}
}
return "unknown"
}
func (llm *ggufModel) FileType() string {
t, ok := llm.kv["general.file_type"].(uint32)
if ok {
if t, ok := llm.kv["general.file_type"].(uint32); ok {
return fileType(t)
}
return "unknown"
}
func (llm *ggufModel) Decode(r io.Reader) error {
func (llm *ggufModel) Decode(rso *readSeekOffset) error {
// decode key-values
for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := llm.readString(r)
k, err := llm.readString(rso)
if err != nil {
return err
}
vtype := llm.readU32(r)
vtype := llm.readU32(rso)
var v any
switch vtype {
case ggufTypeUint8:
v = llm.readU8(r)
v = llm.readU8(rso)
case ggufTypeInt8:
v = llm.readI8(r)
v = llm.readI8(rso)
case ggufTypeUint16:
v = llm.readU16(r)
v = llm.readU16(rso)
case ggufTypeInt16:
v = llm.readI16(r)
v = llm.readI16(rso)
case ggufTypeUint32:
v = llm.readU32(r)
v = llm.readU32(rso)
case ggufTypeInt32:
v = llm.readI32(r)
v = llm.readI32(rso)
case ggufTypeUint64:
v = llm.readU64(r)
v = llm.readU64(rso)
case ggufTypeInt64:
v = llm.readI64(r)
v = llm.readI64(rso)
case ggufTypeFloat32:
v = llm.readF32(r)
v = llm.readF32(rso)
case ggufTypeFloat64:
v = llm.readF64(r)
v = llm.readF64(rso)
case ggufTypeBool:
v = llm.readBool(r)
v = llm.readBool(rso)
case ggufTypeString:
s, err := llm.readString(r)
s, err := llm.readString(rso)
if err != nil {
return err
}
v = s
case ggufTypeArray:
a, err := llm.readArray(r)
a, err := llm.readArray(rso)
if err != nil {
return err
}
@@ -199,21 +188,85 @@ func (llm *ggufModel) Decode(r io.Reader) error {
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
name, err := llm.readString(rso)
if err != nil {
return err
}
dimensions := llm.readU32(r)
// dims is the number of dimensions in the tensor
dims := llm.readU32(rso)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ {
shape[i] = llm.readU64(rso)
}
llm.readU32(r) // type
llm.readU64(r) // offset
kind := llm.readU32(rso)
offset := llm.readU64(rso)
llm.parameters += elements
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name,
kind: kind,
offset: offset,
size: size,
shape: shape,
})
llm.parameters += parameters
}
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent)
}
return nil

View File

@@ -59,6 +59,7 @@ ws ::= ([ \t\n] ws)?
var llamaCppEmbed embed.FS
type ModelRunner struct {
Type string // "gguf" or "ggml"
Path string // path to the model runner executable
Accelerated bool
}
@@ -72,25 +73,25 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
runners = []ModelRunner{{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
} else {
runners = []ModelRunner{{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
}
case "linux":
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
case "windows":
// TODO: select windows GPU runner here when available
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
}
default:
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
runners = []ModelRunner{
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
}
@@ -148,6 +149,7 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
for _, r := range runners {
// clean the ModelRunner paths so that they match the OS we are running on
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
Type: r.Type,
Path: filepath.Clean(path.Join(workDir, r.Path)),
Accelerated: r.Accelerated,
})
@@ -221,8 +223,14 @@ type Running struct {
*StatusWriter // captures error messages from the llama runner process
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}
type llama struct {
api.Options
ImageData []ImageData
Running
}
@@ -325,7 +333,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
return os.Stderr.Write(b)
}
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
@@ -365,6 +373,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params = append(params, "--lora", adapters[0])
}
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
}
@@ -397,11 +410,13 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
}
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
params := append(params, "--port", strconv.Itoa(port))
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(
ctx,
runner.Path,
append(params, "--port", strconv.Itoa(port))...,
params...,
)
var libraryPaths []string
@@ -530,22 +545,39 @@ type prediction struct {
}
const maxBufferSize = 512 * format.KiloByte
const maxRetries = 6
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
}
type PredictResult struct {
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
// IsRetryable checks if the line matches a condition that can be retried
func isRetryable(line []byte) bool {
return bytes.Contains(line, []byte("slot unavailable"))
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
log.Printf("loaded %d images", len(imageData))
request := map[string]any{
"prompt": nextContext.String(),
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -565,103 +597,121 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
"image_data": imageData,
}
if format == "json" {
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err)
retryDelay := 100 * time.Microsecond
for retries := 0; retries < maxRetries; retries++ {
if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
retryDelay *= 2 // exponential backoff
}
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
scanner := bufio.NewScanner(resp.Body)
// increase the buffer size to avoid running out of space
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err)
}
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
scanner := bufio.NewScanner(resp.Body)
// increase the buffer size to avoid running out of space
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
retryNeeded := false
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if isRetryable(line) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
var p prediction
if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
fn(PredictResult{
Content: p.Content,
})
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(api.GenerateResponse{
fn(PredictResult{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}
}
}
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") {
// this means the llama runner subprocess crashed
llm.Close()
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") {
// this means the llama runner subprocess crashed
llm.Close()
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
}
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
}
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
return fmt.Errorf("error reading llm response: %v", err)
}
if !retryNeeded {
return nil // success
}
return fmt.Errorf("error reading llm response: %v", err)
}
return nil
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
}
type TokenizeRequest struct {

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, PredictOpts, func(PredictResult)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
@@ -23,7 +23,7 @@ type LLM interface {
Ping(context.Context) error
}
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
@@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
}

View File

@@ -1,23 +0,0 @@
package llm
const (
starCoderModelType1B = 24
starCoderModelType3B = 36
starCoderModelType7B = 42
starCoderModelType15B = 40
)
func starCoderModelType(numLayer uint32) string {
switch numLayer {
case 24:
return "1B"
case 36:
return "3B"
case 42:
return "7B"
case 40:
return "15B"
default:
return "unknown"
}
}

View File

@@ -37,10 +37,13 @@ func Parse(reader io.Reader) ([]Command, error) {
switch string(bytes.ToUpper(fields[0])) {
case "FROM":
command.Name = "model"
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
// copy command for validation
modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
case "ADAPTER":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(bytes.TrimSpace(fields[1]))
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
@@ -50,7 +53,7 @@ func Parse(reader io.Reader) ([]Command, error) {
}
command.Name = string(fields[0])
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
default:

View File

@@ -191,6 +191,8 @@ func (i *Instance) Readline() (string, error) {
buf.ClearScreen()
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:
return handleCharCtrlZ(fd, termios)
case CharEnter:
output := buf.String()
if output != "" {

18
readline/readline_unix.go Normal file
View File

@@ -0,0 +1,18 @@
//go:build !windows
package readline
import (
"syscall"
)
func handleCharCtrlZ(fd int, termios *Termios) (string, error) {
if err := UnsetRawMode(fd, termios); err != nil {
return "", err
}
syscall.Kill(0, syscall.SIGSTOP)
// on resume...
return "", nil
}

View File

@@ -0,0 +1,6 @@
package readline
func handleCharCtrlZ(fd int, state *State) (string, error) {
// not supported
return "", nil
}

View File

@@ -217,7 +217,7 @@ fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $OS_VERSION ;;
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
amzn) install_cuda_driver_yum 'fedora' '35' ;;
@@ -230,7 +230,8 @@ fi
if ! lsmod | grep -q nvidia; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
centos|rhel|rocky|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
*) exit ;;

View File

@@ -14,7 +14,6 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@@ -36,80 +35,160 @@ type RegistryOptions struct {
}
type Model struct {
Name string `json:"name"`
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
Name string `json:"name"`
Config ConfigV2
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Size int64
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
}
tmpl, err := template.New("").Parse(t)
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
if err != nil {
return "", err
}
var vars struct {
First bool
System string
Prompt string
if p.System == "" {
// use the default system message for this model if one is not specified
p.System = m.System
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
return sb.String(), nil
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages
var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{
First: true,
}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
}
currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", nil, err
}
default:
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
}
return prompt.String(), currentImages, nil
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Layer `json:"config"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
}
type LayerReader struct {
Layer
io.Reader
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"`
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
}
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct {
@@ -166,6 +245,22 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
for _, layer := range manifest.Layers {
@@ -184,6 +279,8 @@ func GetModel(name string) (*Model, error) {
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -257,11 +354,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
deleteMap := make(map[string]struct{})
var layers []*LayerReader
var layers Layers
params := make(map[string][]string)
fromParams := make(map[string]any)
@@ -318,10 +418,10 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return err
}
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
config.SetModelFormat(fromConfig.ModelFormat)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
@@ -342,13 +442,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}
layer, err := GetLayerWithBufferFromLayer(layer)
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return err
}
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
layers.Add(layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
@@ -356,26 +455,48 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil {
return err
var offset int64
for {
fn(api.ProgressResponse{Status: "creating model layer"})
bin.Seek(offset, io.SeekStart)
ggml, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.ModelFamily())
config.SetModelType(ggml.ModelType())
config.SetFileType(ggml.FileType())
mediatype := mediatype
if ggml.ModelFamily() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, ggml.Size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
offset += ggml.Size
}
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil {
return err
}
layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
@@ -383,41 +504,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
layer, err := CreateLayer(bin)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicate layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Replace(layer)
default:
params[c.Name] = append(params[c.Name], c.Args)
}
@@ -426,7 +538,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := formatParams(params)
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
@@ -437,6 +549,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}
// xxx - can this be removed?
if config.ModelType == "65B" {
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B"
@@ -449,40 +562,51 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return err
}
layer.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, layer)
layers.Replace(layer)
}
digests, err := getLayerDigests(layers)
digests := make([]string, len(layers.items))
for i, layer := range layers.items {
digests[i] = layer.Digest
}
config.RootFS.DiffIDs = digests
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(config); err != nil {
return err
}
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
configLayer, err := createConfigLayer(config, digests)
if err != nil {
return err
}
layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil {
return err
}
for _, layer := range append(layers.items, configLayer) {
committed, err := layer.Commit()
if err != nil {
return err
}
status := "writing layer"
if !committed {
status = "using already created layer"
}
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err := WriteManifest(name, configLayer, layers.items); err != nil {
return err
}
@@ -496,177 +620,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
return layer.MediaType == mediaType
})
}
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
// Write each of the layers to disk
for _, layer := range layers {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
_, err = os.Stat(fp)
if os.IsNotExist(err) || force {
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
out, err := os.Create(fp)
if err != nil {
log.Printf("couldn't create %s", fp)
return err
}
defer out.Close()
if _, err = io.Copy(out, layer.Reader); err != nil {
return err
}
} else {
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
}
}
return nil
}
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: Layer{
MediaType: cfg.MediaType,
Size: cfg.Size,
Digest: cfg.Digest,
},
Layers: layers,
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
file, err := os.Open(fp)
if err != nil {
return nil, fmt.Errorf("could not open blob: %w", err)
}
defer file.Close()
newLayer, err := CreateLayer(file)
if err != nil {
return nil, err
}
newLayer.MediaType = layer.MediaType
return newLayer, nil
}
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string
for _, l := range layers {
if l.Digest == "" {
return nil, fmt.Errorf("layer is missing a digest")
}
digests = append(digests, l.Digest)
}
return digests, nil
}
// CreateLayer creates a Layer object from a given file
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
digest, size := GetSHA256Digest(f)
f.Seek(0, io.SeekStart)
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.image.rootfs.diff.tar",
Digest: digest,
Size: size,
},
Reader: f,
}
return layer, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()
@@ -934,7 +887,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -1005,7 +958,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := downloadBlob(
@@ -1093,30 +1046,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
return m, err
}
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
config.RootFS = RootFS{
Type: "layers",
DiffIDs: layers,
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: digest,
Size: size,
},
Reader: bytes.NewBuffer(configJSON),
}
return layer, nil
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()

View File

@@ -1,23 +1,98 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestModelPrompt(t *testing.T) {
var m Model
req := api.GenerateRequest{
Template: "a{{ .Prompt }}b",
Prompt: "<h1>",
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
}
s, err := m.Prompt(req)
if err != nil {
t.Fatal(err)
}
want := "a<h1>b"
if s != want {
t.Errorf("got %q, want %q", s, want)
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, _, err := m.ChatPrompt(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}

109
server/layers.go Normal file
View File

@@ -0,0 +1,109 @@
package server
import (
"crypto/sha256"
"fmt"
"io"
"os"
"runtime"
"strings"
"golang.org/x/exp/slices"
)
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
tempFileName string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
delimiter := ":"
if runtime.GOOS == "windows" {
delimiter = "-"
}
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
if err != nil {
return nil, err
}
defer temp.Close()
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
Size: n,
tempFileName: temp.Name(),
}, nil
}
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
fi, err := os.Stat(blob)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: digest,
Size: fi.Size(),
From: from,
}, nil
}
func (l *Layer) Commit() (bool, error) {
// always remove temp
defer os.Remove(l.tempFileName)
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return false, err
}
if _, err := os.Stat(blob); err != nil {
return true, os.Rename(l.tempFileName, blob)
}
return false, nil
}

34
server/manifests.go Normal file
View File

@@ -0,0 +1,34 @@
package server
import (
"bytes"
"encoding/json"
"os"
"path/filepath"
)
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0644)
}

View File

@@ -67,6 +67,20 @@ func ParseModelPath(name string) ModelPath {
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}

View File

@@ -2,7 +2,6 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@@ -60,17 +59,26 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir")
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
return err
return nil, err
}
if err := opts.FromMap(reqOpts); err != nil {
return err
return nil, err
}
ctx := c.Request.Context()
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil {
@@ -97,7 +105,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.Options = nil
}
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
@@ -106,7 +114,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
return nil, err
}
loaded.Model = model
@@ -140,7 +148,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
}
loaded.expireTimer.Reset(sessionDuration)
return nil
return model, nil
}
func GenerateHandler(c *gin.Context) {
@@ -148,9 +156,9 @@ func GenerateHandler(c *gin.Context) {
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -173,88 +181,150 @@ func GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
// TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true})
return
}
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
prompt, err = model.Prompt(req)
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
var rebuild strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
p, err := model.Prompt(PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rebuild.WriteString(p)
prompt = rebuild.String()
}
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) {
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
}
if req.Raw {
// in raw mode the client must manage history on their own
r.Context = nil
}
ch <- r
ch <- resp
}
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: req.Images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var response api.GenerateResponse
generated := ""
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
for resp := range ch {
if r, ok := resp.(api.GenerateResponse); ok {
generated += r.Response
response = r
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
response.Response = generated
c.JSON(http.StatusOK, response)
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
@@ -281,15 +351,18 @@ func EmbeddingHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
@@ -416,6 +489,11 @@ func CreateModelHandler(c *gin.Context) {
return
}
if err := ParseModelPath(req.Name).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
@@ -538,10 +616,19 @@ func GetModelInfo(name string) (*api.ShowResponse, error) {
return nil, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
Details: modelDetails,
}
mf, err := ShowModelfile(model)
@@ -591,25 +678,42 @@ func ListModelsHandler(c *gin.Context) {
return
}
modelResponse := func(modelName string) (api.ModelResponse, error) {
model, err := GetModel(modelName)
if err != nil {
return api.ModelResponse{}, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
return api.ModelResponse{
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
Details: modelDetails,
}, nil
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
mp := ParseModelPath(tag)
manifest, digest, err := GetManifest(mp)
resp, err := modelResponse(tag)
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
}
models = append(models, api.ModelResponse{
Name: mp.GetShortTagname(),
Size: manifest.GetTotalSize(),
Digest: digest,
ModifiedAt: info.ModTime(),
})
resp.ModifiedAt = info.ModTime()
models = append(models, resp)
}
return nil
@@ -640,6 +744,11 @@ func CopyModelHandler(c *gin.Context) {
return
}
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
@@ -666,37 +775,18 @@ func HeadBlobHandler(c *gin.Context) {
}
func CreateBlobHandler(c *gin.Context) {
targetPath, err := GetBlobsPath(c.Param("digest"))
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
hash := sha256.New()
temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -757,6 +847,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
@@ -772,6 +863,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
@@ -794,7 +888,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if runtime.GOOS == "linux" {
// check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil {
log.Printf(err.Error())
log.Print(err.Error())
}
}
@@ -850,3 +944,136 @@ func streamResponse(c *gin.Context, ch chan any) {
return true
})
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
return
}
checkpointLoaded := time.Now()
prompt, images, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
final.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/md5"
"errors"
"fmt"
"hash"
"io"
"log"
"math"
@@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
}
// set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
offset += size
}
@@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
@@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL
var sb strings.Builder
// calculate md5 checksum and add it to the commit request
var sb strings.Builder
for _, part := range b.Parts {
hash := md5.New()
if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
b.err = err
return
}
sb.Write(hash.Sum(nil))
sb.Write(part.Sum(nil))
}
md5sum := md5.Sum([]byte(sb.String()))
@@ -201,27 +194,25 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ {
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if err != nil {
b.err = err
if errors.Is(err, context.Canceled) {
return
}
var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) {
break
} else if err != nil {
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep)
continue
}
defer resp.Body.Close()
b.err = nil
b.done = true
return
break
}
b.err = err
b.done = true
}
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -232,8 +223,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
}
sr := io.NewSectionReader(b.file, part.Offset, part.Size)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
}
defer resp.Body.Close()
@@ -245,11 +241,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
nextURL, err := url.Parse(location)
if err != nil {
w.Rollback()
return err
}
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
w.Rollback()
b.nextURL <- nextURL
redirectURL, err := resp.Location()
@@ -259,14 +257,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
// retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
@@ -279,6 +276,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
@@ -289,6 +287,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
opts.Token = token
fallthrough
case resp.StatusCode >= http.StatusBadRequest:
w.Rollback()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
@@ -301,6 +300,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
b.nextURL <- nextURL
}
part.Hash = md5sum
return nil
}
@@ -341,22 +341,26 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
type blobUploadPart struct {
// N is the part number
N int
Offset int64
Size int64
N int
Offset int64
Size int64
hash.Hash
}
type progressWriter struct {
written int64
*blobUpload
}
func (p *blobUploadPart) Write(b []byte) (n int, err error) {
func (p *progressWriter) Write(b []byte) (n int, err error) {
n = len(b)
p.written += int64(n)
p.Completed.Add(int64(n))
return n, nil
}
func (p *blobUploadPart) Reset() {
p.Completed.Add(-int64(p.written))
func (p *progressWriter) Rollback() {
p.Completed.Add(-p.written)
p.written = 0
}