Compare commits

...

125 Commits

Author SHA1 Message Date
Blake Mizerany
2bed62926e types/model: remove Digest (for now) (#3970)
The Digest type needs more thought and is not necessary at the moment.
2024-04-26 21:14:28 -07:00
Jeffrey Morgan
aad8d128a0 also look at cwd as a root for windows runners (#3959) 2024-04-26 19:14:08 -04:00
Daniel Hiltgen
ec1acbb867 Merge pull request #3968 from dhiltgen/win_generate
Fine grain control over windows generate steps
2024-04-26 16:03:38 -07:00
Daniel Hiltgen
e4859c4563 Fine grain control over windows generate steps
This will speed up CI which already tries to only build static for unit tests
2024-04-26 15:49:46 -07:00
Nataly Merezhuk
8e30eb26bd Updates the setup command to use llama3. (#3962) 2024-04-26 18:41:01 -04:00
Daniel Hiltgen
0b5c589ca2 Merge pull request #3966 from dhiltgen/bump
Fix target in gen_windows.ps1
2024-04-26 15:36:53 -07:00
Michael Yang
65fadddc85 Merge pull request #3964 from ollama/mxyng/weights
fix gemma, command-r layer weights
2024-04-26 15:23:33 -07:00
Daniel Hiltgen
ed5fb088c4 Fix target in gen_windows.ps1 2024-04-26 15:10:42 -07:00
Michael Yang
f81f308118 fix gemma, command-r layer weights 2024-04-26 15:00:55 -07:00
Blake Mizerany
b1390a7b37 types/model: export ParseNameBare and Merge (#3957)
These are useful outside this package.
2024-04-26 14:58:07 -07:00
Michael Yang
11d83386a5 Merge pull request #3951 from ollama/mxyng/zip
check file type before zip
2024-04-26 14:51:23 -07:00
Jeffrey Morgan
bb31def011 return code 499 when user cancels request while a model is loading (#3955) 2024-04-26 17:38:29 -04:00
Michael Yang
41e03ede95 check file type before zip 2024-04-26 14:18:07 -07:00
Michael Yang
7fea1ecdf6 Merge pull request #3958 from ollama/mxyng/fix-workflow
use merge base for diff-tree
2024-04-26 14:17:56 -07:00
Blake Mizerany
054894271d .github/workflows/test.yaml: add in-flight cancellations on new push (#3956)
Also, remove a superfluous 'go get'
2024-04-26 13:54:24 -07:00
Michael Yang
6fef042f0b use merge base for diff-tree 2024-04-26 13:54:15 -07:00
Daniel Hiltgen
5c0c2d1d09 Merge pull request #3954 from dhiltgen/ci_fixes
Put back non-avx CPU build for windows
2024-04-26 13:09:03 -07:00
Blake Mizerany
37f9c8ad99 types/model: overhaul Name and Digest types (#3924) 2024-04-26 13:08:32 -07:00
Quinten van Buul
2a80f55e2a Update windows.md (#3855)
Fixed a typo
2024-04-26 16:04:15 -04:00
Daniel Hiltgen
421c878a2d Put back non-avx CPU build for windows 2024-04-26 12:44:07 -07:00
Daniel Hiltgen
36666c2142 Merge pull request #3925 from dhiltgen/bump
Bump llama.cpp to b2737
2024-04-26 10:09:38 -07:00
Daniel Hiltgen
85801317d1 Fix clip log import 2024-04-26 09:43:46 -07:00
Daniel Hiltgen
2ed0d65948 Bump llama.cpp to b2737 2024-04-26 09:43:28 -07:00
Daniel Hiltgen
d459dc4ad1 Merge pull request #3950 from dhiltgen/windows_packaging
Fix exe name for zip packaging on windows
2024-04-26 09:27:37 -07:00
Daniel Hiltgen
40bc4622ef Fix exe name for zip packaging on windows
The zip file encodes the OS and architecture, so keep the short exe name
2024-04-26 09:18:05 -07:00
Daniel Hiltgen
c0f818a07a Merge pull request #3948 from dhiltgen/win_generate
Refactor windows generate for more modular usage
2024-04-26 09:17:20 -07:00
Daniel Hiltgen
8671fdeda6 Refactor windows generate for more modular usage 2024-04-26 08:35:50 -07:00
Daniel Hiltgen
2619850fb4 Merge pull request #3933 from dhiltgen/ci_fixes
Move cuda/rocm dependency gathering into generate script
2024-04-26 07:01:24 -07:00
Daniel Hiltgen
8feb97dc0d Move cuda/rocm dependency gathering into generate script
This will make it simpler for CI to accumulate artifacts from prior steps
2024-04-25 22:38:44 -07:00
Daniel Hiltgen
4e1ff6dcbb Merge pull request #3926 from dhiltgen/ci_fixes
Fix release CI
2024-04-25 17:42:31 -07:00
Daniel Hiltgen
8589d752ac Fix release CI
download-artifact path was being used incorrectly.  It is where to
extract the zip not the files in the zip to extract.  Default is
workspace dir which is what we want, so omit it
2024-04-25 17:27:11 -07:00
Michael Yang
de4ded68b0 Merge pull request #3923 from ollama/mxyng/mem
only count output tensors
2024-04-25 16:34:17 -07:00
Daniel Hiltgen
9b5a3c5991 Merge pull request #3914 from dhiltgen/mac_perf
Improve mac parallel performance
2024-04-25 16:28:31 -07:00
Jeffrey Morgan
00b0699c75 Reload model if num_gpu changes (#3920)
* reload model if `num_gpu` changes

* dont reload on -1

* fix tests
2024-04-25 19:02:40 -04:00
Jeffrey Morgan
993cf8bf55 llm: limit generation to 10x context size to avoid run on generations (#3918)
* llm: limit generation to 10x context size to avoid run on generations

* add comment

* simplify condition statement
2024-04-25 19:02:30 -04:00
Michael Yang
7bb7cb8a60 only count output tensors 2024-04-25 15:24:08 -07:00
Daniel Hiltgen
b123be5b71 Adjust context size for parallelism 2024-04-25 13:58:54 -07:00
jmorganca
ddf5c09a9b use matrix multiplcation kernels in more cases 2024-04-25 13:58:54 -07:00
Roy Yang
5f73c08729 Remove trailing spaces (#3889) 2024-04-25 14:32:26 -04:00
Daniel Hiltgen
f503a848c2 Merge pull request #3895 from brycereitano/shiftloading
Move ggml loading to when attempting to fit
2024-04-25 09:24:08 -07:00
Bryce Reitano
36a6daccab Restructure loading conditional chain 2024-04-24 17:37:03 -06:00
Bryce Reitano
ceb0e26e5e Provide variable ggml for TestLoad 2024-04-24 17:19:55 -06:00
Bryce Reitano
284e02bed0 Move ggml loading to when we attempt fitting 2024-04-24 17:17:24 -06:00
Michael Yang
3450a57d4a Merge pull request #3713 from ollama/mxyng/modelname
update copy handler to use model.Name
2024-04-24 16:00:32 -07:00
Michael Yang
592dae31c8 update copy to use model.Name 2024-04-24 15:54:54 -07:00
Michael Yang
2010cbc5fa Merge pull request #3833 from ollama/mxyng/fix-from
fix: from blob
2024-04-24 15:13:47 -07:00
Michael Yang
ac0801eced only replace if it matches command 2024-04-24 14:49:26 -07:00
Michael Yang
ad66e5b060 split temp zip files 2024-04-24 14:18:01 -07:00
Blake Mizerany
ade4b55520 types/model: make ParseName use default without question (#3886) 2024-04-24 11:52:55 -07:00
Daniel Hiltgen
a6d62e0617 Merge pull request #3882 from dhiltgen/amd_gfx
AMD gfx patch rev is hex
2024-04-24 11:07:49 -07:00
Daniel Hiltgen
6e76348df7 Merge pull request #3834 from dhiltgen/not_found_in_path
Report errors on server lookup instead of path lookup failure
2024-04-24 10:50:48 -07:00
Daniel Hiltgen
0d6687f84c AMD gfx patch rev is hex
Correctly handle gfx90a discovery
2024-04-24 09:43:52 -07:00
Patrick Devine
74d2a9ef9a add OLLAMA_KEEP_ALIVE env variable to FAQ (#3865) 2024-04-23 21:06:51 -07:00
Patrick Devine
14476d48cc fixes for gguf (#3863) 2024-04-23 20:57:20 -07:00
Patrick Devine
ce8ce82567 add mixtral 8x7b model conversion (#3859) 2024-04-23 20:17:04 -07:00
Blake Mizerany
4dc4f1be34 types/model: restrict digest hash part to a minimum of 2 characters (#3858)
This allows users of a valid Digest to know it has a minimum of 2
characters in the hash part for use when sharding.

This is a reasonable restriction as the hash part is a SHA256 hash which
is 64 characters long, which is the common hash used. There is no
anticipation of using a hash with less than 2 characters.

Also, add MustParseDigest.

Also, replace Digest.Type with Digest.Split for getting both the type
and hash parts together, which is most the common case when asking for
either.
2024-04-23 18:24:17 -07:00
Daniel Hiltgen
16b52331a4 Merge pull request #3857 from dhiltgen/mem_escape_valve
Add back memory escape valve
2024-04-23 17:32:24 -07:00
Daniel Hiltgen
5445aaa94e Add back memory escape valve
If we get our predictions wrong, this can be used to
set a lower memory limit as a workaround.  Recent multi-gpu
refactoring accidentally removed it, so this adds it back.
2024-04-23 17:09:02 -07:00
Daniel Hiltgen
2ac3dd6853 Merge pull request #3850 from dhiltgen/windows_packaging
Move nested payloads to installer and zip file on windows
2024-04-23 16:35:20 -07:00
Daniel Hiltgen
d8851cb7a0 Harden sched TestLoad
Give the go routine a moment to deliver the expired event
2024-04-23 16:14:47 -07:00
Daniel Hiltgen
058f6cd2cc Move nested payloads to installer and zip file on windows
Now that the llm runner is an executable and not just a dll, more users are facing
problems with security policy configurations on windows that prevent users
writing to directories and then executing binaries from the same location.
This change removes payloads from the main executable on windows and shifts them
over to be packaged in the installer and discovered based on the executables location.
This also adds a new zip file for people who want to "roll their own" installation model.
2024-04-23 16:14:47 -07:00
Daniel Hiltgen
790cf34d17 Merge pull request #3846 from dhiltgen/missing_runner
Detect and recover if runner removed
2024-04-23 13:14:12 -07:00
Michael
928d844896 adding phi-3 mini to readme
adding phi-3 mini to readme
2024-04-23 13:58:31 -04:00
Daniel Hiltgen
939d6a8606 Make CI lint verbvose 2024-04-23 10:17:42 -07:00
Daniel Hiltgen
58888a74bc Detect and recover if runner removed
Tmp cleaners can nuke the file out from underneath us.  This detects the missing
runner, and re-initializes the payloads.
2024-04-23 10:05:26 -07:00
Daniel Hiltgen
cc5a71e0e3 Merge pull request #3709 from remy415/custom-gpu-defs
Adds support for customizing GPU build flags in llama.cpp
2024-04-23 09:28:34 -07:00
Michael Yang
e83bcf7f9a Merge pull request #3836 from ollama/mxyng/mixtral
fix: mixtral graph
2024-04-23 09:15:10 -07:00
Daniel Hiltgen
5690e5ce99 Merge pull request #3418 from dhiltgen/concurrency
Request and model concurrency
2024-04-23 08:31:38 -07:00
Daniel Hiltgen
f2ea8470e5 Local unicode test case 2024-04-22 19:29:12 -07:00
Daniel Hiltgen
34b9db5afc Request and model concurrency
This change adds support for multiple concurrent requests, as well as
loading multiple models by spawning multiple runners. The default
settings are currently set at 1 concurrent request per model and only 1
loaded model at a time, but these can be adjusted by setting
OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
2024-04-22 19:29:12 -07:00
Daniel Hiltgen
8711d03df7 Report errors on server lookup instead of path lookup failure 2024-04-22 19:08:47 -07:00
Daniel Hiltgen
ee448deaba Merge pull request #3835 from dhiltgen/harden_llm_override
Trim spaces and quotes from llm lib override
2024-04-22 19:06:54 -07:00
Bruce MacDonald
6e8db04716 tidy community integrations
- move some popular integrations to the top of the lists
2024-04-22 17:29:08 -07:00
Bruce MacDonald
658e60cf73 Revert "stop running model on interactive exit"
This reverts commit fad00a85e5.
2024-04-22 17:23:11 -07:00
Bruce MacDonald
4c78f028f8 Merge branch 'main' of https://github.com/ollama/ollama 2024-04-22 17:22:28 -07:00
Michael Yang
435cc866a3 fix: mixtral graph 2024-04-22 17:19:44 -07:00
Hao Wu
c7d3a558f6 docs: update README to add chat (web UI) for LLM (#3810)
* add chat (web UI) for LLM

I have used chat with llama3 in local successfully and the code is MIT licensed.

* Update README.md

---------

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-04-22 20:19:39 -04:00
Maple Gao
089cdb2877 docs: Update README for Lobe-chat integration. (#3817)
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-04-22 20:18:15 -04:00
Võ Đình Đạt
ea1e9aa36b Update README.md (#3655) 2024-04-22 20:16:55 -04:00
Jonathan Smoley
d0d28ef90d Update README.md with Discord-Ollama project (#3633)
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-04-22 20:14:20 -04:00
Eric Curtin
6654186a7c Add podman-ollama to terminal apps (#3626)
The goal of podman-ollama is to make AI even more boring.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2024-04-22 20:13:23 -04:00
Daniel Hiltgen
aa72281eae Trim spaces and quotes from llm lib override 2024-04-22 17:11:14 -07:00
reid41
74bcbf828f add qa-pilot link (#3612)
* add qa-pilot link

* format the link

* add shell-pilot
2024-04-22 20:10:34 -04:00
Christian Neff
fe39147e64 Add Chatbot UI v2 to Community Integrations (#3503) 2024-04-22 20:09:55 -04:00
Bruce MacDonald
fad00a85e5 stop running model on interactive exit 2024-04-22 16:22:14 -07:00
Jeremy
9c0db4cc83 Update gen_windows.ps1
Fixed improper env references
2024-04-21 16:13:41 -04:00
Cheng
62be2050dd chore: use errors.New to replace fmt.Errorf will much better (#3789) 2024-04-20 22:11:06 -04:00
Blake Mizerany
56f8aa6912 types/model: export IsValidNamePart (#3788) 2024-04-20 18:26:34 -07:00
Sri Siddhaarth
e6f9bfc0e8 Update api.md (#3705) 2024-04-20 15:17:03 -04:00
Jeremy
6f18297b3a Update gen_windows.ps1
Forgot a " on the write-host
2024-04-18 19:47:44 -04:00
Jeremy
15016413de Update gen_windows.ps1
Added OLLAMA_CUSTOM_CUDA_DEFS and OLLAMA_CUSTOM_ROCM_DEFS to customize GPU builds on Windows
2024-04-18 19:27:16 -04:00
Jeremy
440b7190ed Update gen_linux.sh
Added OLLAMA_CUSTOM_CUDA_DEFS and OLLAMA_CUSTOM_ROCM_DEFS instead of OLLAMA_CUSTOM_GPU_DEFS
2024-04-18 19:18:10 -04:00
Daniel Hiltgen
8d1995c625 Merge pull request #3708 from remy415/arm64static
move Ollama static build to its own flag
2024-04-18 16:04:12 -07:00
Daniel Hiltgen
fd01fbf038 Merge pull request #3710 from remy415/update-jetson-docs
update jetson tutorial
2024-04-18 16:02:08 -07:00
Blake Mizerany
0408205c1c types/model: accept former : as a separator in digest (#3724)
This also converges the old sep `:` to the new sep `-`.
2024-04-18 14:17:46 -07:00
Jeffrey Morgan
63a7edd771 Update README.md 2024-04-18 16:09:38 -04:00
Michael
554ffdcce3 add llama3 to readme
add llama3 to readme
2024-04-18 15:18:48 -04:00
Jeremy
9850a4ce08 Merge branch 'ollama:main' into update-jetson-docs 2024-04-18 09:55:17 -04:00
Jeremy
3934c15895 Merge branch 'ollama:main' into custom-gpu-defs 2024-04-18 09:55:10 -04:00
Jeremy
fd048f1367 Merge branch 'ollama:main' into arm64static 2024-04-18 09:55:04 -04:00
Michael Yang
8645076a71 Merge pull request #3712 from ollama/mxyng/mem
add stablelm graph calculation
2024-04-17 15:57:51 -07:00
Michael Yang
05e9424824 Merge pull request #3664 from ollama/mxyng/fix-padding-2
fix padding to only return padding
2024-04-17 15:57:40 -07:00
Michael Yang
52ebe67a98 Merge pull request #3714 from ollama/mxyng/model-name-host
types/model: support : in PartHost for host:port
2024-04-17 15:34:03 -07:00
Michael Yang
889b31ab78 types/model: support : in PartHost for host:port 2024-04-17 15:16:07 -07:00
Michael Yang
3cf483fe48 add stablelm graph calculation 2024-04-17 13:57:19 -07:00
Jeremy
8dca03173d Merge remote-tracking branch 'upstream/main' into update-jetson-docs 2024-04-17 16:18:50 -04:00
Jeremy
85bdf14b56 update jetson tutorial 2024-04-17 16:17:42 -04:00
Jeremy
d524e5ef5e Merge branch 'custom-gpu-defs' of https://github.com/remy415/ollama into custom-gpu-defs 2024-04-17 16:01:03 -04:00
Jeremy
52f5370c48 add support for custom gpu build flags for llama.cpp 2024-04-17 16:00:48 -04:00
Jeremy
da8a0c7657 Merge branch 'ollama:main' into arm64static 2024-04-17 15:22:34 -04:00
Jeremy
1b42b4b59a Merge branch 'ollama:main' into custom-gpu-defs 2024-04-17 15:21:56 -04:00
Jeremy
7c000ec3ed adds support for OLLAMA_CUSTOM_GPU_DEFS to customize GPU build flags 2024-04-17 15:21:05 -04:00
jmorganca
c8afe7168c use correct extension for feature and model request issue templates 2024-04-17 15:18:40 -04:00
jmorganca
28d3cd0148 simpler feature and model request forms 2024-04-17 15:17:08 -04:00
jmorganca
eb5554232a simpler feature and model request forms 2024-04-17 15:14:49 -04:00
Jeremy
ea4c284a48 Merge branch 'ollama:main' into arm64static 2024-04-17 15:11:38 -04:00
jmorganca
2bdc320216 add descriptions to issue templates 2024-04-17 15:08:36 -04:00
jmorganca
32561aed09 simplify github issue templates a bit 2024-04-17 15:07:03 -04:00
Michael Yang
71548d9829 Merge pull request #3706 from ollama/mxyng/mem
account for all non-repeating layers
2024-04-17 11:58:20 -07:00
Jeremy
8aec92fa6d rearranged conditional logic for static build, dockerfile updated 2024-04-17 14:43:28 -04:00
Michael Yang
a8b9b930b4 account for all non-repeating layers 2024-04-17 11:21:21 -07:00
Michael
9755cf9173 acknowledge the amazing work done by Georgi and team! 2024-04-17 13:48:14 -04:00
Jeremy
70261b9bb6 move static build to its own flag 2024-04-17 13:04:28 -04:00
Blake Mizerany
9df6c85c3a types/model: add FilepathNoBuild (#3680)
Also, add test for DisplayLongest.

Also, plumb fill param to ParseName in MustParseName
2024-04-16 18:35:43 -07:00
Michael Yang
e74163af4c fix padding to only return padding 2024-04-16 15:43:26 -07:00
71 changed files with 4034 additions and 3373 deletions

View File

@@ -0,0 +1,60 @@
name: Bug report
labels: [bug]
description: Something isn't working right.
body:
- type: textarea
id: description
attributes:
label: What is the issue?
description: What happened? What did you expect to happen?
validations:
required: true
- type: dropdown
id: os
attributes:
label: OS
description: Which operating system are you using?
multiple: true
options:
- Linux
- macOS
- Windows
- Docker
- WSL2
validations:
required: false
- type: dropdown
id: gpu
attributes:
label: GPU
description: Which GPU are you using?
multiple: true
options:
- Nvidia
- AMD
- Intel
- Apple
- Other
validations:
required: false
- type: dropdown
id: cpu
attributes:
label: CPU
description: Which CPU are you using?
multiple: true
options:
- Intel
- AMD
- Apple
- Other
validations:
required: false
- type: input
id: version
attributes:
label: Ollama version
description: What version of Ollama are you using? (`ollama --version`)
placeholder: e.g., 0.1.32
validations:
required: false

View File

@@ -1,18 +0,0 @@
name: Model request
description: Request a new model for the library
labels: [mr]
body:
- type: markdown
attributes:
value: |
Please check if your Model request is [already available](https://ollama.com/search) or that you cannot [import it](https://github.com/ollama/ollama/blob/main/docs/import.md#import-a-model) yourself.
Tell us about which Model you'd like to see in the library!
- type: textarea
id: problem
attributes:
label: What model would you like?
description: Please provide a link to the model.
- type: markdown
attributes:
value: |
Thanks for filing a model request!

View File

@@ -0,0 +1,6 @@
---
name: Feature request
about: Request a new feature
labels: feature request
---

View File

@@ -1,41 +0,0 @@
name: Feature request
description: Propose a new feature
labels: [needs-triage, fr]
body:
- type: markdown
attributes:
value: |
Please check if your feature request is [already filed](https://github.com/ollama/ollama/issues).
Tell us about your idea!
- type: textarea
id: problem
attributes:
label: What are you trying to do?
description: Tell us about the problem you're trying to solve.
validations:
required: false
- type: textarea
id: solution
attributes:
label: How should we solve this?
description: If you have an idea of how you'd like to see this feature work, let us know.
validations:
required: false
- type: textarea
id: alternative
attributes:
label: What is the impact of not solving this?
description: (How) Are you currently working around the issue?
validations:
required: false
- type: textarea
id: context
attributes:
label: Anything else?
description: Any additional context to share, e.g., links
validations:
required: false
- type: markdown
attributes:
value: |
Thanks for filing a feature request!

View File

@@ -0,0 +1,5 @@
---
name: Model request
about: Request support for a new model to be added to Ollama
labels: model request
---

View File

@@ -1,125 +0,0 @@
name: Bug report
description: File a bug report. If you need help, please join our Discord server.
labels: [needs-triage, bug]
body:
- type: markdown
attributes:
value: |
Please check if your bug is [already filed](https://github.com/ollama/ollama/issues) before filing a new one.
- type: textarea
id: what-happened
attributes:
label: What is the issue?
description: What happened? What did you expect to happen?
validations:
required: true
- type: textarea
id: what-was-expected
attributes:
label: What did you expect to see?
description: What did you expect to see/happen instead?
validations:
required: false
- type: textarea
id: steps
attributes:
label: Steps to reproduce
description: What are the steps you took that hit this issue?
validations:
required: false
- type: textarea
id: changes
attributes:
label: Are there any recent changes that introduced the issue?
description: If so, what are those changes?
validations:
required: false
- type: dropdown
id: os
attributes:
label: OS
description: What OS are you using? You may select more than one.
multiple: true
options:
- Linux
- macOS
- Windows
- Other
validations:
required: false
- type: dropdown
id: architecture
attributes:
label: Architecture
description: What architecture are you using? You may select more than one.
multiple: true
options:
- arm64
- amd64
- x86
- Other
- type: dropdown
id: platform
attributes:
label: Platform
description: What platform are you using? You may select more than one.
multiple: true
options:
- Docker
- WSL
- WSL2
validations:
required: false
- type: input
id: ollama-version
attributes:
label: Ollama version
description: What Ollama version are you using? (`ollama --version`)
placeholder: e.g., 1.14.4
validations:
required: false
- type: dropdown
id: gpu
attributes:
label: GPU
description: What GPU, if any, are you using? You may select more than one.
multiple: true
options:
- Nvidia
- AMD
- Intel
- Apple
- Other
validations:
required: false
- type: textarea
id: gpu-info
attributes:
label: GPU info
description: What GPU info do you have? (`nvidia-smi`, `rocminfo`, `system_profiler SPDisplaysDataType`, etc.)
validations:
required: false
- type: dropdown
id: cpu
attributes:
label: CPU
description: What CPU are you using? You may select more than one.
multiple: true
options:
- Intel
- AMD
- Apple
- Other
validations:
required: false
- type: textarea
id: other-software
attributes:
label: Other software
description: What other software are you using that might be related to this issue?
validations:
required: false
- type: markdown
attributes:
value: |
Thanks for filing a bug report!

View File

@@ -103,6 +103,7 @@ jobs:
path: |
llm/build/**/bin/*
llm/build/**/*.a
dist/windows-amd64/**
# ROCm generation step
generate-windows-rocm:
@@ -173,7 +174,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-rocm
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4
with:
name: windows-rocm-deps
@@ -253,7 +256,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-cuda
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4
with:
name: windows-cuda-deps
@@ -306,23 +311,18 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: generate-windows-cpu
path: llm/build
- uses: actions/download-artifact@v4
with:
name: generate-windows-cuda
path: llm/build
- uses: actions/download-artifact@v4
with:
name: windows-cuda-deps
path: dist/deps
- uses: actions/download-artifact@v4
with:
name: windows-rocm-deps
path: dist/deps
- uses: actions/download-artifact@v4
with:
name: generate-windows-rocm
path: llm/build
- run: dir llm/build
- run: |
$gopath=(get-command go).source | split-path -parent
@@ -331,13 +331,13 @@ jobs:
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$env:PATH"
$env:OLLAMA_SKIP_GENERATE="1"
$env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
$env:HIP_PATH=$(resolve-path ".\dist\deps")
& .\scripts\build_windows.ps1
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: dist/*.exe
path: |
dist/OllamaSetup.exe
dist/ollama-windows-*.zip
# Linux x86 assets built using the container based build
build-linux-amd64:

View File

@@ -1,5 +1,15 @@
name: test
concurrency:
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
# cancels running CI jobs and starts all new ones.
#
# For non-PR pushes, concurrency.group needs to be unique for every distinct
# CI run we want to have happen. Use run_id, which in practice means all
# non-PR CI runs will be allowed to run without preempting each other.
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
cancel-in-progress: true
on:
pull_request:
paths:
@@ -21,7 +31,9 @@ jobs:
- id: changes
run: |
changed() {
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
}
@@ -103,7 +115,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
generate-rocm:
needs: [changes]
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
@@ -134,7 +148,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
# ROCm generation step
generate-windows-rocm:
@@ -253,14 +269,9 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- uses: golangci/golangci-lint-action@v4
with:
args: --timeout 8m0s
args: --timeout 8m0s -v
test:
strategy:
matrix:
@@ -284,7 +295,6 @@ jobs:
with:
go-version-file: go.mod
cache: true
- run: go get
- run: |
case ${{ matrix.arch }} in
amd64) echo ARCH=x86_64 ;;
@@ -299,10 +309,6 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- run: go generate ./...
- run: go build

View File

@@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
@@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
ARG CMAKE_VERSION
@@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN mkdir /tmp/scratch && \
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
cp ${dep} /tmp/scratch/ || exit 1 ; \
@@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
ARG CMAKE_VERSION
@@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
# Intermediate stage used for ./scripts/build_linux.sh

View File

@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart
To run and chat with [Llama 2](https://ollama.com/library/llama2):
To run and chat with [Llama 3](https://ollama.com/library/llama3):
```
ollama run llama2
ollama run llama3
```
## Model library
@@ -49,18 +49,14 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | ------------------------------ |
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
| Vicuna | 7B | 3.8GB | `ollama run vicuna` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
@@ -98,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
### Customize a prompt
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
```
ollama pull llama2
ollama pull llama3
```
Create a `Modelfile`:
```
FROM llama2
FROM llama3
# set the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1
@@ -142,7 +138,7 @@ ollama create mymodel -f ./Modelfile
### Pull a model
```
ollama pull llama2
ollama pull llama3
```
> This command can also be used to update a local model. Only the diff will be pulled.
@@ -150,13 +146,13 @@ ollama pull llama2
### Remove a model
```
ollama rm llama2
ollama rm llama3
```
### Copy a model
```
ollama cp llama2 my-llama2
ollama cp llama3 my-model
```
### Multiline input
@@ -180,7 +176,7 @@ The image features a yellow smiley face, which is likely the central focus of th
### Pass in prompt as arguments
```
$ ollama run llama2 "Summarize this file: $(cat README.md)"
$ ollama run llama3 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
```
@@ -227,7 +223,7 @@ Next, start the server:
Finally, in a separate shell, run a model:
```
./ollama run llama2
./ollama run llama3
```
## REST API
@@ -238,7 +234,7 @@ Ollama has a REST API for running and managing models.
```
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt":"Why is the sky blue?"
}'
```
@@ -247,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"model": "llama3",
"messages": [
{ "role": "user", "content": "why is the sky blue?" }
]
@@ -260,16 +256,17 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
@@ -291,9 +288,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [QA-Pilot: Chat with Code Repository](https://github.com/reid41/QA-Pilot)
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
- [chat: chat web app for teams](https://github.com/swuecho/chat)
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
- [Ollama RAG Chatbot: Local Chat with multiples PDFs using Ollama and RAG.](https://github.com/datvodinh/rag-chatbot.git)
### Terminal
@@ -309,11 +310,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
- [cmdh](https://github.com/pgibler/cmdh)
- [ooo](https://github.com/npahlfer/ooo)
- [shell-pilot](https://github.com/reid41/shell-pilot)
- [tenere](https://github.com/pythops/tenere)
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
- [ShellOracle](https://github.com/djcopley/ShellOracle)
- [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
### Database
@@ -378,3 +381,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View File

@@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
}, nil
}
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
http: http,
}
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte

View File

@@ -2,6 +2,7 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"math"
"os"
@@ -307,7 +308,7 @@ func (m *Metrics) Summary() {
}
}
var ErrInvalidOpts = fmt.Errorf("invalid options")
var ErrInvalidOpts = errors.New("invalid options")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
@@ -395,8 +396,10 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
func DefaultOptions() Options {
return Options{
// options set on request to runner
NumPredict: -1,
NumKeep: 0,
NumPredict: -1,
// set a minimal num_keep to avoid issues on context shifts
NumKeep: 4,
Temperature: 0.8,
TopK: 40,
TopP: 0.9,

View File

@@ -88,15 +88,12 @@ DialogFontSize=12
[Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
; Assumes v5.7, may need adjustments for v6
#if GetEnv("HIP_PATH") != ""
Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
; amdhip64.dll dependency comes from the driver and must be installed already
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
#if DirExists("..\dist\windows-amd64\rocm")
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
#endif
@@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
;FinishedHeadingLabel=Run your first model
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
;ClickFinish=%n
[Registry]

View File

@@ -17,6 +17,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strings"
"syscall"
@@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
modelfile, err := os.ReadFile(filename)
if err != nil {
return err
@@ -95,95 +94,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
// TODO make this work w/ adapters
if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf")
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil {
return err
}
defer os.RemoveAll(tf.Name())
defer os.RemoveAll(tempfile)
zf := zip.NewWriter(tf)
files := []string{}
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
}
files = append(files, tfiles...)
if len(files) == 0 {
return fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return err
}
} else {
continue
}
} else if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zf.CreateHeader(h)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
return err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
path = tempfile
}
digest, err := createBlob(cmd, client, path)
@@ -191,10 +111,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
name := c.Name
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
}
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
@@ -228,6 +155,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
}
}
return matches, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
for _, file := range files {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
}
}
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {

View File

@@ -18,19 +18,23 @@ import (
)
type Params struct {
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"` // n_embd
HiddenLayers int `json:"num_hidden_layers"` // n_layer
ContextSize int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"` // n_embd
HiddenLayers int `json:"num_hidden_layers"` // n_layer
ContextSize int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`
Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"`
ByteOrder
}

96
convert/mixtral.go Normal file
View File

@@ -0,0 +1,96 @@
package convert
import (
"os"
"regexp"
"github.com/ollama/ollama/llm"
)
type MixtralModel struct {
ModelData
}
func (m *MixtralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *MixtralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *MixtralModel) WriteGGUF() (string, error) {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.expert_count": uint32(m.Params.Experts),
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
}

View File

@@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
}
slices.Sort(keys)
slog.Info("converting layers")
var tensors []llm.Tensor
@@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
return nil, 0, err
}
slog.Debug(fmt.Sprintf("metadata = %#v", data))
var size uint64
var kind uint32
switch len(data.Shape) {
@@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
padding: 8 + jsonSize,
}
tensors = append(tensors, t)
offset += size
tensors = append(tensors, t)
}
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil
}
@@ -185,15 +185,19 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
}
tMap := map[string]string{
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
}
v, ok := directMap[n]
@@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
Format: m,
},
}, nil
case "MixtralForCausalLM":
return &MixtralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "GemmaForCausalLM":
return &GemmaModel{
ModelData{

View File

@@ -90,7 +90,7 @@ The final response in the stream also includes additional data about the generat
- `load_duration`: time spent in nanoseconds loading the model
- `prompt_eval_count`: number of tokens in the prompt
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
- `eval_count`: number of tokens the response
- `eval_count`: number of tokens in the response
- `eval_duration`: time in nanoseconds spent generating the response
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
- `response`: empty if the response was streamed, if not streamed, this will contain the full response

View File

@@ -228,3 +228,7 @@ To unload the model and free up memory use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
```
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.

View File

@@ -1,38 +1,15 @@
# Running Ollama on NVIDIA Jetson Devices
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
mode. This can be verified by using a monitoring tool like jtop.
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
version of our target model.
Prerequisites:
- curl
- tmux
Here are the steps:
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
- Stop the Ollama service: `sudo systemctl stop ollama`
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
```
FROM mistral
PARAMETER num_gpu 999
```
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
- Run the new model: `ollama run mistral-jetson`
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
- Start an interactive session: `ollama run mistral`
And that's it!
# Running Ollama in Docker
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).

View File

@@ -14,7 +14,7 @@ As this is a preview release, you should expect a few bugs here and there. If
you run into a problem you can reach out on
[Discord](https://discord.gg/ollama), or file an
[issue](https://github.com/ollama/ollama/issues).
Logs will often be helpful in dianosing the problem (see
Logs will often be helpful in diagnosing the problem (see
[Troubleshooting](#troubleshooting) below)
## System Requirements

View File

@@ -15,6 +15,7 @@ const (
KibiByte = Byte * 1024
MebiByte = KibiByte * 1024
GibiByte = MebiByte * 1024
)
func HumanBytes(b int64) string {

View File

@@ -7,7 +7,7 @@ import (
"log/slog"
"os"
"path/filepath"
"strconv"
"runtime"
"strings"
)
@@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
return ret, nil
}
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
// Set the visible devices if not already set
// TODO - does sort order matter?
devices := []string{}
for i := range ids {
if _, skipped := skip[i]; skipped {
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue
}
devices = append(devices, strconv.Itoa(i))
ids = append(ids, info.ID)
}
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
val := strings.Join(devices, ",")
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
if err != nil {
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
} else {
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
// Scan the LD_LIBRARY_PATH or PATH
pathEnv := "LD_LIBRARY_PATH"
if runtime.GOOS == "windows" {
pathEnv = "PATH"
}
paths := os.Getenv(pathEnv)
for _, path := range filepath.SplitList(paths) {
d, err := filepath.Abs(path)
if err != nil {
continue
}
if rocmLibUsable(d) {
return d, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll)
if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
slog.Warn("failed to unload amdhip64.dll", "error", err)
}
hl.dll = 0
}
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
return 0
}
if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
}
return count
}

View File

@@ -11,6 +11,8 @@ import (
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
// Discovery logic for AMD/ROCm GPUs
@@ -24,9 +26,6 @@ const (
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
// TODO find a better way to detect iGPU instead of minimum memory
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
)
var (
@@ -35,14 +34,11 @@ var (
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
if !AMDDetected() {
return
return resp
}
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion()
@@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
slog.Info("AMD Driver: " + ver)
} else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
}
// If the user has specified exactly which GPUs to use, look up their memory
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
if visibleDevices != "" {
ids := []int{}
for _, idStr := range strings.Split(visibleDevices, ",") {
id, err := strconv.Atoi(idStr)
if err != nil {
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
} else {
ids = append(ids, id)
}
}
amdProcMemLookup(resp, nil, ids)
return
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
switch {
// TODO is this priorty order right?
case hipVD != "":
visibleDevices = strings.Split(hipVD, ",")
case rocrVD != "":
visibleDevices = strings.Split(rocrVD, ",")
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
// all our test systems show GPU-XX indicating UUID is not supported
case gpuDO != "":
visibleDevices = strings.Split(gpuDO, ",")
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
updateLibPath(libDir)
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" {
supported, err := GetSupportedGFX(libDir)
var supported []string
libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
cpuCount := 0
for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match)
fp, err := os.Open(match)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
func updateLibPath(libDir string) {
ldPaths := []string{}
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
ldPaths = strings.Split(val, ":")
}
for _, d := range ldPaths {
if d == libDir {
return
}
}
val := strings.Join(append(ldPaths, libDir), ":")
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
os.Setenv("LD_LIBRARY_PATH", val)
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
slog.Debug("discovering VRAM for amdgpu devices")
if len(ids) == 0 {
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue
}
id, err := strconv.Atoi(node.Name())
if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
continue
}
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
for _, id := range ids {
if _, skipped := skip[id]; skipped {
slog.Debug("failed to open sysfs node", "file", match, "error", err)
continue
}
defer fp.Close()
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug("failed to parse node ID", "error", err)
continue
}
scanner := bufio.NewScanner(fp)
isCPU := false
var major, minor, patch uint64
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
// Detect CPUs
if len(ver) == 2 && ver[1] == "0" {
slog.Debug("detected CPU " + match)
isCPU = true
break
}
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Warn("malformed "+match, "gfx_target_version", line)
// If this winds up being a CPU, our offsets may be wrong
continue
}
l := len(ver[1])
var err1, err2, err3 error
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
}
if isCPU {
cpuCount++
continue
}
// CPUs are always first in the list
gpuID := nodeID - cpuCount
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
return []GpuInfo{}
}
if int(major) < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
continue
}
// Look up the memory for the current node
totalMemory := uint64(0)
usedMemory := uint64(0)
// Adjust for sysfs vs HIP ids
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
}
// 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles {
fp, err := os.Open(propFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
continue
}
defer fp.Close()
@@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
}
}
if totalMemory == 0 {
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
skip[id] = struct{}{}
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
continue
}
if totalMemory < IGPUMemLimit {
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
skip[id] = struct{}{}
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
continue
}
for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
continue
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
continue
}
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
slog.Warn("malformed used memory", "data", string(data), "error", err)
continue
}
usedMemory += used
}
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
continue
}
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: (totalMemory - usedMemory),
},
ID: fmt.Sprintf("%d", gpuID),
// Name: not exposed in sysfs directly, would require pci device id lookup
Major: int(major),
Minor: int(minor),
Patch: int(patch),
MinimumMemory: rocmMinimumMemory,
}
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len(visibleDevices) > 0 {
include := false
for _, visible := range visibleDevices {
if visible == gpuInfo.ID {
include = true
break
}
}
if !include {
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir, err = AMDValidateLibDir()
if err != nil {
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return []GpuInfo{}
}
}
gpuInfo.DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len(supported) == 0 {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return []GpuInfo{}
}
slog.Debug("rocm supported GPUs", "types", supported)
}
gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
continue
} else {
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
// The GPU has passed all the verification steps and is supported
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
if len(resp) == 0 {
slog.Info("no compatible amdgpu devices detected")
}
return resp
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
@@ -280,87 +297,24 @@ func AMDDetected() bool {
slog.Debug("amdgpu driver not detected " + sysfsDir)
return false
} else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false
}
return true
}
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm
// so we establish a symlink to wherever we find it on the system
// to <payloads>/rocm
payloadsDir, err := PayloadsDir()
if err != nil {
return "", err
}
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// next to the running binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
peerDir := filepath.Dir(exe)
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
return libDir, nil
}
// Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable(installedRocmDir) {
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
return installedRocmDir, nil
}
// If we still haven't found a usable rocm, the user will have to install it on their own
@@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() map[int]Version {
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
if i == 0 {
// Skipping the CPU
continue
}
// Align with HIP IDs (zero is first GPU, not CPU)
i -= 1
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] != "0" {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View File

@@ -7,7 +7,10 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
@@ -22,36 +25,32 @@ var (
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
)
func AMDGetGPUInfo(resp *GpuInfo) {
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
hl, err := NewHipLib()
if err != nil {
slog.Debug(err.Error())
return
return nil
}
defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
slog.Debug("error looking up amd driver version", "error", err)
}
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
if count == 0 {
return
return nil
}
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return nil
}
var supported []string
@@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return nil
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
slog.Info(fmt.Sprintf("detected %d hip devices", count))
slog.Info("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("set device", "id", i, "error", err)
continue
}
props, err := hl.HipGetDeviceProperties(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("get properties", "id", i, "error", err)
continue
}
n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
// TODO is UUID actually populated on windows?
// Can luid be used on windows for setting visible devices (and is it actually set?)
n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
var major, minor, patch string
switch len(gfx) {
case 6:
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
case 7:
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
}
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
skip[i] = struct{}{}
slog.Info("iGPU detected skipping", "id", i)
continue
}
if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
}
}
totalMemory, freeMemory, err := hl.HipMemGetInfo()
freeMemory, totalMemory, err := hl.HipMemGetInfo()
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
slog.Warn("get mem info", "id", i, "error", err)
continue
}
// TODO according to docs, freeMem may lie on windows!
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += freeMemory
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
continue
}
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: freeMemory,
},
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory,
}
if major != "" {
gpuInfo.Major, err = strconv.Atoi(major)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if minor != "" {
gpuInfo.Minor, err = strconv.Atoi(minor)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if patch != "" {
// Patch rev is hex; e.g. gfx90a
p, err := strconv.ParseInt(patch, 16, 0)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
} else {
gpuInfo.Patch = int(p)
}
}
if gpuInfo.Major < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
continue
}
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
}
// Abort if all GPUs are skipped
if len(skip) >= count {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
UpdatePath(libDir)
return resp
}
func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
return libDir, nil
}
// Installer payload (if we're running from some other location)
@@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
return rocmTargetDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

View File

@@ -24,6 +24,51 @@ func PayloadsDir() (string, error) {
defer lock.Unlock()
var err error
if payloadsDir == "" {
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
// On Windows we do not carry the payloads inside the main executable
if runtime.GOOS == "windows" && runnersDir == "" {
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
return "", err
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
return "", err
}
var paths []string
for _, root := range []string{appExe, cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
runnersDir = candidate
break
}
}
if runnersDir == "" {
err = fmt.Errorf("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
slog.Error("incomplete distribution", "error", err)
return "", err
}
}
if runnersDir != "" {
payloadsDir = runnersDir
return payloadsDir, nil
}
// The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs()
tmpDir := os.Getenv("OLLAMA_TMPDIR")
if tmpDir == "" {
@@ -80,7 +125,7 @@ func cleanupTmpDirs() {
}
err = os.RemoveAll(d)
if err != nil {
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
}
}
}
@@ -88,7 +133,8 @@ func cleanupTmpDirs() {
func Cleanup() {
lock.Lock()
defer lock.Unlock()
if payloadsDir != "" {
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
slog.Debug("cleaning up", "dir", tmpDir)
@@ -120,7 +166,7 @@ func UpdatePath(dir string) {
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
slog.Info("updating", "PATH", newPath)
os.Setenv("PATH", newPath)
}
// linux and darwin rely on rpath

22
gpu/cuda_common.go Normal file
View File

@@ -0,0 +1,22 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View File

@@ -16,7 +16,6 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
@@ -25,8 +24,8 @@ import (
)
type handles struct {
nvml *C.nvml_handle_t
cudart *C.cudart_handle_t
deviceCount int
cudart *C.cudart_handle_t
}
const (
@@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
// With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library
var NvmlLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
var RocmComputeMin = 9
// TODO: are these stubs ever valid?
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
}
var NvmlWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
// TODO find a better way to detect iGPU instead of minimum memory
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
var CudartLinuxGlobs = []string{
"/usr/local/cuda/lib64/libcudart.so*",
@@ -88,26 +71,18 @@ func initGPUHandles() *handles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{nil, nil}
var nvmlMgmtName string
var nvmlMgmtPatterns []string
gpuHandles := &handles{}
var cudartMgmtName string
var cudartMgmtPatterns []string
tmpDir, _ := PayloadsDir()
switch runtime.GOOS {
case "windows":
nvmlMgmtName = "nvml.dll"
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
cudartMgmtName = "cudart64_*.dll"
localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
case "linux":
nvmlMgmtName = "libnvidia-ml.so"
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
cudartMgmtName = "libcudart.so*"
if tmpDir != "" {
// TODO - add "payloads" for subprocess
@@ -118,31 +93,21 @@ func initGPUHandles() *handles {
return gpuHandles
}
slog.Info("Detecting GPU type")
slog.Info("Detecting GPUs")
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 {
cudart := LoadCUDARTMgmt(cudartLibPaths)
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil {
slog.Info("Nvidia GPU detected via cudart")
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart
return gpuHandles
}
}
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
if len(nvmlLibPaths) > 0 {
nvml := LoadNVMLMgmt(nvmlLibPaths)
if nvml != nil {
slog.Info("Nvidia GPU detected via nvidia-ml")
gpuHandles.nvml = nvml
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
return gpuHandles
}
func GetGPUInfo() GpuInfo {
func GetGPUInfo() GpuInfoList {
// TODO - consider exploring lspci (and equivalent on windows) to check for
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
gpuMutex.Lock()
@@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
gpuHandles := initGPUHandles()
defer func() {
if gpuHandles.nvml != nil {
C.nvml_release(*gpuHandles.nvml)
}
if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart)
}
@@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
}
var memInfo C.mem_info_t
resp := GpuInfo{}
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
resp := []GpuInfo{}
// NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
gpuInfo := GpuInfo{
Library: "cuda",
}
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.nvml_compute_capability_t
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
continue
}
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.cudart_compute_capability_t
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
continue
}
} else {
AMDGetGPUInfo(&resp)
if resp.Library != "" {
resp.MinimumMemory = rocmMinimumMemory
return resp
}
}
if resp.Library == "" {
C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = cpuVariant
}
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
return resp
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Major = int(memInfo.major)
gpuInfo.Minor = int(memInfo.minor)
gpuInfo.MinimumMemory = cudaMinimumMemory
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo)
}
// Then AMD
resp = append(resp, AMDGetGPUInfo()...)
if len(resp) == 0 {
C.cpu_check_ram(&memInfo)
if memInfo.err != nil {
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
return resp
}
gpuInfo := GpuInfo{
Library: "cpu",
Variant: cpuVariant,
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp = append(resp, gpuInfo)
}
resp.DeviceCount = uint32(memInfo.count)
resp.FreeMemory = uint64(memInfo.free)
resp.TotalMemory = uint64(memInfo.total)
return resp
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
var ret memInfo
var info C.mem_info_t
C.cpu_check_ram(&info)
@@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
return ret, nil
}
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string
gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
slog.Debug("Searching for GPU library", "name", baseLibName)
switch runtime.GOOS {
case "windows":
@@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
}
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns {
// Ignore glob discovery errors
matches, _ := filepath.Glob(pattern)
@@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
}
}
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
return gpuLibPaths
}
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
var resp C.nvml_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvmlLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvml_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
var resp C.cudart_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range cudartLibPaths {
@@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
defer C.free(unsafe.Pointer(lib))
C.cudart_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
return int(resp.num_devices), &resp.ch, libPath
}
}
return nil
return 0, nil, ""
}
func getVerboseState() C.uint16_t {
@@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
}
return C.uint16_t(0)
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return "", ""
}
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
}

View File

@@ -9,52 +9,41 @@ package gpu
*/
import "C"
import (
"fmt"
"log/slog"
"os"
"runtime"
"strconv"
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
func GetGPUInfo() GpuInfoList {
mem, _ := GetCPUMem()
if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal
return 0, nil
}
return uint64(C.getRecommendedMaxVRAM()), nil
}
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
return []GpuInfo{
{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
},
}
}
return GpuInfo{
info := GpuInfo{
Library: "metal",
memInfo: mem,
ID: "0",
}
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
info.FreeMemory = info.TotalMemory
info.MinimumMemory = 0
return []GpuInfo{info}
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0,
DeviceCount: 1,
}, nil
}
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return "", ""
}

View File

@@ -38,12 +38,17 @@
extern "C" {
#endif
#define GPU_ID_LEN 64
typedef struct mem_info {
char *err; // If non-nill, caller responsible for freeing
char gpu_id[GPU_ID_LEN];
uint64_t total;
uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing
// Compute Capability
int major;
int minor;
} mem_info_t;
void cpu_check_ram(mem_info_t *resp);
@@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
}
#endif
#include "gpu_info_nvml.h"
#include "gpu_info_cudart.h"
#endif // __GPU_INFO_H__

View File

@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
MEMORYSTATUSEX info;
info.dwLength = sizeof(info);
if (GlobalMemoryStatusEx(&info) != 0) {
resp->count = 1;
resp->total = info.ullTotalPhys;
resp->free = info.ullAvailPhys;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} else {
resp->err = LOAD_ERR();
}
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
if (sysinfo(&info) != 0) {
resp->err = strdup(strerror(errno));
} else {
resp->count = 1;
resp->total = info.totalram * info.mem_unit;
resp->free = info.freeram * info.mem_unit;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
}
return;
}

View File

@@ -6,6 +6,7 @@
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
cudartReturn_t ret;
resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
{NULL, NULL},
};
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
char *msg = LOAD_ERR();
@@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "cudart init failure: %d", ret);
@@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
}
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
if (ret != CUDART_SUCCESS) {
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
}
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL;
cudartMemory_t memInfo = {0,0,0};
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle isn't initialized");
return;
}
// cudaGetDeviceCount takes int type, resp-> count is uint
int deviceCount;
ret = (*h.cudaGetDeviceCount)(&deviceCount);
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
cudaDeviceProp_t props;
ret = (*h.cudaGetDeviceProperties)(&props, i);
if (ret != CUDART_SUCCESS) {
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
resp->major = 0;
resp->minor = 0;
} else {
resp->count = (unsigned int)deviceCount;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp-> count; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
int allNull = 1;
for (int j = 0; j < 16; j++) {
if (props.uuid.bytes[j] != 0) {
allNull = 0;
break;
}
}
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
if (allNull != 0) {
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
props.uuid.bytes[0],
props.uuid.bytes[1],
props.uuid.bytes[2],
props.uuid.bytes[3],
props.uuid.bytes[4],
props.uuid.bytes[5],
props.uuid.bytes[6],
props.uuid.bytes[7],
props.uuid.bytes[8],
props.uuid.bytes[9],
props.uuid.bytes[10],
props.uuid.bytes[11],
props.uuid.bytes[12],
props.uuid.bytes[13],
props.uuid.bytes[14],
props.uuid.bytes[15]
);
}
resp->major = props.major;
resp->minor = props.minor;
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
// TODO add other useful properties from props
}
}
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle not initialized");
return;
}
int devices;
ret = (*h.cudaGetDeviceCount)(&devices);
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
resp->total = memInfo.total;
resp->free = memInfo.free;
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
void cudart_release(cudart_handle_t h) {

View File

@@ -6,7 +6,8 @@
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudartReturn_enum {
CUDART_SUCCESS = 0,
CUDART_UNSUPPORTED = 1,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now...
} cudartReturn_t;
@@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
typedef enum cudartDeviceAttr_enum {
cudartDevAttrComputeCapabilityMajor = 75,
cudartDevAttrComputeCapabilityMinor = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
cudaDevAttrIntegrated = 18
} cudartDeviceAttr_t;
typedef void *cudartDevice_t; // Opaque is sufficient
@@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
int minor;
} cudartDriverVersion_t;
typedef struct cudaUUID {
unsigned char bytes[16];
} cudaUUID_t;
typedef struct cudaDeviceProp {
char name[256]; /**< ASCII string identifying device */
cudaUUID_t uuid; /**< 16-byte unique identifier */
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
size_t totalGlobalMem; /**< Global memory available on device in bytes */
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
int regsPerBlock; /**< 32-bit registers available per block */
int warpSize; /**< Warp size in threads */
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
int maxThreadsPerBlock; /**< Maximum number of threads per block */
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
int clockRate; /**< Clock frequency in kilohertz */
size_t totalConstMem; /**< Constant memory available on device in bytes */
int major; /**< Major compute capability */
int minor; /**< Minor compute capability */
size_t textureAlignment; /**< Alignment requirement for textures */
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
int multiProcessorCount; /**< Number of multiprocessors on device */
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
int integrated; /**< Device is integrated as opposed to discrete */
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
int maxTexture1D; /**< Maximum 1D texture size */
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
int maxSurface1D; /**< Maximum 1D surface size */
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
int ECCEnabled; /**< Device has ECC support enabled */
int pciBusID; /**< PCI bus ID of the device */
int pciDeviceID; /**< PCI device ID of the device */
int pciDomainID; /**< PCI domain ID of the device */
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
int asyncEngineCount; /**< Number of asynchronous engines */
int unifiedAddressing; /**< Device shares a unified address space with the host */
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
int memoryBusWidth; /**< Global memory bus width in bits */
int l2CacheSize; /**< Size of L2 cache in bytes */
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
int streamPrioritiesSupported; /**< Device supports stream priorities */
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
int localL1CacheSupported; /**< Device supports caching locals in L1 */
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
int managedMemory; /**< Device supports allocating managed memory on this system */
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
int computePreemptionSupported; /**< Device supports Compute Preemption */
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
} cudaDeviceProp_t;
typedef struct cudart_handle {
void *handle;
uint16_t verbose;
@@ -38,23 +130,17 @@ typedef struct cudart_handle {
cudartReturn_t (*cudaGetDeviceCount)(int *);
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
} cudart_handle_t;
typedef struct cudart_init_resp {
char *err; // If err is non-null handle is invalid
cudart_handle_t ch;
int num_devices;
} cudart_init_resp_t;
typedef struct cudart_compute_capability {
char *err;
int major;
int minor;
} cudart_compute_capability_t;
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
void cudart_release(cudart_handle_t ch);
#endif // __GPU_INFO_CUDART_H__

View File

@@ -1,221 +0,0 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvml.h"
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvml_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
}
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
resp->err = NULL;
nvmlDevice_t device;
nvmlMemory_t memInfo = {0};
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle isn't initialized");
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
}
}
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
nvmlDevice_t device;
int major = 0;
int minor = 0;
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle not initialized");
return;
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
}
void nvml_release(nvml_handle_t h) {
LOG(h.verbose, "releasing nvml library\n");
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
#endif // __APPLE__

View File

@@ -1,57 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVML_H__
#define __GPU_INFO_NVML_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum nvmlReturn_enum {
NVML_SUCCESS = 0,
// Other values omitted for now...
} nvmlReturn_t;
typedef void *nvmlDevice_t; // Opaque is sufficient
typedef struct nvmlMemory_st {
unsigned long long total;
unsigned long long free;
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct nvml_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} nvml_handle_t;
typedef struct nvml_init_resp {
char *err; // If err is non-null handle is invalid
nvml_handle_t ch;
} nvml_init_resp_t;
typedef struct nvml_compute_capability {
char *err;
int major;
int minor;
} nvml_compute_capability_t;
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
void nvml_release(nvml_handle_t ch);
#endif // __GPU_INFO_NVML_H__
#endif // __APPLE__

View File

@@ -9,23 +9,16 @@ import (
func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library)
switch runtime.GOOS {
case "darwin":
// TODO - remove this once MacOS returns some size for CPU
return
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
default:
return
assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0))
assert.Greater(t, info[0].FreeMemory, uint64(0))
}
}
func TestCPUMemInfo(t *testing.T) {
info, err := getCPUMem()
info, err := GetCPUMem()
assert.NoError(t, err)
switch runtime.GOOS {
case "darwin":

View File

@@ -3,7 +3,6 @@ package gpu
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
DeviceCount uint32 `json:"device_count,omitempty"`
}
// Beginning of an `ollama info` command
@@ -17,11 +16,49 @@ type GpuInfo struct {
// MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath string `json:"lib_path,omitempty"`
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
// TODO other performance capability info to help in scheduling decisions
}
type Version struct {
Major uint
Minor uint
Patch uint
type GpuInfoList []GpuInfo
// Split up the set of gpu info's by Library and variant
func (l GpuInfoList) ByLibrary() []GpuInfoList {
resp := []GpuInfoList{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
requested += "_" + info.Variant
}
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, info.Library)
resp = append(resp, []GpuInfo{info})
}
}
return resp
}
// Sort by Free Space
type ByFreeMemory []GpuInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

View File

@@ -4,11 +4,14 @@ package integration
import (
"context"
"net/http"
"log/slog"
"os"
"runtime"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestOrcaMiniBlueSky(t *testing.T) {
@@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
"seed": 123,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}
func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" {
t.Skip("Unicode test only applicable to windows")
}
// Only works for local testing
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
}
modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err)
defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS")
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}

View File

@@ -0,0 +1,225 @@
//go:build integration
package integration
import (
"context"
"log/slog"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, req[i], resp[i])
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req, resp := GenerateRequests()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < 5; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the 4 concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
const MB = uint64(1024 * 1024)
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
size: 2992 * MB,
},
{
name: "phi",
size: 2616 * MB,
},
{
name: "gemma:2b",
size: 2364 * MB,
},
{
name: "stable-code:3b",
size: 2608 * MB,
},
{
name: "starcoder2:3b",
size: 2166 * MB,
},
}
mediumModels := []model{
{
name: "llama2",
size: 5118 * MB,
},
{
name: "mistral",
size: 4620 * MB,
},
{
name: "orca-mini:7b",
size: 5118 * MB,
},
{
name: "dolphin-mistral",
size: 4620 * MB,
},
{
name: "gemma:7b",
size: 5000 * MB,
},
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * MB,
// },
// {
// name: "codellama:13b",
// size: 7400 * MB,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * MB,
// },
// {
// name: "gemma:7b",
// size: 5000 * MB,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * MB,
// },
// }
var chosenModels []model
switch {
case max < 10000*MB:
slog.Info("selecting small models")
chosenModels = smallModels
// case max < 30000*MB:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
break
}
consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}

View File

@@ -4,7 +4,6 @@ package integration
import (
"context"
"net/http"
"testing"
"time"
@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
}

View File

@@ -5,7 +5,6 @@ package integration
import (
"context"
"encoding/base64"
"net/http"
"testing"
"time"
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
},
}
resp := "the ollamas"
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
GenerateTestHelper(ctx, t, req, []string{resp})
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

View File

@@ -4,8 +4,6 @@ package integration
import (
"context"
"net/http"
"sync"
"testing"
"time"
@@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
GenerateTestHelper(ctx, t, req[0], resp[0])
}
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

View File

@@ -5,13 +5,14 @@ package integration
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Init() {
lifecycle.InitLogging()
}
func FindPort() string {
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port)
}
func GetTestEndpoint() (string, string) {
func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST")
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort()
}
url := fmt.Sprintf("%s:%s", host, port)
slog.Info("server connection", "url", url)
return scheme, url
slog.Info("server connection", "host", host, "port", port)
return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
}
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex
var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error {
func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
CLIName, err := filepath.Abs("../ollama")
if err != nil {
@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil
}
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
showCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err
}
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
default:
slog.Info("model already present", "model", modelName)
return nil
}
slog.Info("model missing", "status", response.StatusCode)
slog.Info("model missing", "model", modelName)
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
}
return nil
}
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
var pullError error
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
done := make(chan int)
go func() {
pullError = client.Pull(ctx, pullReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
}
var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
// TODO maybe stuff in an init routine?
lifecycle.InitLogging()
requestJSON, err := json.Marshal(genReq)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
}
defer func() {
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
req.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var payload api.GenerateResponse
err = json.Unmarshal(body, &payload)
if err != nil {
assert.NoError(t, err, body)
}
// Verify the response contains the expected data
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(payload.Response), resp) {
atLeastOne = true
break
}
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
}
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// fmt.Print(".")
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
genReq.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(ctx, &genReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
}
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}

