Compare commits

...

67 Commits

Author SHA1 Message Date
Blake Mizerany
425219f1ca use vanity imports 2024-04-20 20:04:00 -07:00
Michael Yang
8645076a71 Merge pull request #3712 from ollama/mxyng/mem
add stablelm graph calculation
2024-04-17 15:57:51 -07:00
Michael Yang
05e9424824 Merge pull request #3664 from ollama/mxyng/fix-padding-2
fix padding to only return padding
2024-04-17 15:57:40 -07:00
Michael Yang
52ebe67a98 Merge pull request #3714 from ollama/mxyng/model-name-host
types/model: support : in PartHost for host:port
2024-04-17 15:34:03 -07:00
Michael Yang
889b31ab78 types/model: support : in PartHost for host:port 2024-04-17 15:16:07 -07:00
Michael Yang
3cf483fe48 add stablelm graph calculation 2024-04-17 13:57:19 -07:00
jmorganca
c8afe7168c use correct extension for feature and model request issue templates 2024-04-17 15:18:40 -04:00
jmorganca
28d3cd0148 simpler feature and model request forms 2024-04-17 15:17:08 -04:00
jmorganca
eb5554232a simpler feature and model request forms 2024-04-17 15:14:49 -04:00
jmorganca
2bdc320216 add descriptions to issue templates 2024-04-17 15:08:36 -04:00
jmorganca
32561aed09 simplify github issue templates a bit 2024-04-17 15:07:03 -04:00
Michael Yang
71548d9829 Merge pull request #3706 from ollama/mxyng/mem
account for all non-repeating layers
2024-04-17 11:58:20 -07:00
Michael Yang
a8b9b930b4 account for all non-repeating layers 2024-04-17 11:21:21 -07:00
Michael
9755cf9173 acknowledge the amazing work done by Georgi and team! 2024-04-17 13:48:14 -04:00
Blake Mizerany
9df6c85c3a types/model: add FilepathNoBuild (#3680)
Also, add test for DisplayLongest.

Also, plumb fill param to ParseName in MustParseName
2024-04-16 18:35:43 -07:00
Michael Yang
e74163af4c fix padding to only return padding 2024-04-16 15:43:26 -07:00
Michael Yang
fb9580df85 Merge pull request #3684 from ollama/mxyng/scale-graph
scale graph based on gpu count
2024-04-16 14:57:09 -07:00
Michael Yang
26df674785 scale graph based on gpu count 2024-04-16 14:44:13 -07:00
Jeffrey Morgan
7c9792a6e0 Support unicode characters in model path (#3681)
* parse wide argv characters on windows

* cleanup

* move cleanup to end of `main`
2024-04-16 17:00:12 -04:00
Michael Yang
7afb2e125a Merge pull request #3678 from ollama/mxyng/fix-darwin-partial-offloading
darwin: no partial offloading if required memory greater than system
2024-04-16 12:05:56 -07:00
Michael Yang
41a272de9f darwin: no partial offloading if required memory greater than system 2024-04-16 11:22:38 -07:00
Jeffrey Morgan
f335722275 update llama.cpp submodule to 7593639 (#3665) 2024-04-15 23:04:43 -04:00
Michael Yang
6d53b67c2c Merge pull request #3663 from ollama/mxyng/fix-padding 2024-04-15 17:44:54 -07:00
Michael Yang
969238b19e fix padding in decode
TODO: update padding() to _only_ returning the padding
2024-04-15 17:27:06 -07:00
Blake Mizerany
949d7832cf Revert "cmd: provide feedback if OLLAMA_MODELS is set on non-serve command (#3470)" (#3662)
This reverts commit 7d05a6ee8f.

This proved to be more painful than useful.

See: https://github.com/ollama/ollama/issues/3624
2024-04-15 16:58:00 -07:00
Sung Kim
99d227c9db Added Solar example at README.md (#3610)
Added just one line

| Solar              | 10.7B      | 6.1GB | `ollama run solar`             |
2024-04-15 19:54:23 -04:00
Carlos Gamez
a27e419b47 Update langchainjs.md (#2030)
Changed ollama.call() for ollama.invoke() as per deprecated documentation from langchain
2024-04-15 18:37:30 -04:00
Chandre Van Der Westhuizen
e4d0db5a97 Added MindsDB information (#3595)
* Added MindsDB information

Added more details to MindsDB so that Ollama users can know that they can connect their Ollama model with 200+ databases and apps

* updated text for mindsdb
2024-04-15 18:35:29 -04:00
Eli Bendersky
ba460802c2 examples: add more Go examples using the API (#3599)
* examples: go-multimodal

* examples: add go-pull-progress

* examples: add go-chat

* fix
2024-04-15 18:34:54 -04:00
Jeffrey Morgan
e54a3c7fcd Update modelfile.md
Remove Modelfile parameters that are decided at runtime
2024-04-15 15:35:44 -04:00
Patrick Devine
9f8691c6c8 Add llama2 / torch models for ollama create (#3607) 2024-04-15 11:26:42 -07:00
Jeffrey Morgan
a0b8a32eb4 Terminate subprocess if receiving SIGINT or SIGTERM signals while model is loading (#3653)
* terminate subprocess if receiving `SIGINT` or `SIGTERM` signals while model is loading

* use `unload` in signal handler
2024-04-15 12:09:32 -04:00
Jeffrey Morgan
7027f264fb app: gracefully shut down ollama serve on windows (#3641)
* app: gracefully shut down `ollama serve` on windows

* fix linter errors

* bring back `HideWindow`

* remove creation flags

* restore `windows.CREATE_NEW_PROCESS_GROUP`
2024-04-14 18:33:25 -04:00
Blake Mizerany
9bee3b63b1 types/model: add path helpers (#3619)
This commit adds path helpers for working with Names in URL and file
paths. The new helpers are ParseNameFromPath, ParseNameFromFilePath,
Name.Path, and Name.FilePath.

This commit also adds Name.DisplayLongest, and Name.DisplayLong.

Also, be it updates a place where strings.StripPrefix is more consistent
with the surrounding code.

Also, replace Parts with specific methods
2024-04-13 12:59:19 -07:00
Jeffrey Morgan
309aef7fee update llama.cpp submodule to 4bd0f93 (#3627) 2024-04-13 10:43:02 -07:00
Blake Mizerany
08655170aa types/model: make ParseName variants less confusing (#3617)
Also, fix http stripping bug.

Also, improve upon docs about fills and masks.
2024-04-12 13:57:57 -07:00
Blake Mizerany
2b341069a7 types/model: remove (*Digest).Scan and Digest.Value (#3605) 2024-04-11 13:32:31 -07:00
Daniel Hiltgen
c00fee6936 Merge pull request #3604 from dhiltgen/fix_rocm_deps
Fix rocm deps with new subprocess paths
2024-04-11 13:08:29 -07:00
Daniel Hiltgen
c2d813bdc3 Fix rocm deps with new subprocess paths 2024-04-11 12:52:06 -07:00
Michael Yang
786f3a1c44 Merge pull request #3600 from ollama/mxyng/mixtral 2024-04-11 12:23:37 -07:00
Michael Yang
3397eff0cd mixtral mem 2024-04-11 11:10:41 -07:00
Blake Mizerany
0efb7931c7 Revert "types/model: remove (*Digest).Scan and Digest.Value (#3589)"
This reverts commit 42f2cc408e.
2024-04-11 00:45:07 -07:00
Blake Mizerany
42f2cc408e types/model: remove (*Digest).Scan and Digest.Value (#3589) 2024-04-11 00:37:26 -07:00
Blake Mizerany
9446b795b5 types/model: remove DisplayLong (#3587) 2024-04-10 16:55:12 -07:00
Blake Mizerany
62f8cda3b3 types/model: remove MarshalText/UnmarshalText from Digest (#3586) 2024-04-10 16:52:49 -07:00
Blake Mizerany
6a1de23175 types/model: init with Name and Digest types (#3541) 2024-04-10 16:30:05 -07:00
Blake Mizerany
a7b431e743 server: provide helpful workaround hint when stalling on pull (#3584)
This is a quick fix to help users who are stuck on the "pull" step at
99%.

In the near future we're introducing a new registry client that
should/will hopefully be smarter. In the meantime, this should unblock
the users hitting issue #1736.
2024-04-10 16:24:37 -07:00
Michael Yang
5a25f93522 Merge pull request #3478 from ollama/mxyng/tensor-layer
refactor tensor query
2024-04-10 12:45:03 -07:00
Michael Yang
7e33a017c0 partial offloading 2024-04-10 11:37:20 -07:00
Michael Yang
8b2c10061c refactor tensor query 2024-04-10 11:37:20 -07:00
Michael Yang
c5c451ca3b Merge pull request #3579 from ollama/mxyng/fix-ci
fix ci
2024-04-10 11:37:01 -07:00
Michael Yang
2b4ca6cf36 fix ci 2024-04-10 11:35:12 -07:00
Eli Bendersky
ad90b9ab3d api: start adding documentation to package api (#2878)
* api: start adding documentation to package api

Updates #2840

* Fix lint typo report
2024-04-10 13:31:55 -04:00
Eli Bendersky
4340f8eba4 examples: start adding Go examples using api/ (#2879)
We can have the same examples as e.g. https://github.com/ollama/ollama-python/tree/main/examples
here. Using consistent naming and renaming the existing example to have -http-
since it uses direct HTTP requests rather than api/

Updates #2840
2024-04-10 13:26:45 -04:00
Daniel Hiltgen
4c7db6b7e9 Merge pull request #3566 from dhiltgen/more_time
Handle very slow model loads
2024-04-09 16:53:49 -07:00
Michael Yang
c03f0e3c3d Merge pull request #3565 from ollama/mxyng/rope
fix: rope
2024-04-09 16:36:55 -07:00
Daniel Hiltgen
c5ff443b9f Handle very slow model loads
During testing, we're seeing some models take over 3 minutes.
2024-04-09 16:35:10 -07:00
Michael Yang
01114b4526 fix: rope 2024-04-09 16:15:24 -07:00
Blake Mizerany
1524f323a3 Revert "build.go: introduce a friendlier way to build Ollama (#3548)" (#3564) 2024-04-09 15:57:45 -07:00
Blake Mizerany
fccf3eecaa build.go: introduce a friendlier way to build Ollama (#3548)
This commit introduces a more friendly way to build Ollama dependencies
and the binary without abusing `go generate` and removing the
unnecessary extra steps it brings with it.

This script also provides nicer feedback to the user about what is
happening during the build process.

At the end, it prints a helpful message to the user about what to do
next (e.g. run the new local Ollama).
2024-04-09 14:18:47 -07:00
Michael Yang
c77d45d836 Merge pull request #3506 from ollama/mxyng/quantize-redux
cgo quantize
2024-04-09 12:32:53 -07:00
Jeffrey Morgan
5ec12cec6c update llama.cpp submodule to 1b67731 (#3561) 2024-04-09 15:10:17 -04:00
Michael Yang
d9578d2bad Merge pull request #3559 from ollama/mxyng/ci
ci: use go-version-file
2024-04-09 11:03:18 -07:00
Michael Yang
cb8352d6b4 ci: use go-version-file 2024-04-09 09:50:12 -07:00
Alex Mavrogiannis
fc6558f47f Correct directory reference in macapp/README (#3555) 2024-04-09 09:48:46 -04:00
Michael Yang
9502e5661f cgo quantize 2024-04-08 15:31:08 -07:00
Michael Yang
e1c9a2a00f no blob create if already exists 2024-04-08 15:09:48 -07:00
87 changed files with 3407 additions and 807 deletions

View File

@@ -0,0 +1,60 @@
name: Bug report
labels: [bug]
description: Something isn't working right.
body:
- type: textarea
id: description
attributes:
label: What is the issue?
description: What happened? What did you expect to happen?
validations:
required: true
- type: dropdown
id: os
attributes:
label: OS
description: Which operating system are you using?
multiple: true
options:
- Linux
- macOS
- Windows
- Docker
- WSL2
validations:
required: false
- type: dropdown
id: gpu
attributes:
label: GPU
description: Which GPU are you using?
multiple: true
options:
- Nvidia
- AMD
- Intel
- Apple
- Other
validations:
required: false
- type: dropdown
id: cpu
attributes:
label: CPU
description: Which CPU are you using?
multiple: true
options:
- Intel
- AMD
- Apple
- Other
validations:
required: false
- type: input
id: version
attributes:
label: Ollama version
description: What version of Ollama are you using? (`ollama --version`)
placeholder: e.g., 0.1.32
validations:
required: false

View File

@@ -1,18 +0,0 @@
name: Model request
description: Request a new model for the library
labels: [mr]
body:
- type: markdown
attributes:
value: |
Please check if your Model request is [already available](https://ollama.com/search) or that you cannot [import it](https://github.com/ollama/ollama/blob/main/docs/import.md#import-a-model) yourself.
Tell us about which Model you'd like to see in the library!
- type: textarea
id: problem
attributes:
label: What model would you like?
description: Please provide a link to the model.
- type: markdown
attributes:
value: |
Thanks for filing a model request!

View File

@@ -0,0 +1,6 @@
---
name: Feature request
about: Request a new feature
labels: feature request
---

View File

@@ -1,41 +0,0 @@
name: Feature request
description: Propose a new feature
labels: [needs-triage, fr]
body:
- type: markdown
attributes:
value: |
Please check if your feature request is [already filed](https://github.com/ollama/ollama/issues).
Tell us about your idea!
- type: textarea
id: problem
attributes:
label: What are you trying to do?
description: Tell us about the problem you're trying to solve.
validations:
required: false
- type: textarea
id: solution
attributes:
label: How should we solve this?
description: If you have an idea of how you'd like to see this feature work, let us know.
validations:
required: false
- type: textarea
id: alternative
attributes:
label: What is the impact of not solving this?
description: (How) Are you currently working around the issue?
validations:
required: false
- type: textarea
id: context
attributes:
label: Anything else?
description: Any additional context to share, e.g., links
validations:
required: false
- type: markdown
attributes:
value: |
Thanks for filing a feature request!

View File

@@ -0,0 +1,5 @@
---
name: Model request
about: Request support for a new model to be added to Ollama
labels: model request
---

View File

@@ -1,125 +0,0 @@
name: Bug report
description: File a bug report. If you need help, please join our Discord server.
labels: [needs-triage, bug]
body:
- type: markdown
attributes:
value: |
Please check if your bug is [already filed](https://github.com/ollama/ollama/issues) before filing a new one.
- type: textarea
id: what-happened
attributes:
label: What is the issue?
description: What happened? What did you expect to happen?
validations:
required: true
- type: textarea
id: what-was-expected
attributes:
label: What did you expect to see?
description: What did you expect to see/happen instead?
validations:
required: false
- type: textarea
id: steps
attributes:
label: Steps to reproduce
description: What are the steps you took that hit this issue?
validations:
required: false
- type: textarea
id: changes
attributes:
label: Are there any recent changes that introduced the issue?
description: If so, what are those changes?
validations:
required: false
- type: dropdown
id: os
attributes:
label: OS
description: What OS are you using? You may select more than one.
multiple: true
options:
- Linux
- macOS
- Windows
- Other
validations:
required: false
- type: dropdown
id: architecture
attributes:
label: Architecture
description: What architecture are you using? You may select more than one.
multiple: true
options:
- arm64
- amd64
- x86
- Other
- type: dropdown
id: platform
attributes:
label: Platform
description: What platform are you using? You may select more than one.
multiple: true
options:
- Docker
- WSL
- WSL2
validations:
required: false
- type: input
id: ollama-version
attributes:
label: Ollama version
description: What Ollama version are you using? (`ollama --version`)
placeholder: e.g., 1.14.4
validations:
required: false
- type: dropdown
id: gpu
attributes:
label: GPU
description: What GPU, if any, are you using? You may select more than one.
multiple: true
options:
- Nvidia
- AMD
- Intel
- Apple
- Other
validations:
required: false
- type: textarea
id: gpu-info
attributes:
label: GPU info
description: What GPU info do you have? (`nvidia-smi`, `rocminfo`, `system_profiler SPDisplaysDataType`, etc.)
validations:
required: false
- type: dropdown
id: cpu
attributes:
label: CPU
description: What CPU are you using? You may select more than one.
multiple: true
options:
- Intel
- AMD
- Apple
- Other
validations:
required: false
- type: textarea
id: other-software
attributes:
label: Other software
description: What other software are you using that might be related to this issue?
validations:
required: false
- type: markdown
attributes:
value: |
Thanks for filing a bug report!

View File

@@ -30,7 +30,7 @@ jobs:
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- name: Build Darwin
env:
@@ -86,7 +86,7 @@ jobs:
write-host "plugin installed"
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get ./...
- run: |
@@ -139,7 +139,7 @@ jobs:
write-host "plugin installed"
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- name: 'Install ROCm'
run: |
@@ -214,7 +214,7 @@ jobs:
write-host "plugin installed"
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- name: 'Install CUDA'
run: |
@@ -300,7 +300,7 @@ jobs:
write-host "plugin installed"
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get
- uses: actions/download-artifact@v4

View File

@@ -5,7 +5,6 @@ on:
paths:
- '**/*'
- '!docs/**'
- '!examples/**'
- '!README.md'
jobs:
@@ -51,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get ./...
- run: |
@@ -93,7 +92,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get ./...
- run: |
@@ -124,7 +123,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get ./...
- run: |
@@ -146,7 +145,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- name: 'Install ROCm'
run: |
@@ -183,7 +182,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- name: 'Install CUDA'
run: |
@@ -238,7 +237,7 @@ jobs:
submodules: recursive
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: false
- run: |
case ${{ matrix.arch }} in
@@ -283,7 +282,7 @@ jobs:
submodules: recursive
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version-file: go.mod
cache: true
- run: go get
- run: |

View File

@@ -42,7 +42,7 @@ ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN mkdir /tmp/scratch && \
for dep in $(cat /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/x86_64/rocm*/lib/deps.txt) ; do \
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
cp ${dep} /tmp/scratch/ || exit 1 ; \
done && \
(cd /opt/rocm/lib && tar cf - rocblas/library) | (cd /tmp/scratch/ && tar xf - ) && \

View File

@@ -60,10 +60,10 @@ Here are some example models that can be downloaded:
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
| Vicuna | 7B | 3.8GB | `ollama run vicuna` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
| Solar | 10.7B | 6.1GB | `ollama run solar` |
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -316,7 +316,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
### Package managers
@@ -377,3 +377,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View File

@@ -1,3 +1,9 @@
// Package api implements the client-side API for code wishing to interact
// with the ollama service. The methods of the [Client] type correspond to
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
//
// The ollama command-line client itself uses this package to interact with
// the backend service.
package api
import (
@@ -5,7 +11,6 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@@ -15,10 +20,12 @@ import (
"runtime"
"strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
"ollama.com/format"
"ollama.com/version"
)
// Client encapsulates client state for interacting with the ollama
// service. Use [ClientFromEnvironment] to create new Clients.
type Client struct {
base *url.URL
http *http.Client
@@ -40,6 +47,15 @@ func checkError(resp *http.Response, body []byte) error {
return apiError
}
// ClientFromEnvironment creates a new [Client] using configuration from the
// environment variable OLLAMA_HOST, which points to the network host and
// port on which the ollama service is listenting. The format of this variable
// is:
//
// <scheme>://<host>:<port>
//
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
defaultPort := "11434"
@@ -191,8 +207,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return nil
}
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
// a response is received from the service. If this function returns an error,
// [Client.Generate] will stop generating and return this error.
type GenerateResponseFunc func(GenerateResponse) error
// Generate generates a response for a given prompt. The req parameter should
// be populated with prompt details. fn is called for each response (there may
// be multiple responses, e.g. in case streaming is enabled).
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
var resp GenerateResponse
@@ -204,8 +226,15 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
// ChatResponseFunc is a function that [Client.Chat] invokes every time
// a response is received from the service. If this function returns an error,
// [Client.Chat] will stop generating and return this error.
type ChatResponseFunc func(ChatResponse) error
// Chat generates the next message in a chat. [ChatRequest] may contain a
// sequence of messages which can be used to maintain chat history with a model.
// fn is called for each response (there may be multiple responses, e.g. if case
// streaming is enabled).
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
var resp ChatResponse
@@ -217,8 +246,14 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
})
}
// PullProgressFunc is a function that [Client.Pull] invokes every time there
// is progress with a "pull" request sent to the service. If this function
// returns an error, [Client.Pull] will stop the process and return this error.
type PullProgressFunc func(ProgressResponse) error
// Pull downloads a model from the ollama library. fn is called each time
// progress is made on the request and can be used to display a progress bar,
// etc.
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
var resp ProgressResponse
@@ -301,18 +336,7 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
}
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
var statusError StatusError
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
return err
}
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
return err
}
}
return nil
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
}
func (c *Client) Version(ctx context.Context) (string, error) {

View File

@@ -33,18 +33,46 @@ func (e StatusError) Error() string {
type ImageData []byte
// GenerateRequest describes a request sent by [Client.Generate]. While you
// have to specify the Model and Prompt fields, all the other fields have
// reasonable defaults for basic uses.
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"`
// Model is the model name; it should be a name familiar to Ollama from
// the library at https://ollama.com/library
Model string `json:"model"`
// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`
// System overrides the model's default system message/prompt.
System string `json:"system"`
// Template overrides the model's default prompt template.
Template string `json:"template"`
// Context is the context parameter returned from a previous call to
// Generate call. It can be used to keep a short conversational memory.
Context []int `json:"context,omitempty"`
// Stream specifies whether the response is streaming; it is true by default.
Stream *bool `json:"stream,omitempty"`
// Raw set to true means that no formatting will be applied to the prompt.
Raw bool `json:"raw,omitempty"`
// Format specifies the format to return a response in.
Format string `json:"format"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
}
@@ -109,19 +137,24 @@ type Options struct {
// Runner options which must be set when the model is loaded into memory
type Runner struct {
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
}
type EmbeddingRequest struct {
@@ -137,10 +170,11 @@ type EmbeddingResponse struct {
}
type CreateRequest struct {
Model string `json:"model"`
Path string `json:"path"`
Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"`
Model string `json:"model"`
Path string `json:"path"`
Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"`
Quantization string `json:"quantization,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
@@ -380,16 +414,16 @@ func DefaultOptions() Options {
Runner: Runner{
// options set when the model is loaded
NumCtx: 2048,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1,
NumThread: 0, // let the runtime decide
LowVRAM: false,
F16KV: true,
UseMLock: false,
UseMMap: true,
UseNUMA: false,
NumCtx: 2048,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1,
NumThread: 0, // let the runtime decide
LowVRAM: false,
F16KV: true,
UseMLock: false,
UseMMap: true,
UseNUMA: false,
},
}
}

View File

@@ -9,8 +9,8 @@ import (
"os/signal"
"syscall"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tray"
"ollama.com/app/store"
"ollama.com/app/tray"
)
func Run() {

View File

@@ -9,10 +9,9 @@ import (
"os"
"os/exec"
"path/filepath"
"syscall"
"time"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func getCLIFullPath(command string) string {
@@ -87,19 +86,29 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
// Re-wire context done behavior to attempt a graceful shutdown of the server
cmd.Cancel = func() error {
if cmd.Process != nil {
cmd.Process.Signal(os.Interrupt) //nolint:errcheck
err := terminate(cmd)
if err != nil {
slog.Warn("error trying to gracefully terminate server", "err", err)
return cmd.Process.Kill()
}
tick := time.NewTicker(10 * time.Millisecond)
defer tick.Stop()
for {
select {
case <-tick.C:
// OS agnostic "is it still running"
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
return nil //nolint:nilerr
exited, err := isProcessExited(cmd.Process.Pid)
if err != nil {
return err
}
if exited {
return nil
}
case <-time.After(5 * time.Second):
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
cmd.Process.Kill() //nolint:errcheck
return cmd.Process.Kill()
}
}
}

View File

@@ -4,9 +4,35 @@ package lifecycle
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"syscall"
)
func getCmd(ctx context.Context, cmd string) *exec.Cmd {
return exec.CommandContext(ctx, cmd, "serve")
}
func terminate(cmd *exec.Cmd) error {
return cmd.Process.Signal(os.Interrupt)
}
func isProcessExited(pid int) (bool, error) {
proc, err := os.FindProcess(pid)
if err != nil {
return false, fmt.Errorf("failed to find process: %v", err)
}
err = proc.Signal(syscall.Signal(0))
if err != nil {
if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
return true, nil
}
return false, fmt.Errorf("error signaling process: %v", err)
}
return false, nil
}

View File

@@ -2,12 +2,88 @@ package lifecycle
import (
"context"
"fmt"
"os/exec"
"syscall"
"golang.org/x/sys/windows"
)
func getCmd(ctx context.Context, exePath string) *exec.Cmd {
cmd := exec.CommandContext(ctx, exePath, "serve")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000}
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
}
return cmd
}
func terminate(cmd *exec.Cmd) error {
dll, err := windows.LoadDLL("kernel32.dll")
if err != nil {
return err
}
defer dll.Release() // nolint: errcheck
pid := cmd.Process.Pid
f, err := dll.FindProc("AttachConsole")
if err != nil {
return err
}
r1, _, err := f.Call(uintptr(pid))
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
return err
}
f, err = dll.FindProc("SetConsoleCtrlHandler")
if err != nil {
return err
}
r1, _, err = f.Call(0, 1)
if r1 == 0 {
return err
}
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
if err != nil {
return err
}
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
if r1 == 0 {
return err
}
r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
if r1 == 0 {
return err
}
return nil
}
const STILL_ACTIVE = 259
func isProcessExited(pid int) (bool, error) {
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
if err != nil {
return false, fmt.Errorf("failed to open process: %v", err)
}
defer windows.CloseHandle(hProcess) // nolint: errcheck
var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode)
if err != nil {
return false, fmt.Errorf("failed to get exit code: %v", err)
}
if exitCode == STILL_ACTIVE {
return false, nil
}
return true, nil
}

View File

@@ -18,8 +18,8 @@ import (
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/version"
"ollama.com/auth"
"ollama.com/version"
)
var (

View File

@@ -4,7 +4,7 @@ package main
// go build -ldflags="-H windowsgui" .
import (
"github.com/ollama/ollama/app/lifecycle"
"ollama.com/app/lifecycle"
)
func main() {

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"runtime"
"github.com/ollama/ollama/app/assets"
"github.com/ollama/ollama/app/tray/commontray"
"ollama.com/app/assets"
"ollama.com/app/tray/commontray"
)
func NewTray() (commontray.OllamaTray, error) {
@@ -24,10 +24,5 @@ func NewTray() (commontray.OllamaTray, error) {
return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err)
}
tray, err := InitPlatformTray(icon, updateIcon)
if err != nil {
return nil, err
}
return tray, nil
return InitPlatformTray(icon, updateIcon)
}

View File

@@ -5,7 +5,7 @@ package tray
import (
"fmt"
"github.com/ollama/ollama/app/tray/commontray"
"ollama.com/app/tray/commontray"
)
func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) {

View File

@@ -1,8 +1,8 @@
package tray
import (
"github.com/ollama/ollama/app/tray/commontray"
"github.com/ollama/ollama/app/tray/wintray"
"ollama.com/app/tray/commontray"
"ollama.com/app/tray/wintray"
)
func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) {

View File

@@ -13,8 +13,8 @@ import (
"sync"
"unsafe"
"github.com/ollama/ollama/app/tray/commontray"
"golang.org/x/sys/windows"
"ollama.com/app/tray/commontray"
)
// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32

View File

@@ -30,12 +30,12 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/version"
"ollama.com/api"
"ollama.com/format"
"ollama.com/parser"
"ollama.com/progress"
"ollama.com/server"
"ollama.com/version"
)
func CreateHandler(cmd *cobra.Command, args []string) error {
@@ -105,24 +105,48 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
zf := zip.NewWriter(tf)
files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
files := []string{}
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
}
files = append(files, tfiles...)
if len(files) == 0 {
return fmt.Errorf("no safetensors files were found in '%s'", path)
return fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor config file + tokenizer
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
continue
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return err
}
} else {
continue
}
} else if err != nil {
return err
}
@@ -194,7 +218,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -226,14 +252,6 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
func RunHandler(cmd *cobra.Command, args []string) error {
if os.Getenv("OLLAMA_MODELS") != "" {
return errors.New("OLLAMA_MODELS must only be set for 'ollama serve'")
}
if err := checkServerHeartbeat(cmd, args); err != nil {
return err
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -943,6 +961,7 @@ func NewCLI() *cobra.Command {
}
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -959,10 +978,11 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("system", false, "Show system message of a model")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",
Short: "Run a model",
Args: cobra.MinimumNArgs(1),
RunE: RunHandler,
Use: "run MODEL [PROMPT]",
Short: "Run a model",
Args: cobra.MinimumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: RunHandler,
}
runCmd.Flags().Bool("verbose", false, "Show timings for response")

View File

@@ -14,9 +14,9 @@ import (
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"ollama.com/api"
"ollama.com/progress"
"ollama.com/readline"
)
type MultilineState int

View File

@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func TestExtractFilenames(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func startApp(ctx context.Context, client *api.Client) error {

View File

@@ -6,7 +6,7 @@ import (
"context"
"fmt"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func startApp(ctx context.Context, client *api.Client) error {

View File

@@ -10,7 +10,7 @@ import (
"strings"
"syscall"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func startApp(ctx context.Context, client *api.Client) error {

View File

@@ -1,25 +1,20 @@
package convert
import (
"bytes"
"cmp"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/mitchellh/mapstructure"
"github.com/x448/float16"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
"github.com/ollama/ollama/llm"
"ollama.com/convert/sentencepiece"
"ollama.com/llm"
)
type Params struct {
@@ -45,157 +40,45 @@ type ByteOrder interface {
binary.AppendByteOrder
}
type MetaData struct {
Type string `mapstructure:"dtype"`
Shape []int `mapstructure:"shape"`
Offsets []int `mapstructure:"data_offsets"`
}
type ModelArch interface {
GetTensors() error
LoadVocab() error
WriteGGUF() (string, error)
}
type ModelFormat interface {
GetLayerName(string) (string, error)
GetTensors(string, *Params) ([]llm.Tensor, error)
GetParams(string) (*Params, error)
GetModelArch(string, string, *Params) (ModelArch, error)
}
type ModelData struct {
Path string
Name string
Params *Params
Vocab *Vocab
Tensors []llm.Tensor
Format ModelFormat
}
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
f, err := os.Open(fn)
if err != nil {
return nil, 0, err
}
defer f.Close()
var jsonSize uint64
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
return nil, 0, err
}
buf := make([]byte, jsonSize)
_, err = io.ReadFull(f, buf)
if err != nil {
return nil, 0, err
}
d := json.NewDecoder(bytes.NewBuffer(buf))
d.UseNumber()
var parsed map[string]interface{}
if err = d.Decode(&parsed); err != nil {
return nil, 0, err
}
var keys []string
for k := range parsed {
keys = append(keys, k)
}
slices.Sort(keys)
slog.Info("converting layers")
var tensors []llm.Tensor
for _, k := range keys {
vals := parsed[k].(map[string]interface{})
var data MetaData
if err = mapstructure.Decode(vals, &data); err != nil {
return nil, 0, err
}
var size uint64
var kind uint32
switch len(data.Shape) {
case 0:
// metadata
continue
case 1:
// convert to float32
kind = 0
size = uint64(data.Shape[0] * 4)
case 2:
// convert to float16
kind = 1
size = uint64(data.Shape[0] * data.Shape[1] * 2)
}
ggufName, err := GetTensorName(k)
if err != nil {
slog.Error("%v", err)
return nil, 0, err
}
shape := []uint64{0, 0, 0, 0}
for i := range data.Shape {
shape[i] = uint64(data.Shape[i])
}
t := llm.Tensor{
Name: ggufName,
Kind: kind,
Offset: offset,
Shape: shape[:],
}
t.WriterTo = safetensorWriterTo{
t: &t,
params: params,
bo: params.ByteOrder,
filename: fn,
start: uint64(data.Offsets[0]),
end: uint64(data.Offsets[1]),
padding: 8 + jsonSize,
}
slog.Debug(fmt.Sprintf("%v", t))
tensors = append(tensors, t)
offset += size
}
return tensors, offset, nil
}
func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
var tensors []llm.Tensor
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
func GetModelFormat(dirname string) (ModelFormat, error) {
files, err := filepath.Glob(filepath.Join(dirname, "*"))
if err != nil {
return nil, err
}
var offset uint64
for _, f := range files {
var t []llm.Tensor
var err error
t, offset, err = ReadSafeTensors(f, offset, params)
if err != nil {
slog.Error("%v", err)
return nil, err
for _, fn := range files {
slog.Debug(fmt.Sprintf("file = %s", fn))
if strings.HasSuffix(fn, ".safetensors") {
return &SafetensorFormat{}, nil
} else if strings.HasSuffix(fn, ".bin") {
slog.Debug("model is torch")
return &TorchFormat{}, nil
}
tensors = append(tensors, t...)
}
return tensors, nil
}
func GetParams(dirpath string) (*Params, error) {
f, err := os.Open(filepath.Join(dirpath, "config.json"))
if err != nil {
return nil, err
}
defer f.Close()
var params Params
d := json.NewDecoder(f)
err = d.Decode(&params)
if err != nil {
return nil, err
}
params.ByteOrder = binary.LittleEndian
return &params, nil
return nil, fmt.Errorf("couldn't determine model format")
}
// Details on gguf's tokenizer can be found at:
@@ -206,7 +89,7 @@ type Vocab struct {
Types []int32
}
func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
if err != nil {
@@ -286,8 +169,8 @@ func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
}
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
if vocabSize > len(v.Tokens) {
missingTokens := vocabSize - len(v.Tokens)
if params.VocabSize > len(v.Tokens) {
missingTokens := params.VocabSize - len(v.Tokens)
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := 0; cnt < missingTokens; cnt++ {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
@@ -298,136 +181,3 @@ func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
return v, nil
}
func GetTensorName(n string) (string, error) {
tMap := map[string]string{
"model.embed_tokens.weight": "token_embd.weight",
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"lm_head.weight": "output.weight",
"model.norm.weight": "output_norm.weight",
}
v, ok := tMap[n]
if ok {
return v, nil
}
// quick hack to rename the layers to gguf format
for k, v := range tMap {
re := regexp.MustCompile(k)
newName := re.ReplaceAllString(n, v)
if newName != n {
return newName, nil
}
}
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
}
type safetensorWriterTo struct {
t *llm.Tensor
params *Params
bo ByteOrder
filename string
start, end, padding uint64
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
}
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
f, err := os.Open(r.filename)
if err != nil {
return 0, err
}
defer f.Close()
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
return 0, err
}
// use the handler if one is present
if r.handler != nil {
return 0, r.handler(w, r, f)
}
remaining := r.end - r.start
bufSize := uint64(10240)
var finished bool
for {
data := make([]byte, min(bufSize, remaining))
b, err := io.ReadFull(f, data)
remaining -= uint64(b)
if err == io.EOF || remaining <= 0 {
finished = true
} else if err != nil {
return 0, err
}
// convert bfloat16 -> ieee float32
tDataF32 := bfloat16.DecodeFloat32(data)
switch r.t.Kind {
case 0:
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return 0, err
}
case 1:
// convert float32 -> float16
tempBuf := make([]uint16, len(data)/2)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
return 0, err
}
}
if finished {
break
}
}
return 0, nil
}
func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
switch len(params.Architectures) {
case 0:
return nil, fmt.Errorf("No architecture specified to convert")
case 1:
switch params.Architectures[0] {
case "MistralForCausalLM":
return &MistralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
},
}, nil
case "GemmaForCausalLM":
return &GemmaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
},
}, nil
default:
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
}
}
return nil, fmt.Errorf("Unknown error")
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
"ollama.com/llm"
)
type GemmaModel struct {
@@ -65,13 +65,14 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
}
func (m *GemmaModel) GetTensors() error {
t, err := GetSafeTensors(m.Path, m.Params)
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
m.Tensors = []llm.Tensor{}
for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo)
@@ -85,7 +86,7 @@ func (m *GemmaModel) GetTensors() error {
}
func (m *GemmaModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}

176
convert/llama.go Normal file
View File

@@ -0,0 +1,176 @@
package convert
import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"os"
"regexp"
"strings"
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
"ollama.com/llm"
)
type LlamaModel struct {
ModelData
}
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
data := r.storage.(*pytorch.HalfStorage).Data
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
var err error
var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}
slog.Debug(fmt.Sprintf("heads = %d", heads))
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}
if err = binary.Write(w, r.bo, tData); err != nil {
return err
}
return nil
}
func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(origShape...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}
var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (m *LlamaModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *LlamaModel) LoadVocab() error {
var v *Vocab
var err error
slog.Debug("loading vocab")
v, err = LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
slog.Debug("vocab loaded")
m.Vocab = v
return nil
}
func (m *LlamaModel) WriteGGUF() (string, error) {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
return f.Name(), nil
}

View File

@@ -13,7 +13,7 @@ import (
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
"github.com/ollama/ollama/llm"
"ollama.com/llm"
)
type MistralModel struct {
@@ -97,7 +97,7 @@ func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
}
func (m *MistralModel) GetTensors() error {
t, err := GetSafeTensors(m.Path, m.Params)
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
@@ -124,7 +124,7 @@ func (m *MistralModel) GetTensors() error {
}
func (m *MistralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}

304
convert/safetensors.go Normal file
View File

@@ -0,0 +1,304 @@
package convert
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"slices"
"github.com/d4l3k/go-bfloat16"
"github.com/mitchellh/mapstructure"
"github.com/x448/float16"
"ollama.com/llm"
)
type safetensorWriterTo struct {
t *llm.Tensor
params *Params
bo ByteOrder
filename string
start, end, padding uint64
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
}
type tensorMetaData struct {
Type string `mapstructure:"dtype"`
Shape []int `mapstructure:"shape"`
Offsets []int `mapstructure:"data_offsets"`
}
type SafetensorFormat struct{}
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
slog.Debug("getting tensor data")
var tensors []llm.Tensor
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
if err != nil {
return nil, err
}
var offset uint64
for _, f := range files {
var t []llm.Tensor
var err error
t, offset, err = m.readTensors(f, offset, params)
if err != nil {
slog.Error("%v", err)
return nil, err
}
tensors = append(tensors, t...)
}
slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
return tensors, nil
}
func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
f, err := os.Open(fn)
if err != nil {
return nil, 0, err
}
defer f.Close()
var jsonSize uint64
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
return nil, 0, err
}
buf := make([]byte, jsonSize)
_, err = io.ReadFull(f, buf)
if err != nil {
return nil, 0, err
}
d := json.NewDecoder(bytes.NewBuffer(buf))
d.UseNumber()
var parsed map[string]interface{}
if err = d.Decode(&parsed); err != nil {
return nil, 0, err
}
var keys []string
for k := range parsed {
keys = append(keys, k)
}
slices.Sort(keys)
slog.Info("converting layers")
var tensors []llm.Tensor
for _, k := range keys {
vals := parsed[k].(map[string]interface{})
var data tensorMetaData
if err = mapstructure.Decode(vals, &data); err != nil {
slog.Error("couldn't decode properly")
return nil, 0, err
}
slog.Debug(fmt.Sprintf("metadata = %#v", data))
var size uint64
var kind uint32
switch len(data.Shape) {
case 0:
// metadata
continue
case 1:
// convert to float32
kind = 0
size = uint64(data.Shape[0] * 4)
case 2:
// convert to float16
kind = 1
size = uint64(data.Shape[0] * data.Shape[1] * 2)
}
ggufName, err := m.GetLayerName(k)
if err != nil {
slog.Error("%v", err)
return nil, 0, err
}
shape := []uint64{0, 0, 0, 0}
for i := range data.Shape {
shape[i] = uint64(data.Shape[i])
}
t := llm.Tensor{
Name: ggufName,
Kind: kind,
Offset: offset,
Shape: shape[:],
}
t.WriterTo = safetensorWriterTo{
t: &t,
params: params,
bo: params.ByteOrder,
filename: fn,
start: uint64(data.Offsets[0]),
end: uint64(data.Offsets[1]),
padding: 8 + jsonSize,
}
tensors = append(tensors, t)
offset += size
}
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil
}
func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
f, err := os.Open(filepath.Join(dirpath, "config.json"))
if err != nil {
return nil, err
}
defer f.Close()
var params Params
d := json.NewDecoder(f)
err = d.Decode(&params)
if err != nil {
return nil, err
}
params.ByteOrder = binary.LittleEndian
return &params, nil
}
func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
directMap := map[string]string{
"model.embed_tokens.weight": "token_embd.weight",
"lm_head.weight": "output.weight",
"model.norm.weight": "output_norm.weight",
}
tMap := map[string]string{
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
}
v, ok := directMap[n]
if ok {
return v, nil
}
// quick hack to rename the layers to gguf format
for k, v := range tMap {
re := regexp.MustCompile(k)
newName := re.ReplaceAllString(n, v)
if newName != n {
return newName, nil
}
}
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
}
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
f, err := os.Open(r.filename)
if err != nil {
return 0, err
}
defer f.Close()
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
return 0, err
}
// use the handler if one is present
if r.handler != nil {
return 0, r.handler(w, r, f)
}
remaining := r.end - r.start
bufSize := uint64(10240)
var finished bool
for {
data := make([]byte, min(bufSize, remaining))
b, err := io.ReadFull(f, data)
remaining -= uint64(b)
if err == io.EOF || remaining <= 0 {
finished = true
} else if err != nil {
return 0, err
}
// convert bfloat16 -> ieee float32
tDataF32 := bfloat16.DecodeFloat32(data)
switch r.t.Kind {
case 0:
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return 0, err
}
case 1:
// convert float32 -> float16
tempBuf := make([]uint16, len(data)/2)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err := binary.Write(w, r.bo, tempBuf); err != nil {
return 0, err
}
}
if finished {
break
}
}
return 0, nil
}
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
switch len(params.Architectures) {
case 0:
return nil, fmt.Errorf("No architecture specified to convert")
case 1:
switch params.Architectures[0] {
case "MistralForCausalLM":
return &MistralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "GemmaForCausalLM":
return &GemmaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
default:
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
}
}
return nil, fmt.Errorf("Unknown error")
}

286
convert/torch.go Normal file
View File

@@ -0,0 +1,286 @@
package convert
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/nlpodyssey/gopickle/types"
"github.com/x448/float16"
"ollama.com/llm"
)
type torchWriterTo struct {
t *llm.Tensor
params *Params
bo ByteOrder
storage pytorch.StorageInterface
handler func(w io.Writer, r torchWriterTo) error
}
type TorchFormat struct{}
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
slog.Debug("getting torch tensors")
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
if err != nil {
slog.Error("didn't find any torch files")
return nil, err
}
var offset uint64
var tensors []llm.Tensor
for _, fn := range files {
m, err := pytorch.Load(fn)
if err != nil {
slog.Error(fmt.Sprintf("error unpickling: %q", err))
return []llm.Tensor{}, err
}
for _, k := range m.(*types.Dict).Keys() {
if strings.HasSuffix(k.(string), "self_attn.rotary_emb.inv_freq") {
continue
}
t, _ := m.(*types.Dict).Get(k)
tshape := t.(*pytorch.Tensor).Size
var size uint64
var kind uint32
switch len(tshape) {
case 0:
continue
case 1:
// convert to float32
kind = 0
size = uint64(tshape[0] * 4)
case 2:
// convert to float16
kind = 1
size = uint64(tshape[0] * tshape[1] * 2)
}
ggufName, err := tf.GetLayerName(k.(string))
if err != nil {
slog.Error("%v", err)
return nil, err
}
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
shape := []uint64{0, 0, 0, 0}
for i := range tshape {
shape[i] = uint64(tshape[i])
}
tensor := llm.Tensor{
Name: ggufName,
Kind: kind,
Offset: offset, // calculate the offset
Shape: shape[:],
}
tensor.WriterTo = torchWriterTo{
t: &tensor,
params: params,
bo: params.ByteOrder,
storage: t.(*pytorch.Tensor).Source,
}
tensors = append(tensors, tensor)
offset += size
}
}
return tensors, nil
}
func getAltParams(dirpath string) (*Params, error) {
f, err := os.Open(filepath.Join(dirpath, "params.json"))
if err != nil {
slog.Error("no params.json")
return nil, err
}
defer f.Close()
type TorchParams struct {
HiddenSize int `json:"dim"`
AttentionHeads int `json:"n_heads"`
KeyValHeads int `json:"n_kv_heads"`
HiddenLayers int `json:"n_layers"`
RopeTheta int `json:"rope_theta"`
NormEPS float64 `json:"norm_eps"`
}
var tparams TorchParams
d := json.NewDecoder(f)
err = d.Decode(&tparams)
if err != nil {
return nil, err
}
params := &Params{
HiddenSize: tparams.HiddenSize,
AttentionHeads: tparams.AttentionHeads,
KeyValHeads: tparams.KeyValHeads,
HiddenLayers: tparams.HiddenLayers,
NormEPS: tparams.NormEPS,
}
switch {
case tparams.RopeTheta == 1000000:
// Codellama
params.ContextSize = 16384
case tparams.NormEPS == 1e-06:
// llama2
slog.Debug("Found llama2 - setting context size to 4096")
params.ContextSize = 4096
default:
params.ContextSize = 2048
}
params.ByteOrder = binary.LittleEndian
return params, nil
}
func (m *TorchFormat) GetParams(dirpath string) (*Params, error) {
f, err := os.Open(filepath.Join(dirpath, "config.json"))
if err != nil {
if os.IsNotExist(err) {
// try params.json instead
return getAltParams(dirpath)
} else {
return nil, err
}
}
var params Params
d := json.NewDecoder(f)
err = d.Decode(&params)
if err != nil {
return nil, err
}
params.ByteOrder = binary.LittleEndian
return &params, nil
}
func (m *TorchFormat) GetLayerName(n string) (string, error) {
directMap := map[string]string{
"tok_embeddings.weight": "token_embd.weight",
"output.weight": "output.weight",
"norm.weight": "output_norm.weight",
"rope.freqs": "rope_freqs.weight",
"model.embed_tokens.weight": "token_embd.weight",
"lm_head.weight": "output.weight",
"model.norm.weight": "output_norm.weight",
}
lMap := map[string]string{
"layers.(\\d+).attention_norm.weight": "blk.$1.attn_norm.weight",
"layers.(\\d+).attention_output_norm.weight": "blk.$1.attn_norm.weight",
"layers.(\\d+).feed_forward.w2.weight": "blk.$1.ffn_down.weight",
"layers.(\\d+).feed_forward.w1.weight": "blk.$1.ffn_gate.weight",
"layers.(\\d+).feed_forward.w3.weight": "blk.$1.ffn_up.weight",
"layers.(\\d+).ffn_norm.weight": "blk.$1.ffn_norm.weight",
"layers.(\\d+).attention.wk.weight": "blk.$1.attn_k.weight",
"layers.(\\d+).attention.wo.weight": "blk.$1.attn_output.weight",
"layers.(\\d+).attention.wq.weight": "blk.$1.attn_q.weight",
"layers.(\\d+).attention.wv.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
}
v, ok := directMap[n]
if ok {
return v, nil
}
// quick hack to rename the layers to gguf format
for k, v := range lMap {
re := regexp.MustCompile(k)
newName := re.ReplaceAllString(n, v)
if newName != n {
return newName, nil
}
}
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
}
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
// use the handler if one is present
if r.handler != nil {
return 0, r.handler(w, r)
}
switch r.storage.(type) {
case *pytorch.FloatStorage:
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name))
return 0, nil
case *pytorch.HalfStorage:
switch r.t.Kind {
case 0:
data := r.storage.(*pytorch.HalfStorage).Data
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
if err := binary.Write(w, r.bo, data); err != nil {
return 0, err
}
case 1:
data := r.storage.(*pytorch.HalfStorage).Data
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
if err := binary.Write(w, r.bo, tData); err != nil {
return 0, err
}
}
}
return 0, nil
}
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
switch len(params.Architectures) {
case 0:
return nil, fmt.Errorf("No architecture specified to convert")
case 1:
switch params.Architectures[0] {
case "LlamaForCausalLM":
return &LlamaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
default:
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
}
}
return nil, fmt.Errorf("Unknown error")
}

View File

@@ -139,9 +139,6 @@ PARAMETER <parameter> <parametervalue>
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| num_gqa | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | int | num_gqa 1 |
| num_gpu | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 50 |
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |

View File

@@ -18,7 +18,7 @@ const ollama = new Ollama({
model: "llama2",
});
const answer = await ollama.call(`why is the sky blue?`);
const answer = await ollama.invoke(`why is the sky blue?`);
console.log(answer);
```

51
examples/go-chat/main.go Normal file
View File

@@ -0,0 +1,51 @@
package main
import (
"context"
"fmt"
"log"
"ollama.com/api"
)
func main() {
client, err := api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
messages := []api.Message{
api.Message{
Role: "system",
Content: "Provide very brief, concise responses",
},
api.Message{
Role: "user",
Content: "Name some unusual animals",
},
api.Message{
Role: "assistant",
Content: "Monotreme, platypus, echidna",
},
api.Message{
Role: "user",
Content: "which of these is the most dangerous?",
},
}
ctx := context.Background()
req := &api.ChatRequest{
Model: "llama2",
Messages: messages,
}
respFunc := func(resp api.ChatResponse) error {
fmt.Print(resp.Message.Content)
return nil
}
err = client.Chat(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,40 @@
package main
import (
"context"
"fmt"
"log"
"ollama.com/api"
)
func main() {
client, err := api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
// By default, GenerateRequest is streaming.
req := &api.GenerateRequest{
Model: "gemma",
Prompt: "how many planets are there?",
}
ctx := context.Background()
respFunc := func(resp api.GenerateResponse) error {
// Only print the response here; GenerateResponse has a number of other
// interesting fields you want to examine.
// In streaming mode, responses are partial so we call fmt.Print (and not
// Println) in order to avoid spurious newlines being introduced. The
// model will insert its own newlines if it wants.
fmt.Print(resp.Response)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
fmt.Println()
}

View File

@@ -0,0 +1,37 @@
package main
import (
"context"
"fmt"
"log"
"ollama.com/api"
)
func main() {
client, err := api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
req := &api.GenerateRequest{
Model: "gemma",
Prompt: "how many planets are there?",
// set streaming to false
Stream: new(bool),
}
ctx := context.Background()
respFunc := func(resp api.GenerateResponse) error {
// Only print the response here; GenerateResponse has a number of other
// interesting fields you want to examine.
fmt.Println(resp.Response)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,47 @@
package main
import (
"context"
"fmt"
"log"
"os"
"ollama.com/api"
)
func main() {
if len(os.Args) <= 1 {
log.Fatal("usage: <image name>")
}
imgData, err := os.ReadFile(os.Args[1])
if err != nil {
log.Fatal(err)
}
client, err := api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
req := &api.GenerateRequest{
Model: "llava",
Prompt: "describe this image",
Images: []api.ImageData{imgData},
}
ctx := context.Background()
respFunc := func(resp api.GenerateResponse) error {
// In streaming mode, responses are partial so we call fmt.Print (and not
// Println) in order to avoid spurious newlines being introduced. The
// model will insert its own newlines if it wants.
fmt.Print(resp.Response)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
fmt.Println()
}

View File

@@ -0,0 +1,31 @@
package main
import (
"context"
"fmt"
"log"
"ollama.com/api"
)
func main() {
client, err := api.ClientFromEnvironment()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
req := &api.PullRequest{
Model: "mistral",
}
progressFunc := func(resp api.ProgressResponse) error {
fmt.Printf("Progress: status=%v, total=%v, completed=%v\n", resp.Status, resp.Total, resp.Completed)
return nil
}
err = client.Pull(ctx, req, progressFunc)
if err != nil {
log.Fatal(err)
}
}

View File

@@ -50,7 +50,7 @@ func HumanBytes(b int64) string {
}
}
func HumanBytes2(b int64) string {
func HumanBytes2(b uint64) string {
switch {
case b >= MebiByte:
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)

9
go.mod
View File

@@ -1,4 +1,4 @@
module github.com/ollama/ollama
module ollama.com
go 1.22
@@ -19,7 +19,10 @@ require (
golang.org/x/sync v0.3.0
)
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
)
require (
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
@@ -68,7 +71,7 @@ require (
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0
golang.org/x/term v0.13.0
golang.org/x/text v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)

6
go.sum
View File

@@ -122,6 +122,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 h1:DV4iXjNn6fGeDl1AkZ1I0QB/0DBjrc7kPpxHrmuDzW4=
@@ -236,8 +238,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -21,7 +21,7 @@ import (
"sync"
"unsafe"
"github.com/ollama/ollama/format"
"ollama.com/format"
)
type handles struct {
@@ -243,7 +243,7 @@ func getCPUMem() (memInfo, error) {
return ret, nil
}
func CheckVRAM() (int64, error) {
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
@@ -251,11 +251,11 @@ func CheckVRAM() (int64, error) {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return avail, nil
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return int64(gpuInfo.FreeMemory), nil
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation

View File

@@ -17,7 +17,7 @@ import (
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) {
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
@@ -25,15 +25,15 @@ func CheckVRAM() (int64, error) {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return avail, nil
return uint64(avail), nil
}
if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal
return 0, nil
}
recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM())
return recommendedMaxVRAM, nil
return uint64(C.getRecommendedMaxVRAM()), nil
}
func GetGPUInfo() GpuInfo {
@@ -53,8 +53,8 @@ func GetGPUInfo() GpuInfo {
func getCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: 0,
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0,
DeviceCount: 0,
DeviceCount: 1,
}, nil
}

View File

@@ -1,3 +1,4 @@
#import <Metal/Metal.h>
#include <stdint.h>
uint64_t getRecommendedMaxVRAM();
uint64_t getPhysicalMemory();

View File

@@ -1,11 +1,13 @@
//go:build darwin
// go:build darwin
#include "gpu_info_darwin.h"
uint64_t getRecommendedMaxVRAM()
{
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
uint64_t result = device.recommendedMaxWorkingSetSize;
CFRelease(device);
return result;
uint64_t getRecommendedMaxVRAM() {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
uint64_t result = device.recommendedMaxWorkingSetSize;
CFRelease(device);
return result;
}
uint64_t getPhysicalMemory() {
return [[NSProcessInfo processInfo] physicalMemory];
}

View File

@@ -15,7 +15,7 @@ type GpuInfo struct {
Variant string `json:"variant,omitempty"`
// MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory int64 `json:"-"`
MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func TestOrcaMiniBlueSky(t *testing.T) {

View File

@@ -8,7 +8,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func TestContextExhaustion(t *testing.T) {

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
"ollama.com/api"
)
func TestIntegrationMultimodal(t *testing.T) {

View File

@@ -9,7 +9,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server

View File

@@ -21,9 +21,9 @@ import (
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"ollama.com/api"
"ollama.com/app/lifecycle"
)
func FindPort() string {

View File

@@ -39,6 +39,10 @@
#include "httplib.h"
#include "json.hpp"
#if defined(_WIN32)
#include <windows.h>
#endif
#include <cstddef>
#include <thread>
#include <chrono>
@@ -2770,8 +2774,28 @@ inline void signal_handler(int signal) {
shutdown_handler(signal);
}
int main(int argc, char **argv)
{
#if defined(_WIN32)
char* wchar_to_char(const wchar_t* wstr) {
if (wstr == nullptr) return nullptr;
// Determine the number of bytes needed for the UTF-8 string
int bytes = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr);
char* str = new char[bytes];
// Convert the wide-character string to a UTF-8 string
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, bytes, nullptr, nullptr);
return str;
}
int wmain(int argc, wchar_t **wargv) {
char** argv = new char*[argc];
for (int i = 0; i < argc; ++i) {
argv[i] = wchar_to_char(wargv[i]);
}
#else
int main(int argc, char **argv) {
#endif
#if SERVER_VERBOSE != 1
log_disable();
#endif
@@ -3282,6 +3306,11 @@ int main(int argc, char **argv)
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
for (int i = 0; i < argc; ++i) {
delete[] argv[i];
}
delete[] argv;
#endif
llama.queue_tasks.start_loop();
svr.stop();

View File

@@ -49,7 +49,7 @@ func (llm *ggla) KV() KV {
return llm.kv
}
func (llm *ggla) Tensors() []*Tensor {
func (llm *ggla) Tensors() Tensors {
return llm.tensors
}

View File

@@ -13,16 +13,6 @@ type GGML struct {
model
}
func (ggml *GGML) LayerSize(prefix string) (n int64) {
for _, t := range ggml.Tensors() {
if strings.HasPrefix(t.Name, prefix) {
n += int64(t.size())
}
}
return
}
const (
fileTypeF32 uint32 = iota
fileTypeF16
@@ -101,7 +91,7 @@ func fileType(fileType uint32) string {
type model interface {
KV() KV
Tensors() []*Tensor
Tensors() Tensors
}
type KV map[string]any
@@ -167,6 +157,37 @@ func (kv KV) ContextLength() uint64 {
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
}
type Tensors []*Tensor
func (ts Tensors) Layers() map[string]Layer {
layers := make(map[string]Layer)
for _, t := range ts {
parts := strings.Split(t.Name, ".")
if parts[0] == "blk" {
// join first and second part, e.g. blk.%d
parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
}
if _, ok := layers[parts[0]]; !ok {
layers[parts[0]] = make(Layer)
}
layers[parts[0]][strings.Join(parts[1:], ".")] = t
}
return layers
}
type Layer map[string]*Tensor
func (l Layer) size() (size uint64) {
for _, t := range l {
size += t.size()
}
return size
}
type Tensor struct {
Name string `json:"name"`
Kind uint32 `json:"kind"`
@@ -304,49 +325,69 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
}, offset, nil
}
func (llm GGML) GraphSize(context, batch int) (int64, bool) {
embeddingLength := llm.KV().EmbeddingLength()
headCount := llm.KV().HeadCount()
headCountKV := llm.KV().HeadCountKV()
vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
embedding := llm.KV().EmbeddingLength()
heads := llm.KV().HeadCount()
headsKV := llm.KV().HeadCountKV()
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
var attnQKVWeight1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
attnQKVWeight1 = t.Shape[1]
break
}
}
var ffnGate1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
ffnGate1 = t.Shape[1]
break
}
}
layers := llm.Tensors().Layers()
switch llm.KV().Architecture() {
case "gemma", "command-r":
return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
case "phi2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
), true
case "qwen2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
), true
case "llama":
if ffnGate1 > 0 {
// moe
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
partialOffload = 4 * batch * embedding
partialOffload += max(
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
partialOffload = max(
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
)
}
return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
case "gemma":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
case "command-r":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+4*embedding+context*(1+heads)),
)
partialOffload = max(
4*batch*(embedding+vocab)+embedding*vocab*105/128,
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
)
case "qwen2":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(1+2*embedding+context+context*heads),
)
partialOffload = max(
4*batch*(embedding+vocab)+embedding*vocab*105/128,
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
)
case "phi2":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(1+4*embedding+context+context*heads),
)
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
case "stablelm":
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
partialOffload = max(
4*batch*(vocab+2*embedding),
fullOffload,
)
}
return 0, false
return
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"io"
"strings"
"log/slog"
)
type containerGGUF struct {
@@ -52,6 +54,7 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
}
model := newGGUF(c)
slog.Debug(fmt.Sprintf("model = %#v", model))
if err := model.Decode(rs); err != nil {
return nil, err
}
@@ -109,7 +112,7 @@ func (llm *gguf) KV() KV {
return llm.kv
}
func (llm *gguf) Tensors() []*Tensor {
func (llm *gguf) Tensors() Tensors {
return llm.tensors
}
@@ -187,6 +190,8 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
llm.kv[k] = v
}
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
// decode tensors
for i := 0; uint64(i) < llm.numTensor(); i++ {
name, err := readGGUFString(llm, rs)
@@ -248,8 +253,12 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
}
for _, tensor := range llm.tensors {
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil {
return err
}
padding := llm.padding(int64(tensor.size()), int64(alignment))
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err
}
}
@@ -451,6 +460,7 @@ var ggufKVOrder = map[string][]string{
"llama": {
"general.architecture",
"general.name",
"llama.vocab_size",
"llama.context_length",
"llama.embedding_length",
"llama.block_count",
@@ -509,11 +519,17 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
kvCheck := make(map[string]bool)
for k := range kv {
kvCheck[k] = false
}
for _, k := range ggufKVOrder["llama"] {
v, ok := kv[k]
if !ok {
continue
}
kvCheck[k] = true
if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
return err
@@ -567,6 +583,12 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
}
}
for k, v := range kvCheck {
if !v {
return fmt.Errorf("Didn't know how to write kv %s", k)
}
}
for _, tensor := range tensors {
if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
return err
@@ -605,8 +627,9 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
padding := llm.padding(offset, 32)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
var alignment int64 = 32
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
@@ -620,8 +643,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
padding := llm.padding(offset, 32)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
}
@@ -630,5 +653,5 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
}
func (gguf) padding(offset, align int64) int64 {
return (offset + align - 1) / align * align
return (align - offset%align) % align
}

View File

@@ -6,10 +6,81 @@ package llm
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h>
// #include "llama.h"
import "C"
import (
"fmt"
"unsafe"
)
// SystemInfo is an unused example of calling llama.cpp functions using CGo
func SystemInfo() string {
return C.GoString(C.llama_print_system_info())
}
func Quantize(infile, outfile, filetype string) error {
cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile))
coutfile := C.CString(outfile)
defer C.free(unsafe.Pointer(coutfile))
params := C.llama_model_quantize_default_params()
params.nthread = -1
switch filetype {
case "F32":
params.ftype = fileTypeF32
case "F16":
params.ftype = fileTypeF16
case "Q4_0":
params.ftype = fileTypeQ4_0
case "Q4_1":
params.ftype = fileTypeQ4_1
case "Q4_1_F16":
params.ftype = fileTypeQ4_1_F16
case "Q8_0":
params.ftype = fileTypeQ8_0
case "Q5_0":
params.ftype = fileTypeQ5_0
case "Q5_1":
params.ftype = fileTypeQ5_1
case "Q2_K":
params.ftype = fileTypeQ2_K
case "Q3_K_S":
params.ftype = fileTypeQ3_K_S
case "Q3_K_M":
params.ftype = fileTypeQ3_K_M
case "Q3_K_L":
params.ftype = fileTypeQ3_K_L
case "Q4_K_S":
params.ftype = fileTypeQ4_K_S
case "Q4_K_M":
params.ftype = fileTypeQ4_K_M
case "Q5_K_S":
params.ftype = fileTypeQ5_K_S
case "Q5_K_M":
params.ftype = fileTypeQ5_K_M
case "Q6_K":
params.ftype = fileTypeQ6_K
case "IQ2_XXS":
params.ftype = fileTypeIQ2_XXS
case "IQ2_XS":
params.ftype = fileTypeIQ2_XS
case "Q2_K_S":
params.ftype = fileTypeQ2_K_S
case "Q3_K_XS":
params.ftype = fileTypeQ3_K_XS
case "IQ3_XXS":
params.ftype = fileTypeIQ3_XXS
default:
return fmt.Errorf("unknown filetype: %s", filetype)
}
if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
return fmt.Errorf("llama_model_quantize: %d", retval)
}
return nil
}

View File

@@ -14,7 +14,7 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/gpu"
"ollama.com/gpu"
)
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")

View File

@@ -17,14 +17,13 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"ollama.com/api"
"ollama.com/format"
"ollama.com/gpu"
)
// LlamaServer is an instance of the llama.cpp server
@@ -36,15 +35,7 @@ type LlamaServer struct {
options api.Options
}
var cpuOnlyFamilies = []string{
"mamba",
}
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
@@ -65,67 +56,123 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
opts.NumCtx = 4
}
availableMemory, _ := gpu.CheckVRAM()
memoryAvailable, _ := gpu.CheckVRAM()
info := gpu.GetGPUInfo()
usedMemory := info.MinimumMemory
memoryMinimum := info.MinimumMemory
for _, projector := range projectors {
usedMemory += projectorMemoryRequirements(projector)
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graph, ok := ggml.GraphSize(opts.NumCtx, min(opts.NumCtx, opts.NumBatch))
if !ok {
graph = int64(ggml.KV().GQA()) * kv / 6
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
usedMemory += graph
if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {
info.Library = "cpu"
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
requiredMemory := usedMemory
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
var layers int
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
layerMemory := ggml.LayerSize(fmt.Sprintf("blk.%d.", i)) + kv/int64(ggml.KV().BlockCount())
requiredMemory += layerMemory
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
if availableMemory > usedMemory+layerMemory && (opts.NumGPU < 0 || layers < opts.NumGPU) {
usedMemory += layerMemory
layers++
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
}
}
memOutputLayer := ggml.LayerSize("output.")
requiredMemory += memOutputLayer
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// only offload output layer if all repeating layers are offloaded
if layers >= int(ggml.KV().BlockCount()) && availableMemory > usedMemory+memOutputLayer {
usedMemory += memOutputLayer
layers++
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
"layers", layers,
"required", format.HumanBytes2(requiredMemory),
"used", format.HumanBytes2(usedMemory),
"available", format.HumanBytes2(availableMemory),
"kv", format.HumanBytes2(kv),
"graph", format.HumanBytes2(graph),
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if opts.NumGPU < 0 && info.Library != "cpu" {
opts.NumGPU = layers
}
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
@@ -269,12 +316,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait()
}()
if err = s.waitUntilRunning(); err != nil {
slog.Error("error starting llama server", "server", servers[i], "error", err)
s.Close()
finalErr = err
continue
}
return s, nil
}
@@ -282,7 +323,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
return nil, finalErr
}
func projectorMemoryRequirements(filename string) int64 {
func projectorMemoryRequirements(filename string) uint64 {
file, err := os.Open(filename)
if err != nil {
return 0
@@ -294,18 +335,12 @@ func projectorMemoryRequirements(filename string) int64 {
return 0
}
prefixes := make(map[string]struct{})
for _, layer := range ggml.Tensors() {
parts := strings.Split(layer.Name, ".")
prefixes[strings.Join(parts[:2], ".")] = struct{}{}
var mem uint64
for _, layer := range ggml.Tensors().Layers() {
mem += layer.size()
}
var ask int64
for prefix := range prefixes {
ask += ggml.LayerSize(prefix)
}
return ask
return mem
}
type ServerStatus int
@@ -381,9 +416,10 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil
}
func (s *LlamaServer) waitUntilRunning() error {
func (s *LlamaServer) WaitUntilRunning() error {
start := time.Now()
expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()

View File

@@ -14,7 +14,7 @@ go build .
Then run the desktop app with `npm start`:
```
cd app
cd macapp
npm install
npm start
```

View File

@@ -3,8 +3,8 @@ package main
import (
"context"
"github.com/ollama/ollama/cmd"
"github.com/spf13/cobra"
"ollama.com/cmd"
)
func main() {

View File

@@ -11,7 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
type Error struct {

View File

@@ -6,8 +6,8 @@ import (
"strings"
"time"
"github.com/ollama/ollama/format"
"golang.org/x/term"
"ollama.com/format"
)
type Bar struct {

View File

@@ -15,8 +15,8 @@ import (
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"ollama.com/api"
"ollama.com/auth"
)
type registryChallenge struct {

View File

@@ -21,8 +21,8 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"ollama.com/api"
"ollama.com/format"
)
const maxRetries = 6
@@ -247,7 +247,8 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled

View File

@@ -24,12 +24,12 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
"ollama.com/api"
"ollama.com/convert"
"ollama.com/format"
"ollama.com/llm"
"ollama.com/parser"
"ollama.com/version"
)
type registryOptions struct {
@@ -284,7 +284,7 @@ func realpath(mfDir, from string) string {
return abspath
}
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
@@ -322,7 +322,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertSafetensors(name, pathName, fn)
ggufName, err := convertModel(name, pathName, fn)
if err != nil {
var pathErr *fs.PathError
switch {
@@ -337,8 +337,27 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if ggufName != "" {
pathName = ggufName
slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
defer os.RemoveAll(ggufName)
if quantization != "" {
quantization = strings.ToUpper(quantization)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
if err != nil {
return err
}
defer os.RemoveAll(tempfile.Name())
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
return err
}
if err := tempfile.Close(); err != nil {
return err
}
pathName = tempfile.Name()
}
}
bin, err := os.Open(pathName)
@@ -614,7 +633,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return nil
}
func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
r, err := zip.OpenReader(path)
if err != nil {
return "", err
@@ -649,17 +668,22 @@ func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (
rc.Close()
}
params, err := convert.GetParams(tempDir)
mf, err := convert.GetModelFormat(tempDir)
if err != nil {
return "", err
}
mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
params, err := mf.GetParams(tempDir)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing safetensors"})
mArch, err := mf.GetModelArch(name, tempDir, params)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return "", err
}

View File

@@ -7,7 +7,7 @@ import (
"text/template"
"text/template/parse"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
// isResponseNode checks if the node contains .Response

View File

@@ -4,7 +4,7 @@ import (
"strings"
"testing"
"github.com/ollama/ollama/api"
"ollama.com/api"
)
func TestPrompt(t *testing.T) {

View File

@@ -27,12 +27,12 @@ import (
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
"ollama.com/api"
"ollama.com/gpu"
"ollama.com/llm"
"ollama.com/openai"
"ollama.com/parser"
"ollama.com/version"
)
var mode string = gin.DebugMode
@@ -68,6 +68,18 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
@@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
loaded.llama.Close()
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
@@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
unload()
})
}
@@ -647,7 +651,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -913,6 +917,24 @@ func HeadBlobHandler(c *gin.Context) {
}
func CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
// noop
case err != nil:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
default:
c.Status(http.StatusOK)
return
}
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1128,9 +1150,7 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.llama != nil {
loaded.llama.Close()
}
unload()
gpu.Cleanup()
os.Exit(0)
}()

View File

@@ -16,9 +16,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
"ollama.com/api"
"ollama.com/parser"
"ollama.com/version"
)
func Test_Routes(t *testing.T) {
@@ -61,7 +61,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", commands, fn)
err = CreateModel(context.TODO(), name, "", "", commands, fn)
assert.Nil(t, err)
}

View File

@@ -16,9 +16,9 @@ import (
"sync/atomic"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"golang.org/x/sync/errgroup"
"ollama.com/api"
"ollama.com/format"
)
var blobUploadManager sync.Map

79
types/model/digest.go Normal file
View File

@@ -0,0 +1,79 @@
package model
import (
"log/slog"
"strings"
"unicode"
)
// Digest represents a digest of a model Manifest. It is a comparable value
// type and is immutable.
//
// The zero Digest is not a valid digest.
type Digest struct {
s string
}
// Type returns the digest type of the digest.
//
// Example:
//
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
}
// String returns the digest in the form of "<digest-type>-<digest>", or the
// empty string if the digest is invalid.
func (d Digest) String() string { return d.s }
// IsValid returns true if the digest is valid (not zero).
//
// A valid digest may be created only by ParseDigest, or
// ParseName(name).Digest().
func (d Digest) IsValid() bool { return d.s != "" }
// LogValue implements slog.Value.
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}
var (
_ slog.LogValuer = Digest{}
)
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
// Digest.
func ParseDigest(s string) Digest {
typ, digest, ok := strings.Cut(s, "-")
if ok && isValidDigestType(typ) && isValidHex(digest) {
return Digest{s: s}
}
return Digest{}
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}

View File

@@ -0,0 +1,46 @@
package model
import "testing"
var testDigests = map[string]Digest{
"": {},
"sha256-1234": {s: "sha256-1234"},
"sha256-5678": {s: "sha256-5678"},
"blake2-9abc": {s: "blake2-9abc"},
"-1234": {},
"sha256-": {},
"sha256-1234-5678": {},
"sha256-P": {}, // invalid hex
"sha256-1234P": {},
"---": {},
}
func TestDigestParse(t *testing.T) {
// Test cases.
for s, want := range testDigests {
got := ParseDigest(s)
t.Logf("ParseDigest(%q) = %#v", s, got)
if got != want {
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
}
}
}
func TestDigestString(t *testing.T) {
// Test cases.
for s, d := range testDigests {
want := s
if !d.IsValid() {
want = ""
}
got := d.String()
if got != want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
}
got = ParseDigest(s).String()
if got != want {
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
}
}
}

693
types/model/name.go Normal file
View File

@@ -0,0 +1,693 @@
package model
import (
"cmp"
"errors"
"fmt"
"hash/maphash"
"io"
"log/slog"
"path/filepath"
"slices"
"strings"
"sync"
"ollama.com/types/structs"
)
// Errors
var (
// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
// used by this package, but are exported so that other packages can
// use them, instead of defining their own errors for them.
ErrInvalidName = errors.New("invalid model name")
ErrIncompleteName = errors.New("incomplete model name")
ErrInvalidDigest = errors.New("invalid digest")
)
// Defaults
const (
// MaskDefault is the default mask used by [Name.DisplayShortest].
MaskDefault = "registry.ollama.ai/library/?:latest"
// MaskNothing is a mask that masks nothing.
MaskNothing = "?/?/?:?"
// DefaultFill is the default fill used by [ParseName].
FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
// FillNothing is a fill that fills nothing.
FillNothing = "?/?/?:?+?"
)
const MaxNamePartLen = 128
type PartKind int
// Levels of concreteness
const (
// Each value aligns with its index in the Name.parts array.
PartHost PartKind = iota
PartNamespace
PartModel
PartTag
PartBuild
PartDigest
// NumParts is the number of parts in a Name. In this list, it must
// follow the final part.
NumParts
PartExtraneous = -1
)
var kindNames = map[PartKind]string{
PartHost: "Host",
PartNamespace: "Namespace",
PartModel: "Name",
PartTag: "Tag",
PartBuild: "Build",
PartDigest: "Digest",
}
func (k PartKind) String() string {
return cmp.Or(kindNames[k], "Unknown")
}
// Name is an opaque reference to a model. It holds the parts of a model
// with the case preserved, but is not directly comparable with other Names
// since model names can be represented with different casing depending on
// the use case. For instance, "Mistral" and "mistral" are the same model
// but each version may have come from different sources (e.g. copied from a
// Web page, or from a file path).
//
// Valid Names can ONLY be constructed by calling [ParseName].
//
// A Name is valid if and only if is have a valid Model part. The other parts
// are optional.
//
// A Name is considered "complete" if it has all parts present. To check if a
// Name is complete, use [Name.IsComplete].
//
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
//
// The parts of a Name are:
//
// - Host: the domain of the model (optional)
// - Namespace: the namespace of the model (optional)
// - Model: the name of the model (required)
// - Tag: the tag of the model (optional)
// - Build: the build of the model; usually the quantization or "file type" (optional)
//
// The parts can be obtained in their original form by calling [Name.Parts].
//
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
type Name struct {
_ structs.Incomparable
parts [NumParts]string // host, namespace, model, tag, build, digest
// TODO(bmizerany): track offsets and hold s (raw string) here? We
// could pack the offsets all into a single uint64 since the first
// parts take less bits since their max offset is less than the max
// offset of the next part. This would save a ton of bytes per Name
// and mean zero allocations for String.
}
// ParseName parses s into a Name, and returns the result of filling it with
// defaults. The input string must be a valid string
// representation of a model name in the form:
//
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
//
// The name part is required, all others are optional. If a part is missing,
// it is left empty in the returned Name. If a part is invalid, the zero Ref
// value is returned.
//
// The build part is normalized to uppercase.
//
// Examples of valid paths:
//
// "example.com/library/mistral:7b+x"
// "example.com/eva/mistral:7b+Q4_0"
// "mistral:7b+x"
// "example.com/mike/mistral:latest+Q4_0"
// "example.com/bruce/mistral:latest"
// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
//
// Examples of invalid paths:
//
// "example.com/mistral:7b+"
// "example.com/mistral:7b+Q4_0+"
// "x/y/z/z:8n+I"
// ""
//
// It returns the zero value if any part is invalid.
//
// # Fills
//
// For any valid s, the fill string is used to fill in missing parts of the
// Name. The fill string must be a valid Name with the exception that any part
// may be the string ("?"), which will not be considered for filling.
func ParseName(s, fill string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if kind == PartDigest && !ParseDigest(part).IsValid() {
r = Name{}
return false
}
if kind == PartExtraneous || !isValidPart(kind, part) {
r = Name{}
return false
}
r.parts[kind] = part
return true
})
if r.IsValid() || r.IsResolved() {
return fillName(r, fill)
}
return Name{}
}
func parseMask(s string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if part == "?" {
// mask part; treat as empty but valid
return true
}
if !isValidPart(kind, part) {
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
}
r.parts[kind] = part
return true
})
return r
}
func MustParseName(s, fill string) Name {
r := ParseName(s, fill)
if !r.IsValid() {
panic("invalid Name: " + s)
}
return r
}
// fillName fills in the missing parts of dst with the parts of src.
//
// The returned Name will only be valid if dst is valid.
//
// It skipps fill parts that are "?".
func fillName(r Name, fill string) Name {
fill = cmp.Or(fill, FillDefault)
f := parseMask(fill)
if fill != FillNothing && f.IsZero() {
panic("invalid fill")
}
for i := range r.parts {
if f.parts[i] == "?" {
continue
}
r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
}
return r
}
// WithBuild returns a copy of r with the build set to the given string.
func (r Name) WithBuild(build string) Name {
r.parts[PartBuild] = build
return r
}
func (r Name) WithDigest(digest Digest) Name {
r.parts[PartDigest] = digest.String()
return r
}
var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality
// checks. For a convenient way to compare names, use [Name.EqualFold].
//
//nolint:errcheck
func (r Name) MapHash() uint64 {
// correctly hash the parts with case insensitive comparison
var h maphash.Hash
h.SetSeed(mapHashSeed)
for _, part := range r.parts {
// downcase the part for hashing
for i := range part {
c := part[i]
if c >= 'A' && c <= 'Z' {
c = c - 'A' + 'a'
}
h.WriteByte(c)
}
}
return h.Sum64()
}
func (r Name) slice(from, to PartKind) Name {
var v Name
copy(v.parts[from:to+1], r.parts[from:to+1])
return v
}
// DisplayShortest returns the shortest possible, masked display string in form:
//
// [host/][<namespace>/]<model>[:<tag>]
//
// # Masks
//
// The mask is a string that specifies which parts of the name to omit based
// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
// that are the same as the mask, moving from left to right until the first
// unequal part is found. It then moves right to left until the first unequal
// part is found. The result is the shortest possible display string.
//
// Unlike a [Name] the mask can contain "?" characters which are treated as
// wildcards. A "?" will never match a part of the name, since a valid name
// can never contain a "?" character.
//
// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
// with ("registry.ollama.ai/library/?:latest") will produce the display string
// ("mistral").
//
// If mask is the empty string, then [MaskDefault] is used.
//
// DisplayShortest panics if the mask is not the empty string, MaskNothing, and
// invalid.
//
// # Builds
//
// For now, DisplayShortest does consider the build or return one in the
// result. We can lift this restriction when needed.
func (r Name) DisplayShortest(mask string) string {
mask = cmp.Or(mask, MaskDefault)
d := parseMask(mask)
if mask != MaskNothing && r.IsZero() {
panic("invalid Name")
}
for i := range PartTag {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
}
for i := PartTag; i >= 0; i-- {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
}
return r.slice(PartHost, PartTag).DisplayLong()
}
// DisplayLongest returns the result of r.DisplayShortest(MaskNothing).
func (r Name) DisplayLongest() string {
return r.DisplayShortest(MaskNothing)
}
var seps = [...]string{
PartHost: "/",
PartNamespace: "/",
PartModel: ":",
PartTag: "+",
PartBuild: "@",
PartDigest: "",
}
// WriteTo implements io.WriterTo. It writes the fullest possible display
// string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
//
// Missing parts and their separators are not written.
//
// The full digest is always prefixed with "@". That is if [Name.IsValid]
// reports false and [Name.IsResolved] reports true, then the string is
// returned as "@<digest-type>-<digest>".
func (r Name) writeTo(w io.StringWriter) error {
var partsWritten int
for i := range r.parts {
if r.parts[i] == "" {
continue
}
if partsWritten > 0 || i == int(PartDigest) {
if _, err := w.WriteString(seps[i-1]); err != nil {
return err
}
}
if _, err := w.WriteString(r.parts[i]); err != nil {
return err
}
partsWritten++
}
return nil
}
var builderPool = sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
}
// DisplayLong returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayLong() string {
b := builderPool.Get().(*strings.Builder)
defer builderPool.Put(b)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
_ = r.writeTo(b)
return b.String()
}
// GoString implements fmt.GoStringer. It returns a string suitable for
// debugging and logging. It is similar to [Name.DisplayLong] but it always
// returns a string that includes all parts of the Name, with missing parts
// replaced with a ("?").
func (r Name) GoString() string {
for i := range r.parts {
r.parts[i] = cmp.Or(r.parts[i], "?")
}
return r.DisplayLong()
}
// LogValue implements slog.Valuer.
func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString())
}
// IsComplete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) IsComplete() bool {
return !slices.Contains(r.parts[:PartDigest], "")
}
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
// build part to be present.
func (r Name) IsCompleteNoBuild() bool {
return !slices.Contains(r.parts[:PartBuild], "")
}
// IsResolved reports true if the Name has a valid digest.
//
// It is possible to have a valid Name, or a complete Name that is not
// resolved.
func (r Name) IsResolved() bool {
return r.Digest().IsValid()
}
// Digest returns the digest part of the Name, if any.
//
// If Digest returns a non-empty string, then [Name.IsResolved] will return
// true, and digest is considered valid.
func (r Name) Digest() Digest {
// This was already validated by ParseName, so we can just return it.
return Digest{r.parts[PartDigest]}
}
// EqualFold reports whether r and o are equivalent model names, ignoring
// case.
func (r Name) EqualFold(o Name) bool {
return r.CompareFold(o) == 0
}
// CompareFold performs a case-insensitive cmp.Compare on r and o.
//
// This can be used with [slices.SortFunc].
//
// For simple equality checks, use [Name.EqualFold].
func (r Name) CompareFold(o Name) int {
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
}
func compareFold(a, b string) int {
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
return cmp.Compare(downcase(a), downcase(b))
})
}
func downcase(r rune) rune {
if r >= 'A' && r <= 'Z' {
return r - 'A' + 'a'
}
return r
}
func (r Name) Host() string { return r.parts[PartHost] }
func (r Name) Namespace() string { return r.parts[PartNamespace] }
func (r Name) Model() string { return r.parts[PartModel] }
func (r Name) Build() string { return r.parts[PartBuild] }
func (r Name) Tag() string { return r.parts[PartTag] }
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
// restrictions in the go1.22 iter package requiring the
// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
// which we are not yet ready to support.
//
// Once we are ready to support rangefunc, this can be removed and replaced
// with the iter.Seq2 type.
type iter_Seq2[A, B any] func(func(A, B) bool)
// Parts returns a sequence of the parts of a Name string from most specific
// to least specific.
//
// It normalizes the input string by removing "http://" and "https://" only.
// No other normalizations are performed.
func parts(s string) iter_Seq2[PartKind, string] {
return func(yield func(PartKind, string) bool) {
if strings.HasPrefix(s, "http://") {
s = strings.TrimPrefix(s, "http://")
} else {
s = strings.TrimPrefix(s, "https://")
}
if len(s) > MaxNamePartLen || len(s) == 0 {
return
}
numConsecutiveDots := 0
partLen := 0
state, j := PartDigest, len(s)
for i := len(s) - 1; i >= 0; i-- {
if partLen++; partLen > MaxNamePartLen {
// catch a part that is too long early, so
// we don't keep spinning on it, waiting for
// an isInValidPart check which would scan
// over it again.
yield(state, s[i+1:j])
return
}
switch s[i] {
case '@':
switch state {
case PartDigest:
if !yield(PartDigest, s[i+1:j]) {
return
}
if i == 0 {
// This is the form
// "@<digest>" which is valid.
//
// We're done.
return
}
state, j, partLen = PartBuild, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
case '+':
switch state {
case PartBuild, PartDigest:
if !yield(PartBuild, s[i+1:j]) {
return
}
state, j, partLen = PartTag, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
case ':':
switch state {
case PartTag, PartBuild, PartDigest:
if !yield(PartTag, s[i+1:j]) {
return
}
state, j, partLen = PartModel, i, 0
case PartHost:
// noop: support for host:port
default:
yield(PartExtraneous, s[i+1:j])
return
}
case '/':
switch state {
case PartModel, PartTag, PartBuild, PartDigest:
if !yield(PartModel, s[i+1:j]) {
return
}
state, j = PartNamespace, i
case PartNamespace:
if !yield(PartNamespace, s[i+1:j]) {
return
}
state, j, partLen = PartHost, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
default:
if s[i] == '.' {
if numConsecutiveDots++; numConsecutiveDots > 1 {
yield(state, "")
return
}
} else {
numConsecutiveDots = 0
}
}
}
if state <= PartNamespace {
yield(state, s[:j])
} else {
yield(PartModel, s[:j])
}
}
}
func (r Name) IsZero() bool {
return r.parts == [NumParts]string{}
}
// IsValid reports if a model has at minimum a valid model part.
func (r Name) IsValid() bool {
// Parts ensures we only have valid parts, so no need to validate
// them here, only check if we have a name or not.
return r.parts[PartModel] != ""
}
// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically,
// it trims any leading "/" and then calls [ParseName] with fill.
func ParseNameFromURLPath(s, fill string) Name {
s = strings.TrimPrefix(s, "/")
return ParseName(s, fill)
}
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
// complete Name.
//
// The parts maintain their original case.
//
// Example:
//
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
func (r Name) URLPath() string {
return r.DisplayShortest(MaskNothing)
}
// ParseNameFromFilepath parses a file path into a Name. The input string must be a
// valid file path representation of a model name in the form:
//
// host/namespace/model/tag/build
//
// The zero valid is returned if s does not contain all path elements
// leading up to the model part, or if any path element is an invalid part
// for the its corresponding part kind.
//
// The fill string is used to fill in missing parts of any constructed Name.
// See [ParseName] for more information on the fill string.
func ParseNameFromFilepath(s, fill string) Name {
var r Name
for i := range PartBuild + 1 {
part, rest, _ := strings.Cut(s, string(filepath.Separator))
if !isValidPart(i, part) {
return Name{}
}
r.parts[i] = part
s = rest
if s == "" {
break
}
}
if s != "" {
return Name{}
}
if !r.IsValid() {
return Name{}
}
return fillName(r, fill)
}
// Filepath returns a complete, canonicalized, relative file path using the
// parts of a complete Name.
//
// Each parts is downcased, except for the build part which is upcased.
//
// Example:
//
// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD"
func (r Name) Filepath() string {
for i := range r.parts {
if PartKind(i) == PartBuild {
r.parts[i] = strings.ToUpper(r.parts[i])
} else {
r.parts[i] = strings.ToLower(r.parts[i])
}
}
return filepath.Join(r.parts[:]...)
}
// FilepathNoBuild returns a complete, canonicalized, relative file path using
// the parts of a complete Name, but without the build part.
func (r Name) FilepathNoBuild() string {
for i := range PartBuild {
r.parts[i] = strings.ToLower(r.parts[i])
}
return filepath.Join(r.parts[:PartBuild]...)
}
// isValidPart reports if s contains all valid characters for the given
// part kind.
func isValidPart(kind PartKind, s string) bool {
if s == "" {
return false
}
var consecutiveDots int
for _, c := range []byte(s) {
if c == '.' {
if consecutiveDots++; consecutiveDots >= 2 {
return false
}
} else {
consecutiveDots = 0
}
if !isValidByteFor(kind, c) {
return false
}
}
return true
}
func isValidByteFor(kind PartKind, c byte) bool {
if kind == PartNamespace && c == '.' {
return false
}
if kind == PartHost && c == ':' {
return true
}
if c == '.' || c == '-' {
return true
}
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
return true
}
return false
}

709
types/model/name_test.go Normal file
View File

@@ -0,0 +1,709 @@
package model
import (
"bytes"
"cmp"
"fmt"
"log/slog"
"path/filepath"
"slices"
"strings"
"testing"
)
type fields struct {
host, namespace, model, tag, build string
digest string
}
func fieldsFromName(p Name) fields {
return fields{
host: p.parts[PartHost],
namespace: p.parts[PartNamespace],
model: p.parts[PartModel],
tag: p.parts[PartTag],
build: p.parts[PartBuild],
digest: p.parts[PartDigest],
}
}
var testNames = map[string]fields{
"mistral:latest": {model: "mistral", tag: "latest"},
"mistral": {model: "mistral"},
"mistral:30B": {model: "mistral", tag: "30B"},
"mistral:7b": {model: "mistral", tag: "7b"},
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
"mistral+KQED": {model: "mistral", build: "KQED"},
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
"llama2": {model: "llama2"},
"user/model": {namespace: "user", model: "model"},
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
"localhost:5000/ns/mistral": {host: "localhost:5000", namespace: "ns", model: "mistral"},
// invalid digest
"mistral:latest@invalid256-": {},
"mistral:latest@-123": {},
"mistral:latest@!-123": {},
"mistral:latest@1-!": {},
"mistral:latest@": {},
// resolved
"x@sha123-1": {model: "x", digest: "sha123-1"},
"@sha456-2": {digest: "sha456-2"},
"@@sha123-1": {},
// preserves case for build
"x+b": {model: "x", build: "b"},
// invalid (includes fuzzing trophies)
" / / : + ": {},
" / : + ": {},
" : + ": {},
" + ": {},
" : ": {},
" / ": {},
" /": {},
"/ ": {},
"/": {},
":": {},
"+": {},
// (".") in namepsace is not allowed
"invalid.com/7b+x": {},
"invalid:7b+Q4_0:latest": {},
"in valid": {},
"invalid/y/z/foo": {},
"/0": {},
"0 /0": {},
"0 /": {},
"0/": {},
":/0": {},
"+0/00000": {},
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
"0//0": {},
"m+^^^": {},
"file:///etc/passwd": {},
"file:///etc/passwd:latest": {},
"file:///etc/passwd:latest+u": {},
":x": {},
"+x": {},
"x+": {},
// Disallow ("\.+") in any part to prevent path traversal anywhere
// we convert the name to a path.
"../etc/passwd": {},
".../etc/passwd": {},
"./../passwd": {},
"./0+..": {},
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
strings.Repeat("a", MaxNamePartLen+1): {},
}
// TestConsecutiveDots tests that consecutive dots are not allowed in any
// part, to avoid path traversal. There also are some tests in testNames, but
// this test is more exhaustive and exists to emphasize the importance of
// preventing path traversal.
func TestNameConsecutiveDots(t *testing.T) {
for i := 1; i < 10; i++ {
s := strings.Repeat(".", i)
if i > 1 {
if g := ParseName(s, FillNothing).DisplayLong(); g != "" {
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
}
} else {
if g := ParseName(s, FillNothing).DisplayLong(); g != s {
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
}
}
}
}
func TestNameParts(t *testing.T) {
var p Name
if w, g := int(NumParts), len(p.parts); w != g {
t.Errorf("Parts() = %d; want %d", g, w)
}
}
func TestNamePartString(t *testing.T) {
if g := PartKind(-2).String(); g != "Unknown" {
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
}
for kind, name := range kindNames {
if g := kind.String(); g != name {
t.Errorf("%s = %q; want %q", kind, g, name)
}
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
// We should get the same results with or without the
// http(s) prefixes
s := prefix + baseName
t.Run(s, func(t *testing.T) {
name := ParseName(s, FillNothing)
got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) {
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName)
}
})
}
}
}
func TestParseNameFill(t *testing.T) {
cases := []struct {
in string
fill string
want string
}{
{"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
{"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
{"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
// Invalid
{"", "example.com/library/?:latest+Q4_0", ""},
{"llama2:?", "example.com/library/?:latest+Q4_0", ""},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
name := ParseName(tt.in, tt.fill)
if g := name.DisplayLong(); g != tt.want {
t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
}
})
}
t.Run("invalid fill", func(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("expected panic")
}
}()
ParseName("x", "^")
})
}
func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
cases := []string{
"http://https://valid.com/valid/valid:latest",
"https://http://valid.com/valid/valid:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
name := ParseName(s, FillNothing)
if name.IsValid() {
t.Errorf("expected invalid path; got %#v", name)
}
})
}
}
func TestCompleteWithAndWithoutBuild(t *testing.T) {
cases := []struct {
in string
complete bool
completeNoBuild bool
}{
{"", false, false},
{"incomplete/mistral:7b+x", false, false},
{"incomplete/mistral:7b+Q4_0", false, false},
{"incomplete:7b+x", false, false},
{"complete.com/x/mistral:latest+Q4_0", true, true},
{"complete.com/x/mistral:latest", false, true},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.IsComplete(); g != tt.complete {
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
}
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
}
})
}
// Complete uses Parts which returns a slice, but it should be
// inlined when used in Complete, preventing any allocations or
// escaping to the heap.
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
})
if allocs > 0 {
t.Errorf("Complete allocs = %v; want 0", allocs)
}
}
func TestNameLogValue(t *testing.T) {
cases := []string{
"example.com/library/mistral:latest+Q4_0",
"mistral:latest",
"mistral:7b+Q4_0",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
var b bytes.Buffer
log := slog.New(slog.NewTextHandler(&b, nil))
name := ParseName(s, FillNothing)
log.Info("", "name", name)
want := fmt.Sprintf("name=%s", name.GoString())
got := b.String()
if !strings.Contains(got, want) {
t.Errorf("expected log output to contain %q; got %q", want, got)
}
})
}
}
func TestNameGoString(t *testing.T) {
cases := []struct {
name string
in string
wantString string
wantGoString string // default is tt.in
}{
{
name: "Complete Name",
in: "example.com/library/mistral:latest+Q4_0",
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
},
{
name: "Short Name",
in: "mistral:latest",
wantGoString: "?/?/mistral:latest+?@?",
},
{
name: "Long Name",
in: "library/mistral:latest",
wantGoString: "?/library/mistral:latest+?@?",
},
{
name: "Case Preserved",
in: "Library/Mistral:Latest",
wantGoString: "?/Library/Mistral:Latest+?@?",
},
{
name: "With digest",
in: "Library/Mistral:Latest@sha256-123456",
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
}
})
}
}
func TestDisplayLongest(t *testing.T) {
g := ParseName("example.com/library/mistral:latest+Q4_0", FillNothing).DisplayLongest()
if g != "example.com/library/mistral:latest" {
t.Errorf("got = %q; want %q", g, "example.com/library/mistral:latest")
}
}
func TestDisplayShortest(t *testing.T) {
cases := []struct {
in string
mask string
want string
wantPanic bool
}{
{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
// case-insensitive
{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
// zero value
{"", MaskDefault, "", true},
// invalid mask
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
// DefaultMask
{"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
// Auto-Fill
{"x", "example.com/library/_:latest", "x", false},
{"x", "example.com/library/_:latest+Q4_0", "x", false},
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
defer func() {
if tt.wantPanic {
if recover() == nil {
t.Errorf("expected panic")
}
}
}()
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.DisplayShortest(tt.mask); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
})
}
}
func TestParseNameAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
})
if allocs > 0 {
t.Errorf("ParseName allocs = %v; want 0", allocs)
}
}
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
}
}
func FuzzParseNameFromFilepath(f *testing.F) {
f.Add("example.com/library/mistral/7b/Q4_0")
f.Add("example.com/../mistral/7b/Q4_0")
f.Add("example.com/x/../7b/Q4_0")
f.Add("example.com/x/../7b")
f.Fuzz(func(t *testing.T, s string) {
name := ParseNameFromFilepath(s, FillNothing)
if strings.Contains(s, "..") && !name.IsZero() {
t.Fatalf("non-zero value for path with '..': %q", s)
}
if name.IsValid() == name.IsZero() {
t.Errorf("expected valid path to be non-zero value; got %#v", name)
}
})
}
func FuzzParseName(f *testing.F) {
f.Add("example.com/mistral:7b+Q4_0")
f.Add("example.com/mistral:7b+q4_0")
f.Add("example.com/mistral:7b+x")
f.Add("x/y/z:8n+I")
f.Add(":x")
f.Add("@sha256-123456")
f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
f.Add(":@!@")
f.Add("...")
f.Fuzz(func(t *testing.T, s string) {
r0 := ParseName(s, FillNothing)
if strings.Contains(s, "..") && !r0.IsZero() {
t.Fatalf("non-zero value for path with '..': %q", s)
}
if !r0.IsValid() && !r0.IsResolved() {
if !r0.EqualFold(Name{}) {
t.Errorf("expected invalid path to be zero value; got %#v", r0)
}
t.Skipf("invalid path: %q", s)
}
for _, p := range r0.parts {
if len(p) > MaxNamePartLen {
t.Errorf("part too long: %q", p)
}
}
if !strings.EqualFold(r0.DisplayLong(), s) {
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s)
}
r1 := ParseName(r0.DisplayLong(), FillNothing)
if !r0.EqualFold(r1) {
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
}
})
}
func TestNameStringAllocs(t *testing.T) {
name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
allocs := testing.AllocsPerRun(1000, func() {
keep(name.DisplayLong())
})
if allocs > 1 {
t.Errorf("String allocs = %v; want 0", allocs)
}
}
func TestNamePath(t *testing.T) {
cases := []struct {
in string
want string
}{
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"},
// incomplete
{"example.com/library/mistral:latest", "example.com/library/mistral:latest"},
{"", ""},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.URLPath(); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
})
}
}
func TestNameFilepath(t *testing.T) {
cases := []struct {
in string
want string
wantNoBuild string
}{
{
in: "example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
wantNoBuild: "example.com/library/mistral/latest",
},
{
in: "Example.Com/Library/Mistral:Latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
wantNoBuild: "example.com/library/mistral/latest",
},
{
in: "Example.Com/Library/Mistral:Latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
wantNoBuild: "example.com/library/mistral/latest",
},
{
in: "example.com/library/mistral:latest",
want: "example.com/library/mistral/latest",
wantNoBuild: "example.com/library/mistral/latest",
},
{
in: "",
want: "",
wantNoBuild: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
g := p.Filepath()
g = filepath.ToSlash(g)
if g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
g = p.FilepathNoBuild()
g = filepath.ToSlash(g)
if g != tt.wantNoBuild {
t.Errorf("got = %q; want %q", g, tt.wantNoBuild)
}
})
}
}
func TestParseNameFilepath(t *testing.T) {
cases := []struct {
in string
fill string // default is FillNothing
want string
}{
{
in: "example.com/library/mistral/latest/Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library/mistral/latest",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library/mistral",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library",
want: "",
},
{
in: "example.com/",
want: "",
},
{
in: "example.com/^/mistral/latest/Q4_0",
want: "",
},
{
in: "example.com/library/mistral/../Q4_0",
want: "",
},
{
in: "example.com/library/mistral/latest/Q4_0/extra",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator))
fill := cmp.Or(tt.fill, FillNothing)
want := ParseName(tt.want, fill)
if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) {
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
}
})
}
}
func TestParseNameFromPath(t *testing.T) {
cases := []struct {
in string
want string
fill string // default is FillNothing
}{
{
in: "example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library/mistral",
want: "example.com/library/mistral",
},
{
in: "/example.com/library/mistral",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library",
want: "",
},
{
in: "/example.com/",
want: "",
},
{
in: "/example.com/^/mistral/latest",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
fill := cmp.Or(tt.fill, FillNothing)
if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want {
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
}
})
}
}
func ExampleName_MapHash() {
m := map[uint64]bool{}
// key 1
m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
// key 2
m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
fmt.Println(len(m))
// Output:
// 2
}
func ExampleName_CompareFold_sort() {
names := []Name{
ParseName("mistral:latest", FillNothing),
ParseName("mistRal:7b+q4", FillNothing),
ParseName("MIstral:7b", FillNothing),
}
slices.SortFunc(names, Name.CompareFold)
for _, n := range names {
fmt.Println(n.DisplayLong())
}
// Output:
// MIstral:7b
// mistRal:7b+q4
// mistral:latest
}
func ExampleName_completeAndResolved() {
for _, s := range []string{
"x/y/z:latest+q4_0@sha123-1",
"x/y/z:latest+q4_0",
"@sha123-1",
} {
name := ParseName(s, FillNothing)
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
}
// Output:
// complete:true resolved:true digest:sha123-1
// complete:true resolved:false digest:
// complete:false resolved:true digest:sha123-1
}
func ExampleName_DisplayShortest() {
name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
fmt.Println(name.DisplayShortest("example.com/_/_:_"))
fmt.Println(name.DisplayShortest("_/_/_:_"))
// Default
name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest(""))
// Output:
// mistral
// jmorganca/mistral
// jmorganca/mistral:latest
// example.com/jmorganca/mistral:latest
// mistral
}
func keep[T any](v T) T { return v }

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("/0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("0//0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("0 /0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("+0/00000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string(":")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")

15
types/structs/structs.go Normal file
View File

@@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package structs contains the Incomparable type.
package structs
// Incomparable is a zero-width incomparable type. If added as the
// first field in a struct, it marks that struct as not comparable
// (can't do == or be a map key) and usually doesn't add any width to
// the struct (unless the struct has only small fields).
//
// By making a struct incomparable, you can prevent misuse (prevent
// people from using ==), but also you can shrink generated binaries,
// as the compiler can omit equality funcs from the binary.
type Incomparable [0]func()