mirror of
https://github.com/ollama/ollama.git
synced 2025-12-27 01:30:39 -05:00
Compare commits
254 Commits
modelfile-
...
native
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
201a987ff9 | ||
|
|
2d8125042a | ||
|
|
776e7bb5e4 | ||
|
|
b8d7ca1a7b | ||
|
|
2bed62926e | ||
|
|
aad8d128a0 | ||
|
|
ec1acbb867 | ||
|
|
e4859c4563 | ||
|
|
8e30eb26bd | ||
|
|
0b5c589ca2 | ||
|
|
65fadddc85 | ||
|
|
ed5fb088c4 | ||
|
|
f81f308118 | ||
|
|
b1390a7b37 | ||
|
|
11d83386a5 | ||
|
|
bb31def011 | ||
|
|
41e03ede95 | ||
|
|
7fea1ecdf6 | ||
|
|
054894271d | ||
|
|
6fef042f0b | ||
|
|
5c0c2d1d09 | ||
|
|
37f9c8ad99 | ||
|
|
2a80f55e2a | ||
|
|
421c878a2d | ||
|
|
36666c2142 | ||
|
|
85801317d1 | ||
|
|
2ed0d65948 | ||
|
|
d459dc4ad1 | ||
|
|
40bc4622ef | ||
|
|
c0f818a07a | ||
|
|
8671fdeda6 | ||
|
|
2619850fb4 | ||
|
|
8feb97dc0d | ||
|
|
4e1ff6dcbb | ||
|
|
8589d752ac | ||
|
|
de4ded68b0 | ||
|
|
9b5a3c5991 | ||
|
|
00b0699c75 | ||
|
|
993cf8bf55 | ||
|
|
7bb7cb8a60 | ||
|
|
b123be5b71 | ||
|
|
ddf5c09a9b | ||
|
|
5f73c08729 | ||
|
|
f503a848c2 | ||
|
|
36a6daccab | ||
|
|
ceb0e26e5e | ||
|
|
284e02bed0 | ||
|
|
3450a57d4a | ||
|
|
592dae31c8 | ||
|
|
2010cbc5fa | ||
|
|
ac0801eced | ||
|
|
ad66e5b060 | ||
|
|
ade4b55520 | ||
|
|
a6d62e0617 | ||
|
|
6e76348df7 | ||
|
|
0d6687f84c | ||
|
|
74d2a9ef9a | ||
|
|
14476d48cc | ||
|
|
ce8ce82567 | ||
|
|
4dc4f1be34 | ||
|
|
16b52331a4 | ||
|
|
5445aaa94e | ||
|
|
2ac3dd6853 | ||
|
|
d8851cb7a0 | ||
|
|
058f6cd2cc | ||
|
|
790cf34d17 | ||
|
|
928d844896 | ||
|
|
939d6a8606 | ||
|
|
58888a74bc | ||
|
|
cc5a71e0e3 | ||
|
|
e83bcf7f9a | ||
|
|
5690e5ce99 | ||
|
|
f2ea8470e5 | ||
|
|
34b9db5afc | ||
|
|
8711d03df7 | ||
|
|
ee448deaba | ||
|
|
6e8db04716 | ||
|
|
658e60cf73 | ||
|
|
4c78f028f8 | ||
|
|
435cc866a3 | ||
|
|
c7d3a558f6 | ||
|
|
089cdb2877 | ||
|
|
ea1e9aa36b | ||
|
|
d0d28ef90d | ||
|
|
6654186a7c | ||
|
|
aa72281eae | ||
|
|
74bcbf828f | ||
|
|
fe39147e64 | ||
|
|
fad00a85e5 | ||
|
|
9c0db4cc83 | ||
|
|
62be2050dd | ||
|
|
56f8aa6912 | ||
|
|
e6f9bfc0e8 | ||
|
|
6f18297b3a | ||
|
|
15016413de | ||
|
|
440b7190ed | ||
|
|
8d1995c625 | ||
|
|
fd01fbf038 | ||
|
|
0408205c1c | ||
|
|
63a7edd771 | ||
|
|
554ffdcce3 | ||
|
|
9850a4ce08 | ||
|
|
3934c15895 | ||
|
|
fd048f1367 | ||
|
|
8645076a71 | ||
|
|
05e9424824 | ||
|
|
52ebe67a98 | ||
|
|
889b31ab78 | ||
|
|
3cf483fe48 | ||
|
|
8dca03173d | ||
|
|
85bdf14b56 | ||
|
|
d524e5ef5e | ||
|
|
52f5370c48 | ||
|
|
da8a0c7657 | ||
|
|
1b42b4b59a | ||
|
|
7c000ec3ed | ||
|
|
c8afe7168c | ||
|
|
28d3cd0148 | ||
|
|
eb5554232a | ||
|
|
ea4c284a48 | ||
|
|
2bdc320216 | ||
|
|
32561aed09 | ||
|
|
71548d9829 | ||
|
|
8aec92fa6d | ||
|
|
a8b9b930b4 | ||
|
|
9755cf9173 | ||
|
|
70261b9bb6 | ||
|
|
9df6c85c3a | ||
|
|
e74163af4c | ||
|
|
fb9580df85 | ||
|
|
26df674785 | ||
|
|
7c9792a6e0 | ||
|
|
7afb2e125a | ||
|
|
41a272de9f | ||
|
|
f335722275 | ||
|
|
6d53b67c2c | ||
|
|
969238b19e | ||
|
|
949d7832cf | ||
|
|
99d227c9db | ||
|
|
a27e419b47 | ||
|
|
e4d0db5a97 | ||
|
|
ba460802c2 | ||
|
|
e54a3c7fcd | ||
|
|
9f8691c6c8 | ||
|
|
a0b8a32eb4 | ||
|
|
7027f264fb | ||
|
|
9bee3b63b1 | ||
|
|
309aef7fee | ||
|
|
08655170aa | ||
|
|
2b341069a7 | ||
|
|
c00fee6936 | ||
|
|
c2d813bdc3 | ||
|
|
786f3a1c44 | ||
|
|
3397eff0cd | ||
|
|
0efb7931c7 | ||
|
|
42f2cc408e | ||
|
|
9446b795b5 | ||
|
|
62f8cda3b3 | ||
|
|
6a1de23175 | ||
|
|
a7b431e743 | ||
|
|
5a25f93522 | ||
|
|
7e33a017c0 | ||
|
|
8b2c10061c | ||
|
|
c5c451ca3b | ||
|
|
2b4ca6cf36 | ||
|
|
ad90b9ab3d | ||
|
|
4340f8eba4 | ||
|
|
4c7db6b7e9 | ||
|
|
c03f0e3c3d | ||
|
|
c5ff443b9f | ||
|
|
01114b4526 | ||
|
|
1524f323a3 | ||
|
|
fccf3eecaa | ||
|
|
c77d45d836 | ||
|
|
5ec12cec6c | ||
|
|
d9578d2bad | ||
|
|
cb8352d6b4 | ||
|
|
fc6558f47f | ||
|
|
9502e5661f | ||
|
|
e1c9a2a00f | ||
|
|
1341ee1b56 | ||
|
|
63efa075a0 | ||
|
|
cb03fc9571 | ||
|
|
a5ec9cfc0f | ||
|
|
be517e491c | ||
|
|
fc8e108642 | ||
|
|
c5d5c4a96c | ||
|
|
dfe330fa1c | ||
|
|
01f77ae25d | ||
|
|
483b81a863 | ||
|
|
36bd967722 | ||
|
|
b0e7d35db8 | ||
|
|
aeb1fb5192 | ||
|
|
a2e60ebcaf | ||
|
|
883ec4d1ef | ||
|
|
4de0126719 | ||
|
|
9768e2dc75 | ||
|
|
08600d5bec | ||
|
|
a624e672d2 | ||
|
|
e4a7e5b2ca | ||
|
|
a0a15cfd5b | ||
|
|
12e923e158 | ||
|
|
cd135317d2 | ||
|
|
4f895d633f | ||
|
|
7d05a6ee8f | ||
|
|
464d817824 | ||
|
|
531324a9be | ||
|
|
6589eb8a8c | ||
|
|
90f071c658 | ||
|
|
a039e383cd | ||
|
|
80163ebcb5 | ||
|
|
a57818d93e | ||
|
|
841adda157 | ||
|
|
0035e31af8 | ||
|
|
c863c6a96d | ||
|
|
1f11b52511 | ||
|
|
526d4eb204 | ||
|
|
0a74cb31d5 | ||
|
|
10ed1b6292 | ||
|
|
4fec5816d6 | ||
|
|
0a0e9f3e0f | ||
|
|
58d95cc9bd | ||
|
|
3b6a9154dd | ||
|
|
d6dd2ff839 | ||
|
|
e57a6ba89f | ||
|
|
12ec2346ef | ||
|
|
1ec0df1069 | ||
|
|
91b3e4d282 | ||
|
|
d338d70492 | ||
|
|
011bb67351 | ||
|
|
d124627202 | ||
|
|
b0a8246a69 | ||
|
|
e6fb39c182 | ||
|
|
e1f1c374ea | ||
|
|
06a1508bfe | ||
|
|
5a5efee46b | ||
|
|
97ae517fbf | ||
|
|
44b813e459 | ||
|
|
539043f5e0 | ||
|
|
dbcace6847 | ||
|
|
c91a4ebcff | ||
|
|
b79c7e4528 | ||
|
|
035b274b70 | ||
|
|
9c6a254945 | ||
|
|
f31f2bedf4 | ||
|
|
756c257553 | ||
|
|
5255d0af8a | ||
|
|
af8a8a6b59 | ||
|
|
461ad25015 | ||
|
|
8838ae787d | ||
|
|
db75402ade | ||
|
|
1e85a140a3 | ||
|
|
c363282fdc | ||
|
|
5b0c48d29e |
60
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
Normal file
60
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Bug report
|
||||
labels: [bug]
|
||||
description: Something isn't working right.
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What is the issue?
|
||||
description: What happened? What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: OS
|
||||
description: Which operating system are you using?
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Docker
|
||||
- WSL2
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: gpu
|
||||
attributes:
|
||||
label: GPU
|
||||
description: Which GPU are you using?
|
||||
multiple: true
|
||||
options:
|
||||
- Nvidia
|
||||
- AMD
|
||||
- Intel
|
||||
- Apple
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: cpu
|
||||
attributes:
|
||||
label: CPU
|
||||
description: Which CPU are you using?
|
||||
multiple: true
|
||||
options:
|
||||
- Intel
|
||||
- AMD
|
||||
- Apple
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Ollama version
|
||||
description: What version of Ollama are you using? (`ollama --version`)
|
||||
placeholder: e.g., 0.1.32
|
||||
validations:
|
||||
required: false
|
||||
18
.github/ISSUE_TEMPLATE/10_model_request.yml
vendored
18
.github/ISSUE_TEMPLATE/10_model_request.yml
vendored
@@ -1,18 +0,0 @@
|
||||
name: Model request
|
||||
description: Request a new model for the library
|
||||
labels: [mr]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please check if your Model request is [already available](https://ollama.com/search) or that you cannot [import it](https://github.com/ollama/ollama/blob/main/docs/import.md#import-a-model) yourself.
|
||||
Tell us about which Model you'd like to see in the library!
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What model would you like?
|
||||
description: Please provide a link to the model.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for filing a model request!
|
||||
6
.github/ISSUE_TEMPLATE/20_feature_request.md
vendored
Normal file
6
.github/ISSUE_TEMPLATE/20_feature_request.md
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Request a new feature
|
||||
labels: feature request
|
||||
---
|
||||
|
||||
41
.github/ISSUE_TEMPLATE/20_feature_request.yml
vendored
41
.github/ISSUE_TEMPLATE/20_feature_request.yml
vendored
@@ -1,41 +0,0 @@
|
||||
name: Feature request
|
||||
description: Propose a new feature
|
||||
labels: [needs-triage, fr]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please check if your feature request is [already filed](https://github.com/ollama/ollama/issues).
|
||||
Tell us about your idea!
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What are you trying to do?
|
||||
description: Tell us about the problem you're trying to solve.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: How should we solve this?
|
||||
description: If you have an idea of how you'd like to see this feature work, let us know.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: alternative
|
||||
attributes:
|
||||
label: What is the impact of not solving this?
|
||||
description: (How) Are you currently working around the issue?
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Anything else?
|
||||
description: Any additional context to share, e.g., links
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for filing a feature request!
|
||||
5
.github/ISSUE_TEMPLATE/30_model_request.md
vendored
Normal file
5
.github/ISSUE_TEMPLATE/30_model_request.md
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
---
|
||||
name: Model request
|
||||
about: Request support for a new model to be added to Ollama
|
||||
labels: model request
|
||||
---
|
||||
125
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
125
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
@@ -1,125 +0,0 @@
|
||||
name: Bug report
|
||||
description: File a bug report. If you need help, please join our Discord server.
|
||||
labels: [needs-triage, bug]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Please check if your bug is [already filed](https://github.com/ollama/ollama/issues) before filing a new one.
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
attributes:
|
||||
label: What is the issue?
|
||||
description: What happened? What did you expect to happen?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: what-was-expected
|
||||
attributes:
|
||||
label: What did you expect to see?
|
||||
description: What did you expect to see/happen instead?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: What are the steps you took that hit this issue?
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: changes
|
||||
attributes:
|
||||
label: Are there any recent changes that introduced the issue?
|
||||
description: If so, what are those changes?
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: OS
|
||||
description: What OS are you using? You may select more than one.
|
||||
multiple: true
|
||||
options:
|
||||
- Linux
|
||||
- macOS
|
||||
- Windows
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: architecture
|
||||
attributes:
|
||||
label: Architecture
|
||||
description: What architecture are you using? You may select more than one.
|
||||
multiple: true
|
||||
options:
|
||||
- arm64
|
||||
- amd64
|
||||
- x86
|
||||
- Other
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Platform
|
||||
description: What platform are you using? You may select more than one.
|
||||
multiple: true
|
||||
options:
|
||||
- Docker
|
||||
- WSL
|
||||
- WSL2
|
||||
validations:
|
||||
required: false
|
||||
- type: input
|
||||
id: ollama-version
|
||||
attributes:
|
||||
label: Ollama version
|
||||
description: What Ollama version are you using? (`ollama --version`)
|
||||
placeholder: e.g., 1.14.4
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: gpu
|
||||
attributes:
|
||||
label: GPU
|
||||
description: What GPU, if any, are you using? You may select more than one.
|
||||
multiple: true
|
||||
options:
|
||||
- Nvidia
|
||||
- AMD
|
||||
- Intel
|
||||
- Apple
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: gpu-info
|
||||
attributes:
|
||||
label: GPU info
|
||||
description: What GPU info do you have? (`nvidia-smi`, `rocminfo`, `system_profiler SPDisplaysDataType`, etc.)
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
id: cpu
|
||||
attributes:
|
||||
label: CPU
|
||||
description: What CPU are you using? You may select more than one.
|
||||
multiple: true
|
||||
options:
|
||||
- Intel
|
||||
- AMD
|
||||
- Apple
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: other-software
|
||||
attributes:
|
||||
label: Other software
|
||||
description: What other software are you using that might be related to this issue?
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for filing a bug report!
|
||||
24
.github/workflows/latest.yaml
vendored
Normal file
24
.github/workflows/latest.yaml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: latest
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
jobs:
|
||||
update-latest:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
|
||||
- name: Tag images as latest
|
||||
env:
|
||||
PUSH: "1"
|
||||
shell: bash
|
||||
run: |
|
||||
export "VERSION=${GITHUB_REF_NAME#v}"
|
||||
./scripts/tag_latest.sh
|
||||
100
.github/workflows/release.yaml
vendored
100
.github/workflows/release.yaml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
# Full build of the Mac assets
|
||||
build-darwin:
|
||||
runs-on: macos-latest
|
||||
runs-on: macos-12
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build Darwin
|
||||
env:
|
||||
@@ -38,9 +38,11 @@ jobs:
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
|
||||
run: |
|
||||
./scripts/build_darwin.sh
|
||||
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
@@ -48,7 +50,6 @@ jobs:
|
||||
dist/*arwin*
|
||||
!dist/*-cov
|
||||
|
||||
|
||||
# Windows builds take a long time to both install the dependencies and build, so parallelize
|
||||
# CPU generation step
|
||||
generate-windows-cpu:
|
||||
@@ -85,7 +86,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -99,7 +100,10 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
dist/windows-amd64/**
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -136,9 +140,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install ROCm"
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -146,7 +150,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: "Verify ROCm"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -160,7 +164,7 @@ jobs:
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
- name: "gather rocm dependencies"
|
||||
- name: 'gather rocm dependencies'
|
||||
run: |
|
||||
$HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
md "dist\deps\bin\rocblas\library"
|
||||
@@ -170,7 +174,9 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
dist/windows-amd64/**
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-rocm-deps
|
||||
@@ -211,29 +217,36 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
# TODO - consider replacing this action with a ps1 snippet to install
|
||||
# This actions seems to fail sometimes with "no tools in cache" but a re-run of the failed job clears it
|
||||
# https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
- name: "Install CUDA"
|
||||
uses: Jimver/cuda-toolkit@v0.2.14
|
||||
id: cuda-toolkit
|
||||
with:
|
||||
cuda: '11.3.1'
|
||||
- name: "Verify CUDA"
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
- name: "gather cuda dependencies"
|
||||
- name: 'gather cuda dependencies'
|
||||
run: |
|
||||
$NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0]
|
||||
md "dist\deps"
|
||||
@@ -243,7 +256,9 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
dist/windows-amd64/**
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -290,30 +305,25 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
path: dist/deps
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-rocm-deps
|
||||
path: dist/deps
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build
|
||||
- run: dir llm/llama.cpp/build
|
||||
- run: dir llm/build
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
@@ -321,22 +331,22 @@ jobs:
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:OLLAMA_SKIP_GENERATE="1"
|
||||
$env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
|
||||
$env:HIP_PATH=$(resolve-path ".\dist\deps")
|
||||
& .\scripts\build_windows.ps1
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist/*.exe
|
||||
path: |
|
||||
dist/OllamaSetup.exe
|
||||
dist/ollama-windows-*.zip
|
||||
|
||||
# Linux x86 assets built using the container based build
|
||||
# Linux x86 assets built using the container based build
|
||||
build-linux-amd64:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: amd64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -366,9 +376,9 @@ jobs:
|
||||
environment: release
|
||||
runs-on: linux-arm64
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: arm64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -376,7 +386,7 @@ jobs:
|
||||
- name: Set Version
|
||||
shell: bash
|
||||
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
|
||||
- name: "Install Docker"
|
||||
- name: 'Install Docker'
|
||||
run: |
|
||||
# Add Docker's official GPG key:
|
||||
env
|
||||
@@ -413,7 +423,7 @@ jobs:
|
||||
!dist/*-cov
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
release:
|
||||
needs:
|
||||
- build-darwin
|
||||
- build-windows
|
||||
@@ -424,8 +434,8 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
OLLAMA_SKIP_IMAGE_BUILD: "1"
|
||||
PUSH: "1"
|
||||
OLLAMA_SKIP_IMAGE_BUILD: '1'
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set Version
|
||||
@@ -453,11 +463,11 @@ jobs:
|
||||
with:
|
||||
name: ${{ env.RELEASE_VERSION }}
|
||||
allowUpdates: true
|
||||
artifacts: "dist/*"
|
||||
artifacts: 'dist/*'
|
||||
draft: true
|
||||
prerelease: true
|
||||
omitBodyDuringUpdate: true
|
||||
generateReleaseNotes: true
|
||||
omitDraftDuringUpdate: true
|
||||
omitPrereleaseDuringUpdate: true
|
||||
replacesArtifacts: true
|
||||
replacesArtifacts: true
|
||||
|
||||
200
.github/workflows/test.yaml
vendored
200
.github/workflows/test.yaml
vendored
@@ -1,15 +1,51 @@
|
||||
name: test
|
||||
|
||||
concurrency:
|
||||
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
|
||||
# cancels running CI jobs and starts all new ones.
|
||||
#
|
||||
# For non-PR pushes, concurrency.group needs to be unique for every distinct
|
||||
# CI run we want to have happen. Use run_id, which in practice means all
|
||||
# non-PR CI runs will be allowed to run without preempting each other.
|
||||
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**/*'
|
||||
- '!docs/**'
|
||||
- '!examples/**'
|
||||
- '!README.md'
|
||||
|
||||
jobs:
|
||||
changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
GENERATE: ${{ steps.changes.outputs.GENERATE }}
|
||||
GENERATE_CUDA: ${{ steps.changes.outputs.GENERATE_CUDA }}
|
||||
GENERATE_ROCM: ${{ steps.changes.outputs.GENERATE_ROCM }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- id: changes
|
||||
run: |
|
||||
changed() {
|
||||
git diff-tree -r --no-commit-id --name-only \
|
||||
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
|
||||
${{ github.event.pull_request.head.sha }} \
|
||||
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
||||
}
|
||||
|
||||
{
|
||||
echo GENERATE=$(changed llm/)
|
||||
echo GENERATE_CUDA=$(changed llm/)
|
||||
echo GENERATE_ROCM=$(changed llm/)
|
||||
} >>$GITHUB_OUTPUT
|
||||
|
||||
generate:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -26,26 +62,32 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$gccpath=(get-command gcc).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
||||
echo $env:PATH
|
||||
go generate -x ./...
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
name: "Windows Go Generate"
|
||||
name: 'Windows Go Generate'
|
||||
- run: go generate -x ./...
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
name: "Unix Go Generate"
|
||||
name: 'Unix Go Generate'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
generate-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
cuda-version:
|
||||
@@ -62,7 +104,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -73,12 +115,16 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
dist/windows-amd64/**
|
||||
generate-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '6.0'
|
||||
- '6.0.2'
|
||||
runs-on: linux
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
@@ -91,7 +137,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -102,7 +148,89 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
dist/windows-amd64/**
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
# CUDA generation step
|
||||
generate-windows-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -125,24 +253,26 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: false
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
|
||||
touch llm/llama.cpp/ggml-metal.metal
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s -v
|
||||
test:
|
||||
needs: generate
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -156,19 +286,31 @@ jobs:
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: '1'
|
||||
OLLAMA_CPU_TARGET: 'static'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
shell: bash
|
||||
- run: go generate ./...
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ ggml-metal.metal
|
||||
*.exe
|
||||
.idea
|
||||
test_data
|
||||
*.crt
|
||||
*.crt
|
||||
llm/build
|
||||
|
||||
@@ -15,13 +15,3 @@ linters:
|
||||
- misspell
|
||||
- nilerr
|
||||
- unused
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# exclude the following functions since we don't generally
|
||||
# need to be concerned with the returned errors
|
||||
exclude-functions:
|
||||
- encoding/binary.Read
|
||||
- (*os.File).Seek
|
||||
- (*bufio.Writer).WriteString
|
||||
- (*github.com/spf13/pflag.FlagSet).Set
|
||||
- (*github.com/ollama/ollama/llm.readSeekOffset).Seek
|
||||
|
||||
43
Dockerfile
43
Dockerfile
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
ARG ROCM_VERSION=6.0
|
||||
ARG ROCM_VERSION=6.0.2
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
@@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
@@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
@@ -40,9 +40,9 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN mkdir /tmp/scratch && \
|
||||
for dep in $(cat /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/x86_64/rocm*/lib/deps.txt) ; do \
|
||||
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
||||
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
||||
done && \
|
||||
(cd /opt/rocm/lib && tar cf - rocblas/library) | (cd /tmp/scratch/ && tar xf - ) && \
|
||||
@@ -61,35 +61,42 @@ ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=static-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/deps/ ./dist/deps/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
@@ -101,8 +108,8 @@ ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
RUN mkdir -p /go/src/github.com/ollama/ollama/dist/deps/
|
||||
COPY --from=static-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build -trimpath .
|
||||
|
||||
58
README.md
58
README.md
@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 2](https://ollama.com/library/llama2):
|
||||
To run and chat with [Llama 3](https://ollama.com/library/llama3):
|
||||
|
||||
```
|
||||
ollama run llama2
|
||||
ollama run llama3
|
||||
```
|
||||
|
||||
## Model library
|
||||
@@ -49,21 +49,18 @@ Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
|
||||
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
|
||||
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
|
||||
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
|
||||
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
|
||||
| Vicuna | 7B | 3.8GB | `ollama run vicuna` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
||||
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
|
||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
@@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
|
||||
|
||||
### Customize a prompt
|
||||
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
|
||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
|
||||
|
||||
```
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
Create a `Modelfile`:
|
||||
|
||||
```
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
|
||||
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
PARAMETER temperature 1
|
||||
@@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
|
||||
### Pull a model
|
||||
|
||||
```
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
> This command can also be used to update a local model. Only the diff will be pulled.
|
||||
@@ -149,13 +146,13 @@ ollama pull llama2
|
||||
### Remove a model
|
||||
|
||||
```
|
||||
ollama rm llama2
|
||||
ollama rm llama3
|
||||
```
|
||||
|
||||
### Copy a model
|
||||
|
||||
```
|
||||
ollama cp llama2 my-llama2
|
||||
ollama cp llama3 my-model
|
||||
```
|
||||
|
||||
### Multiline input
|
||||
@@ -179,7 +176,7 @@ The image features a yellow smiley face, which is likely the central focus of th
|
||||
### Pass in prompt as arguments
|
||||
|
||||
```
|
||||
$ ollama run llama2 "Summarize this file: $(cat README.md)"
|
||||
$ ollama run llama3 "Summarize this file: $(cat README.md)"
|
||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||
```
|
||||
|
||||
@@ -226,7 +223,7 @@ Next, start the server:
|
||||
Finally, in a separate shell, run a model:
|
||||
|
||||
```
|
||||
./ollama run llama2
|
||||
./ollama run llama3
|
||||
```
|
||||
|
||||
## REST API
|
||||
@@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt":"Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
@@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "mistral",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{ "role": "user", "content": "why is the sky blue?" }
|
||||
]
|
||||
@@ -259,15 +256,17 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
@@ -289,6 +288,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [QA-Pilot: Chat with Code Repository](https://github.com/reid41/QA-Pilot)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
- [chat: chat web app for teams](https://github.com/swuecho/chat)
|
||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
||||
- [Ollama RAG Chatbot: Local Chat with multiples PDFs using Ollama and RAG.](https://github.com/datvodinh/rag-chatbot.git)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -304,15 +310,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||
- [cmdh](https://github.com/pgibler/cmdh)
|
||||
- [ooo](https://github.com/npahlfer/ooo)
|
||||
- [shell-pilot](https://github.com/reid41/shell-pilot)
|
||||
- [tenere](https://github.com/pythops/tenere)
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
||||
- [tlm](https://github.com/yusufcanb/tlm)
|
||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||
|
||||
### Database
|
||||
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
|
||||
### Package managers
|
||||
|
||||
@@ -371,3 +380,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
|
||||
### Supported backends
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
// Package api implements the client-side API for code wishing to interact
|
||||
// with the ollama service. The methods of the [Client] type correspond to
|
||||
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
//
|
||||
// The ollama command-line client itself uses this package to interact with
|
||||
// the backend service.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -5,7 +11,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -19,6 +24,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Client encapsulates client state for interacting with the ollama
|
||||
// service. Use [ClientFromEnvironment] to create new Clients.
|
||||
type Client struct {
|
||||
base *url.URL
|
||||
http *http.Client
|
||||
@@ -40,6 +47,15 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return apiError
|
||||
}
|
||||
|
||||
// ClientFromEnvironment creates a new [Client] using configuration from the
|
||||
// environment variable OLLAMA_HOST, which points to the network host and
|
||||
// port on which the ollama service is listenting. The format of this variable
|
||||
// is:
|
||||
//
|
||||
// <scheme>://<host>:<port>
|
||||
//
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
@@ -75,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClient(base *url.URL, http *http.Client) *Client {
|
||||
return &Client{
|
||||
base: base,
|
||||
http: http,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||
var reqBody io.Reader
|
||||
var data []byte
|
||||
@@ -191,8 +214,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Generate] will stop generating and return this error.
|
||||
type GenerateResponseFunc func(GenerateResponse) error
|
||||
|
||||
// Generate generates a response for a given prompt. The req parameter should
|
||||
// be populated with prompt details. fn is called for each response (there may
|
||||
// be multiple responses, e.g. in case streaming is enabled).
|
||||
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
|
||||
var resp GenerateResponse
|
||||
@@ -204,8 +233,15 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
type ChatResponseFunc func(ChatResponse) error
|
||||
|
||||
// Chat generates the next message in a chat. [ChatRequest] may contain a
|
||||
// sequence of messages which can be used to maintain chat history with a model.
|
||||
// fn is called for each response (there may be multiple responses, e.g. if case
|
||||
// streaming is enabled).
|
||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||
var resp ChatResponse
|
||||
@@ -217,8 +253,14 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
|
||||
})
|
||||
}
|
||||
|
||||
// PullProgressFunc is a function that [Client.Pull] invokes every time there
|
||||
// is progress with a "pull" request sent to the service. If this function
|
||||
// returns an error, [Client.Pull] will stop the process and return this error.
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
// Pull downloads a model from the ollama library. fn is called each time
|
||||
// progress is made on the request and can be used to display a progress bar,
|
||||
// etc.
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
@@ -301,18 +343,7 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
|
||||
var statusError StatusError
|
||||
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
119
api/types.go
119
api/types.go
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
@@ -33,18 +34,46 @@ func (e StatusError) Error() string {
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
// Model is the model name; it should be a name familiar to Ollama from
|
||||
// the library at https://ollama.com/library
|
||||
Model string `json:"model"`
|
||||
|
||||
// Prompt is the textual prompt to send to the model.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// System overrides the model's default system message/prompt.
|
||||
System string `json:"system"`
|
||||
|
||||
// Template overrides the model's default prompt template.
|
||||
Template string `json:"template"`
|
||||
|
||||
// Context is the context parameter returned from a previous call to
|
||||
// Generate call. It can be used to keep a short conversational memory.
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Raw set to true means that no formatting will be applied to the prompt.
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
@@ -109,21 +138,24 @@ type Options struct {
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
type Runner struct {
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
|
||||
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
@@ -139,10 +171,11 @@ type EmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantization string `json:"quantization,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
@@ -275,7 +308,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||
var ErrInvalidOpts = errors.New("invalid options")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
@@ -363,8 +396,10 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
func DefaultOptions() Options {
|
||||
return Options{
|
||||
// options set on request to runner
|
||||
NumPredict: -1,
|
||||
NumKeep: 0,
|
||||
NumPredict: -1,
|
||||
|
||||
// set a minimal num_keep to avoid issues on context shifts
|
||||
NumKeep: 4,
|
||||
Temperature: 0.8,
|
||||
TopK: 40,
|
||||
TopP: 0.9,
|
||||
@@ -382,18 +417,16 @@ func DefaultOptions() Options {
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
NumCtx: 2048,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
1
app/.gitignore
vendored
1
app/.gitignore
vendored
@@ -1 +1,2 @@
|
||||
ollama.syso
|
||||
app
|
||||
7
app/AppDelegate.h
Normal file
7
app/AppDelegate.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#import <Cocoa/Cocoa.h>
|
||||
|
||||
@interface AppDelegate : NSObject <NSApplicationDelegate>
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification;
|
||||
|
||||
@end
|
||||
@@ -1,10 +1,6 @@
|
||||
# Ollama App
|
||||
|
||||
## Linux
|
||||
|
||||
TODO
|
||||
|
||||
## MacOS
|
||||
## macOS
|
||||
|
||||
TODO
|
||||
|
||||
|
||||
76
app/app_darwin.go
Normal file
76
app/app_darwin.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package main
|
||||
|
||||
// #cgo CFLAGS: -x objective-c
|
||||
// #cgo LDFLAGS: -framework Cocoa -framework LocalAuthentication -framework ServiceManagement
|
||||
// #include "app_darwin.h"
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func init() {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ServerLogFile = filepath.Join(home, ".ollama", "logs", "server.log")
|
||||
}
|
||||
|
||||
func run() {
|
||||
initLogging()
|
||||
slog.Info("ollama macOS app started")
|
||||
|
||||
// Ask to move to applications directory
|
||||
moving := C.askToMoveToApplications()
|
||||
if moving {
|
||||
return
|
||||
}
|
||||
|
||||
C.killOtherInstances()
|
||||
|
||||
code := C.installSymlink()
|
||||
if code != 0 {
|
||||
slog.Error("Failed to install symlink")
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var options ServerOptions
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var done chan int
|
||||
|
||||
done, err = SpawnServer(ctx, filepath.Join(filepath.Dir(exe), "..", "Resources", "ollama"), options)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to spawn ollama server %s", err))
|
||||
done = make(chan int, 1)
|
||||
done <- 1
|
||||
}
|
||||
|
||||
// Run the native macOS app
|
||||
// Note: this will block until the app is closed
|
||||
C.run()
|
||||
|
||||
slog.Info("ollama macOS app closed")
|
||||
|
||||
cancel()
|
||||
slog.Info("Waiting for ollama server to shutdown...")
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
slog.Info("Ollama app exiting")
|
||||
}
|
||||
|
||||
//export Quit
|
||||
func Quit() {
|
||||
syscall.Kill(os.Getpid(), syscall.SIGTERM)
|
||||
}
|
||||
13
app/app_darwin.h
Normal file
13
app/app_darwin.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#import <Cocoa/Cocoa.h>
|
||||
|
||||
@interface AppDelegate : NSObject <NSApplicationDelegate>
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification;
|
||||
@end
|
||||
|
||||
void run();
|
||||
void killOtherInstances();
|
||||
bool askToMoveToApplications();
|
||||
int createSymlinkWithAuthorization();
|
||||
int installSymlink();
|
||||
extern void Restart();
|
||||
extern void Quit();
|
||||
282
app/app_darwin.m
Normal file
282
app/app_darwin.m
Normal file
@@ -0,0 +1,282 @@
|
||||
#import <AppKit/AppKit.h>
|
||||
#import <Cocoa/Cocoa.h>
|
||||
#import <CoreServices/CoreServices.h>
|
||||
#import <Security/Security.h>
|
||||
#import <ServiceManagement/ServiceManagement.h>
|
||||
#import "app_darwin.h"
|
||||
|
||||
@interface AppDelegate ()
|
||||
|
||||
@property (strong, nonatomic) NSStatusItem *statusItem;
|
||||
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// show status menu
|
||||
NSMenu *menu = [[NSMenu alloc] init];
|
||||
|
||||
NSMenuItem *aboutMenuItem = [[NSMenuItem alloc] initWithTitle:@"About Ollama" action:@selector(aboutOllama) keyEquivalent:@""];
|
||||
[aboutMenuItem setTarget:self];
|
||||
[menu addItem:aboutMenuItem];
|
||||
|
||||
// Settings submenu
|
||||
NSMenu *settingsMenu = [[NSMenu alloc] initWithTitle:@"Settings"];
|
||||
|
||||
// Submenu items
|
||||
NSMenuItem *chooseModelDirectoryItem = [[NSMenuItem alloc] initWithTitle:@"Choose model directory..." action:@selector(chooseModelDirectory) keyEquivalent:@""];
|
||||
[chooseModelDirectoryItem setTarget:self];
|
||||
[chooseModelDirectoryItem setEnabled:YES];
|
||||
[settingsMenu addItem:chooseModelDirectoryItem];
|
||||
|
||||
NSMenuItem *exposeExternallyItem = [[NSMenuItem alloc] initWithTitle:@"Allow external connections" action:@selector(toggleExposeExternally:) keyEquivalent:@""];
|
||||
[exposeExternallyItem setTarget:self];
|
||||
[exposeExternallyItem setState:NSOffState]; // Set initial state to off
|
||||
[exposeExternallyItem setEnabled:YES];
|
||||
[settingsMenu addItem:exposeExternallyItem];
|
||||
|
||||
NSMenuItem *allowCrossOriginItem = [[NSMenuItem alloc] initWithTitle:@"Allow browser requests" action:@selector(toggleCrossOrigin:) keyEquivalent:@""];
|
||||
[allowCrossOriginItem setTarget:self];
|
||||
[allowCrossOriginItem setState:NSOffState]; // Set initial state to off
|
||||
[allowCrossOriginItem setEnabled:YES];
|
||||
[settingsMenu addItem:allowCrossOriginItem];
|
||||
|
||||
NSMenuItem *settingsMenuItem = [[NSMenuItem alloc] initWithTitle:@"Settings" action:nil keyEquivalent:@""];
|
||||
[settingsMenuItem setSubmenu:settingsMenu];
|
||||
[menu addItem:settingsMenuItem];
|
||||
|
||||
[menu addItemWithTitle:@"Quit Ollama" action:@selector(quit) keyEquivalent:@"q"];
|
||||
|
||||
self.statusItem = [[NSStatusBar systemStatusBar] statusItemWithLength:NSVariableStatusItemLength];
|
||||
[self.statusItem addObserver:self forKeyPath:@"button.effectiveAppearance" options:NSKeyValueObservingOptionNew|NSKeyValueObservingOptionInitial context:nil];
|
||||
|
||||
self.statusItem.menu = menu;
|
||||
[self showIcon];
|
||||
}
|
||||
|
||||
- (void)aboutOllama {
|
||||
[[NSApplication sharedApplication] orderFrontStandardAboutPanel:nil];
|
||||
}
|
||||
|
||||
- (void)toggleCrossOrigin:(id)sender {
|
||||
NSMenuItem *item = (NSMenuItem *)sender;
|
||||
if ([item state] == NSOffState) {
|
||||
// Do something when cross-origin requests are allowed
|
||||
[item setState:NSOnState];
|
||||
} else {
|
||||
// Do something when cross-origin requests are disallowed
|
||||
[item setState:NSOffState];
|
||||
}
|
||||
}
|
||||
|
||||
- (void)toggleExposeExternally:(id)sender {
|
||||
NSMenuItem *item = (NSMenuItem *)sender;
|
||||
if ([item state] == NSOffState) {
|
||||
// Do something when Ollama is exposed externally
|
||||
[item setState:NSOnState];
|
||||
} else {
|
||||
// Do something when Ollama is not exposed externally
|
||||
[item setState:NSOffState];
|
||||
}
|
||||
}
|
||||
|
||||
- (void)chooseModelDirectory {
|
||||
NSOpenPanel *openPanel = [NSOpenPanel openPanel];
|
||||
[openPanel setCanChooseFiles:NO];
|
||||
[openPanel setCanChooseDirectories:YES];
|
||||
[openPanel setAllowsMultipleSelection:NO];
|
||||
|
||||
NSInteger result = [openPanel runModal];
|
||||
if (result == NSModalResponseOK) {
|
||||
NSURL *selectedDirectoryURL = [openPanel URLs].firstObject;
|
||||
// Do something with the selected directory URL
|
||||
}
|
||||
}
|
||||
|
||||
-(void) showIcon {
|
||||
NSAppearance* appearance = self.statusItem.button.effectiveAppearance;
|
||||
NSString* appearanceName = (NSString*)(appearance.name);
|
||||
NSString* iconName = [[appearanceName lowercaseString] containsString:@"dark"] ? @"iconDark" : @"icon";
|
||||
NSImage* statusImage = [NSImage imageNamed:iconName];
|
||||
[statusImage setTemplate:YES];
|
||||
self.statusItem.button.image = statusImage;
|
||||
}
|
||||
|
||||
-(void)observeValueForKeyPath:(NSString *)keyPath ofObject:(id)object change:(NSDictionary<NSKeyValueChangeKey,id> *)change context:(void *)context {
|
||||
[self showIcon];
|
||||
}
|
||||
|
||||
- (void)quit {
|
||||
[NSApp stop:nil];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
void run() {
|
||||
@autoreleasepool {
|
||||
[NSApplication sharedApplication];
|
||||
AppDelegate *appDelegate = [[AppDelegate alloc] init];
|
||||
[NSApp setDelegate:appDelegate];
|
||||
[NSApp run];
|
||||
}
|
||||
}
|
||||
|
||||
// killOtherInstances kills all other instances of the app currently
|
||||
// running. This way we can ensure that only the most recently started
|
||||
// instance of Ollama is running
|
||||
void killOtherInstances() {
|
||||
pid_t pid = getpid();
|
||||
NSArray *all = [[NSWorkspace sharedWorkspace] runningApplications];
|
||||
NSMutableArray *apps = [NSMutableArray array];
|
||||
|
||||
for (NSRunningApplication *app in all) {
|
||||
if ([app.bundleIdentifier isEqualToString:[[NSBundle mainBundle] bundleIdentifier]] ||
|
||||
[app.bundleIdentifier isEqualToString:@"ai.ollama.ollama"] ||
|
||||
[app.bundleIdentifier isEqualToString:@"com.electron.ollama"]) {
|
||||
if (app.processIdentifier != pid) {
|
||||
[apps addObject:app];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (NSRunningApplication *app in apps) {
|
||||
kill(app.processIdentifier, SIGTERM);
|
||||
}
|
||||
|
||||
NSDate *startTime = [NSDate date];
|
||||
for (NSRunningApplication *app in apps) {
|
||||
while (!app.terminated) {
|
||||
if (-[startTime timeIntervalSinceNow] >= 5) {
|
||||
kill(app.processIdentifier, SIGKILL);
|
||||
break;
|
||||
}
|
||||
|
||||
[[NSRunLoop currentRunLoop] runUntilDate:[NSDate dateWithTimeIntervalSinceNow:0.1]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool askToMoveToApplications() {
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if ([bundlePath hasPrefix:@"/Applications"]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NSAlert *alert = [[NSAlert alloc] init];
|
||||
[alert setMessageText:@"Move to Applications?"];
|
||||
[alert setInformativeText:@"Ollama works best when run from the Applications directory."];
|
||||
[alert addButtonWithTitle:@"Move to Applications"];
|
||||
[alert addButtonWithTitle:@"Don't move"];
|
||||
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
|
||||
if ([alert runModal] != NSAlertFirstButtonReturn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// move to applications
|
||||
NSString *applicationsPath = @"/Applications";
|
||||
NSString *newPath = [applicationsPath stringByAppendingPathComponent:@"Ollama.app"];
|
||||
NSFileManager *fileManager = [NSFileManager defaultManager];
|
||||
|
||||
// Check if the newPath already exists
|
||||
if ([fileManager fileExistsAtPath:newPath]) {
|
||||
NSError *removeError = nil;
|
||||
[fileManager removeItemAtPath:newPath error:&removeError];
|
||||
if (removeError) {
|
||||
NSLog(@"Error removing file at %@: %@", newPath, removeError);
|
||||
return false; // or handle the error
|
||||
}
|
||||
}
|
||||
|
||||
NSError *moveError = nil;
|
||||
[fileManager moveItemAtPath:bundlePath toPath:newPath error:&moveError];
|
||||
if (moveError) {
|
||||
NSLog(@"Error moving file from %@ to %@: %@", bundlePath, newPath, moveError);
|
||||
return false;
|
||||
}
|
||||
|
||||
NSLog(@"Opening %@", newPath);
|
||||
NSError *error = nil;
|
||||
NSWorkspace *workspace = [NSWorkspace sharedWorkspace];
|
||||
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
[workspace launchApplicationAtURL:[NSURL fileURLWithPath:newPath]
|
||||
options:NSWorkspaceLaunchNewInstance | NSWorkspaceLaunchDefault
|
||||
configuration:@{}
|
||||
error:&error];
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int installSymlink() {
|
||||
NSString *linkPath = @"/usr/local/bin/ollama";
|
||||
NSError *error = nil;
|
||||
|
||||
NSFileManager *fileManager = [NSFileManager defaultManager];
|
||||
NSString *symlinkPath = [fileManager destinationOfSymbolicLinkAtPath:linkPath error:&error];
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
NSString *execPath = [[NSBundle mainBundle] executablePath];
|
||||
NSString *resPath = [[NSBundle mainBundle] pathForResource:@"ollama" ofType:nil];
|
||||
|
||||
// if the symlink already exists and points to the right place, don't prompt
|
||||
if ([symlinkPath isEqualToString:resPath]) {
|
||||
NSLog(@"symbolic link already exists and points to the right place");
|
||||
return 0;
|
||||
}
|
||||
|
||||
NSString *authorizationPrompt = @"Ollama is trying to install its command line interface (CLI) tool.";
|
||||
|
||||
AuthorizationRef auth = NULL;
|
||||
OSStatus createStatus = AuthorizationCreate(NULL, kAuthorizationEmptyEnvironment, kAuthorizationFlagDefaults, &auth);
|
||||
if (createStatus != errAuthorizationSuccess) {
|
||||
NSLog(@"Error creating authorization");
|
||||
return -1;
|
||||
}
|
||||
|
||||
NSString * bundleIdentifier = [[NSBundle mainBundle] bundleIdentifier];
|
||||
NSString *rightNameString = [NSString stringWithFormat:@"%@.%@", bundleIdentifier, @"auth3"];
|
||||
const char *rightName = rightNameString.UTF8String;
|
||||
|
||||
OSStatus getRightResult = AuthorizationRightGet(rightName, NULL);
|
||||
if (getRightResult == errAuthorizationDenied) {
|
||||
if (AuthorizationRightSet(auth, rightName, (__bridge CFTypeRef _Nonnull)(@(kAuthorizationRuleAuthenticateAsAdmin)), (__bridge CFStringRef _Nullable)(authorizationPrompt), NULL, NULL) != errAuthorizationSuccess) {
|
||||
NSLog(@"Failed to set right");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
AuthorizationItem right = { .name = rightName, .valueLength = 0, .value = NULL, .flags = 0 };
|
||||
AuthorizationRights rights = { .count = 1, .items = &right };
|
||||
AuthorizationFlags flags = (AuthorizationFlags)(kAuthorizationFlagExtendRights | kAuthorizationFlagInteractionAllowed);
|
||||
AuthorizationItem iconAuthorizationItem = {.name = kAuthorizationEnvironmentIcon, .valueLength = 0, .value = NULL, .flags = 0};
|
||||
AuthorizationEnvironment authorizationEnvironment = {.count = 0, .items = NULL};
|
||||
|
||||
BOOL failedToUseSystemDomain = NO;
|
||||
OSStatus copyStatus = AuthorizationCopyRights(auth, &rights, &authorizationEnvironment, flags, NULL);
|
||||
if (copyStatus != errAuthorizationSuccess) {
|
||||
failedToUseSystemDomain = YES;
|
||||
|
||||
if (copyStatus == errAuthorizationCanceled) {
|
||||
NSLog(@"User cancelled authorization");
|
||||
return -1;
|
||||
} else {
|
||||
NSLog(@"Failed copying system domain rights: %d", copyStatus);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const char *toolPath = "/bin/ln";
|
||||
const char *args[] = {"-s", "-F", [resPath UTF8String], "/usr/local/bin/ollama", NULL};
|
||||
FILE *pipe = NULL;
|
||||
|
||||
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
OSStatus status = AuthorizationExecuteWithPrivileges(auth, toolPath, kAuthorizationFlagDefaults, (char *const *)args, &pipe);
|
||||
if (status != errAuthorizationSuccess) {
|
||||
NSLog(@"Failed to create symlink");
|
||||
return -1;
|
||||
}
|
||||
|
||||
AuthorizationFree(auth, kAuthorizationFlagDestroyRights);
|
||||
return 0;
|
||||
}
|
||||
166
app/app_windows.go
Normal file
166
app/app_windows.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/tray"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
)
|
||||
|
||||
func init() {
|
||||
AppName += ".exe"
|
||||
CLIName += ".exe"
|
||||
// Logs, configs, downloads go to LOCALAPPDATA
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
AppDataDir = filepath.Join(localAppData, "Ollama")
|
||||
AppLogFile = filepath.Join(AppDataDir, "app.log")
|
||||
ServerLogFile = filepath.Join(AppDataDir, "server.log")
|
||||
|
||||
// Executables are stored in APPDATA
|
||||
AppDir = filepath.Join(localAppData, "Programs", "Ollama")
|
||||
|
||||
// Make sure we have PATH set correctly for any spawned children
|
||||
paths := strings.Split(os.Getenv("PATH"), ";")
|
||||
// Start with whatever we find in the PATH/LD_LIBRARY_PATH
|
||||
found := false
|
||||
for _, path := range paths {
|
||||
d, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(AppDir, d) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
paths = append(paths, AppDir)
|
||||
|
||||
pathVal := strings.Join(paths, ";")
|
||||
slog.Debug("setting PATH=" + pathVal)
|
||||
err := os.Setenv("PATH", pathVal)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to update PATH: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure our logging dir exists
|
||||
_, err := os.Stat(AppDataDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(AppDataDir, 0o755); err != nil {
|
||||
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ShowLogs() {
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
slog.Debug(fmt.Sprintf("viewing logs with start %s", AppDataDir))
|
||||
cmd := exec.Command(cmd_path, "/c", "start", AppDataDir)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: false, CreationFlags: 0x08000000}
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to open log dir: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func Start() {
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
slog.Debug(fmt.Sprintf("viewing logs with start %s", AppDataDir))
|
||||
cmd := exec.Command(cmd_path, "/c", "start", AppDataDir)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: false, CreationFlags: 0x08000000}
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to open log dir: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func run() {
|
||||
initLogging()
|
||||
|
||||
slog.Info("ollama windows app started")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var done chan int
|
||||
|
||||
t, err := tray.NewTray()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start: %s", err)
|
||||
}
|
||||
callbacks := t.GetCallbacks()
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
slog.Debug("starting callback loop")
|
||||
for {
|
||||
select {
|
||||
case <-callbacks.Quit:
|
||||
slog.Debug("quit called")
|
||||
t.Quit()
|
||||
case <-signals:
|
||||
slog.Debug("shutting down due to signal")
|
||||
t.Quit()
|
||||
case <-callbacks.Update:
|
||||
err := updater.DoUpgrade(cancel, done)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("upgrade attempt failed: %s", err))
|
||||
}
|
||||
case <-callbacks.ShowLogs:
|
||||
ShowLogs()
|
||||
case <-callbacks.DoFirstUse:
|
||||
err := lifecycle.GetStarted()
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("Failed to launch getting started shell: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !store.GetFirstTimeRun() {
|
||||
slog.Debug("First time run")
|
||||
err = t.DisplayFirstUseNotification()
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("XXX failed to display first use notification %v", err))
|
||||
}
|
||||
store.SetFirstTimeRun(true)
|
||||
} else {
|
||||
slog.Debug("Not first time, skipping first run notification")
|
||||
}
|
||||
|
||||
if isServerRunning(ctx) {
|
||||
slog.Info("Detected another instance of ollama running, exiting")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
done, err = SpawnServer(ctx, CLIName)
|
||||
if err != nil {
|
||||
// TODO - should we retry in a backoff loop?
|
||||
// TODO - should we pop up a warning and maybe add a menu item to view application logs?
|
||||
slog.Error(fmt.Sprintf("Failed to spawn ollama server %s", err))
|
||||
done = make(chan int, 1)
|
||||
done <- 1
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, t.UpdateAvailable)
|
||||
|
||||
t.Run()
|
||||
cancel()
|
||||
slog.Info("Waiting for ollama server to shutdown...")
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
slog.Info("Ollama app exiting")
|
||||
}
|
||||
40
app/darwin/Ollama.app/Contents/Info.plist
Normal file
40
app/darwin/Ollama.app/Contents/Info.plist
Normal file
@@ -0,0 +1,40 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleDisplayName</key>
|
||||
<string>Ollama</string>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>Ollama</string>
|
||||
<key>CFBundleIconFile</key>
|
||||
<string>icon.icns</string>
|
||||
<key>CFBundleIdentifier</key>
|
||||
<string>com.ollama.ollama</string>
|
||||
<key>CFBundleInfoDictionaryVersion</key>
|
||||
<string>6.0</string>
|
||||
<key>CFBundleName</key>
|
||||
<string>Ollama</string>
|
||||
<key>CFBundlePackageType</key>
|
||||
<string>APPL</string>
|
||||
<key>CFBundleShortVersionString</key>
|
||||
<string>0.0.0</string>
|
||||
<key>CFBundleVersion</key>
|
||||
<string>0.0.0</string>
|
||||
<key>DTCompiler</key>
|
||||
<string>com.apple.compilers.llvm.clang.1_0</string>
|
||||
<key>DTSDKBuild</key>
|
||||
<string>22E245</string>
|
||||
<key>DTSDKName</key>
|
||||
<string>macosx13.3</string>
|
||||
<key>DTXcode</key>
|
||||
<string>1431</string>
|
||||
<key>DTXcodeBuild</key>
|
||||
<string>14E300c</string>
|
||||
<key>LSApplicationCategoryType</key>
|
||||
<string>public.app-category.developer-tools</string>
|
||||
<key>LSMinimumSystemVersion</key>
|
||||
<string>11.0</string>
|
||||
<key>LSUIElement</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
BIN
app/darwin/Ollama.app/Contents/Resources/icon.png
Normal file
BIN
app/darwin/Ollama.app/Contents/Resources/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 382 B |
BIN
app/darwin/Ollama.app/Contents/Resources/icon@2x.png
Normal file
BIN
app/darwin/Ollama.app/Contents/Resources/icon@2x.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 691 B |
BIN
app/darwin/Ollama.app/Contents/Resources/iconDark.png
Normal file
BIN
app/darwin/Ollama.app/Contents/Resources/iconDark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 382 B |
BIN
app/darwin/Ollama.app/Contents/Resources/iconDark@2x.png
Normal file
BIN
app/darwin/Ollama.app/Contents/Resources/iconDark@2x.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 721 B |
@@ -1,5 +1,3 @@
|
||||
//go:build !windows
|
||||
|
||||
package lifecycle
|
||||
|
||||
import "fmt"
|
||||
@@ -1,92 +0,0 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/tray"
|
||||
)
|
||||
|
||||
func Run() {
|
||||
InitLogging()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var done chan int
|
||||
|
||||
t, err := tray.NewTray()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start: %s", err)
|
||||
}
|
||||
callbacks := t.GetCallbacks()
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
slog.Debug("starting callback loop")
|
||||
for {
|
||||
select {
|
||||
case <-callbacks.Quit:
|
||||
slog.Debug("quit called")
|
||||
t.Quit()
|
||||
case <-signals:
|
||||
slog.Debug("shutting down due to signal")
|
||||
t.Quit()
|
||||
case <-callbacks.Update:
|
||||
err := DoUpgrade(cancel, done)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("upgrade attempt failed: %s", err))
|
||||
}
|
||||
case <-callbacks.ShowLogs:
|
||||
ShowLogs()
|
||||
case <-callbacks.DoFirstUse:
|
||||
err := GetStarted()
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("Failed to launch getting started shell: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Are we first use?
|
||||
if !store.GetFirstTimeRun() {
|
||||
slog.Debug("First time run")
|
||||
err = t.DisplayFirstUseNotification()
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("XXX failed to display first use notification %v", err))
|
||||
}
|
||||
store.SetFirstTimeRun(true)
|
||||
} else {
|
||||
slog.Debug("Not first time, skipping first run notification")
|
||||
}
|
||||
|
||||
if IsServerRunning(ctx) {
|
||||
slog.Info("Detected another instance of ollama running, exiting")
|
||||
os.Exit(1)
|
||||
} else {
|
||||
done, err = SpawnServer(ctx, CLIName)
|
||||
if err != nil {
|
||||
// TODO - should we retry in a backoff loop?
|
||||
// TODO - should we pop up a warning and maybe add a menu item to view application logs?
|
||||
slog.Error(fmt.Sprintf("Failed to spawn ollama server %s", err))
|
||||
done = make(chan int, 1)
|
||||
done <- 1
|
||||
}
|
||||
}
|
||||
|
||||
StartBackgroundUpdaterChecker(ctx, t.UpdateAvailable)
|
||||
|
||||
t.Run()
|
||||
cancel()
|
||||
slog.Info("Waiting for ollama server to shutdown...")
|
||||
if done != nil {
|
||||
<-done
|
||||
}
|
||||
slog.Info("Ollama app exiting")
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package lifecycle
|
||||
|
||||
import "log/slog"
|
||||
|
||||
func ShowLogs() {
|
||||
slog.Warn("ShowLogs not yet implemented")
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func ShowLogs() {
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
slog.Debug(fmt.Sprintf("viewing logs with start %s", AppDataDir))
|
||||
cmd := exec.Command(cmd_path, "/c", "start", AppDataDir)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: false, CreationFlags: 0x08000000}
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to open log dir: %s", err))
|
||||
}
|
||||
}
|
||||
@@ -70,10 +70,5 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
} else if runtime.GOOS == "darwin" {
|
||||
// TODO
|
||||
AppName += ".app"
|
||||
// } else if runtime.GOOS == "linux" {
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func getCLIFullPath(command string) string {
|
||||
cmdPath := ""
|
||||
appExe, err := os.Executable()
|
||||
if err == nil {
|
||||
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
||||
_, err := os.Stat(cmdPath)
|
||||
if err == nil {
|
||||
return cmdPath
|
||||
}
|
||||
}
|
||||
cmdPath, err = exec.LookPath(command)
|
||||
if err == nil {
|
||||
_, err := os.Stat(cmdPath)
|
||||
if err == nil {
|
||||
return cmdPath
|
||||
}
|
||||
}
|
||||
pwd, err := os.Getwd()
|
||||
if err == nil {
|
||||
cmdPath = filepath.Join(pwd, command)
|
||||
_, err = os.Stat(cmdPath)
|
||||
if err == nil {
|
||||
return cmdPath
|
||||
}
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err := os.Stat(logDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
cmd := getCmd(ctx, getCLIFullPath(command))
|
||||
// send stdout and stderr to a file
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
|
||||
}
|
||||
|
||||
// TODO - rotation
|
||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to create server log %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
}
|
||||
if cmd.Process != nil {
|
||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||
|
||||
go func() {
|
||||
// Keep the server running unless we're shuttind down the app
|
||||
crashCount := 0
|
||||
for {
|
||||
cmd.Wait() //nolint:errcheck
|
||||
stdin.Close()
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
done <- code
|
||||
return
|
||||
default:
|
||||
crashCount++
|
||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to restart server %s", err))
|
||||
// Keep trying, but back off if we keep failing
|
||||
time.Sleep(time.Duration(crashCount) * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func IsServerRunning(ctx context.Context) bool {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
slog.Info("unable to connect to server")
|
||||
return false
|
||||
}
|
||||
err = client.Heartbeat(ctx)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("heartbeat from server: %s", err))
|
||||
slog.Info("unable to connect to server")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, cmd string) *exec.Cmd {
|
||||
return exec.CommandContext(ctx, cmd, "serve")
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, exePath string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, exePath, "serve")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000}
|
||||
return cmd
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package lifecycle
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func InitLogging() {
|
||||
func initLogging() {
|
||||
level := slog.LevelInfo
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
@@ -41,6 +41,4 @@ func InitLogging() {
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
slog.Info("ollama app started")
|
||||
}
|
||||
12
app/main.go
12
app/main.go
@@ -2,11 +2,15 @@ package main
|
||||
|
||||
// Compile with the following to get rid of the cmd pop up on windows
|
||||
// go build -ldflags="-H windowsgui" .
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
var (
|
||||
AppName string
|
||||
CLIName string
|
||||
AppDir string
|
||||
AppDataDir string
|
||||
AppLogFile string
|
||||
ServerLogFile string
|
||||
)
|
||||
|
||||
func main() {
|
||||
lifecycle.Run()
|
||||
run()
|
||||
}
|
||||
|
||||
163
app/server.go
Normal file
163
app/server.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
Cors bool
|
||||
Expose bool
|
||||
ModelsPath string
|
||||
}
|
||||
|
||||
func start(ctx context.Context, command string, options ServerOptions) (*exec.Cmd, error) {
|
||||
cmd := getCmd(ctx, command)
|
||||
|
||||
// set environment variables
|
||||
if options.ModelsPath != "" {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("OLLAMA_MODELS=%s", options.ModelsPath))
|
||||
}
|
||||
|
||||
if options.Cors {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_ORIGINS=*")
|
||||
}
|
||||
|
||||
if options.Expose {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_HOST=0.0.0.0")
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// TODO - rotation
|
||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
// Re-wire context done behavior to attempt a graceful shutdown of the server
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
err := terminate(cmd)
|
||||
if err != nil {
|
||||
slog.Warn("error trying to gracefully terminate server", "err", err)
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
|
||||
tick := time.NewTicker(10 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
exited, err := isProcessExited(cmd.Process.Pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exited {
|
||||
return nil
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start server %w", err)
|
||||
}
|
||||
if cmd.Process != nil {
|
||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string, options ServerOptions) (chan int, error) {
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err := os.Stat(logDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
|
||||
go func() {
|
||||
// Keep the server running unless we're shuttind down the app
|
||||
crashCount := 0
|
||||
for {
|
||||
slog.Info(fmt.Sprintf("starting server..."))
|
||||
cmd, err := start(ctx, command, options)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to start server %s", err))
|
||||
}
|
||||
|
||||
cmd.Wait() //nolint:errcheck
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
done <- code
|
||||
return
|
||||
default:
|
||||
crashCount++
|
||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func isServerRunning(ctx context.Context) bool {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
slog.Info("unable to connect to server")
|
||||
return false
|
||||
}
|
||||
err = client.Heartbeat(ctx)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("heartbeat from server: %s", err))
|
||||
slog.Info("unable to connect to server")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
36
app/server_darwin.go
Normal file
36
app/server_darwin.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, cmd string) *exec.Cmd {
|
||||
return exec.CommandContext(ctx, cmd, "serve")
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
return cmd.Process.Signal(os.Interrupt)
|
||||
}
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to find process: %v", err)
|
||||
}
|
||||
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("error signaling process: %v", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
89
app/server_windows.go
Normal file
89
app/server_windows.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, exePath string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, exePath, "serve")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
dll, err := windows.LoadDLL("kernel32.dll")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dll.Release() // nolint: errcheck
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
f, err := dll.FindProc("AttachConsole")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := f.Call(uintptr(pid))
|
||||
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("SetConsoleCtrlHandler")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(0, 1)
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const STILL_ACTIVE = 259
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open process: %v", err)
|
||||
}
|
||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
||||
|
||||
var exitCode uint32
|
||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get exit code: %v", err)
|
||||
}
|
||||
|
||||
if exitCode == STILL_ACTIVE {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -24,10 +24,5 @@ func NewTray() (commontray.OllamaTray, error) {
|
||||
return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err)
|
||||
}
|
||||
|
||||
tray, err := InitPlatformTray(icon, updateIcon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tray, nil
|
||||
return InitPlatformTray(icon, updateIcon)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !windows
|
||||
|
||||
package tray
|
||||
|
||||
import (
|
||||
@@ -1,4 +1,4 @@
|
||||
package lifecycle
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -22,6 +22,10 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
UpdateStageDir string
|
||||
)
|
||||
|
||||
var (
|
||||
UpdateCheckURLBase = "https://ollama.com/api/update"
|
||||
UpdateDownloaded = false
|
||||
@@ -123,7 +127,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
slog.Debug("no etag detected, falling back to filename based dedup")
|
||||
etag = "_"
|
||||
}
|
||||
filename := Installer
|
||||
filename := "OllamaSetup.exe"
|
||||
_, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition"))
|
||||
if err == nil {
|
||||
filename = params["filename"]
|
||||
@@ -1,6 +1,4 @@
|
||||
//go:build !windows
|
||||
|
||||
package lifecycle
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package lifecycle
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,7 +9,13 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func init() {
|
||||
UpdateStageDir = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "updates")
|
||||
}
|
||||
|
||||
func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||
logFile := filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "upgrade.log")
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(UpdateStageDir, "*", "*.exe")) // TODO generalize for multiplatform
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup downloads: %s", err)
|
||||
@@ -23,13 +29,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||
installerExe := files[0]
|
||||
|
||||
slog.Info("starting upgrade with " + installerExe)
|
||||
slog.Info("upgrade log file " + UpgradeLogFile)
|
||||
slog.Info("upgrade log file " + logFile)
|
||||
|
||||
// When running in debug mode, we'll be "verbose" and let the installer pop up and prompt
|
||||
installArgs := []string{
|
||||
"/CLOSEAPPLICATIONS", // Quit the tray app if it's still running
|
||||
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
|
||||
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
|
||||
"/CLOSEAPPLICATIONS", // Quit the tray app if it's still running
|
||||
"/LOG=" + filepath.Base(logFile), // Only relative seems reliable, so set pwd
|
||||
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
|
||||
}
|
||||
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
|
||||
// TODO - temporarily disable since we're pinning in debug mode for the preview
|
||||
@@ -53,7 +59,7 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("starting installer: %s %v", installerExe, installArgs))
|
||||
os.Chdir(filepath.Dir(UpgradeLogFile)) //nolint:errcheck
|
||||
os.Chdir(filepath.Dir(logFile)) //nolint:errcheck
|
||||
cmd := exec.Command(installerExe, installArgs...)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -88,15 +88,12 @@ DialogFontSize=12
|
||||
[Files]
|
||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||
; Assumes v5.7, may need adjustments for v6
|
||||
#if GetEnv("HIP_PATH") != ""
|
||||
Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
|
||||
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
|
||||
; amdhip64.dll dependency comes from the driver and must be installed already
|
||||
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
|
||||
#if DirExists("..\dist\windows-amd64\rocm")
|
||||
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
|
||||
#endif
|
||||
|
||||
|
||||
@@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
|
||||
|
||||
|
||||
;FinishedHeadingLabel=Run your first model
|
||||
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2
|
||||
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
|
||||
;ClickFinish=%n
|
||||
|
||||
[Registry]
|
||||
196
cmd/cmd.go
196
cmd/cmd.go
@@ -17,6 +17,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
|
||||
modelfile, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -95,71 +94,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO make this work w/ adapters
|
||||
if fi.IsDir() {
|
||||
tf, err := os.CreateTemp("", "ollama-tf")
|
||||
// this is likely a safetensors or pytorch directory
|
||||
// TODO make this work w/ adapters
|
||||
tempfile, err := tempZipFiles(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(tf.Name())
|
||||
defer os.RemoveAll(tempfile)
|
||||
|
||||
zf := zip.NewWriter(tf)
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("no safetensors files were found in '%s'", path)
|
||||
}
|
||||
|
||||
// add the safetensor config file + tokenizer
|
||||
files = append(files, filepath.Join(path, "config.json"))
|
||||
files = append(files, filepath.Join(path, "added_tokens.json"))
|
||||
files = append(files, filepath.Join(path, "tokenizer.model"))
|
||||
|
||||
for _, fn := range files {
|
||||
f, err := os.Open(fn)
|
||||
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := zip.FileInfoHeader(fi)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.Name = filepath.Base(fn)
|
||||
h.Method = zip.Store
|
||||
|
||||
w, err := zf.CreateHeader(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(w, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err := zf.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tf.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
path = tf.Name()
|
||||
path = tempfile
|
||||
}
|
||||
|
||||
digest, err := createBlob(cmd, client, path)
|
||||
@@ -167,10 +111,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
|
||||
name := c.Name
|
||||
if c.Name == "model" {
|
||||
name = "from"
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
|
||||
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
|
||||
}
|
||||
}
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
spinner.Stop()
|
||||
@@ -194,7 +145,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -202,6 +155,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func tempZipFiles(path string) (string, error) {
|
||||
tempfile, err := os.CreateTemp("", "ollama-tf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer tempfile.Close()
|
||||
|
||||
zipfile := zip.NewWriter(tempfile)
|
||||
defer zipfile.Close()
|
||||
|
||||
detectContentType := func(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var b bytes.Buffer
|
||||
b.Grow(512)
|
||||
|
||||
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
glob := func(pattern, contentType string) ([]string, error) {
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, safetensor := range matches {
|
||||
if ct, err := detectContentType(safetensor); err != nil {
|
||||
return nil, err
|
||||
} else if ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
files = append(files, pt...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers consolidated.x.pth, consolidated.pth
|
||||
files = append(files, pt...)
|
||||
} else {
|
||||
return "", errors.New("no safetensors or torch files found")
|
||||
}
|
||||
|
||||
// add configuration files, json files are detected as text/plain
|
||||
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
files = append(files, js...)
|
||||
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
zfi, err := zip.FileInfoHeader(fi)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
zf, err := zipfile.CreateHeader(zfi)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(zf, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return tempfile.Name(), nil
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||
bin, err := os.Open(path)
|
||||
if err != nil {
|
||||
@@ -213,7 +274,10 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
}
|
||||
bin.Seek(0, io.SeekStart)
|
||||
|
||||
if _, err := bin.Seek(0, io.SeekStart); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
@@ -932,6 +996,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -973,6 +1038,7 @@ Environment Variables:
|
||||
OLLAMA_ORIGINS A comma separated list of allowed origins.
|
||||
OLLAMA_MODELS The path to the models directory (default is "~/.ollama/models")
|
||||
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default is "5m")
|
||||
OLLAMA_DEBUG Set to 1 to enable additional debug logging
|
||||
`)
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
|
||||
@@ -295,10 +295,14 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
@@ -21,146 +18,71 @@ import (
|
||||
)
|
||||
|
||||
type Params struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize int `json:"vocab_size"`
|
||||
HiddenSize int `json:"hidden_size"` // n_embd
|
||||
HiddenLayers int `json:"num_hidden_layers"` // n_layer
|
||||
ContextSize int `json:"max_position_embeddings"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
AttentionHeads int `json:"num_attention_heads"` // n_head
|
||||
KeyValHeads int `json:"num_key_value_heads"`
|
||||
NormEPS float64 `json:"rms_norm_eps"`
|
||||
RopeFreqBase float64 `json:"rope_theta"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize int `json:"vocab_size"`
|
||||
HiddenSize int `json:"hidden_size"` // n_embd
|
||||
HiddenLayers int `json:"num_hidden_layers"` // n_layer
|
||||
ContextSize int `json:"max_position_embeddings"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
AttentionHeads int `json:"num_attention_heads"` // n_head
|
||||
KeyValHeads int `json:"num_key_value_heads"`
|
||||
NormEPS float64 `json:"rms_norm_eps"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
HeadDimension int `json:"head_dim"`
|
||||
PaddingTokenID int `json:"pad_token_id"`
|
||||
RopeFrequencyBase float64 `json:"rope_theta"`
|
||||
|
||||
Experts int `json:"num_local_experts"`
|
||||
ExpertsUsed int `json:"num_experts_per_tok"`
|
||||
|
||||
ByteOrder
|
||||
}
|
||||
|
||||
type MetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
type ByteOrder interface {
|
||||
binary.ByteOrder
|
||||
binary.AppendByteOrder
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
binary.Read(f, binary.LittleEndian, &jsonSize)
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data MetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := GetTensorName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
FileName: fn,
|
||||
OffsetPadding: 8 + jsonSize,
|
||||
FileOffsets: []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])},
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("%v", t))
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
}
|
||||
return tensors, offset, nil
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
}
|
||||
|
||||
func GetSafeTensors(dirpath string) ([]llm.Tensor, error) {
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = ReadSafeTensors(f, offset)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
return tensors, nil
|
||||
type ModelFormat interface {
|
||||
GetLayerName(string) (string, error)
|
||||
GetTensors(string, *Params) ([]llm.Tensor, error)
|
||||
GetParams(string) (*Params, error)
|
||||
GetModelArch(string, string, *Params) (ModelArch, error)
|
||||
}
|
||||
|
||||
func GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
type ModelData struct {
|
||||
Path string
|
||||
Name string
|
||||
Params *Params
|
||||
Vocab *Vocab
|
||||
Tensors []llm.Tensor
|
||||
Format ModelFormat
|
||||
}
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
func GetModelFormat(dirname string) (ModelFormat, error) {
|
||||
files, err := filepath.Glob(filepath.Join(dirname, "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ¶ms, nil
|
||||
for _, fn := range files {
|
||||
slog.Debug(fmt.Sprintf("file = %s", fn))
|
||||
if strings.HasSuffix(fn, ".safetensors") {
|
||||
return &SafetensorFormat{}, nil
|
||||
} else if strings.HasSuffix(fn, ".bin") {
|
||||
slog.Debug("model is torch")
|
||||
return &TorchFormat{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't determine model format")
|
||||
}
|
||||
|
||||
// Details on gguf's tokenizer can be found at:
|
||||
@@ -171,7 +93,7 @@ type Vocab struct {
|
||||
Types []int32
|
||||
}
|
||||
|
||||
func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
|
||||
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
|
||||
if err != nil {
|
||||
@@ -196,6 +118,14 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
v.Tokens = append(v.Tokens, p.GetPiece())
|
||||
v.Scores = append(v.Scores, p.GetScore())
|
||||
t := p.GetType()
|
||||
switch t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
|
||||
case sentencepiece.ModelProto_SentencePiece_CONTROL:
|
||||
case sentencepiece.ModelProto_SentencePiece_UNUSED:
|
||||
case sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
default:
|
||||
t = sentencepiece.ModelProto_SentencePiece_NORMAL
|
||||
}
|
||||
v.Types = append(v.Types, int32(t))
|
||||
}
|
||||
|
||||
@@ -243,89 +173,15 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
}
|
||||
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func GetTensorName(n string) (string, error) {
|
||||
tMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
v, ok := tMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
if params.VocabSize > len(v.Tokens) {
|
||||
missingTokens := params.VocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
v.Scores = append(v.Scores, -1)
|
||||
v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
|
||||
c := llm.ContainerGGUF{
|
||||
ByteOrder: binary.LittleEndian,
|
||||
}
|
||||
|
||||
m := llm.NewGGUFModel(&c)
|
||||
m.Tensors = tensors
|
||||
m.KV["general.architecture"] = "llama"
|
||||
m.KV["general.name"] = name
|
||||
m.KV["llama.context_length"] = uint32(params.ContextSize)
|
||||
m.KV["llama.embedding_length"] = uint32(params.HiddenSize)
|
||||
m.KV["llama.block_count"] = uint32(params.HiddenLayers)
|
||||
m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize)
|
||||
m.KV["llama.rope.dimension_count"] = uint32(128)
|
||||
m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads)
|
||||
m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
|
||||
m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
|
||||
m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase)
|
||||
m.KV["general.file_type"] = uint32(1)
|
||||
m.KV["tokenizer.ggml.model"] = "llama"
|
||||
|
||||
m.KV["tokenizer.ggml.tokens"] = vocab.Tokens
|
||||
m.KV["tokenizer.ggml.scores"] = vocab.Scores
|
||||
m.KV["tokenizer.ggml.token_type"] = vocab.Types
|
||||
|
||||
m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
|
||||
m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
|
||||
m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0)
|
||||
m.KV["tokenizer.ggml.add_bos_token"] = true
|
||||
m.KV["tokenizer.ggml.add_eos_token"] = false
|
||||
|
||||
// llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer
|
||||
// m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme
|
||||
|
||||
c.V3.NumTensor = uint64(len(tensors))
|
||||
c.V3.NumKV = uint64(len(m.KV))
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = m.Encode(f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
return v, nil
|
||||
}
|
||||
|
||||
137
convert/gemma.go
Normal file
137
convert/gemma.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type GemmaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
|
||||
|
||||
data := make([]byte, r.end-r.start)
|
||||
if err := binary.Read(f, r.bo, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
var err error
|
||||
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, vectorSize)
|
||||
|
||||
var err error
|
||||
n, err = n.Add(ones)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
newN, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
var fullTensor []float32
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
for _, l := range t {
|
||||
if strings.HasSuffix(l.Name, "norm.weight") {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = gemmaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "gemma",
|
||||
"general.name": m.Name,
|
||||
"gemma.context_length": uint32(m.Params.ContextSize),
|
||||
"gemma.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"gemma.block_count": uint32(m.Params.HiddenLayers),
|
||||
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
|
||||
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(3),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
176
convert/llama.go
Normal file
176
convert/llama.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type LlamaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
||||
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
|
||||
var err error
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("heads = %d", heads))
|
||||
|
||||
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||
wt := l.WriterTo.(torchWriterTo)
|
||||
wt.handler = llamaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) LoadVocab() error {
|
||||
var v *Vocab
|
||||
var err error
|
||||
|
||||
slog.Debug("loading vocab")
|
||||
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("vocab loaded")
|
||||
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
173
convert/mistral.go
Normal file
173
convert/mistral.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type MistralModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
layerSize := r.end - r.start
|
||||
|
||||
var err error
|
||||
tData := make([]uint16, layerSize/2)
|
||||
if err = binary.Read(f, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
tData, err = repack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for _, n := range tData {
|
||||
buf = r.bo.AppendUint16(buf, n)
|
||||
}
|
||||
|
||||
tempBuf := make([]uint16, len(tData))
|
||||
tDataF32 := bfloat16.DecodeFloat32(buf)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = mistralLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
96
convert/mixtral.go
Normal file
96
convert/mixtral.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type MixtralModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func (m *MixtralModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = mistralLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MixtralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MixtralModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
|
||||
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
|
||||
"llama.expert_count": uint32(m.Params.Experts),
|
||||
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
|
||||
|
||||
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
317
convert/safetensors.go
Normal file
317
convert/safetensors.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
filename string
|
||||
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
type tensorMetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
type SafetensorFormat struct{}
|
||||
|
||||
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting tensor data")
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = m.readTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data tensorMetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
slog.Error("couldn't decode properly")
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := m.GetLayerName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
offset += size
|
||||
tensors = append(tensors, t)
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
|
||||
slog.Debug(fmt.Sprintf("offset = %d", offset))
|
||||
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
tMap := map[string]string{
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
|
||||
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
|
||||
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
|
||||
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
case "MixtralForCausalLM":
|
||||
return &MixtralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
286
convert/torch.go
Normal file
286
convert/torch.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/nlpodyssey/gopickle/types"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type torchWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
storage pytorch.StorageInterface
|
||||
handler func(w io.Writer, r torchWriterTo) error
|
||||
}
|
||||
|
||||
type TorchFormat struct{}
|
||||
|
||||
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting torch tensors")
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
|
||||
if err != nil {
|
||||
slog.Error("didn't find any torch files")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, fn := range files {
|
||||
m, err := pytorch.Load(fn)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error unpickling: %q", err))
|
||||
return []llm.Tensor{}, err
|
||||
}
|
||||
|
||||
for _, k := range m.(*types.Dict).Keys() {
|
||||
if strings.HasSuffix(k.(string), "self_attn.rotary_emb.inv_freq") {
|
||||
continue
|
||||
}
|
||||
|
||||
t, _ := m.(*types.Dict).Get(k)
|
||||
tshape := t.(*pytorch.Tensor).Size
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(tshape) {
|
||||
case 0:
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(tshape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(tshape[0] * tshape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := tf.GetLayerName(k.(string))
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range tshape {
|
||||
shape[i] = uint64(tshape[i])
|
||||
}
|
||||
|
||||
tensor := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset, // calculate the offset
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
tensor.WriterTo = torchWriterTo{
|
||||
t: &tensor,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
storage: t.(*pytorch.Tensor).Source,
|
||||
}
|
||||
|
||||
tensors = append(tensors, tensor)
|
||||
offset += size
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
|
||||
}
|
||||
|
||||
func getAltParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "params.json"))
|
||||
if err != nil {
|
||||
slog.Error("no params.json")
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
type TorchParams struct {
|
||||
HiddenSize int `json:"dim"`
|
||||
AttentionHeads int `json:"n_heads"`
|
||||
KeyValHeads int `json:"n_kv_heads"`
|
||||
HiddenLayers int `json:"n_layers"`
|
||||
RopeTheta int `json:"rope_theta"`
|
||||
NormEPS float64 `json:"norm_eps"`
|
||||
}
|
||||
|
||||
var tparams TorchParams
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(&tparams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := &Params{
|
||||
HiddenSize: tparams.HiddenSize,
|
||||
AttentionHeads: tparams.AttentionHeads,
|
||||
KeyValHeads: tparams.KeyValHeads,
|
||||
HiddenLayers: tparams.HiddenLayers,
|
||||
NormEPS: tparams.NormEPS,
|
||||
}
|
||||
|
||||
switch {
|
||||
case tparams.RopeTheta == 1000000:
|
||||
// Codellama
|
||||
params.ContextSize = 16384
|
||||
case tparams.NormEPS == 1e-06:
|
||||
// llama2
|
||||
slog.Debug("Found llama2 - setting context size to 4096")
|
||||
params.ContextSize = 4096
|
||||
default:
|
||||
params.ContextSize = 2048
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// try params.json instead
|
||||
return getAltParams(dirpath)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var params Params
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"tok_embeddings.weight": "token_embd.weight",
|
||||
"output.weight": "output.weight",
|
||||
"norm.weight": "output_norm.weight",
|
||||
"rope.freqs": "rope_freqs.weight",
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
lMap := map[string]string{
|
||||
"layers.(\\d+).attention_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).attention_output_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).feed_forward.w2.weight": "blk.$1.ffn_down.weight",
|
||||
"layers.(\\d+).feed_forward.w1.weight": "blk.$1.ffn_gate.weight",
|
||||
"layers.(\\d+).feed_forward.w3.weight": "blk.$1.ffn_up.weight",
|
||||
"layers.(\\d+).ffn_norm.weight": "blk.$1.ffn_norm.weight",
|
||||
"layers.(\\d+).attention.wk.weight": "blk.$1.attn_k.weight",
|
||||
"layers.(\\d+).attention.wo.weight": "blk.$1.attn_output.weight",
|
||||
"layers.(\\d+).attention.wq.weight": "blk.$1.attn_q.weight",
|
||||
"layers.(\\d+).attention.wv.weight": "blk.$1.attn_v.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range lMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r)
|
||||
}
|
||||
|
||||
switch r.storage.(type) {
|
||||
case *pytorch.FloatStorage:
|
||||
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name))
|
||||
return 0, nil
|
||||
case *pytorch.HalfStorage:
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
|
||||
if err := binary.Write(w, r.bo, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
|
||||
if err := binary.Write(w, r.bo, tData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
return &LlamaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
@@ -90,7 +90,7 @@ The final response in the stream also includes additional data about the generat
|
||||
- `load_duration`: time spent in nanoseconds loading the model
|
||||
- `prompt_eval_count`: number of tokens in the prompt
|
||||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
||||
- `eval_count`: number of tokens the response
|
||||
- `eval_count`: number of tokens in the response
|
||||
- `eval_duration`: time in nanoseconds spent generating the response
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
@@ -394,7 +394,6 @@ Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
|
||||
@@ -228,3 +228,7 @@ To unload the model and free up memory use:
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
|
||||
```
|
||||
|
||||
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
|
||||
|
||||
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
|
||||
|
||||
@@ -139,9 +139,6 @@ PARAMETER <parameter> <parametervalue>
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_gqa | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | int | num_gqa 1 |
|
||||
| num_gpu | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 50 |
|
||||
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -76,3 +76,10 @@ install script which version to install.
|
||||
```sh
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its
|
||||
temporary executable files, you can specify an alternate location by setting
|
||||
OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example
|
||||
OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
@@ -18,7 +18,7 @@ const ollama = new Ollama({
|
||||
model: "llama2",
|
||||
});
|
||||
|
||||
const answer = await ollama.call(`why is the sky blue?`);
|
||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
@@ -1,38 +1,15 @@
|
||||
# Running Ollama on NVIDIA Jetson Devices
|
||||
|
||||
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
|
||||
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
|
||||
|
||||
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
|
||||
|
||||
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
|
||||
mode. This can be verified by using a monitoring tool like jtop.
|
||||
|
||||
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
|
||||
version of our target model.
|
||||
|
||||
Prerequisites:
|
||||
|
||||
- curl
|
||||
- tmux
|
||||
|
||||
Here are the steps:
|
||||
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
|
||||
|
||||
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
|
||||
- Stop the Ollama service: `sudo systemctl stop ollama`
|
||||
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
|
||||
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
|
||||
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
||||
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
|
||||
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
|
||||
|
||||
```
|
||||
FROM mistral
|
||||
PARAMETER num_gpu 999
|
||||
```
|
||||
|
||||
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
|
||||
- Run the new model: `ollama run mistral-jetson`
|
||||
|
||||
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
|
||||
- Start an interactive session: `ollama run mistral`
|
||||
|
||||
And that's it!
|
||||
|
||||
# Running Ollama in Docker
|
||||
|
||||
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).
|
||||
@@ -14,7 +14,7 @@ As this is a preview release, you should expect a few bugs here and there. If
|
||||
you run into a problem you can reach out on
|
||||
[Discord](https://discord.gg/ollama), or file an
|
||||
[issue](https://github.com/ollama/ollama/issues).
|
||||
Logs will often be helpful in dianosing the problem (see
|
||||
Logs will often be helpful in diagnosing the problem (see
|
||||
[Troubleshooting](#troubleshooting) below)
|
||||
|
||||
## System Requirements
|
||||
|
||||
51
examples/go-chat/main.go
Normal file
51
examples/go-chat/main.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
api.Message{
|
||||
Role: "system",
|
||||
Content: "Provide very brief, concise responses",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "Name some unusual animals",
|
||||
},
|
||||
api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Monotreme, platypus, echidna",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "which of these is the most dangerous?",
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := &api.ChatRequest{
|
||||
Model: "llama2",
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
respFunc := func(resp api.ChatResponse) error {
|
||||
fmt.Print(resp.Message.Content)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Chat(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
40
examples/go-generate-streaming/main.go
Normal file
40
examples/go-generate-streaming/main.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// By default, GenerateRequest is streaming.
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
37
examples/go-generate/main.go
Normal file
37
examples/go-generate/main.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
|
||||
// set streaming to false
|
||||
Stream: new(bool),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
fmt.Println(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
47
examples/go-multimodal/main.go
Normal file
47
examples/go-multimodal/main.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) <= 1 {
|
||||
log.Fatal("usage: <image name>")
|
||||
}
|
||||
|
||||
imgData, err := os.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "llava",
|
||||
Prompt: "describe this image",
|
||||
Images: []api.ImageData{imgData},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
31
examples/go-pull-progress/main.go
Normal file
31
examples/go-pull-progress/main.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
req := &api.PullRequest{
|
||||
Model: "mistral",
|
||||
}
|
||||
progressFunc := func(resp api.ProgressResponse) error {
|
||||
fmt.Printf("Progress: status=%v, total=%v, completed=%v\n", resp.Status, resp.Total, resp.Completed)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Pull(ctx, req, progressFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Byte = 1
|
||||
Byte = 1
|
||||
|
||||
KiloByte = Byte * 1000
|
||||
MegaByte = KiloByte * 1000
|
||||
GigaByte = MegaByte * 1000
|
||||
TeraByte = GigaByte * 1000
|
||||
|
||||
KibiByte = Byte * 1024
|
||||
MebiByte = KibiByte * 1024
|
||||
GibiByte = MebiByte * 1024
|
||||
)
|
||||
|
||||
func HumanBytes(b int64) string {
|
||||
@@ -45,3 +50,14 @@ func HumanBytes(b int64) string {
|
||||
return fmt.Sprintf("%d %s", int(value), unit)
|
||||
}
|
||||
}
|
||||
|
||||
func HumanBytes2(b uint64) string {
|
||||
switch {
|
||||
case b >= MebiByte:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||
case b >= KibiByte:
|
||||
return fmt.Sprintf("%.1f KiB", float64(b)/KibiByte)
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
}
|
||||
|
||||
9
go.mod
9
go.mod
@@ -9,7 +9,7 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
@@ -19,7 +19,10 @@ require (
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
|
||||
@@ -68,7 +71,7 @@ require (
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
6
go.sum
6
go.sum
@@ -122,6 +122,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 h1:DV4iXjNn6fGeDl1AkZ1I0QB/0DBjrc7kPpxHrmuDzW4=
|
||||
@@ -236,8 +238,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
|
||||
// Set the visible devices if not already set
|
||||
// TODO - does sort order matter?
|
||||
devices := []string{}
|
||||
for i := range ids {
|
||||
if _, skipped := skip[i]; skipped {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
devices = append(devices, strconv.Itoa(i))
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func commonAMDValidateLibDir() (string, error) {
|
||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||
// the system version. Only use our bundled version if the system version doesn't work
|
||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
if hipPath != "" {
|
||||
hipLibDir := filepath.Join(hipPath, "bin")
|
||||
if rocmLibUsable(hipLibDir) {
|
||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||
return hipLibDir, nil
|
||||
}
|
||||
}
|
||||
|
||||
val := strings.Join(devices, ",")
|
||||
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
|
||||
} else {
|
||||
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
|
||||
// Scan the LD_LIBRARY_PATH or PATH
|
||||
pathEnv := "LD_LIBRARY_PATH"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathEnv = "PATH"
|
||||
}
|
||||
|
||||
paths := os.Getenv(pathEnv)
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
d, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if rocmLibUsable(d) {
|
||||
return d, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Well known location(s)
|
||||
if rocmLibUsable(RocmStandardLocation) {
|
||||
return RocmStandardLocation, nil
|
||||
}
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
|
||||
func (hl *HipLib) Release() {
|
||||
err := windows.FreeLibrary(hl.dll)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
|
||||
slog.Warn("failed to unload amdhip64.dll", "error", err)
|
||||
}
|
||||
hl.dll = 0
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
|
||||
return 0
|
||||
}
|
||||
if status != hipSuccess {
|
||||
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
|
||||
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
466
gpu/amd_linux.go
466
gpu/amd_linux.go
@@ -11,6 +11,8 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
// Discovery logic for AMD/ROCm GPUs
|
||||
@@ -24,9 +26,6 @@ const (
|
||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||||
RocmStandardLocation = "/opt/rocm/lib"
|
||||
|
||||
// TODO find a better way to detect iGPU instead of minimum memory
|
||||
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -35,14 +34,11 @@ var (
|
||||
)
|
||||
|
||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
|
||||
// and the user hasn't already set this variable
|
||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
// TODO - DRY this out with windows
|
||||
func AMDGetGPUInfo() []GpuInfo {
|
||||
resp := []GpuInfo{}
|
||||
if !AMDDetected() {
|
||||
return
|
||||
return resp
|
||||
}
|
||||
skip := map[int]interface{}{}
|
||||
|
||||
// Opportunistic logging of driver version to aid in troubleshooting
|
||||
ver, err := AMDDriverVersion()
|
||||
@@ -50,143 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
slog.Info("AMD Driver: " + ver)
|
||||
} else {
|
||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||
}
|
||||
|
||||
// If the user has specified exactly which GPUs to use, look up their memory
|
||||
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
|
||||
if visibleDevices != "" {
|
||||
ids := []int{}
|
||||
for _, idStr := range strings.Split(visibleDevices, ",") {
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
|
||||
} else {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
amdProcMemLookup(resp, nil, ids)
|
||||
return
|
||||
}
|
||||
|
||||
// Gather GFX version information from all detected cards
|
||||
gfx := AMDGFXVersions()
|
||||
verStrings := []string{}
|
||||
for i, v := range gfx {
|
||||
verStrings = append(verStrings, v.ToGFXString())
|
||||
if v.Major == 0 {
|
||||
// Silently skip CPUs
|
||||
skip[i] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if v.Major < 9 {
|
||||
// TODO consider this a build-time setting if we can support 8xx family GPUs
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
|
||||
skip[i] = struct{}{}
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
|
||||
|
||||
// Abort if all GPUs are skipped
|
||||
if len(skip) >= len(gfx) {
|
||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
||||
return
|
||||
}
|
||||
|
||||
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
|
||||
libDir, err := AMDValidateLibDir()
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
||||
return
|
||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||
var visibleDevices []string
|
||||
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
|
||||
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
|
||||
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
|
||||
switch {
|
||||
// TODO is this priorty order right?
|
||||
case hipVD != "":
|
||||
visibleDevices = strings.Split(hipVD, ",")
|
||||
case rocrVD != "":
|
||||
visibleDevices = strings.Split(rocrVD, ",")
|
||||
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
|
||||
// all our test systems show GPU-XX indicating UUID is not supported
|
||||
case gpuDO != "":
|
||||
visibleDevices = strings.Split(gpuDO, ",")
|
||||
}
|
||||
|
||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||
if gfxOverride == "" {
|
||||
supported, err := GetSupportedGFX(libDir)
|
||||
var supported []string
|
||||
libDir := ""
|
||||
|
||||
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||||
cpuCount := 0
|
||||
for _, match := range matches {
|
||||
slog.Debug("evaluating amdgpu node " + match)
|
||||
fp, err := os.Open(match)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
||||
return
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
|
||||
|
||||
for i, v := range gfx {
|
||||
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
skip[i] = struct{}{}
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||
}
|
||||
|
||||
if len(skip) >= len(gfx) {
|
||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int, len(gfx))
|
||||
i := 0
|
||||
for k := range gfx {
|
||||
ids[i] = k
|
||||
i++
|
||||
}
|
||||
amdProcMemLookup(resp, skip, ids)
|
||||
if resp.memInfo.DeviceCount == 0 {
|
||||
return
|
||||
}
|
||||
if len(skip) > 0 {
|
||||
amdSetVisibleDevices(ids, skip)
|
||||
}
|
||||
}
|
||||
|
||||
// Walk the sysfs nodes for the available GPUs and gather information from them
|
||||
// skipping over any devices in the skip map
|
||||
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||
resp.memInfo.DeviceCount = 0
|
||||
resp.memInfo.TotalMemory = 0
|
||||
resp.memInfo.FreeMemory = 0
|
||||
slog.Debug("discovering VRAM for amdgpu devices")
|
||||
if len(ids) == 0 {
|
||||
entries, err := os.ReadDir(AMDNodesSysfsDir)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
|
||||
return
|
||||
}
|
||||
for _, node := range entries {
|
||||
if !node.IsDir() {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.Atoi(node.Name())
|
||||
if err != nil {
|
||||
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
|
||||
|
||||
for _, id := range ids {
|
||||
if _, skipped := skip[id]; skipped {
|
||||
slog.Debug("failed to open sysfs node", "file", match, "error", err)
|
||||
continue
|
||||
}
|
||||
defer fp.Close()
|
||||
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse node ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(fp)
|
||||
isCPU := false
|
||||
var major, minor, patch uint64
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
||||
if strings.HasPrefix(line, "gfx_target_version") {
|
||||
ver := strings.Fields(line)
|
||||
|
||||
// Detect CPUs
|
||||
if len(ver) == 2 && ver[1] == "0" {
|
||||
slog.Debug("detected CPU " + match)
|
||||
isCPU = true
|
||||
break
|
||||
}
|
||||
|
||||
if len(ver) != 2 || len(ver[1]) < 5 {
|
||||
slog.Warn("malformed "+match, "gfx_target_version", line)
|
||||
// If this winds up being a CPU, our offsets may be wrong
|
||||
continue
|
||||
}
|
||||
l := len(ver[1])
|
||||
var err1, err2, err3 error
|
||||
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||||
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||||
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
slog.Debug("malformed int " + line)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// TODO - any other properties we want to extract and record?
|
||||
// vendor_id + device_id -> pci lookup for "Name"
|
||||
// Other metrics that may help us understand relative performance between multiple GPUs
|
||||
}
|
||||
|
||||
if isCPU {
|
||||
cpuCount++
|
||||
continue
|
||||
}
|
||||
|
||||
// CPUs are always first in the list
|
||||
gpuID := nodeID - cpuCount
|
||||
|
||||
// Shouldn't happen, but just in case...
|
||||
if gpuID < 0 {
|
||||
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
|
||||
return []GpuInfo{}
|
||||
}
|
||||
|
||||
if int(major) < RocmComputeMin {
|
||||
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Look up the memory for the current node
|
||||
totalMemory := uint64(0)
|
||||
usedMemory := uint64(0)
|
||||
// Adjust for sysfs vs HIP ids
|
||||
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
|
||||
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
|
||||
propFiles, err := filepath.Glob(propGlob)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
|
||||
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
|
||||
}
|
||||
// 1 or more memory banks - sum the values of all of them
|
||||
for _, propFile := range propFiles {
|
||||
fp, err := os.Open(propFile)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
|
||||
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
|
||||
continue
|
||||
}
|
||||
defer fp.Close()
|
||||
@@ -209,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||
}
|
||||
}
|
||||
if totalMemory == 0 {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
|
||||
skip[id] = struct{}{}
|
||||
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
|
||||
continue
|
||||
}
|
||||
if totalMemory < IGPUMemLimit {
|
||||
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
|
||||
skip[id] = struct{}{}
|
||||
continue
|
||||
}
|
||||
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
|
||||
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
|
||||
usedFiles, err := filepath.Glob(usedGlob)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
|
||||
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, usedFile := range usedFiles {
|
||||
fp, err := os.Open(usedFile)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
|
||||
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
|
||||
continue
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
|
||||
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
|
||||
continue
|
||||
}
|
||||
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
|
||||
slog.Warn("malformed used memory", "data", string(data), "error", err)
|
||||
continue
|
||||
}
|
||||
usedMemory += used
|
||||
}
|
||||
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
|
||||
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
|
||||
resp.memInfo.DeviceCount++
|
||||
resp.memInfo.TotalMemory += totalMemory
|
||||
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if totalMemory < IGPUMemLimit {
|
||||
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: fmt.Sprintf("%d", gpuID),
|
||||
// Name: not exposed in sysfs directly, would require pci device id lookup
|
||||
Major: int(major),
|
||||
Minor: int(minor),
|
||||
Patch: int(patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
}
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if visible == gpuInfo.ID {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||
// even if the user overrides, we still need to validate the library
|
||||
if libDir == "" {
|
||||
libDir, err = AMDValidateLibDir()
|
||||
if err != nil {
|
||||
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||
return []GpuInfo{}
|
||||
}
|
||||
}
|
||||
gpuInfo.DependencyPath = libDir
|
||||
|
||||
if gfxOverride == "" {
|
||||
// Only load supported list once
|
||||
if len(supported) == 0 {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||
return []GpuInfo{}
|
||||
}
|
||||
slog.Debug("rocm supported GPUs", "types", supported)
|
||||
}
|
||||
gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
|
||||
if !slices.Contains[[]string, string](supported, gfx) {
|
||||
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
continue
|
||||
} else {
|
||||
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
if resp.memInfo.DeviceCount > 0 {
|
||||
resp.Library = "rocm"
|
||||
if len(resp) == 0 {
|
||||
slog.Info("no compatible amdgpu devices detected")
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
||||
@@ -263,87 +297,24 @@ func AMDDetected() bool {
|
||||
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
||||
return false
|
||||
} else if err != nil {
|
||||
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
|
||||
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func setupLink(source, target string) error {
|
||||
if err := os.RemoveAll(target); err != nil {
|
||||
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
|
||||
}
|
||||
if err := os.Symlink(source, target); err != nil {
|
||||
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the AMD rocm lib dir is wired up
|
||||
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
||||
// failing that, tell the user how to download it on their own
|
||||
func AMDValidateLibDir() (string, error) {
|
||||
// We rely on the rpath compiled into our library to find rocm
|
||||
// so we establish a symlink to wherever we find it on the system
|
||||
// to <payloads>/rocm
|
||||
payloadsDir, err := PayloadsDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If we already have a rocm dependency wired, nothing more to do
|
||||
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
|
||||
// next to the running binary
|
||||
exe, err := os.Executable()
|
||||
libDir, err := commonAMDValidateLibDir()
|
||||
if err == nil {
|
||||
peerDir := filepath.Dir(exe)
|
||||
if rocmLibUsable(peerDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
||||
}
|
||||
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(peerDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
||||
}
|
||||
return libDir, nil
|
||||
}
|
||||
|
||||
// Well known ollama installer path
|
||||
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
||||
if rocmLibUsable(installedRocmDir) {
|
||||
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
|
||||
}
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
if hipPath != "" {
|
||||
hipLibDir := filepath.Join(hipPath, "lib")
|
||||
if rocmLibUsable(hipLibDir) {
|
||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
|
||||
}
|
||||
}
|
||||
|
||||
// Scan the library path for potential matches
|
||||
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
||||
for _, ldPath := range ldPaths {
|
||||
d, err := filepath.Abs(ldPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if rocmLibUsable(d) {
|
||||
return rocmTargetDir, setupLink(d, rocmTargetDir)
|
||||
}
|
||||
}
|
||||
|
||||
// Well known location(s)
|
||||
if rocmLibUsable("/opt/rocm/lib") {
|
||||
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
|
||||
return installedRocmDir, nil
|
||||
}
|
||||
|
||||
// If we still haven't found a usable rocm, the user will have to install it on their own
|
||||
@@ -367,68 +338,3 @@ func AMDDriverVersion() (string, error) {
|
||||
}
|
||||
return strings.TrimSpace(string(verString)), nil
|
||||
}
|
||||
|
||||
func AMDGFXVersions() map[int]Version {
|
||||
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
|
||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||
res := map[int]Version{}
|
||||
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||||
for _, match := range matches {
|
||||
fp, err := os.Open(match)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
|
||||
continue
|
||||
}
|
||||
defer fp.Close()
|
||||
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
// Skipping the CPU
|
||||
continue
|
||||
}
|
||||
// Align with HIP IDs (zero is first GPU, not CPU)
|
||||
i -= 1
|
||||
|
||||
scanner := bufio.NewScanner(fp)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "gfx_target_version") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 || len(ver[1]) < 5 {
|
||||
if ver[1] != "0" {
|
||||
slog.Debug("malformed " + line)
|
||||
}
|
||||
res[i] = Version{
|
||||
Major: 0,
|
||||
Minor: 0,
|
||||
Patch: 0,
|
||||
}
|
||||
continue
|
||||
}
|
||||
l := len(ver[1])
|
||||
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||||
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||||
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
slog.Debug("malformed int " + line)
|
||||
continue
|
||||
}
|
||||
|
||||
res[i] = Version{
|
||||
Major: uint(major),
|
||||
Minor: uint(minor),
|
||||
Patch: uint(patch),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (v Version) ToGFXString() string {
|
||||
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,36 +25,32 @@ var (
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
func AMDGetGPUInfo() []GpuInfo {
|
||||
resp := []GpuInfo{}
|
||||
hl, err := NewHipLib()
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return
|
||||
return nil
|
||||
}
|
||||
defer hl.Release()
|
||||
skip := map[int]interface{}{}
|
||||
ids := []int{}
|
||||
resp.memInfo.DeviceCount = 0
|
||||
resp.memInfo.TotalMemory = 0
|
||||
resp.memInfo.FreeMemory = 0
|
||||
|
||||
ver, err := hl.AMDDriverVersion()
|
||||
if err == nil {
|
||||
slog.Info("AMD Driver: " + ver)
|
||||
} else {
|
||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
|
||||
slog.Debug("error looking up amd driver version", "error", err)
|
||||
}
|
||||
|
||||
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
|
||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||
count := hl.HipGetDeviceCount()
|
||||
if count == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
libDir, err := AMDValidateLibDir()
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
||||
return
|
||||
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var supported []string
|
||||
@@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
if gfxOverride == "" {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
||||
return
|
||||
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("detected %d hip devices", count))
|
||||
slog.Info("detected hip devices", "count", count)
|
||||
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||
for i := 0; i < count; i++ {
|
||||
ids = append(ids, i)
|
||||
err = hl.HipSetDevice(i)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
||||
skip[i] = struct{}{}
|
||||
slog.Warn("set device", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
props, err := hl.HipGetDeviceProperties(i)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
||||
skip[i] = struct{}{}
|
||||
slog.Warn("get properties", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
n := bytes.IndexByte(props.Name[:], 0)
|
||||
name := string(props.Name[:n])
|
||||
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
|
||||
// TODO is UUID actually populated on windows?
|
||||
// Can luid be used on windows for setting visible devices (and is it actually set?)
|
||||
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
||||
gfx := string(props.GcnArchName[:n])
|
||||
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
|
||||
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
|
||||
var major, minor, patch string
|
||||
switch len(gfx) {
|
||||
case 6:
|
||||
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
|
||||
case 7:
|
||||
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
|
||||
}
|
||||
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
||||
// TODO Why isn't props.iGPU accurate!?
|
||||
if strings.EqualFold(name, iGPUName) {
|
||||
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
|
||||
skip[i] = struct{}{}
|
||||
slog.Info("iGPU detected skipping", "id", i)
|
||||
continue
|
||||
}
|
||||
if gfxOverride == "" {
|
||||
if !slices.Contains[[]string, string](supported, gfx) {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
|
||||
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
skip[i] = struct{}{}
|
||||
continue
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
|
||||
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
|
||||
}
|
||||
}
|
||||
|
||||
totalMemory, freeMemory, err := hl.HipMemGetInfo()
|
||||
freeMemory, totalMemory, err := hl.HipMemGetInfo()
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
||||
slog.Warn("get mem info", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO according to docs, freeMem may lie on windows!
|
||||
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
|
||||
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
|
||||
resp.memInfo.DeviceCount++
|
||||
resp.memInfo.TotalMemory += totalMemory
|
||||
resp.memInfo.FreeMemory += freeMemory
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if totalMemory < IGPUMemLimit {
|
||||
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO revisit this once ROCm v6 is available on windows.
|
||||
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
||||
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: freeMemory,
|
||||
},
|
||||
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
|
||||
DependencyPath: libDir,
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
}
|
||||
if major != "" {
|
||||
gpuInfo.Major, err = strconv.Atoi(major)
|
||||
if err != nil {
|
||||
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||
}
|
||||
}
|
||||
if minor != "" {
|
||||
gpuInfo.Minor, err = strconv.Atoi(minor)
|
||||
if err != nil {
|
||||
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||
}
|
||||
}
|
||||
if patch != "" {
|
||||
// Patch rev is hex; e.g. gfx90a
|
||||
p, err := strconv.ParseInt(patch, 16, 0)
|
||||
if err != nil {
|
||||
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||
} else {
|
||||
gpuInfo.Patch = int(p)
|
||||
}
|
||||
}
|
||||
if gpuInfo.Major < RocmComputeMin {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
|
||||
continue
|
||||
}
|
||||
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
if resp.memInfo.DeviceCount > 0 {
|
||||
resp.Library = "rocm"
|
||||
}
|
||||
// Abort if all GPUs are skipped
|
||||
if len(skip) >= count {
|
||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
||||
return
|
||||
}
|
||||
if len(skip) > 0 {
|
||||
amdSetVisibleDevices(ids, skip)
|
||||
}
|
||||
UpdatePath(libDir)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func AMDValidateLibDir() (string, error) {
|
||||
// On windows non-admins typically can't create links
|
||||
// so instead of trying to rely on rpath and a link in
|
||||
// $LibDir/rocm, we instead rely on setting PATH to point
|
||||
// to the location of the ROCm library
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
libDir, err := commonAMDValidateLibDir()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
return libDir, nil
|
||||
}
|
||||
|
||||
// Installer payload (if we're running from some other location)
|
||||
@@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
if hipPath != "" {
|
||||
hipLibDir := filepath.Join(hipPath, "bin")
|
||||
if rocmLibUsable(hipLibDir) {
|
||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||
return hipLibDir, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Well known location(s)
|
||||
if rocmLibUsable(RocmStandardLocation) {
|
||||
return RocmStandardLocation, nil
|
||||
}
|
||||
|
||||
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,11 +22,65 @@ var (
|
||||
func PayloadsDir() (string, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
var err error
|
||||
if payloadsDir == "" {
|
||||
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
|
||||
// On Windows we do not carry the payloads inside the main executable
|
||||
if runtime.GOOS == "windows" && runnersDir == "" {
|
||||
appExe, err := os.Executable()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup executable path", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup working directory", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var paths []string
|
||||
for _, root := range []string{appExe, cwd} {
|
||||
paths = append(paths,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||
)
|
||||
}
|
||||
|
||||
// Try a few variations to improve developer experience when building from source in the local tree
|
||||
for _, p := range paths {
|
||||
candidate := filepath.Join(p, "ollama_runners")
|
||||
_, err := os.Stat(candidate)
|
||||
if err == nil {
|
||||
runnersDir = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
if runnersDir == "" {
|
||||
err = fmt.Errorf("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
|
||||
slog.Error("incomplete distribution", "error", err)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if runnersDir != "" {
|
||||
payloadsDir = runnersDir
|
||||
return payloadsDir, nil
|
||||
}
|
||||
|
||||
// The remainder only applies on non-windows where we still carry payloads in the main executable
|
||||
cleanupTmpDirs()
|
||||
tmpDir, err := os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
tmpDir := os.Getenv("OLLAMA_TMPDIR")
|
||||
if tmpDir == "" {
|
||||
tmpDir, err = os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.MkdirAll(tmpDir, 0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir %s: %w", tmpDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track our pid so we can clean up orphaned tmpdirs
|
||||
@@ -70,7 +125,7 @@ func cleanupTmpDirs() {
|
||||
}
|
||||
err = os.RemoveAll(d)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
|
||||
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,13 +133,19 @@ func cleanupTmpDirs() {
|
||||
func Cleanup() {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if payloadsDir != "" {
|
||||
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
|
||||
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
|
||||
// We want to fully clean up the tmpdir parent of the payloads dir
|
||||
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
|
||||
slog.Debug("cleaning up", "dir", tmpDir)
|
||||
err := os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
// On windows, if we remove too quickly the llama.dll may still be in-use and fail to remove
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
err = os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -105,7 +166,7 @@ func UpdatePath(dir string) {
|
||||
}
|
||||
}
|
||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
|
||||
slog.Info("updating", "PATH", newPath)
|
||||
os.Setenv("PATH", newPath)
|
||||
}
|
||||
// linux and darwin rely on rpath
|
||||
|
||||
22
gpu/cuda_common.go
Normal file
22
gpu/cuda_common.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "cuda" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
|
||||
}
|
||||
266
gpu/gpu.go
266
gpu/gpu.go
@@ -16,43 +16,32 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
type handles struct {
|
||||
nvml *C.nvml_handle_t
|
||||
cudart *C.cudart_handle_t
|
||||
deviceCount int
|
||||
cudart *C.cudart_handle_t
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
)
|
||||
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
|
||||
// Possible locations for the nvidia-ml library
|
||||
var NvmlLinuxGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libnvidia-ml.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
|
||||
"/opt/cuda/lib64/libnvidia-ml.so*",
|
||||
"/usr/lib*/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
|
||||
"/usr/local/lib*/libnvidia-ml.so*",
|
||||
var RocmComputeMin = 9
|
||||
|
||||
// TODO: are these stubs ever valid?
|
||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
||||
}
|
||||
|
||||
var NvmlWindowsGlobs = []string{
|
||||
"c:\\Windows\\System32\\nvml.dll",
|
||||
}
|
||||
// TODO find a better way to detect iGPU instead of minimum memory
|
||||
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||
|
||||
var CudartLinuxGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libcudart.so*",
|
||||
@@ -78,30 +67,22 @@ var CudartWindowsGlobs = []string{
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initGPUHandles() {
|
||||
func initGPUHandles() *handles {
|
||||
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
gpuHandles = &handles{nil, nil}
|
||||
var nvmlMgmtName string
|
||||
var nvmlMgmtPatterns []string
|
||||
gpuHandles := &handles{}
|
||||
var cudartMgmtName string
|
||||
var cudartMgmtPatterns []string
|
||||
|
||||
tmpDir, _ := PayloadsDir()
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
nvmlMgmtName = "nvml.dll"
|
||||
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
|
||||
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
|
||||
cudartMgmtName = "cudart64_*.dll"
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
||||
case "linux":
|
||||
nvmlMgmtName = "libnvidia-ml.so"
|
||||
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
|
||||
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
|
||||
cudartMgmtName = "libcudart.so*"
|
||||
if tmpDir != "" {
|
||||
// TODO - add "payloads" for subprocess
|
||||
@@ -109,41 +90,35 @@ func initGPUHandles() {
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||
default:
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
slog.Info("Detecting GPU type")
|
||||
slog.Info("Detecting GPUs")
|
||||
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
||||
if len(cudartLibPaths) > 0 {
|
||||
cudart := LoadCUDARTMgmt(cudartLibPaths)
|
||||
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
|
||||
if cudart != nil {
|
||||
slog.Info("Nvidia GPU detected via cudart")
|
||||
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
|
||||
gpuHandles.cudart = cudart
|
||||
return
|
||||
gpuHandles.deviceCount = deviceCount
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
|
||||
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
|
||||
if len(nvmlLibPaths) > 0 {
|
||||
nvml := LoadNVMLMgmt(nvmlLibPaths)
|
||||
if nvml != nil {
|
||||
slog.Info("Nvidia GPU detected via nvidia-ml")
|
||||
gpuHandles.nvml = nvml
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
if gpuHandles == nil {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
gpuHandles := initGPUHandles()
|
||||
defer func() {
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_release(*gpuHandles.cudart)
|
||||
}
|
||||
}()
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
@@ -152,69 +127,63 @@ func GetGPUInfo() GpuInfo {
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := GpuInfo{}
|
||||
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
|
||||
resp := []GpuInfo{}
|
||||
|
||||
// NVIDIA first
|
||||
for i := 0; i < gpuHandles.deviceCount; i++ {
|
||||
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
continue
|
||||
}
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "cuda",
|
||||
}
|
||||
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
|
||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else if memInfo.count > 0 {
|
||||
// Verify minimum compute capability
|
||||
var cc C.nvml_compute_capability_t
|
||||
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
|
||||
if cc.err != nil {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
|
||||
C.free(unsafe.Pointer(cc.err))
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else if memInfo.count > 0 {
|
||||
// Verify minimum compute capability
|
||||
var cc C.cudart_compute_capability_t
|
||||
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
|
||||
if cc.err != nil {
|
||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
|
||||
C.free(unsafe.Pointer(cc.err))
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
AMDGetGPUInfo(&resp)
|
||||
if resp.Library != "" {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
if resp.Library == "" {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
resp.Library = "cpu"
|
||||
resp.Variant = cpuVariant
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
return resp
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Major = int(memInfo.major)
|
||||
gpuInfo.Minor = int(memInfo.minor)
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
|
||||
// Then AMD
|
||||
resp = append(resp, AMDGetGPUInfo()...)
|
||||
|
||||
if len(resp) == 0 {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
return resp
|
||||
}
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "cpu",
|
||||
Variant: cpuVariant,
|
||||
}
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
|
||||
resp.DeviceCount = uint32(memInfo.count)
|
||||
resp.FreeMemory = uint64(memInfo.free)
|
||||
resp.TotalMemory = uint64(memInfo.total)
|
||||
return resp
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var ret memInfo
|
||||
var info C.mem_info_t
|
||||
C.cpu_check_ram(&info)
|
||||
@@ -227,42 +196,11 @@ func getCPUMem() (memInfo, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckVRAM() (int64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
}
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
overhead := gpuInfo.FreeMemory / 10
|
||||
gpus := uint64(gpuInfo.DeviceCount)
|
||||
if overhead < gpus*1024*1024*1024 {
|
||||
overhead = gpus * 1024 * 1024 * 1024
|
||||
}
|
||||
// Assigning full reported free memory for Tegras due to OS controlled caching.
|
||||
if CudaTegra != "" {
|
||||
// Setting overhead for non-Tegra devices
|
||||
overhead = 0
|
||||
}
|
||||
avail := int64(gpuInfo.FreeMemory - overhead)
|
||||
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
}
|
||||
|
||||
func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||
var ldPaths []string
|
||||
gpuLibPaths := []string{}
|
||||
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
|
||||
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
@@ -280,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
}
|
||||
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
|
||||
slog.Debug("gpu library search", "globs", patterns)
|
||||
for _, pattern := range patterns {
|
||||
// Ignore glob discovery errors
|
||||
matches, _ := filepath.Glob(pattern)
|
||||
@@ -308,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
|
||||
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
|
||||
return gpuLibPaths
|
||||
}
|
||||
|
||||
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
|
||||
var resp C.nvml_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range nvmlLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.nvml_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return &resp.ch
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
||||
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
|
||||
var resp C.cudart_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range cudartLibPaths {
|
||||
@@ -337,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.cudart_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
|
||||
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return &resp.ch
|
||||
return int(resp.num_devices), &resp.ch, libPath
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return 0, nil, ""
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
@@ -352,3 +273,22 @@ func getVerboseState() C.uint16_t {
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
if len(l) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
switch l[0].Library {
|
||||
case "cuda":
|
||||
return cudaGetVisibleDevicesEnv(l)
|
||||
case "rocm":
|
||||
return rocmGetVisibleDevicesEnv(l)
|
||||
default:
|
||||
slog.Debug("no filter required for library " + l[0].Library)
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build darwin
|
||||
|
||||
package gpu
|
||||
|
||||
/*
|
||||
@@ -9,52 +7,41 @@ package gpu
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (int64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
mem, _ := GetCPUMem()
|
||||
if runtime.GOARCH == "amd64" {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM())
|
||||
return recommendedMaxVRAM, nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
mem, _ := getCPUMem()
|
||||
if runtime.GOARCH == "amd64" {
|
||||
return GpuInfo{
|
||||
Library: "cpu",
|
||||
Variant: GetCPUVariant(),
|
||||
memInfo: mem,
|
||||
return []GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
Variant: GetCPUVariant(),
|
||||
memInfo: mem,
|
||||
},
|
||||
}
|
||||
}
|
||||
return GpuInfo{
|
||||
info := GpuInfo{
|
||||
Library: "metal",
|
||||
memInfo: mem,
|
||||
ID: "0",
|
||||
}
|
||||
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
|
||||
|
||||
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||
info.FreeMemory = info.TotalMemory
|
||||
|
||||
info.MinimumMemory = 0
|
||||
return []GpuInfo{info}
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: 0,
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
FreeMemory: 0,
|
||||
DeviceCount: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
// No-op on darwin
|
||||
return "", ""
|
||||
}
|
||||
|
||||
@@ -38,12 +38,17 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GPU_ID_LEN 64
|
||||
|
||||
typedef struct mem_info {
|
||||
char *err; // If non-nill, caller responsible for freeing
|
||||
char gpu_id[GPU_ID_LEN];
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
unsigned int count;
|
||||
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
|
||||
char *err; // If non-nill, caller responsible for freeing
|
||||
|
||||
// Compute Capability
|
||||
int major;
|
||||
int minor;
|
||||
} mem_info_t;
|
||||
|
||||
void cpu_check_ram(mem_info_t *resp);
|
||||
@@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
|
||||
}
|
||||
#endif
|
||||
|
||||
#include "gpu_info_nvml.h"
|
||||
#include "gpu_info_cudart.h"
|
||||
|
||||
#endif // __GPU_INFO_H__
|
||||
|
||||
@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||
MEMORYSTATUSEX info;
|
||||
info.dwLength = sizeof(info);
|
||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||
resp->count = 1;
|
||||
resp->total = info.ullTotalPhys;
|
||||
resp->free = info.ullAvailPhys;
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||
} else {
|
||||
resp->err = LOAD_ERR();
|
||||
}
|
||||
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||
if (sysinfo(&info) != 0) {
|
||||
resp->err = strdup(strerror(errno));
|
||||
} else {
|
||||
resp->count = 1;
|
||||
resp->total = info.totalram * info.mem_unit;
|
||||
resp->free = info.freeram * info.mem_unit;
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
cudartReturn_t ret;
|
||||
resp->err = NULL;
|
||||
resp->num_devices = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
||||
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
||||
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
||||
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
char *msg = LOAD_ERR();
|
||||
@@ -62,6 +58,10 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
@@ -81,110 +81,101 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
|
||||
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
cudartMemory_t memInfo = {0,0,0};
|
||||
cudartReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("cudart handle isn't initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
// cudaGetDeviceCount takes int type, resp-> count is uint
|
||||
int deviceCount;
|
||||
ret = (*h.cudaGetDeviceCount)(&deviceCount);
|
||||
ret = (*h.cudaSetDevice)(i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
cudaDeviceProp_t props;
|
||||
ret = (*h.cudaGetDeviceProperties)(&props, i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
} else {
|
||||
resp->count = (unsigned int)deviceCount;
|
||||
int allNull = 1;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (props.uuid.bytes[j] != 0) {
|
||||
allNull = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allNull != 0) {
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
} else {
|
||||
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
props.uuid.bytes[0],
|
||||
props.uuid.bytes[1],
|
||||
props.uuid.bytes[2],
|
||||
props.uuid.bytes[3],
|
||||
props.uuid.bytes[4],
|
||||
props.uuid.bytes[5],
|
||||
props.uuid.bytes[6],
|
||||
props.uuid.bytes[7],
|
||||
props.uuid.bytes[8],
|
||||
props.uuid.bytes[9],
|
||||
props.uuid.bytes[10],
|
||||
props.uuid.bytes[11],
|
||||
props.uuid.bytes[12],
|
||||
props.uuid.bytes[13],
|
||||
props.uuid.bytes[14],
|
||||
props.uuid.bytes[15]
|
||||
);
|
||||
}
|
||||
resp->major = props.major;
|
||||
resp->minor = props.minor;
|
||||
|
||||
// TODO add other useful properties from props
|
||||
}
|
||||
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp-> count; i++) {
|
||||
ret = (*h.cudaSetDevice)(i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
|
||||
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
|
||||
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
|
||||
|
||||
resp->total += memInfo.total;
|
||||
resp->free += memInfo.free;
|
||||
}
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
}
|
||||
|
||||
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
|
||||
resp->err = NULL;
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
int major = 0;
|
||||
int minor = 0;
|
||||
cudartReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("cudart handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
int devices;
|
||||
ret = (*h.cudaGetDeviceCount)(&devices);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < devices; i++) {
|
||||
ret = (*h.cudaSetDevice)(i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report the lowest major.minor we detect as that limits our compatibility
|
||||
if (resp->major == 0 || resp->major > major ) {
|
||||
resp->major = major;
|
||||
resp->minor = minor;
|
||||
} else if ( resp->major == major && resp->minor > minor ) {
|
||||
resp->minor = minor;
|
||||
}
|
||||
}
|
||||
void cudart_release(cudart_handle_t h) {
|
||||
LOG(h.verbose, "releasing cudart library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -6,13 +6,20 @@
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum cudartReturn_enum {
|
||||
CUDART_SUCCESS = 0,
|
||||
CUDART_UNSUPPORTED = 1,
|
||||
CUDA_ERROR_INVALID_VALUE = 1,
|
||||
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} cudartReturn_t;
|
||||
|
||||
typedef enum cudartDeviceAttr_enum {
|
||||
cudartDevAttrComputeCapabilityMajor = 75,
|
||||
cudartDevAttrComputeCapabilityMinor = 76,
|
||||
|
||||
// TODO - not yet wired up but may be useful for Jetson or other
|
||||
// integrated GPU scenarios with shared memory
|
||||
cudaDevAttrIntegrated = 18
|
||||
|
||||
} cudartDeviceAttr_t;
|
||||
|
||||
typedef void *cudartDevice_t; // Opaque is sufficient
|
||||
@@ -27,6 +34,92 @@ typedef struct cudartDriverVersion {
|
||||
int minor;
|
||||
} cudartDriverVersion_t;
|
||||
|
||||
typedef struct cudaUUID {
|
||||
unsigned char bytes[16];
|
||||
} cudaUUID_t;
|
||||
typedef struct cudaDeviceProp {
|
||||
char name[256]; /**< ASCII string identifying device */
|
||||
cudaUUID_t uuid; /**< 16-byte unique identifier */
|
||||
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
|
||||
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
|
||||
size_t totalGlobalMem; /**< Global memory available on device in bytes */
|
||||
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
|
||||
int regsPerBlock; /**< 32-bit registers available per block */
|
||||
int warpSize; /**< Warp size in threads */
|
||||
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
|
||||
int maxThreadsPerBlock; /**< Maximum number of threads per block */
|
||||
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
|
||||
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
|
||||
int clockRate; /**< Clock frequency in kilohertz */
|
||||
size_t totalConstMem; /**< Constant memory available on device in bytes */
|
||||
int major; /**< Major compute capability */
|
||||
int minor; /**< Minor compute capability */
|
||||
size_t textureAlignment; /**< Alignment requirement for textures */
|
||||
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
|
||||
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
|
||||
int multiProcessorCount; /**< Number of multiprocessors on device */
|
||||
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
|
||||
int integrated; /**< Device is integrated as opposed to discrete */
|
||||
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
|
||||
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
|
||||
int maxTexture1D; /**< Maximum 1D texture size */
|
||||
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
|
||||
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
|
||||
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
|
||||
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
|
||||
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
|
||||
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
|
||||
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
|
||||
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
|
||||
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
|
||||
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
|
||||
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
|
||||
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
|
||||
int maxSurface1D; /**< Maximum 1D surface size */
|
||||
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
|
||||
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
|
||||
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
|
||||
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
|
||||
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
|
||||
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
|
||||
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
|
||||
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
|
||||
int ECCEnabled; /**< Device has ECC support enabled */
|
||||
int pciBusID; /**< PCI bus ID of the device */
|
||||
int pciDeviceID; /**< PCI device ID of the device */
|
||||
int pciDomainID; /**< PCI domain ID of the device */
|
||||
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
|
||||
int asyncEngineCount; /**< Number of asynchronous engines */
|
||||
int unifiedAddressing; /**< Device shares a unified address space with the host */
|
||||
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
|
||||
int memoryBusWidth; /**< Global memory bus width in bits */
|
||||
int l2CacheSize; /**< Size of L2 cache in bytes */
|
||||
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
|
||||
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
|
||||
int streamPrioritiesSupported; /**< Device supports stream priorities */
|
||||
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
|
||||
int localL1CacheSupported; /**< Device supports caching locals in L1 */
|
||||
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
|
||||
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
|
||||
int managedMemory; /**< Device supports allocating managed memory on this system */
|
||||
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
|
||||
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
|
||||
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
|
||||
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
|
||||
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
|
||||
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
|
||||
int computePreemptionSupported; /**< Device supports Compute Preemption */
|
||||
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
|
||||
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
|
||||
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
|
||||
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
|
||||
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
|
||||
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
|
||||
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
|
||||
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
|
||||
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
|
||||
} cudaDeviceProp_t;
|
||||
|
||||
typedef struct cudart_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
@@ -37,23 +130,18 @@ typedef struct cudart_handle {
|
||||
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
||||
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
||||
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
||||
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
|
||||
} cudart_handle_t;
|
||||
|
||||
typedef struct cudart_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
cudart_handle_t ch;
|
||||
int num_devices;
|
||||
} cudart_init_resp_t;
|
||||
|
||||
typedef struct cudart_compute_capability {
|
||||
char *err;
|
||||
int major;
|
||||
int minor;
|
||||
} cudart_compute_capability_t;
|
||||
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
|
||||
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
|
||||
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void cudart_release(cudart_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_CUDART_H__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#import <Metal/Metal.h>
|
||||
#include <stdint.h>
|
||||
uint64_t getRecommendedMaxVRAM();
|
||||
uint64_t getPhysicalMemory();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build darwin
|
||||
// go:build darwin
|
||||
#include "gpu_info_darwin.h"
|
||||
|
||||
uint64_t getRecommendedMaxVRAM()
|
||||
{
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
uint64_t getRecommendedMaxVRAM() {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t getPhysicalMemory() {
|
||||
return [[NSProcessInfo processInfo] physicalMemory];
|
||||
}
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "gpu_info_nvml.h"
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
||||
nvmlReturn_t ret;
|
||||
resp->err = NULL;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
|
||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
|
||||
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
|
||||
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
|
||||
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
|
||||
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
|
||||
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
|
||||
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
nvml_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.nvmlInit_v2)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
|
||||
}
|
||||
}
|
||||
|
||||
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
nvmlDevice_t device;
|
||||
nvmlMemory_t memInfo = {0};
|
||||
nvmlReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("nvml handle isn't initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp->count; i++) {
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
if (h.verbose) {
|
||||
nvmlBrandType_t brand = 0;
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover, but don't fail on error
|
||||
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
|
||||
}
|
||||
}
|
||||
|
||||
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
|
||||
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
|
||||
|
||||
resp->total += memInfo.total;
|
||||
resp->free += memInfo.free;
|
||||
}
|
||||
}
|
||||
|
||||
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
|
||||
resp->err = NULL;
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
nvmlDevice_t device;
|
||||
int major = 0;
|
||||
int minor = 0;
|
||||
nvmlReturn_t ret;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("nvml handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned int devices;
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < devices; i++) {
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
// Report the lowest major.minor we detect as that limits our compatibility
|
||||
if (resp->major == 0 || resp->major > major ) {
|
||||
resp->major = major;
|
||||
resp->minor = minor;
|
||||
} else if ( resp->major == major && resp->minor > minor ) {
|
||||
resp->minor = minor;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // __APPLE__
|
||||
@@ -1,56 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_NVML_H__
|
||||
#define __GPU_INFO_NVML_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum nvmlReturn_enum {
|
||||
NVML_SUCCESS = 0,
|
||||
// Other values omitted for now...
|
||||
} nvmlReturn_t;
|
||||
typedef void *nvmlDevice_t; // Opaque is sufficient
|
||||
typedef struct nvmlMemory_st {
|
||||
unsigned long long total;
|
||||
unsigned long long free;
|
||||
unsigned long long used;
|
||||
} nvmlMemory_t;
|
||||
|
||||
typedef enum nvmlBrandType_enum
|
||||
{
|
||||
NVML_BRAND_UNKNOWN = 0,
|
||||
} nvmlBrandType_t;
|
||||
|
||||
typedef struct nvml_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
||||
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
|
||||
} nvml_handle_t;
|
||||
|
||||
typedef struct nvml_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
nvml_handle_t ch;
|
||||
} nvml_init_resp_t;
|
||||
|
||||
typedef struct nvml_compute_capability {
|
||||
char *err;
|
||||
int major;
|
||||
int minor;
|
||||
} nvml_compute_capability_t;
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
||||
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
|
||||
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
|
||||
|
||||
#endif // __GPU_INFO_NVML_H__
|
||||
#endif // __APPLE__
|
||||
@@ -9,23 +9,16 @@ import (
|
||||
|
||||
func TestBasicGetGPUInfo(t *testing.T) {
|
||||
info := GetGPUInfo()
|
||||
assert.Contains(t, "cuda rocm cpu metal", info.Library)
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// TODO - remove this once MacOS returns some size for CPU
|
||||
return
|
||||
case "linux", "windows":
|
||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
||||
assert.Greater(t, info.DeviceCount, uint32(0))
|
||||
default:
|
||||
return
|
||||
assert.Greater(t, len(info), 0)
|
||||
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||
if info[0].Library != "cpu" {
|
||||
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||
assert.Greater(t, info[0].FreeMemory, uint64(0))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPUMemInfo(t *testing.T) {
|
||||
info, err := getCPUMem()
|
||||
info, err := GetCPUMem()
|
||||
assert.NoError(t, err)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
|
||||
52
gpu/types.go
52
gpu/types.go
@@ -3,7 +3,6 @@ package gpu
|
||||
type memInfo struct {
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
DeviceCount uint32 `json:"device_count,omitempty"`
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
@@ -14,11 +13,52 @@ type GpuInfo struct {
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath string `json:"lib_path,omitempty"`
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
|
||||
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
|
||||
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
|
||||
|
||||
// TODO other performance capability info to help in scheduling decisions
|
||||
}
|
||||
|
||||
type Version struct {
|
||||
Major uint
|
||||
Minor uint
|
||||
Patch uint
|
||||
type GpuInfoList []GpuInfo
|
||||
|
||||
// Split up the set of gpu info's by Library and variant
|
||||
func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||
resp := []GpuInfoList{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, info.Library)
|
||||
resp = append(resp, []GpuInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Sort by Free Space
|
||||
type ByFreeMemory []GpuInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||
|
||||
@@ -4,11 +4,14 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
@@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
}
|
||||
|
||||
func TestUnicodeModelDir(t *testing.T) {
|
||||
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Unicode test only applicable to windows")
|
||||
}
|
||||
// Only works for local testing
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
|
||||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
oldModelsDir := os.Getenv("OLLAMA_MODELS")
|
||||
if oldModelsDir == "" {
|
||||
defer os.Unsetenv("OLLAMA_MODELS")
|
||||
} else {
|
||||
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
|
||||
}
|
||||
err = os.Setenv("OLLAMA_MODELS", modelDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
}
|
||||
|
||||
225
integration/concurrency_test.go
Normal file
225
integration/concurrency_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
GenerateTestHelper(ctx, t, req[i], resp[i])
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 5; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the 4 concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if vram == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
}
|
||||
max, err := strconv.ParseUint(vram, 10, 64)
|
||||
require.NoError(t, err)
|
||||
const MB = uint64(1024 * 1024)
|
||||
type model struct {
|
||||
name string
|
||||
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||
}
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "orca-mini",
|
||||
size: 2992 * MB,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
size: 2616 * MB,
|
||||
},
|
||||
{
|
||||
name: "gemma:2b",
|
||||
size: 2364 * MB,
|
||||
},
|
||||
{
|
||||
name: "stable-code:3b",
|
||||
size: 2608 * MB,
|
||||
},
|
||||
{
|
||||
name: "starcoder2:3b",
|
||||
size: 2166 * MB,
|
||||
},
|
||||
}
|
||||
mediumModels := []model{
|
||||
{
|
||||
name: "llama2",
|
||||
size: 5118 * MB,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
size: 4620 * MB,
|
||||
},
|
||||
{
|
||||
name: "orca-mini:7b",
|
||||
size: 5118 * MB,
|
||||
},
|
||||
{
|
||||
name: "dolphin-mistral",
|
||||
size: 4620 * MB,
|
||||
},
|
||||
{
|
||||
name: "gemma:7b",
|
||||
size: 5000 * MB,
|
||||
},
|
||||
// TODO - uncomment this once #3565 is merged and this is rebased on it
|
||||
// {
|
||||
// name: "codellama:7b",
|
||||
// size: 5118 * MB,
|
||||
// },
|
||||
}
|
||||
|
||||
// These seem to be too slow to be useful...
|
||||
// largeModels := []model{
|
||||
// {
|
||||
// name: "llama2:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "codellama:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "orca-mini:13b",
|
||||
// size: 7400 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "gemma:7b",
|
||||
// size: 5000 * MB,
|
||||
// },
|
||||
// {
|
||||
// name: "starcoder2:15b",
|
||||
// size: 9100 * MB,
|
||||
// },
|
||||
// }
|
||||
|
||||
var chosenModels []model
|
||||
switch {
|
||||
case max < 10000*MB:
|
||||
slog.Info("selecting small models")
|
||||
chosenModels = smallModels
|
||||
// case max < 30000*MB:
|
||||
default:
|
||||
slog.Info("selecting medium models")
|
||||
chosenModels = mediumModels
|
||||
// default:
|
||||
// slog.Info("selecting large models")
|
||||
// chosenModels = largModels
|
||||
}
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
|
||||
for i := range req {
|
||||
if i > len(chosenModels) {
|
||||
break
|
||||
}
|
||||
req[i].Model = chosenModels[i].name
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, r := range req {
|
||||
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
consumed := uint64(256 * MB) // Assume some baseline usage
|
||||
for i := 0; i < len(req); i++ {
|
||||
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
|
||||
if i > 1 && consumed > max {
|
||||
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||
break
|
||||
}
|
||||
consumed += chosenModels[i].size
|
||||
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 3; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
28
integration/context_test.go
Normal file
28
integration/context_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||
}
|
||||
@@ -5,7 +5,6 @@ package integration
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resp := "the ollamas"
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
||||
GenerateTestHelper(ctx, t, req, []string{resp})
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
|
||||
@@ -4,8 +4,6 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,10 +13,6 @@ import (
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
@@ -49,25 +43,5 @@ var (
|
||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
|
||||
// TODO
|
||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
||||
// get true concurrency working with n_parallel support in the backend
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
||||
|
||||
@@ -5,13 +5,14 @@ package integration
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -23,9 +24,13 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
|
||||
func FindPort() string {
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
@@ -41,7 +46,7 @@ func FindPort() string {
|
||||
return strconv.Itoa(port)
|
||||
}
|
||||
|
||||
func GetTestEndpoint() (string, string) {
|
||||
func GetTestEndpoint() (*api.Client, string) {
|
||||
defaultPort := "11434"
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
|
||||
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
|
||||
port = FindPort()
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s:%s", host, port)
|
||||
slog.Info("server connection", "url", url)
|
||||
return scheme, url
|
||||
slog.Info("server connection", "host", host, "port", port)
|
||||
|
||||
return api.NewClient(
|
||||
&url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(host, port),
|
||||
},
|
||||
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
|
||||
}
|
||||
|
||||
// TODO make fanicier, grab logs, etc.
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
|
||||
func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
func startServer(ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
CLIName, err := filepath.Abs("../ollama")
|
||||
if err != nil {
|
||||
@@ -125,126 +134,205 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
||||
slog.Debug("checking status of model", "model", modelName)
|
||||
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
requestJSON, err := json.Marshal(showReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
showCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(5*time.Second),
|
||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||
)
|
||||
defer cancel()
|
||||
_, err := client.Show(showCtx, showReq)
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
break
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode == 200 {
|
||||
default:
|
||||
slog.Info("model already present", "model", modelName)
|
||||
return nil
|
||||
}
|
||||
slog.Info("model missing", "status", response.StatusCode)
|
||||
slog.Info("model missing", "model", modelName)
|
||||
|
||||
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||
stallTimer := time.NewTimer(stallDuration)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
if !stallTimer.Reset(stallDuration) {
|
||||
return fmt.Errorf("stall was detected, aborting status reporting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
||||
requestJSON, err = json.Marshal(pullReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("pulling", "model", modelName)
|
||||
var pullError error
|
||||
|
||||
response, err = client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != 200 {
|
||||
return fmt.Errorf("failed to pull model") // TODO more details perhaps
|
||||
}
|
||||
slog.Info("model pulled", "model", modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||
requestJSON, err := json.Marshal(genReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
// TODO
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
done := make(chan int)
|
||||
go func() {
|
||||
pullError = client.Pull(ctx, pullReq, fn)
|
||||
done <- 0
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
return fmt.Errorf("download stalled")
|
||||
case <-done:
|
||||
return pullError
|
||||
}
|
||||
}
|
||||
|
||||
var serverProcMutex sync.Mutex
|
||||
|
||||
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
|
||||
// Starts the server if needed
|
||||
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
|
||||
client, testEndpoint := GetTestEndpoint()
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
|
||||
if err != nil {
|
||||
t.Fatalf("Error pulling model: %v", err)
|
||||
}
|
||||
|
||||
// Make the request and get the response
|
||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating request: %v", err)
|
||||
}
|
||||
|
||||
// Set the content type for the request
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatalf("Error making request: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||
|
||||
// Verify the response is valid JSON
|
||||
var payload api.GenerateResponse
|
||||
err = json.Unmarshal(body, &payload)
|
||||
if err != nil {
|
||||
assert.NoError(t, err, body)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(payload.Response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
return client, testEndpoint, func() {
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err := os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// fmt.Print(".")
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(ctx, &genReq, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
// By default each request uses orca-mini as the model
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"soil", "organic", "earth", "black", "tan"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
[]string{"fourth", "july", "declaration", "independence"},
|
||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#elif _WIN32
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
#define LOAD_ERR() ({\
|
||||
LPSTR messageBuffer = NULL; \
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \
|
||||
char *resp = strdup(messageBuffer); \
|
||||
LocalFree(messageBuffer); \
|
||||
resp; \
|
||||
})
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err) {
|
||||
int i = 0;
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"llama_server_init", (void *)&s->llama_server_init},
|
||||
{"llama_server_start", (void *)&s->llama_server_start},
|
||||
{"llama_server_stop", (void *)&s->llama_server_stop},
|
||||
{"llama_server_completion", (void *)&s->llama_server_completion},
|
||||
{"llama_server_completion_next_result",
|
||||
(void *)&s->llama_server_completion_next_result},
|
||||
{"llama_server_completion_cancel",
|
||||
(void *)&s->llama_server_completion_cancel},
|
||||
{"llama_server_release_task_result",
|
||||
(void *)&s->llama_server_release_task_result},
|
||||
{"llama_server_tokenize", (void *)&s->llama_server_tokenize},
|
||||
{"llama_server_detokenize", (void *)&s->llama_server_detokenize},
|
||||
{"llama_server_embedding", (void *)&s->llama_server_embedding},
|
||||
{"llama_server_release_json_resp",
|
||||
(void *)&s->llama_server_release_json_resp},
|
||||
{"", NULL},
|
||||
};
|
||||
|
||||
printf("loading library %s\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unable to load dynamic server library: %s", msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].p != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(s->handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(s->handle);
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len, "symbol lookup for %s failed: %s",
|
||||
l[i].s, msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_init(sparams, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_start(struct dynamic_llama_server s) {
|
||||
s.llama_server_start();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
|
||||
s.llama_server_stop();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp) {
|
||||
s.llama_server_completion(json_req, resp);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result) {
|
||||
s.llama_server_completion_next_result(task_id, result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_cancel(
|
||||
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
|
||||
s.llama_server_completion_cancel(task_id, err);
|
||||
}
|
||||
inline void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result) {
|
||||
s.llama_server_release_task_result(result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_tokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_detokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_embedding(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_release_json_resp(
|
||||
struct dynamic_llama_server s, char **json_resp) {
|
||||
s.llama_server_release_json_resp(json_resp);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user