View File

@@ -21,7 +21,7 @@ init_vars() {
# TODO - add additional optimization flags...
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
fi
case $(uname -s) in
case $(uname -s) in
"Darwin")
LIB_EXT="dylib"
WHOLE_ARCHIVE="-Wl,-force_load"

View File

@@ -57,21 +57,21 @@ init_vars
git_module_setup
apply_patches
init_vars
if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
# Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
# Enables optimized Dockerfile builds using a blanket skip and targeted overrides
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build
fi
init_vars
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build
fi
# Users building from source can tune the exact flags we pass to cmake for configuring
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
@@ -165,14 +165,22 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
fi
if [ "${ARCH}" == "arm64" ]; then
echo "ARM CPU detected - disabling unsupported AVX instructions"
# ARM-based CPUs such as M1 and Tegra do not support AVX extensions.
#
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility.
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
fi
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build
@@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
echo "Building custom ROCM GPU"
fi
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
build

View File

@@ -26,15 +26,25 @@ function amdGPUs {
$GPU_LIST -join ';'
}
function init_vars {
$script:SRC_DIR = $(resolve-path "..\..\")
$script:llamacppDir = "../llama.cpp"
if (!$script:SRC_DIR) {
$script:SRC_DIR = $(resolve-path "..\..\")
}
if (!$script:llamacppDir) {
$script:llamacppDir = "../llama.cpp"
}
if (!$script:cmakeTargets) {
$script:cmakeTargets = @("ollama_llama_server")
}
$script:cmakeDefs = @(
"-DBUILD_SHARED_LIBS=on",
"-DLLAMA_NATIVE=off"
)
$script:cmakeTargets = @("ollama_llama_server")
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = "amd64" # arm not yet supported.
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
md "$script:DIST_BASE" -ea 0 > $null
if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
$script:config = "RelWithDebInfo"
@@ -55,7 +65,6 @@ function init_vars {
} else {
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
}
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
@@ -134,21 +143,18 @@ function sign {
}
}
function compress {
if ($script:GZIP -eq $null) {
write-host "gzip not installed, not compressing files"
return
}
write-host "Compressing binaries..."
function install {
write-host "Installing binaries to dist dir ${script:distDir}"
mkdir ${script:distDir} -ErrorAction SilentlyContinue
$binaries = dir "${script:buildDir}/bin/*.exe"
foreach ($file in $binaries) {
& "$script:GZIP" --best -f $file
copy-item -Path $file -Destination ${script:distDir} -Force
}
write-host "Compressing dlls..."
write-host "Installing dlls to dist dir ${script:distDir}"
$dlls = dir "${script:buildDir}/bin/*.dll"
foreach ($file in $dlls) {
& "$script:GZIP" --best -f $file
copy-item -Path $file -Destination ${script:distDir} -Force
}
}
@@ -169,123 +175,191 @@ function cleanup {
}
}
init_vars
git_module_setup
apply_patches
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
function build_static() {
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
# as we need this to be compiled by gcc for golang to be able to link with itx
write-host "Checking for MinGW..."
# error action ensures we exit on failure
get-command gcc
get-command mingw32-make
$oldTargets = $script:cmakeTargets
$script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @(
"-G", "MinGW Makefiles"
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build
$script:cmakeTargets = $oldTargets
} else {
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
build
sign
install
} else {
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
$script:distDir="$script:DIST_BASE\cpu_avx"
write-host "Building AVX CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX generation step as requested"
}
}
function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
$script:distDir="$script:DIST_BASE\cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX2 generation step as requested"
}
}
function build_cuda() {
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
# Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
if ($null -ne $script:CUDA_VERSION) {
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
$script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
write-host "building custom CUDA GPU"
}
build
sign
install
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
} else {
write-host "Skipping CUDA generation step"
}
}
function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
$script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
$script:cmakeDefs += @(
"-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)"
)
# Make sure the ROCm binary dir is first in the path
$env:PATH="$env:HIP_PATH\bin;$env:PATH"
# We have to clobber the LIB var from the developer shell for clang to work properly
$env:LIB=""
if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
$script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
write-host "building custom ROCM GPU"
}
write-host "Building ROCm"
build
# Ninja doesn't prefix with config name
${script:config}=""
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
}
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
} else {
write-host "Skipping ROCm generation step"
}
}
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
# as we need this to be compiled by gcc for golang to be able to link with itx
write-host "Checking for MinGW..."
# error action ensures we exit on failure
get-command gcc
get-command mingw32-make
$script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @(
"-G", "MinGW Makefiles"
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build
if ($($args.count) -eq 0) {
git_module_setup
apply_patches
build_static
build_cpu
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
write-host "Building LCD CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
write-host "Building AVX CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
compress
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
} else {
write-host "Skipping CPU generation step as requested"
}
if ($null -ne $script:CUDA_LIB_DIR) {
# Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
if ($null -ne $script:CUDA_VERSION) {
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build
sign
compress
}
if ($null -ne $env:HIP_PATH) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
$script:cmakeDefs += @(
"-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)"
)
# Make sure the ROCm binary dir is first in the path
$env:PATH="$env:HIP_PATH\bin;$env:PATH"
# We have to clobber the LIB var from the developer shell for clang to work properly
$env:LIB=""
write-host "Building ROCm"
build
# Ninja doesn't prefix with config name
${script:config}=""
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
}
sign
compress
}
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
for ( $i = 0; $i -lt $args.count; $i++ ) {
write-host "performing $($args[$i])"
& $($args[$i])
}
}

View File

@@ -164,7 +164,8 @@ func (ts Tensors) Layers() map[string]Layer {
for _, t := range ts {
parts := strings.Split(t.Name, ".")
if parts[0] == "blk" {
parts = parts[1:]
// join first and second part, e.g. blk.%d
parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
}
if _, ok := layers[parts[0]]; !ok {
@@ -342,7 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
)
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
// mixtral 8x7b
ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
partialOffload = max(
@@ -380,6 +389,12 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
)
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
case "stablelm":
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
partialOffload = max(
4*batch*(vocab+2*embedding),
fullOffload,
)
}
return

View File

@@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
llm.kv[k] = v
}
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
// decode tensors
for i := 0; uint64(i) < llm.numTensor(); i++ {
name, err := readGGUFString(llm, rs)
@@ -248,13 +246,17 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
}
padding := llm.padding(offset, int64(alignment))
if _, err := rs.Seek(padding-offset, io.SeekCurrent); err != nil {
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err
}
for _, tensor := range llm.tensors {
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil {
return err
}
padding := llm.padding(int64(tensor.size()), int64(alignment))
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err
}
}
@@ -461,11 +463,13 @@ var ggufKVOrder = map[string][]string{
"llama.embedding_length",
"llama.block_count",
"llama.feed_forward_length",
"llama.rope.dimension_count",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.layer_norm_rms_epsilon",
"llama.rope.freq_base",
"llama.rope.dimension_count",
"llama.expert_count",
"llama.expert_used_count",
"gemma.context_length",
"gemma.embedding_length",
"gemma.block_count",
@@ -573,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
}
default:
return fmt.Errorf("improper type for '%s'", k)
}
if err != nil {
return err
@@ -594,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
dims := 1
if tensor.Shape[1] > 0 {
dims = 2
dims := 0
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
if tensor.Shape[cnt] > 0 {
dims++
}
}
if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
@@ -623,8 +631,9 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
padding := llm.padding(offset, 32)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
var alignment int64 = 32
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
@@ -638,8 +647,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
padding := llm.padding(offset, 32)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
}
@@ -648,5 +657,5 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
}
func (gguf) padding(offset, align int64) int64 {
return (offset + align - 1) / align * align
return (align - offset%align) % align
}

View File

@@ -2,5 +2,5 @@ package llm
import "embed"
//go:embed build/windows/*/*/bin/*
// unused on windows
var libEmbed embed.FS

188
llm/memory.go Normal file
View File

@@ -0,0 +1,188 @@
package llm
import (
"fmt"
"log/slog"
"os"
"strconv"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
if gpus[0].Library == "cpu" {
return 0, 0
}
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
slog.Info("user override memory limit", "OLLAMA_MAX_VRAM", avail, "actual", memoryAvailable)
memoryAvailable = avail
}
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0
}
layers := ggml.Tensors().Layers()
var memoryLayerOutput uint64
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.size()
}
if layer, ok := layers["output"]; ok {
memoryLayerOutput += layer.size()
} else if layer, ok := layers["token_embd"]; ok {
memoryLayerOutput += layer.size()
}
if gpus[0].Library == "metal" && opts.UseMMap {
// memory is preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
memoryRequiredPartial += memoryLayerOutput
}
var layerCount int
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
if gpus[0].Library != "metal" || !opts.UseMMap {
// memory was not preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
}
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
return layerCount, uint64(memoryRequiredPartial)
}

View File

@@ -0,0 +1,12 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index e431c7f7..f077e688 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -3,6 +3,7 @@
// I'll gradually clean and extend it
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
+#include "common.h"
#include "log.h"
#include "ggml.h"
#include "ggml-alloc.h"

45
llm/patches/04-metal.diff Normal file
View File

@@ -0,0 +1,45 @@
diff --git a/ggml-metal.m b/ggml-metal.m
index 0207b787..b5e9884b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
// to the matrix-vector kernel
int ne11_mm_min = 1;
-#if 0
// the numbers below are measured on M2 Ultra for 7B and 13B models
// these numbers do not translate to other devices or model sizes
// TODO: need to find a better approach
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
}
-#endif
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/exp/slices"
@@ -17,7 +18,7 @@ import (
"github.com/ollama/ollama/gpu"
)
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
func Init() error {
payloadsDir, err := gpu.PayloadsDir()
@@ -25,13 +26,15 @@ func Init() error {
return err
}
slog.Info("extracting embedded files", "dir", payloadsDir)
binGlob := "build/*/*/*/bin/*"
if runtime.GOOS != "windows" {
slog.Info("extracting embedded files", "dir", payloadsDir)
binGlob := "build/*/*/*/bin/*"
// extract server libraries
err = extractFiles(payloadsDir, binGlob)
if err != nil {
return fmt.Errorf("extract binaries: %v", err)
// extract server libraries
err = extractFiles(payloadsDir, binGlob)
if err != nil {
return fmt.Errorf("extract binaries: %v", err)
}
}
var variants []string
@@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers
}
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob)

View File

@@ -21,21 +21,43 @@ import (
"strings"
"time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// LlamaServer is an instance of the llama.cpp server
type LlamaServer struct {
type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
sem *semaphore.Weighted
}
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
@@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close()
ggml, _, err := DecodeGGML(f)
if err != nil {
return nil, err
}
return ggml, err
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
@@ -56,94 +81,51 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
opts.NumCtx = 4
}
memoryAvailable, _ := gpu.CheckVRAM()
info := gpu.GetGPUInfo()
cpuRunner := ""
var estimatedVRAM uint64
var systemMemory uint64
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
memoryMinimum := info.MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
cpuRunner = serverForCpu()
} else {
if gpus[0].Library == "metal" {
memInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
}
var layers int
layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
opts.NumGPU = 0
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
opts.NumGPU = layers
}
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
memoryLayerOutput := layers["output"].size()
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
slog.Info(
"offload to gpu",
"reallayers", opts.NumGPU,
"layers", layerCount,
"required", format.HumanBytes2(memoryRequiredTotal),
"used", format.HumanBytes2(memoryRequiredPartial),
"available", format.HumanBytes2(memoryAvailable),
"kv", format.HumanBytes2(kv),
"fulloffload", format.HumanBytes2(graphFullOffload),
"partialoffload", format.HumanBytes2(graphPartialOffload),
)
// Loop through potential servers
finalErr := fmt.Errorf("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
availableServers := availableServers()
servers := serversForGpu(info)
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
if demandLib != "" {
serverPath := availableServers[demandLib]
if serverPath == "" {
@@ -155,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
}
if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info)
return nil, fmt.Errorf("no servers found for %v", gpus)
}
params := []string{
@@ -212,10 +194,26 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa")
}
// Loop through potential servers
var finalErr error
// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
numParallel, err = strconv.Atoi(onp)
if err != nil || numParallel <= 0 {
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
slog.Error("misconfiguration", "error", err)
return nil, err
}
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]]
if dir == "" {
// Shouldn't happen
finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
slog.Error("sever list inconsistent", "error", finalErr)
continue
}
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
port := 0
@@ -238,30 +236,60 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
}
// append the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
}
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
}
server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" {
server = server + ".exe"
}
s := &LlamaServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
// Detect tmp cleaners wiping out the file
_, err := os.Stat(server)
if errors.Is(err, os.ErrNotExist) {
slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
err = Init()
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
}
}
s := &llmServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
estimatedVRAM: estimatedVRAM,
sem: semaphore.NewWeighted(int64(numParallel)),
}
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv)
s.cmd.Env = append(os.Environ(), libEnv)
s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status
// TODO - multiple GPU selection logic...
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
if key != "" {
s.cmd.Env = append(s.cmd.Env, key+"="+val)
}
slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil {
msg := ""
@@ -279,6 +307,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait()
}()
// TODO - make sure this is all wired up correctly
// if err = s.WaitUntilRunning(); err != nil {
// slog.Error("error starting llama server", "server", servers[i], "error", err)
// s.Close()
// finalErr = err
// continue
// }
return s, nil
}
@@ -316,6 +351,21 @@ const ( // iota is reset to 0
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvaialble:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
@@ -323,7 +373,7 @@ type ServerStatusResp struct {
Error string `json:"error"`
}
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited
if s.cmd.ProcessState != nil {
msg := ""
@@ -370,7 +420,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
}
}
func (s *LlamaServer) Ping(ctx context.Context) error {
func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx)
if err != nil {
slog.Debug("server unhealthy", "error", err)
@@ -379,7 +429,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil
}
func (s *LlamaServer) WaitUntilRunning() error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
@@ -390,6 +440,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
var lastStatus ServerStatus = -1
for {
select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
@@ -413,9 +466,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()
status, err := s.getServerStatus(ctx)
status, err := s.getServerStatus(c)
if err != nil && lastStatus != status {
slog.Debug("server not yet available", "error", err)
lastStatus = status
@@ -501,7 +554,19 @@ type CompletionResponse struct {
EvalDuration time.Duration
}
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
// only allow maximum 10 "context shifts" to avoid infinite generation
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
}
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
@@ -532,7 +597,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status)
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
if req.Format == "json" {
@@ -679,13 +744,18 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
@@ -731,13 +801,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: content})
@@ -783,13 +853,13 @@ type DetokenizeResponse struct {
Content string `json:"content"`
}
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady {
return "", fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@@ -827,7 +897,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil
}
func (s *LlamaServer) Close() error {
func (s *llmServer) Close() error {
if s.cmd != nil {
slog.Debug("stopping llama server")
return s.cmd.Process.Kill()
@@ -836,6 +906,10 @@ func (s *LlamaServer) Close() error {
return nil
}
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {

View File

@@ -30,7 +30,7 @@ function checkEnv() {
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
$script:DEPS_DIR="${script:SRC_DIR}\dist\windeps"
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
$env:CGO_ENABLED="1"
echo "Checking version"
if (!$env:VERSION) {
@@ -81,8 +81,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
New-Item -ItemType Directory -Path .\dist -Force
cp .\ollama.exe .\dist\ollama-windows-amd64.exe
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
cp .\ollama.exe .\dist\windows-amd64\
}
function buildApp() {
@@ -101,7 +101,6 @@ function buildApp() {
function gatherDependencies() {
write-host "Gathering runtime dependencies"
cd "${script:SRC_DIR}"
rm -ea 0 -recurse -force -path "${script:DEPS_DIR}"
md "${script:DEPS_DIR}" -ea 0 > $null
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
@@ -110,9 +109,6 @@ function gatherDependencies() {
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cudart64_*.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cublas64_*.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cublasLt64_*.dll" "${script:DEPS_DIR}\"
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
if ("${env:KEY_CONTAINER}") {
@@ -124,7 +120,6 @@ function gatherDependencies() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildInstaller() {
@@ -139,12 +134,18 @@ function buildInstaller() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
function distZip() {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
}
try {
checkEnv
buildOllama
buildApp
gatherDependencies
buildInstaller
distZip
} catch {
write-host "Build Failed"
write-host $_

View File

@@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -701,36 +702,39 @@ func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string
return path, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()
func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() {
return model.Unqualified(dst)
}
if !src.IsFullyQualified() {
return model.Unqualified(src)
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
destModelPath := ParseModelPath(dest)
destPath, err := destModelPath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
dstpath := filepath.Join(manifests, dst.Filepath())
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
// copy the file
input, err := os.ReadFile(srcPath)
srcpath := filepath.Join(manifests, src.Filepath())
srcfile, err := os.Open(srcpath)
if err != nil {
fmt.Println("Error reading file:", err)
return err
}
defer srcfile.Close()
err = os.WriteFile(destPath, input, 0o644)
dstfile, err := os.Create(dstpath)
if err != nil {
fmt.Println("Error reading file:", err)
return err
}
defer dstfile.Close()
return nil
_, err = io.Copy(dstfile, srcfile)
return err
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
@@ -1137,7 +1141,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized")
var errUnauthorized = errors.New("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for i := 0; i < 2; i++ {
@@ -1255,7 +1259,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
}
}
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)

View File

@@ -15,11 +15,8 @@ import (
"os"
"os/signal"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
@@ -32,13 +29,15 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
addr net.Addr
sched *Scheduler
}
func init() {
@@ -53,88 +52,8 @@ func init() {
gin.SetMode(mode)
}
var loaded struct {
mu sync.Mutex
llama *llm.LlamaServer
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
}
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
unload()
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -154,9 +73,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@@ -224,7 +141,16 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -275,7 +201,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -297,9 +223,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
@@ -331,7 +254,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
@@ -359,7 +282,7 @@ func GenerateHandler(c *gin.Context) {
Images: images,
Options: opts,
}
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -421,10 +344,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration
}
func EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -469,7 +389,16 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -480,7 +409,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -493,7 +422,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func PullModelHandler(c *gin.Context) {
func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -542,7 +471,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func PushModelHandler(c *gin.Context) {
func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -591,7 +520,7 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func CreateModelHandler(c *gin.Context) {
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -664,7 +593,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func DeleteModelHandler(c *gin.Context) {
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -709,7 +638,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
func ShowModelHandler(c *gin.Context) {
func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
@@ -809,7 +738,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
func ListModelsHandler(c *gin.Context) {
func (s *Server) ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath()
if err != nil {
@@ -869,39 +798,39 @@ func ListModelsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
func (s *Server) CopyModelHandler(c *gin.Context) {
var r api.CopyRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
src := model.ParseName(r.Source)
if !src.IsValid() {
_ = c.Error(fmt.Errorf("source %q is invalid", r.Source))
}
dst := model.ParseName(r.Destination)
if !dst.IsValid() {
_ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination))
}
if len(c.Errors) > 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()})
return
}
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
func HeadBlobHandler(c *gin.Context) {
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -916,7 +845,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -1063,27 +992,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
@@ -1137,7 +1066,9 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
ctx, done := context.WithCancel(context.Background())
sched := InitScheduler(ctx)
s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@@ -1150,7 +1081,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
unload()
done()
sched.unloadAllRunners()
gpu.Cleanup()
os.Exit(0)
}()
@@ -1158,12 +1090,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error())
}
}
s.sched.Run(ctx)
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
_ = gpu.GetGPUInfo()
return srvr.Serve(ln)
}
@@ -1219,9 +1151,9 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s)
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
@@ -1232,10 +1164,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
@@ -1292,7 +1221,16 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -1309,7 +1247,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1352,8 +1290,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
@@ -1376,7 +1312,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,

546
server/sched.go Normal file
View File

@@ -0,0 +1,546 @@
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
opts api.Options
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList
}
// TODO set this to zero after a release or two, to enable multiple models by default
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
var maxQueuedRequests = 10 // TODO configurable
var numParallel = 1
func InitScheduler(ctx context.Context) *Scheduler {
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
loadedMax = m
}
}
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid parallel setting, must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else {
numParallel = p
}
}
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
expiredCh: make(chan *runnerRef, maxQueuedRequests),
unloadedCh: make(chan interface{}, maxQueuedRequests),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
}
sched.loadFn = sched.load
return sched
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
req := &LlmRequest{
ctx: c,
model: model,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
errCh: make(chan error, 1),
}
// context split across parallel threads
opts.NumCtx = opts.NumCtx * numParallel
select {
case s.pendingReqCh <- req:
default:
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
for {
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
runnerToExpire = runner
} else {
// Runner is usable, return it
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if loadedMax > 0 && loadedCount >= loadedMax {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload(pending)
} else {
// Either no models are loaded or below loadedMax
// Get a refreshed GPU list
gpus := s.getGpuFn()
// Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath)
if err != nil {
pending.errCh <- err
break
}
// No models loaded. Load the model but prefer the best fit.
if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)
g := pickBestFitGPUs(pending, ggml, gpus)
if g != nil {
gpus = g
}
s.loadFn(pending, ggml, gpus)
break
}
// More than one loaded model, so we have to see if the new one fits
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
gpus = pickBestFitGPUs(pending, ggml, gpus)
if gpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, ggml, gpus)
break
}
runnerToExpire = s.findRunnerToUnload(pending)
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
}
s.expiredCh <- runner
})
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model)
runner.refMu.Lock()
if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
slog.Debug("got lock to unload", "model", runner.model)
runner.unload()
s.loadedMu.Lock()
delete(s.loaded, runner.model)
s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model)
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "model", runner.model)
s.unloadedCh <- struct{}{}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
finished <- pending
}()
}
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
runner := &runnerRef{}
runner.model = req.model.ModelPath
runner.adapters = req.model.AdapterPaths
runner.projectors = req.model.ProjectorPaths
runner.llama = llama
runner.Options = &req.opts
runner.sessionDuration = req.sessionDuration
runner.gpus = gpus
runner.estimatedVRAM = llama.EstimatedVRAM()
runner.loading = true
runner.refCount = 1
runner.refMu.Lock()
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
}
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
for _, r := range s.loaded {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
gpuIDs = append(gpuIDs, gpu.ID)
}
for _, gpu := range allGpus {
if slices.Contains(gpuIDs, gpu.ID) {
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
sessionDuration time.Duration
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.llama != nil {
runner.llama.Close()
}
runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
// Don't reload runner if num_gpu=-1 was provided
optsExisting := runner.Options.Runner
optsNew := req.opts.Runner
if optsNew.NumGPU < 0 {
optsExisting.NumGPU = -1
optsNew.NumGPU = -1
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
type ByDuration []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
var estimatedVRAM uint64
for _, gl := range gpus.ByLibrary() {
var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
return []gpu.GpuInfo{g}
}
}
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
return gl
}
}
return nil
}
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}

562
server/sched_test.go Normal file
View File

@@ -0,0 +1,562 @@
package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
os.Setenv("OLLAMA_DEBUG", "1")
lifecycle.InitLogging()
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
require.NotNil(t, s.loaded)
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
s := InitScheduler(ctx)
var ggml *llm.GGML // value not used in tests
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah")
}
gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
require.Len(t, s.loaded, 0)
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, ggml, gpus)
select {
case err := <-req.errCh:
require.NoError(t, err)
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
require.Len(t, s.loaded, 1)
}
req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure")
s.load(req, ggml, gpus)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
runner := s.loaded["dummy_model_path"]
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
time.Sleep(1 * time.Millisecond)
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
req *LlmRequest
ggml *llm.GGML
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
err = gguf.Encode(f, llm.KV{
"general.architecture": "llama",
"general.name": "name",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
return scenario
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.ggml = scenario1a.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
loadedMax = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
loadedMax = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 2)
// Try to load a model that wont fit
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
maxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
scenario1a.ctxDone()
require.Len(t, s.loaded, 1)
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, successCh1c, 0)
require.Len(t, errCh1c, 0)
time.Sleep(5 * time.Millisecond)
require.Len(t, s.loaded, 0)
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Len(t, s.loaded, 1)
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
t.Errorf("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
time.Sleep(5 * time.Millisecond)
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration)
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case <-ctx.Done():
t.Errorf("timeout")
}
done()
fin := <-finished
require.Equal(t, req, fin)
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
gpus := gpu.GpuInfoList{
{
Library: "a",
ID: "1",
},
{
Library: "a",
ID: "2",
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAM: 100}
llm2 := &mockLlm{estimatedVRAM: 200}
r1 := &runnerRef{llama: llm1, gpus: gpus}
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{ctx: ctx}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
r2.refCount = 1
resp = s.findRunnerToUnload(req)
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm := &mockLlm{}
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &api.Options{},
llama: llm,
}
req := &LlmRequest{
model: &Model{
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.Options{},
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
req.model.AdapterPaths = runner.adapters
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.model.ProjectorPaths = runner.projectors
runner.loading = true
req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumBatch = runner.Options.NumBatch
llm.pingResp = fmt.Errorf("foo")
resp = runner.needsReload(ctx, req)
require.True(t, resp)
llm.pingResp = nil
resp = runner.needsReload(ctx, req)
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumGPU = -1
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm1 := &mockLlm{}
llm2 := &mockLlm{}
s := InitScheduler(ctx)
s.unloadAllRunners()
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loaded["a"] = r1
s.loaded["b"] = r2
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
require.True(t, llm2.closeCalled)
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}}
r1.unload()
require.True(t, llm1.closeCalled)
r2.unload()
require.Nil(t, r2.adapters)
}
type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
detonekizeRespErr error
closeResp error
closeCalled bool
estimatedVRAM uint64
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr
}
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.detokenizeResp, s.detonekizeRespErr
}
func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }

View File

@@ -1,79 +0,0 @@
package model
import (
"log/slog"
"strings"
"unicode"
)
// Digest represents a digest of a model Manifest. It is a comparable value
// type and is immutable.
//
// The zero Digest is not a valid digest.
type Digest struct {
s string
}
// Type returns the digest type of the digest.
//
// Example:
//
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
}
// String returns the digest in the form of "<digest-type>-<digest>", or the
// empty string if the digest is invalid.
func (d Digest) String() string { return d.s }
// IsValid returns true if the digest is valid (not zero).
//
// A valid digest may be created only by ParseDigest, or
// ParseName(name).Digest().
func (d Digest) IsValid() bool { return d.s != "" }
// LogValue implements slog.Value.
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}
var (
_ slog.LogValuer = Digest{}
)
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
// Digest.
func ParseDigest(s string) Digest {
typ, digest, ok := strings.Cut(s, "-")
if ok && isValidDigestType(typ) && isValidHex(digest) {
return Digest{s: s}
}
return Digest{}
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}

View File

@@ -1,46 +0,0 @@
package model
import "testing"
var testDigests = map[string]Digest{
"": {},
"sha256-1234": {s: "sha256-1234"},
"sha256-5678": {s: "sha256-5678"},
"blake2-9abc": {s: "blake2-9abc"},
"-1234": {},
"sha256-": {},
"sha256-1234-5678": {},
"sha256-P": {}, // invalid hex
"sha256-1234P": {},
"---": {},
}
func TestDigestParse(t *testing.T) {
// Test cases.
for s, want := range testDigests {
got := ParseDigest(s)
t.Logf("ParseDigest(%q) = %#v", s, got)
if got != want {
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
}
}
}
func TestDigestString(t *testing.T) {
// Test cases.
for s, d := range testDigests {
want := s
if !d.IsValid() {
want = ""
}
got := d.String()
if got != want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
}
got = ParseDigest(s).String()
if got != want {
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
}
}
}

View File

@@ -1,679 +1,313 @@
// Package model contains types and utilities for parsing, validating, and
// working with model names and digests.
package model
import (
"cmp"
"errors"
"fmt"
"hash/maphash"
"io"
"log/slog"
"path/filepath"
"slices"
"strings"
"sync"
"github.com/ollama/ollama/types/structs"
)
// Errors
var (
// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
// used by this package, but are exported so that other packages can
// use them, instead of defining their own errors for them.
ErrInvalidName = errors.New("invalid model name")
ErrIncompleteName = errors.New("incomplete model name")
ErrInvalidDigest = errors.New("invalid digest")
// ErrUnqualifiedName represents an error where a name is not fully
// qualified. It is not used directly in this package, but is here
// to avoid other packages inventing their own error type.
// Additionally, it can be conveniently used via [Unqualified].
ErrUnqualifiedName = errors.New("unqualified name")
)
// Defaults
const (
// MaskDefault is the default mask used by [Name.DisplayShortest].
MaskDefault = "registry.ollama.ai/library/?:latest"
// MaskNothing is a mask that masks nothing.
MaskNothing = "?/?/?:?"
// DefaultFill is the default fill used by [ParseName].
FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
// FillNothing is a fill that fills nothing.
FillNothing = "?/?/?:?+?"
)
const MaxNamePartLen = 128
type PartKind int
// Levels of concreteness
const (
// Each value aligns with its index in the Name.parts array.
PartHost PartKind = iota
PartNamespace
PartModel
PartTag
PartBuild
PartDigest
// NumParts is the number of parts in a Name. In this list, it must
// follow the final part.
NumParts
PartExtraneous = -1
)
var kindNames = map[PartKind]string{
PartHost: "Host",
PartNamespace: "Namespace",
PartModel: "Name",
PartTag: "Tag",
PartBuild: "Build",
PartDigest: "Digest",
// Unqualified is a helper function that returns an error with
// ErrUnqualifiedName as the cause and the name as the message.
func Unqualified(n Name) error {
return fmt.Errorf("%w: %s", ErrUnqualifiedName, n)
}
func (k PartKind) String() string {
return cmp.Or(kindNames[k], "Unknown")
// MissingPart is used to indicate any part of a name that was "promised" by
// the presence of a separator, but is missing.
//
// The value was chosen because it is deemed unlikely to be set by a user,
// not a valid part name valid when checked by [Name.IsValid], and easy to
// spot in logs.
const MissingPart = "!MISSING!"
// DefaultName returns a name with the default values for the host, namespace,
// and tag parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
func DefaultName() Name {
return Name{
Host: "registry.ollama.ai",
Namespace: "library",
Tag: "latest",
}
}
// Name is an opaque reference to a model. It holds the parts of a model
// with the case preserved, but is not directly comparable with other Names
// since model names can be represented with different casing depending on
// the use case. For instance, "Mistral" and "mistral" are the same model
// but each version may have come from different sources (e.g. copied from a
// Web page, or from a file path).
type partKind int
const (
kindHost partKind = iota
kindNamespace
kindModel
kindTag
kindDigest
)
func (k partKind) String() string {
switch k {
case kindHost:
return "host"
case kindNamespace:
return "namespace"
case kindModel:
return "model"
case kindTag:
return "tag"
case kindDigest:
return "digest"
default:
return "unknown"
}
}
// Name is a structured representation of a model name string, as defined by
// [ParseNameNoDefaults].
//
// Valid Names can ONLY be constructed by calling [ParseName].
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
//
// A Name is valid if and only if is have a valid Model part. The other parts
// are optional.
//
// A Name is considered "complete" if it has all parts present. To check if a
// Name is complete, use [Name.IsComplete].
//
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
//
// The parts of a Name are:
//
// - Host: the domain of the model (optional)
// - Namespace: the namespace of the model (optional)
// - Model: the name of the model (required)
// - Tag: the tag of the model (optional)
// - Build: the build of the model; usually the quantization or "file type" (optional)
//
// The parts can be obtained in their original form by calling [Name.Parts].
//
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
// It is not directly comparable with other Names. Use [Name.Equal] and
// [Name.MapHash] for determining equality and using as a map key.
type Name struct {
_ structs.Incomparable
parts [NumParts]string // host, namespace, model, tag, build, digest
// TODO(bmizerany): track offsets and hold s (raw string) here? We
// could pack the offsets all into a single uint64 since the first
// parts take less bits since their max offset is less than the max
// offset of the next part. This would save a ton of bytes per Name
// and mean zero allocations for String.
Host string
Namespace string
Model string
Tag string
RawDigest string
}
// ParseName parses s into a Name, and returns the result of filling it with
// defaults. The input string must be a valid string
// representation of a model name in the form:
// ParseName parses and assembles a Name from a name string. The
// format of a valid name string is:
//
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
// s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: alphanum { alphanum | "-" | "_" }*
// length: [2, 80]
// model:
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// length: [2, 80]
// tag:
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: alphanum { alphanum | "-" | ":" }*
// length: [2, 80]
//
// The name part is required, all others are optional. If a part is missing,
// it is left empty in the returned Name. If a part is invalid, the zero Ref
// value is returned.
// Most users should use [ParseName] instead, unless need to support
// different defaults than DefaultName.
//
// The build part is normalized to uppercase.
//
// Examples of valid paths:
//
// "example.com/library/mistral:7b+x"
// "example.com/eva/mistral:7b+Q4_0"
// "mistral:7b+x"
// "example.com/mike/mistral:latest+Q4_0"
// "example.com/bruce/mistral:latest"
// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
//
// Examples of invalid paths:
//
// "example.com/mistral:7b+"
// "example.com/mistral:7b+Q4_0+"
// "x/y/z/z:8n+I"
// ""
//
// It returns the zero value if any part is invalid.
//
// # Fills
//
// For any valid s, the fill string is used to fill in missing parts of the
// Name. The fill string must be a valid Name with the exception that any part
// may be the string ("?"), which will not be considered for filling.
func ParseName(s, fill string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if kind == PartDigest && !ParseDigest(part).IsValid() {
r = Name{}
return false
}
if kind == PartExtraneous || !isValidPart(kind, part) {
r = Name{}
return false
}
r.parts[kind] = part
return true
})
if r.IsValid() || r.IsResolved() {
return fillName(r, fill)
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
// if the name is valid.
func ParseName(s string) Name {
return Merge(ParseNameBare(s), DefaultName())
}
// ParseNameBare parses s as a name string and returns a Name. No merge with
// [DefaultName] is performed.
func ParseNameBare(s string) Name {
var n Name
var promised bool
s, n.RawDigest, promised = cutLast(s, "@")
if promised && n.RawDigest == "" {
n.RawDigest = MissingPart
}
return Name{}
}
func parseMask(s string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if part == "?" {
// mask part; treat as empty but valid
return true
}
if !isValidPart(kind, part) {
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
}
r.parts[kind] = part
return true
})
return r
}
func MustParseName(s, defaults string) Name {
r := ParseName(s, "")
if !r.IsValid() {
panic("invalid Name: " + s)
s, n.Tag, _ = cutPromised(s, ":")
s, n.Model, promised = cutPromised(s, "/")
if !promised {
n.Model = s
return n
}
return r
}
// fillName fills in the missing parts of dst with the parts of src.
//
// The returned Name will only be valid if dst is valid.
//
// It skipps fill parts that are "?".
func fillName(r Name, fill string) Name {
fill = cmp.Or(fill, FillDefault)
f := parseMask(fill)
if fill != FillNothing && f.IsZero() {
panic("invalid fill")
s, n.Namespace, promised = cutPromised(s, "/")
if !promised {
n.Namespace = s
return n
}
for i := range r.parts {
if f.parts[i] == "?" {
continue
}
r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
n.Host = s
return n
}
// Merge merges the host, namespace, and tag parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
return a
}
// String returns the name string, in the format that [ParseNameNoDefaults]
// accepts as valid, if [Name.IsValid] reports true; otherwise the empty
// string is returned.
func (n Name) String() string {
var b strings.Builder
if n.Host != "" {
b.WriteString(n.Host)
b.WriteByte('/')
}
return r
}
// WithBuild returns a copy of r with the build set to the given string.
func (r Name) WithBuild(build string) Name {
r.parts[PartBuild] = build
return r
}
func (r Name) WithDigest(digest Digest) Name {
r.parts[PartDigest] = digest.String()
return r
}
var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality
// checks. For a convenient way to compare names, use [Name.EqualFold].
//
//nolint:errcheck
func (r Name) MapHash() uint64 {
// correctly hash the parts with case insensitive comparison
var h maphash.Hash
h.SetSeed(mapHashSeed)
for _, part := range r.parts {
// downcase the part for hashing
for i := range part {
c := part[i]
if c >= 'A' && c <= 'Z' {
c = c - 'A' + 'a'
}
h.WriteByte(c)
}
if n.Namespace != "" {
b.WriteString(n.Namespace)
b.WriteByte('/')
}
return h.Sum64()
}
func (r Name) slice(from, to PartKind) Name {
var v Name
copy(v.parts[from:to+1], r.parts[from:to+1])
return v
}
// DisplayShortest returns the shortest possible, masked display string in form:
//
// [host/][<namespace>/]<model>[:<tag>]
//
// # Masks
//
// The mask is a string that specifies which parts of the name to omit based
// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
// that are the same as the mask, moving from left to right until the first
// unequal part is found. It then moves right to left until the first unequal
// part is found. The result is the shortest possible display string.
//
// Unlike a [Name] the mask can contain "?" characters which are treated as
// wildcards. A "?" will never match a part of the name, since a valid name
// can never contain a "?" character.
//
// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
// with ("registry.ollama.ai/library/?:latest") will produce the display string
// ("mistral").
//
// If mask is the empty string, then [MaskDefault] is used.
//
// DisplayShortest panics if the mask is not the empty string, MaskNothing, and
// invalid.
//
// # Builds
//
// For now, DisplayShortest does consider the build or return one in the
// result. We can lift this restriction when needed.
func (r Name) DisplayShortest(mask string) string {
mask = cmp.Or(mask, MaskDefault)
d := parseMask(mask)
if mask != MaskNothing && r.IsZero() {
panic("invalid Name")
b.WriteString(n.Model)
if n.Tag != "" {
b.WriteByte(':')
b.WriteString(n.Tag)
}
for i := range PartTag {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
if n.RawDigest != "" {
b.WriteByte('@')
b.WriteString(n.RawDigest)
}
for i := PartTag; i >= 0; i-- {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
}
return r.slice(PartHost, PartTag).DisplayLong()
}
// DisplayLongest returns the result of r.DisplayShortest(MaskNothing).
func (r Name) DisplayLongest() string {
return r.DisplayShortest(MaskNothing)
}
var seps = [...]string{
PartHost: "/",
PartNamespace: "/",
PartModel: ":",
PartTag: "+",
PartBuild: "@",
PartDigest: "",
}
// WriteTo implements io.WriterTo. It writes the fullest possible display
// string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
//
// Missing parts and their separators are not written.
//
// The full digest is always prefixed with "@". That is if [Name.IsValid]
// reports false and [Name.IsResolved] reports true, then the string is
// returned as "@<digest-type>-<digest>".
func (r Name) writeTo(w io.StringWriter) error {
var partsWritten int
for i := range r.parts {
if r.parts[i] == "" {
continue
}
if partsWritten > 0 || i == int(PartDigest) {
if _, err := w.WriteString(seps[i-1]); err != nil {
return err
}
}
if _, err := w.WriteString(r.parts[i]); err != nil {
return err
}
partsWritten++
}
return nil
}
var builderPool = sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
}
// DisplayLong returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayLong() string {
b := builderPool.Get().(*strings.Builder)
defer builderPool.Put(b)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
_ = r.writeTo(b)
return b.String()
}
// GoString implements fmt.GoStringer. It returns a string suitable for
// debugging and logging. It is similar to [Name.DisplayLong] but it always
// returns a string that includes all parts of the Name, with missing parts
// replaced with a ("?").
func (r Name) GoString() string {
for i := range r.parts {
r.parts[i] = cmp.Or(r.parts[i], "?")
}
return r.DisplayLong()
}
// LogValue implements slog.Valuer.
func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString())
}
// IsComplete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) IsComplete() bool {
return !slices.Contains(r.parts[:PartDigest], "")
}
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
// build part to be present.
func (r Name) IsCompleteNoBuild() bool {
return !slices.Contains(r.parts[:PartBuild], "")
}
// IsResolved reports true if the Name has a valid digest.
//
// It is possible to have a valid Name, or a complete Name that is not
// resolved.
func (r Name) IsResolved() bool {
return r.Digest().IsValid()
}
// Digest returns the digest part of the Name, if any.
//
// If Digest returns a non-empty string, then [Name.IsResolved] will return
// true, and digest is considered valid.
func (r Name) Digest() Digest {
// This was already validated by ParseName, so we can just return it.
return Digest{r.parts[PartDigest]}
}
// EqualFold reports whether r and o are equivalent model names, ignoring
// case.
func (r Name) EqualFold(o Name) bool {
return r.CompareFold(o) == 0
}
// CompareFold performs a case-insensitive cmp.Compare on r and o.
//
// This can be used with [slices.SortFunc].
//
// For simple equality checks, use [Name.EqualFold].
func (r Name) CompareFold(o Name) int {
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
}
func compareFold(a, b string) int {
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
return cmp.Compare(downcase(a), downcase(b))
})
}
func downcase(r rune) rune {
if r >= 'A' && r <= 'Z' {
return r - 'A' + 'a'
}
return r
}
func (r Name) Host() string { return r.parts[PartHost] }
func (r Name) Namespace() string { return r.parts[PartNamespace] }
func (r Name) Model() string { return r.parts[PartModel] }
func (r Name) Build() string { return r.parts[PartBuild] }
func (r Name) Tag() string { return r.parts[PartTag] }
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
// restrictions in the go1.22 iter package requiring the
// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
// which we are not yet ready to support.
//
// Once we are ready to support rangefunc, this can be removed and replaced
// with the iter.Seq2 type.
type iter_Seq2[A, B any] func(func(A, B) bool)
// Parts returns a sequence of the parts of a Name string from most specific
// to least specific.
//
// It normalizes the input string by removing "http://" and "https://" only.
// No other normalizations are performed.
func parts(s string) iter_Seq2[PartKind, string] {
return func(yield func(PartKind, string) bool) {
if strings.HasPrefix(s, "http://") {
s = strings.TrimPrefix(s, "http://")
} else {
s = strings.TrimPrefix(s, "https://")
}
if len(s) > MaxNamePartLen || len(s) == 0 {
return
}
numConsecutiveDots := 0
partLen := 0
state, j := PartDigest, len(s)
for i := len(s) - 1; i >= 0; i-- {
if partLen++; partLen > MaxNamePartLen {
// catch a part that is too long early, so
// we don't keep spinning on it, waiting for
// an isInValidPart check which would scan
// over it again.
yield(state, s[i+1:j])
return
}
switch s[i] {
case '@':
switch state {
case PartDigest:
if !yield(PartDigest, s[i+1:j]) {
return
}
if i == 0 {
// This is the form
// "@<digest>" which is valid.
//
// We're done.
return
}
state, j, partLen = PartBuild, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
case '+':
switch state {
case PartBuild, PartDigest:
if !yield(PartBuild, s[i+1:j]) {
return
}
state, j, partLen = PartTag, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
case ':':
switch state {
case PartTag, PartBuild, PartDigest:
if !yield(PartTag, s[i+1:j]) {
return
}
state, j, partLen = PartModel, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
case '/':
switch state {
case PartModel, PartTag, PartBuild, PartDigest:
if !yield(PartModel, s[i+1:j]) {
return
}
state, j = PartNamespace, i
case PartNamespace:
if !yield(PartNamespace, s[i+1:j]) {
return
}
state, j, partLen = PartHost, i, 0
default:
yield(PartExtraneous, s[i+1:j])
return
}
default:
if s[i] == '.' {
if numConsecutiveDots++; numConsecutiveDots > 1 {
yield(state, "")
return
}
} else {
numConsecutiveDots = 0
}
}
}
if state <= PartNamespace {
yield(state, s[:j])
} else {
yield(PartModel, s[:j])
}
}
}
func (r Name) IsZero() bool {
return r.parts == [NumParts]string{}
}
// IsValid reports if a model has at minimum a valid model part.
func (r Name) IsValid() bool {
// Parts ensures we only have valid parts, so no need to validate
// them here, only check if we have a name or not.
return r.parts[PartModel] != ""
}
// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically,
// it trims any leading "/" and then calls [ParseName] with fill.
func ParseNameFromURLPath(s, fill string) Name {
s = strings.TrimPrefix(s, "/")
return ParseName(s, fill)
}
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
// complete Name.
//
// The parts maintain their original case.
//
// Example:
//
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
func (r Name) URLPath() string {
return r.DisplayShortest(MaskNothing)
}
// ParseNameFromFilepath parses a file path into a Name. The input string must be a
// valid file path representation of a model name in the form:
//
// host/namespace/model/tag/build
//
// The zero valid is returned if s does not contain all path elements
// leading up to the model part, or if any path element is an invalid part
// for the its corresponding part kind.
//
// The fill string is used to fill in missing parts of any constructed Name.
// See [ParseName] for more information on the fill string.
func ParseNameFromFilepath(s, fill string) Name {
var r Name
for i := range PartBuild + 1 {
part, rest, _ := strings.Cut(s, string(filepath.Separator))
if !isValidPart(i, part) {
return Name{}
}
r.parts[i] = part
s = rest
if s == "" {
break
}
}
if s != "" {
return Name{}
}
if !r.IsValid() {
return Name{}
}
return fillName(r, fill)
}
// Filepath returns a complete, canonicalized, relative file path using the
// parts of a complete Name.
//
// Each parts is downcased, except for the build part which is upcased.
//
// Example:
//
// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD"
func (r Name) Filepath() string {
for i := range r.parts {
if PartKind(i) == PartBuild {
r.parts[i] = strings.ToUpper(r.parts[i])
} else {
r.parts[i] = strings.ToLower(r.parts[i])
}
}
return filepath.Join(r.parts[:]...)
}
// isValidPart reports if s contains all valid characters for the given
// part kind.
func isValidPart(kind PartKind, s string) bool {
if s == "" {
// IsValid reports whether all parts of the name are present and valid. The
// digest is a special case, and is checked for validity only if present.
func (n Name) IsValid() bool {
if n.RawDigest != "" && !isValidPart(kindDigest, n.RawDigest) {
return false
}
var consecutiveDots int
for _, c := range []byte(s) {
if c == '.' {
if consecutiveDots++; consecutiveDots >= 2 {
return false
}
} else {
consecutiveDots = 0
}
if !isValidByteFor(kind, c) {
return n.IsFullyQualified()
}
// IsFullyQualified returns true if all parts of the name are present and
// valid without the digest.
func (n Name) IsFullyQualified() bool {
var parts = []string{
n.Host,
n.Namespace,
n.Model,
n.Tag,
}
for i, part := range parts {
if !isValidPart(partKind(i), part) {
return false
}
}
return true
}
func isValidByteFor(kind PartKind, c byte) bool {
if kind == PartNamespace && c == '.' {
// Filepath returns a canonical filepath that represents the name with each part from
// host to tag as a directory in the form:
//
// {host}/{namespace}/{model}/{tag}
//
// It uses the system's filepath separator and ensures the path is clean.
//
// It panics if the name is not fully qualified. Use [Name.IsFullyQualified]
// to check if the name is fully qualified.
func (n Name) Filepath() string {
if !n.IsFullyQualified() {
panic("illegal attempt to get filepath of invalid name")
}
return filepath.Join(
strings.ToLower(n.Host),
strings.ToLower(n.Namespace),
strings.ToLower(n.Model),
strings.ToLower(n.Tag),
)
}
// LogValue returns a slog.Value that represents the name as a string.
func (n Name) LogValue() slog.Value {
return slog.StringValue(n.String())
}
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:
return len(s) >= 1 && len(s) <= 350
case kindTag:
return len(s) >= 1 && len(s) <= 80
default:
return len(s) >= 2 && len(s) <= 80
}
}
func isValidPart(kind partKind, s string) bool {
if !isValidLen(kind, s) {
return false
}
if c == '.' || c == '-' {
return true
for i := range s {
if i == 0 {
if !isAlphanumeric(s[i]) {
return false
}
continue
}
switch s[i] {
case '_', '-':
case '.':
if kind == kindNamespace {
return false
}
case ':':
if kind != kindHost && kind != kindDigest {
return false
}
default:
if !isAlphanumeric(s[i]) {
return false
}
}
}
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
return true
}
return false
return true
}
func isAlphanumeric(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
}
func cutLast(s, sep string) (before, after string, ok bool) {
i := strings.LastIndex(s, sep)
if i >= 0 {
return s[:i], s[i+len(sep):], true
}
return s, "", false
}
// cutPromised cuts the last part of s at the last occurrence of sep. If sep is
// found, the part before and after sep are returned as-is unless empty, in
// which case they are returned as MissingPart, which will cause
// [Name.IsValid] to return false.
func cutPromised(s, sep string) (before, after string, ok bool) {
before, after, ok = cutLast(s, sep)
if !ok {
return before, after, false
}
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
}

View File

@@ -1,690 +1,237 @@
package model
import (
"bytes"
"cmp"
"fmt"
"log/slog"
"path/filepath"
"slices"
"strings"
"reflect"
"testing"
)
type fields struct {
host, namespace, model, tag, build string
digest string
}
const (
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
)
func fieldsFromName(p Name) fields {
return fields{
host: p.parts[PartHost],
namespace: p.parts[PartNamespace],
model: p.parts[PartModel],
tag: p.parts[PartTag],
build: p.parts[PartBuild],
digest: p.parts[PartDigest],
}
}
var testNames = map[string]fields{
"mistral:latest": {model: "mistral", tag: "latest"},
"mistral": {model: "mistral"},
"mistral:30B": {model: "mistral", tag: "30B"},
"mistral:7b": {model: "mistral", tag: "7b"},
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
"mistral+KQED": {model: "mistral", build: "KQED"},
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
"llama2": {model: "llama2"},
"user/model": {namespace: "user", model: "model"},
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
// invalid digest
"mistral:latest@invalid256-": {},
"mistral:latest@-123": {},
"mistral:latest@!-123": {},
"mistral:latest@1-!": {},
"mistral:latest@": {},
// resolved
"x@sha123-1": {model: "x", digest: "sha123-1"},
"@sha456-2": {digest: "sha456-2"},
"@@sha123-1": {},
// preserves case for build
"x+b": {model: "x", build: "b"},
// invalid (includes fuzzing trophies)
" / / : + ": {},
" / : + ": {},
" : + ": {},
" + ": {},
" : ": {},
" / ": {},
" /": {},
"/ ": {},
"/": {},
":": {},
"+": {},
// (".") in namepsace is not allowed
"invalid.com/7b+x": {},
"invalid:7b+Q4_0:latest": {},
"in valid": {},
"invalid/y/z/foo": {},
"/0": {},
"0 /0": {},
"0 /": {},
"0/": {},
":/0": {},
"+0/00000": {},
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
"0//0": {},
"m+^^^": {},
"file:///etc/passwd": {},
"file:///etc/passwd:latest": {},
"file:///etc/passwd:latest+u": {},
":x": {},
"+x": {},
"x+": {},
// Disallow ("\.+") in any part to prevent path traversal anywhere
// we convert the name to a path.
"../etc/passwd": {},
".../etc/passwd": {},
"./../passwd": {},
"./0+..": {},
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
strings.Repeat("a", MaxNamePartLen+1): {},
}
// TestConsecutiveDots tests that consecutive dots are not allowed in any
// part, to avoid path traversal. There also are some tests in testNames, but
// this test is more exhaustive and exists to emphasize the importance of
// preventing path traversal.
func TestNameConsecutiveDots(t *testing.T) {
for i := 1; i < 10; i++ {
s := strings.Repeat(".", i)
if i > 1 {
if g := ParseName(s, FillNothing).DisplayLong(); g != "" {
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
}
} else {
if g := ParseName(s, FillNothing).DisplayLong(); g != s {
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
}
}
}
}
func TestNameParts(t *testing.T) {
var p Name
if w, g := int(NumParts), len(p.parts); w != g {
t.Errorf("Parts() = %d; want %d", g, w)
}
}
func TestNamePartString(t *testing.T) {
if g := PartKind(-2).String(); g != "Unknown" {
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
}
for kind, name := range kindNames {
if g := kind.String(); g != name {
t.Errorf("%s = %q; want %q", kind, g, name)
}
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
// We should get the same results with or without the
// http(s) prefixes
s := prefix + baseName
t.Run(s, func(t *testing.T) {
name := ParseName(s, FillNothing)
got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) {
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName)
}
})
}
}
}
func TestParseNameFill(t *testing.T) {
cases := []struct {
in string
fill string
want string
}{
{"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
{"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
{"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
// Invalid
{"", "example.com/library/?:latest+Q4_0", ""},
{"llama2:?", "example.com/library/?:latest+Q4_0", ""},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
name := ParseName(tt.in, tt.fill)
if g := name.DisplayLong(); g != tt.want {
t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
}
})
}
t.Run("invalid fill", func(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("expected panic")
}
}()
ParseName("x", "^")
})
}
func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
cases := []string{
"http://https://valid.com/valid/valid:latest",
"https://http://valid.com/valid/valid:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
name := ParseName(s, FillNothing)
if name.IsValid() {
t.Errorf("expected invalid path; got %#v", name)
}
})
}
}
func TestCompleteWithAndWithoutBuild(t *testing.T) {
func TestParseNameParts(t *testing.T) {
cases := []struct {
in string
complete bool
completeNoBuild bool
want Name
wantValidDigest bool
}{
{"", false, false},
{"incomplete/mistral:7b+x", false, false},
{"incomplete/mistral:7b+Q4_0", false, false},
{"incomplete:7b+x", false, false},
{"complete.com/x/mistral:latest+Q4_0", true, true},
{"complete.com/x/mistral:latest", false, true},
{
in: "host/namespace/model:tag",
want: Name{
Host: "host",
Namespace: "namespace",
Model: "model",
Tag: "tag",
},
},
{
in: "host/namespace/model",
want: Name{
Host: "host",
Namespace: "namespace",
Model: "model",
},
},
{
in: "namespace/model",
want: Name{
Namespace: "namespace",
Model: "model",
},
},
{
in: "model",
want: Name{
Model: "model",
},
},
{
in: "h/nn/mm:t",
want: Name{
Host: "h",
Namespace: "nn",
Model: "mm",
Tag: "t",
},
},
{
in: part80 + "/" + part80 + "/" + part80 + ":" + part80,
want: Name{
Host: part80,
Namespace: part80,
Model: part80,
Tag: part80,
},
},
{
in: part350 + "/" + part80 + "/" + part80 + ":" + part80,
want: Name{
Host: part350,
Namespace: part80,
Model: part80,
Tag: part80,
},
},
{
in: "@digest",
want: Name{
RawDigest: "digest",
},
wantValidDigest: false,
},
{
in: "model@sha256:123",
want: Name{
Model: "model",
RawDigest: "sha256:123",
},
wantValidDigest: true,
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.IsComplete(); g != tt.complete {
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
}
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
}
})
}
// Complete uses Parts which returns a slice, but it should be
// inlined when used in Complete, preventing any allocations or
// escaping to the heap.
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
})
if allocs > 0 {
t.Errorf("Complete allocs = %v; want 0", allocs)
}
}
func TestNameLogValue(t *testing.T) {
cases := []string{
"example.com/library/mistral:latest+Q4_0",
"mistral:latest",
"mistral:7b+Q4_0",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
var b bytes.Buffer
log := slog.New(slog.NewTextHandler(&b, nil))
name := ParseName(s, FillNothing)
log.Info("", "name", name)
want := fmt.Sprintf("name=%s", name.GoString())
got := b.String()
if !strings.Contains(got, want) {
t.Errorf("expected log output to contain %q; got %q", want, got)
got := ParseNameBare(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
}
})
}
}
func TestNameGoString(t *testing.T) {
var testCases = map[string]bool{ // name -> valid
"host/namespace/model:tag": true,
"host/namespace/model": false,
"namespace/model": false,
"model": false,
"@sha256-1000000000000000000000000000000000000000000000000000000000000000": false,
"model@sha256-1000000000000000000000000000000000000000000000000000000000000000": false,
"model@sha256:1000000000000000000000000000000000000000000000000000000000000000": false,
// long (but valid)
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
"m": false, // model too short
"n/mm:": false, // namespace too short
"h/n/mm:t": false, // namespace too short
"@t": false, // digest too short
"mm@d": false, // digest too short
// invalids
"^": false,
"mm:": false,
"/nn/mm": false,
"//": false,
"//mm": false,
"hh//": false,
"//mm:@": false,
"00@": false,
"@": false,
// not starting with alphanum
"-hh/nn/mm:tt@dd": false,
"hh/-nn/mm:tt@dd": false,
"hh/nn/-mm:tt@dd": false,
"hh/nn/mm:-tt@dd": false,
"hh/nn/mm:tt@-dd": false,
"": false,
// hosts
"host:https/namespace/model:tag": true,
// colon in non-host part before tag
"host/name:space/model:tag": false,
}
func TestNameparseNameDefault(t *testing.T) {
const name = "xx"
n := ParseName(name)
got := n.String()
want := "registry.ollama.ai/library/xx:latest"
if got != want {
t.Errorf("parseName(%q).String() = %q; want %q", name, got, want)
}
}
func TestNameIsValid(t *testing.T) {
var numStringTests int
for s, want := range testCases {
n := ParseNameBare(s)
t.Logf("n: %#v", n)
got := n.IsValid()
if got != want {
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
}
// Test roundtrip with String
if got {
got := ParseNameBare(s).String()
if got != s {
t.Errorf("parseName(%q).String() = %q; want %q", s, got, s)
}
numStringTests++
}
}
if numStringTests == 0 {
t.Errorf("no tests for Name.String")
}
}
func TestNameIsValidPart(t *testing.T) {
cases := []struct {
name string
in string
wantString string
wantGoString string // default is tt.in
kind partKind
s string
want bool
}{
{
name: "Complete Name",
in: "example.com/library/mistral:latest+Q4_0",
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
},
{
name: "Short Name",
in: "mistral:latest",
wantGoString: "?/?/mistral:latest+?@?",
},
{
name: "Long Name",
in: "library/mistral:latest",
wantGoString: "?/library/mistral:latest+?@?",
},
{
name: "Case Preserved",
in: "Library/Mistral:Latest",
wantGoString: "?/Library/Mistral:Latest+?@?",
},
{
name: "With digest",
in: "Library/Mistral:Latest@sha256-123456",
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
},
{kind: kindHost, s: "", want: false},
{kind: kindHost, s: "a", want: true},
{kind: kindHost, s: "a.", want: true},
{kind: kindHost, s: "a.b", want: true},
{kind: kindHost, s: "a:123", want: true},
{kind: kindHost, s: "a:123/aa/bb", want: false},
{kind: kindNamespace, s: "bb", want: true},
{kind: kindNamespace, s: "a.", want: false},
{kind: kindModel, s: "-h", want: false},
{kind: kindDigest, s: "sha256-1000000000000000000000000000000000000000000000000000000000000000", want: true},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
t.Run(tt.s, func(t *testing.T) {
got := isValidPart(tt.kind, tt.s)
if got != tt.want {
t.Errorf("isValidPart(%s, %q) = %v; want %v", tt.kind, tt.s, got, tt.want)
}
})
}
}
func TestDisplayShortest(t *testing.T) {
cases := []struct {
in string
mask string
want string
wantPanic bool
}{
{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
// case-insensitive
{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
// zero value
{"", MaskDefault, "", true},
// invalid mask
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
// DefaultMask
{"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
// Auto-Fill
{"x", "example.com/library/_:latest", "x", false},
{"x", "example.com/library/_:latest+Q4_0", "x", false},
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
func FuzzName(f *testing.F) {
for s := range testCases {
f.Add(s)
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
defer func() {
if tt.wantPanic {
if recover() == nil {
t.Errorf("expected panic")
}
f.Fuzz(func(t *testing.T, s string) {
n := ParseNameBare(s)
if n.IsValid() {
parts := [...]string{n.Host, n.Namespace, n.Model, n.Tag, n.RawDigest}
for _, part := range parts {
if part == ".." {
t.Errorf("unexpected .. as valid part")
}
if len(part) > 350 {
t.Errorf("part too long: %q", part)
}
}()
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.DisplayShortest(tt.mask); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
})
}
}
func TestParseNameAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
})
if allocs > 0 {
t.Errorf("ParseName allocs = %v; want 0", allocs)
}
}
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
}
}
func FuzzParseNameFromFilepath(f *testing.F) {
f.Add("example.com/library/mistral/7b/Q4_0")
f.Add("example.com/../mistral/7b/Q4_0")
f.Add("example.com/x/../7b/Q4_0")
f.Add("example.com/x/../7b")
f.Fuzz(func(t *testing.T, s string) {
name := ParseNameFromFilepath(s, FillNothing)
if strings.Contains(s, "..") && !name.IsZero() {
t.Fatalf("non-zero value for path with '..': %q", s)
}
if name.IsValid() == name.IsZero() {
t.Errorf("expected valid path to be non-zero value; got %#v", name)
if n.String() != s {
t.Errorf("String() = %q; want %q", n.String(), s)
}
}
})
}
func FuzzParseName(f *testing.F) {
f.Add("example.com/mistral:7b+Q4_0")
f.Add("example.com/mistral:7b+q4_0")
f.Add("example.com/mistral:7b+x")
f.Add("x/y/z:8n+I")
f.Add(":x")
f.Add("@sha256-123456")
f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
f.Add(":@!@")
f.Add("...")
f.Fuzz(func(t *testing.T, s string) {
r0 := ParseName(s, FillNothing)
if strings.Contains(s, "..") && !r0.IsZero() {
t.Fatalf("non-zero value for path with '..': %q", s)
}
if !r0.IsValid() && !r0.IsResolved() {
if !r0.EqualFold(Name{}) {
t.Errorf("expected invalid path to be zero value; got %#v", r0)
}
t.Skipf("invalid path: %q", s)
}
for _, p := range r0.parts {
if len(p) > MaxNamePartLen {
t.Errorf("part too long: %q", p)
}
}
if !strings.EqualFold(r0.DisplayLong(), s) {
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s)
}
r1 := ParseName(r0.DisplayLong(), FillNothing)
if !r0.EqualFold(r1) {
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
}
})
}
func TestNameStringAllocs(t *testing.T) {
name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
allocs := testing.AllocsPerRun(1000, func() {
keep(name.DisplayLong())
})
if allocs > 1 {
t.Errorf("String allocs = %v; want 0", allocs)
}
}
func TestNamePath(t *testing.T) {
cases := []struct {
in string
want string
}{
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"},
// incomplete
{"example.com/library/mistral:latest", "example.com/library/mistral:latest"},
{"", ""},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.URLPath(); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
})
}
}
func TestNameFromFilepath(t *testing.T) {
cases := []struct {
in string
want string
}{
{
in: "example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
},
{
in: "Example.Com/Library/Mistral:Latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
},
{
in: "Example.Com/Library/Mistral:Latest+Q4_0",
want: "example.com/library/mistral/latest/Q4_0",
},
{
in: "example.com/library/mistral:latest",
want: "example.com/library/mistral/latest",
},
{
in: "",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
g := p.Filepath()
g = filepath.ToSlash(g)
if g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
}
})
}
}
func TestParseNameFilepath(t *testing.T) {
cases := []struct {
in string
fill string // default is FillNothing
want string
}{
{
in: "example.com/library/mistral/latest/Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library/mistral/latest",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library/mistral",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "example.com/library",
want: "",
},
{
in: "example.com/",
want: "",
},
{
in: "example.com/^/mistral/latest/Q4_0",
want: "",
},
{
in: "example.com/library/mistral/../Q4_0",
want: "",
},
{
in: "example.com/library/mistral/latest/Q4_0/extra",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator))
fill := cmp.Or(tt.fill, FillNothing)
want := ParseName(tt.want, fill)
if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) {
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
}
})
}
}
func TestParseNameFromPath(t *testing.T) {
cases := []struct {
in string
want string
fill string // default is FillNothing
}{
{
in: "example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library/mistral:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library/mistral",
want: "example.com/library/mistral",
},
{
in: "/example.com/library/mistral",
fill: "?/?/?:latest+Q4_0",
want: "example.com/library/mistral:latest+Q4_0",
},
{
in: "/example.com/library",
want: "",
},
{
in: "/example.com/",
want: "",
},
{
in: "/example.com/^/mistral/latest",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
fill := cmp.Or(tt.fill, FillNothing)
if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want {
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
}
})
}
}
func ExampleName_MapHash() {
m := map[uint64]bool{}
// key 1
m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
// key 2
m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
fmt.Println(len(m))
// Output:
// 2
}
func ExampleName_CompareFold_sort() {
names := []Name{
ParseName("mistral:latest", FillNothing),
ParseName("mistRal:7b+q4", FillNothing),
ParseName("MIstral:7b", FillNothing),
}
slices.SortFunc(names, Name.CompareFold)
for _, n := range names {
fmt.Println(n.DisplayLong())
}
// Output:
// MIstral:7b
// mistRal:7b+q4
// mistral:latest
}
func ExampleName_completeAndResolved() {
for _, s := range []string{
"x/y/z:latest+q4_0@sha123-1",
"x/y/z:latest+q4_0",
"@sha123-1",
} {
name := ParseName(s, FillNothing)
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
}
// Output:
// complete:true resolved:true digest:sha123-1
// complete:true resolved:false digest:
// complete:false resolved:true digest:sha123-1
}
func ExampleName_DisplayShortest() {
name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
fmt.Println(name.DisplayShortest("example.com/_/_:_"))
fmt.Println(name.DisplayShortest("_/_/_:_"))
// Default
name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest(""))
// Output:
// mistral
// jmorganca/mistral
// jmorganca/mistral:latest
// example.com/jmorganca/mistral:latest
// mistral
}
func keep[T any](v T) T { return v }

View File

@@ -1,2 +1,2 @@
go test fuzz v1
string("/0")
string("00@")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0//0")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0 /0")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("+0/00000")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string(":")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")