mirror of
https://github.com/ollama/ollama.git
synced 2026-01-21 05:48:35 -05:00
Compare commits
50 Commits
bmizerany/
...
rmdisplayl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42f2cc408e | ||
|
|
9446b795b5 | ||
|
|
62f8cda3b3 | ||
|
|
6a1de23175 | ||
|
|
a7b431e743 | ||
|
|
5a25f93522 | ||
|
|
7e33a017c0 | ||
|
|
8b2c10061c | ||
|
|
c5c451ca3b | ||
|
|
2b4ca6cf36 | ||
|
|
ad90b9ab3d | ||
|
|
4340f8eba4 | ||
|
|
4c7db6b7e9 | ||
|
|
c03f0e3c3d | ||
|
|
c5ff443b9f | ||
|
|
01114b4526 | ||
|
|
1524f323a3 | ||
|
|
fccf3eecaa | ||
|
|
c77d45d836 | ||
|
|
5ec12cec6c | ||
|
|
d9578d2bad | ||
|
|
cb8352d6b4 | ||
|
|
fc6558f47f | ||
|
|
9502e5661f | ||
|
|
e1c9a2a00f | ||
|
|
1341ee1b56 | ||
|
|
63efa075a0 | ||
|
|
cb03fc9571 | ||
|
|
a5ec9cfc0f | ||
|
|
be517e491c | ||
|
|
fc8e108642 | ||
|
|
c5d5c4a96c | ||
|
|
dfe330fa1c | ||
|
|
01f77ae25d | ||
|
|
483b81a863 | ||
|
|
36bd967722 | ||
|
|
b0e7d35db8 | ||
|
|
aeb1fb5192 | ||
|
|
a2e60ebcaf | ||
|
|
883ec4d1ef | ||
|
|
4de0126719 | ||
|
|
9768e2dc75 | ||
|
|
08600d5bec | ||
|
|
a624e672d2 | ||
|
|
e4a7e5b2ca | ||
|
|
a0a15cfd5b | ||
|
|
12e923e158 | ||
|
|
cd135317d2 | ||
|
|
4f895d633f | ||
|
|
90f071c658 |
67
.github/workflows/release.yaml
vendored
67
.github/workflows/release.yaml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
# Full build of the Mac assets
|
||||
build-darwin:
|
||||
runs-on: macos-latest
|
||||
runs-on: macos-12
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build Darwin
|
||||
env:
|
||||
@@ -38,9 +38,11 @@ jobs:
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
|
||||
run: |
|
||||
./scripts/build_darwin.sh
|
||||
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
@@ -48,7 +50,6 @@ jobs:
|
||||
dist/*arwin*
|
||||
!dist/*-cov
|
||||
|
||||
|
||||
# Windows builds take a long time to both install the dependencies and build, so parallelize
|
||||
# CPU generation step
|
||||
generate-windows-cpu:
|
||||
@@ -85,7 +86,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -99,7 +100,9 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -136,9 +139,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install ROCm"
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -146,7 +149,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: "Verify ROCm"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -160,7 +163,7 @@ jobs:
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
- name: "gather rocm dependencies"
|
||||
- name: 'gather rocm dependencies'
|
||||
run: |
|
||||
$HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
md "dist\deps\bin\rocblas\library"
|
||||
@@ -170,7 +173,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-rocm-deps
|
||||
@@ -211,9 +214,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install CUDA"
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
@@ -227,7 +230,7 @@ jobs:
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: "Verify CUDA"
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
@@ -240,7 +243,7 @@ jobs:
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
- name: "gather cuda dependencies"
|
||||
- name: 'gather cuda dependencies'
|
||||
run: |
|
||||
$NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0]
|
||||
md "dist\deps"
|
||||
@@ -250,7 +253,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -297,17 +300,17 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -319,8 +322,8 @@ jobs:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/llama.cpp/build
|
||||
- run: dir llm/llama.cpp/build
|
||||
path: llm/build
|
||||
- run: dir llm/build
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
@@ -336,14 +339,14 @@ jobs:
|
||||
name: dist-windows
|
||||
path: dist/*.exe
|
||||
|
||||
# Linux x86 assets built using the container based build
|
||||
# Linux x86 assets built using the container based build
|
||||
build-linux-amd64:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: amd64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -373,9 +376,9 @@ jobs:
|
||||
environment: release
|
||||
runs-on: linux-arm64
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
BUILD_ARCH: arm64
|
||||
PUSH: "1"
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -383,7 +386,7 @@ jobs:
|
||||
- name: Set Version
|
||||
shell: bash
|
||||
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
|
||||
- name: "Install Docker"
|
||||
- name: 'Install Docker'
|
||||
run: |
|
||||
# Add Docker's official GPG key:
|
||||
env
|
||||
@@ -420,7 +423,7 @@ jobs:
|
||||
!dist/*-cov
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
release:
|
||||
needs:
|
||||
- build-darwin
|
||||
- build-windows
|
||||
@@ -431,8 +434,8 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
OLLAMA_SKIP_IMAGE_BUILD: "1"
|
||||
PUSH: "1"
|
||||
OLLAMA_SKIP_IMAGE_BUILD: '1'
|
||||
PUSH: '1'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set Version
|
||||
@@ -460,11 +463,11 @@ jobs:
|
||||
with:
|
||||
name: ${{ env.RELEASE_VERSION }}
|
||||
allowUpdates: true
|
||||
artifacts: "dist/*"
|
||||
artifacts: 'dist/*'
|
||||
draft: true
|
||||
prerelease: true
|
||||
omitBodyDuringUpdate: true
|
||||
generateReleaseNotes: true
|
||||
omitDraftDuringUpdate: true
|
||||
omitPrereleaseDuringUpdate: true
|
||||
replacesArtifacts: true
|
||||
replacesArtifacts: true
|
||||
|
||||
60
.github/workflows/test.yaml
vendored
60
.github/workflows/test.yaml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
paths:
|
||||
- '**/*'
|
||||
- '!docs/**'
|
||||
- '!examples/**'
|
||||
- '!README.md'
|
||||
|
||||
jobs:
|
||||
@@ -51,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -64,10 +63,10 @@ jobs:
|
||||
echo $env:PATH
|
||||
go generate -x ./...
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
name: "Windows Go Generate"
|
||||
name: 'Windows Go Generate'
|
||||
- run: go generate -x ./...
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
name: "Unix Go Generate"
|
||||
name: 'Unix Go Generate'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
@@ -93,7 +92,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -124,7 +123,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -135,7 +134,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -146,9 +145,9 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install ROCm"
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -156,7 +155,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: "Verify ROCm"
|
||||
- name: 'Verify ROCm'
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -183,9 +182,9 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: "Install CUDA"
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
@@ -199,7 +198,7 @@ jobs:
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: "Verify CUDA"
|
||||
- name: 'Verify CUDA'
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
@@ -216,7 +215,6 @@ jobs:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -239,7 +237,7 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: false
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
@@ -248,18 +246,18 @@ jobs:
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm/build/linux/$ARCH/stub/bin/stub.so
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: bash
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s
|
||||
@@ -277,14 +275,14 @@ jobs:
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: '1'
|
||||
OLLAMA_CPU_TARGET: "static"
|
||||
OLLAMA_CPU_TARGET: 'static'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- run: go get
|
||||
- run: |
|
||||
@@ -294,18 +292,18 @@ jobs:
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm//build/linux/$ARCH/stub/bin/stub.so
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: bash
|
||||
- run: go generate ./...
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
|
||||
@@ -292,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
|
||||
### Terminal
|
||||
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
// Package api implements the client-side API for code wishing to interact
|
||||
// with the ollama service. The methods of the [Client] type correspond to
|
||||
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
//
|
||||
// The ollama command-line client itself uses this package to interact with
|
||||
// the backend service.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -5,7 +11,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -19,6 +24,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Client encapsulates client state for interacting with the ollama
|
||||
// service. Use [ClientFromEnvironment] to create new Clients.
|
||||
type Client struct {
|
||||
base *url.URL
|
||||
http *http.Client
|
||||
@@ -40,6 +47,15 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return apiError
|
||||
}
|
||||
|
||||
// ClientFromEnvironment creates a new [Client] using configuration from the
|
||||
// environment variable OLLAMA_HOST, which points to the network host and
|
||||
// port on which the ollama service is listenting. The format of this variable
|
||||
// is:
|
||||
//
|
||||
// <scheme>://<host>:<port>
|
||||
//
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
@@ -191,8 +207,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Generate] will stop generating and return this error.
|
||||
type GenerateResponseFunc func(GenerateResponse) error
|
||||
|
||||
// Generate generates a response for a given prompt. The req parameter should
|
||||
// be populated with prompt details. fn is called for each response (there may
|
||||
// be multiple responses, e.g. in case streaming is enabled).
|
||||
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
|
||||
var resp GenerateResponse
|
||||
@@ -204,8 +226,15 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
type ChatResponseFunc func(ChatResponse) error
|
||||
|
||||
// Chat generates the next message in a chat. [ChatRequest] may contain a
|
||||
// sequence of messages which can be used to maintain chat history with a model.
|
||||
// fn is called for each response (there may be multiple responses, e.g. if case
|
||||
// streaming is enabled).
|
||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||
var resp ChatResponse
|
||||
@@ -217,8 +246,14 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
|
||||
})
|
||||
}
|
||||
|
||||
// PullProgressFunc is a function that [Client.Pull] invokes every time there
|
||||
// is progress with a "pull" request sent to the service. If this function
|
||||
// returns an error, [Client.Pull] will stop the process and return this error.
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
// Pull downloads a model from the ollama library. fn is called each time
|
||||
// progress is made on the request and can be used to display a progress bar,
|
||||
// etc.
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
@@ -301,18 +336,7 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
|
||||
var statusError StatusError
|
||||
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
110
api/types.go
110
api/types.go
@@ -33,18 +33,46 @@ func (e StatusError) Error() string {
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
// Model is the model name; it should be a name familiar to Ollama from
|
||||
// the library at https://ollama.com/library
|
||||
Model string `json:"model"`
|
||||
|
||||
// Prompt is the textual prompt to send to the model.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// System overrides the model's default system message/prompt.
|
||||
System string `json:"system"`
|
||||
|
||||
// Template overrides the model's default prompt template.
|
||||
Template string `json:"template"`
|
||||
|
||||
// Context is the context parameter returned from a previous call to
|
||||
// Generate call. It can be used to keep a short conversational memory.
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Raw set to true means that no formatting will be applied to the prompt.
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
@@ -109,21 +137,24 @@ type Options struct {
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
type Runner struct {
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
|
||||
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
@@ -139,10 +170,11 @@ type EmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantization string `json:"quantization,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
@@ -382,18 +414,16 @@ func DefaultOptions() Options {
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
NumCtx: 2048,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,5 @@ func NewTray() (commontray.OllamaTray, error) {
|
||||
return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err)
|
||||
}
|
||||
|
||||
tray, err := InitPlatformTray(icon, updateIcon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tray, nil
|
||||
return InitPlatformTray(icon, updateIcon)
|
||||
}
|
||||
|
||||
@@ -194,7 +194,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -943,6 +945,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
|
||||
@@ -32,7 +32,6 @@ type Params struct {
|
||||
AttentionHeads int `json:"num_attention_heads"` // n_head
|
||||
KeyValHeads int `json:"num_key_value_heads"`
|
||||
NormEPS float64 `json:"rms_norm_eps"`
|
||||
RopeFreqBase float64 `json:"rope_theta"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
HeadDimension int `json:"head_dim"`
|
||||
|
||||
@@ -144,7 +144,6 @@ func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"llama.rope.freq_base": float32(m.Params.RopeFreqBase),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
|
||||
@@ -394,7 +394,6 @@ Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
|
||||
40
examples/go-generate-streaming/main.go
Normal file
40
examples/go-generate-streaming/main.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// By default, GenerateRequest is streaming.
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
37
examples/go-generate/main.go
Normal file
37
examples/go-generate/main.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
|
||||
// set streaming to false
|
||||
Stream: new(bool),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
fmt.Println(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func HumanBytes(b int64) string {
|
||||
}
|
||||
}
|
||||
|
||||
func HumanBytes2(b int64) string {
|
||||
func HumanBytes2(b uint64) string {
|
||||
switch {
|
||||
case b >= MebiByte:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||
|
||||
28
go.mod
28
go.mod
@@ -10,7 +10,7 @@ require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
@@ -19,35 +19,23 @@ require (
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/minio/minio-go/v7 v7.0.69
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
kr.dev/diff v0.3.0
|
||||
)
|
||||
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.0.8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v1.12.0 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.8.2 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -65,7 +53,7 @@ require (
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -75,12 +63,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0
|
||||
golang.org/x/term v0.17.0
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
51
go.sum
51
go.sum
@@ -26,8 +26,6 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -88,8 +86,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -97,12 +95,9 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
@@ -120,12 +115,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
|
||||
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -140,7 +129,6 @@ github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6u
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -150,11 +138,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
@@ -196,8 +181,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -220,8 +205,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -241,18 +226,18 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -307,8 +292,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -320,6 +303,4 @@ gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
|
||||
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -243,7 +243,7 @@ func getCPUMem() (memInfo, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckVRAM() (int64, error) {
|
||||
func CheckVRAM() (uint64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -251,11 +251,11 @@ func CheckVRAM() (int64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
return uint64(avail), nil
|
||||
}
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
return int64(gpuInfo.FreeMemory), nil
|
||||
return gpuInfo.FreeMemory, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (int64, error) {
|
||||
func CheckVRAM() (uint64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -25,15 +25,14 @@ func CheckVRAM() (int64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return avail, nil
|
||||
return uint64(avail), nil
|
||||
}
|
||||
|
||||
if runtime.GOARCH == "amd64" {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM())
|
||||
return recommendedMaxVRAM, nil
|
||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
|
||||
@@ -15,7 +15,7 @@ type GpuInfo struct {
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory int64 `json:"-"`
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
||||
29
integration/context_test.go
Normal file
29
integration/context_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
||||
}
|
||||
@@ -15,10 +15,6 @@ import (
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
|
||||
@@ -18,7 +18,7 @@ sign() {
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
@@ -41,7 +41,7 @@ case "${GOARCH}" in
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
|
||||
#
|
||||
@@ -53,7 +53,7 @@ case "${GOARCH}" in
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
|
||||
#
|
||||
@@ -66,7 +66,7 @@ case "${GOARCH}" in
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
;;
|
||||
"arm64")
|
||||
@@ -74,17 +74,17 @@ case "${GOARCH}" in
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
compress
|
||||
;;
|
||||
*)
|
||||
|
||||
@@ -172,7 +172,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
# Disabling has minimal performance effect while maintaining compatibility.
|
||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
|
||||
@@ -146,7 +146,7 @@ function compress {
|
||||
}
|
||||
|
||||
write-host "Compressing dlls..."
|
||||
$binaries = dir "${script:buildDir}/bin/*.dll"
|
||||
$dlls = dir "${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in $dlls) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
@@ -183,9 +183,17 @@ if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||
|
||||
# GCC build for direct linking into the Go binary
|
||||
init_vars
|
||||
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
||||
# as we need this to be compiled by gcc for golang to be able to link with itx
|
||||
write-host "Checking for MinGW..."
|
||||
# error action ensures we exit on failure
|
||||
get-command gcc
|
||||
get-command mingw32-make
|
||||
$script:cmakeTargets = @("llama", "ggml")
|
||||
$script:cmakeDefs = @(
|
||||
"-G", "MinGW Makefiles"
|
||||
"-DCMAKE_C_COMPILER=gcc.exe",
|
||||
"-DCMAKE_CXX_COMPILER=g++.exe",
|
||||
"-DBUILD_SHARED_LIBS=off",
|
||||
"-DLLAMA_NATIVE=off",
|
||||
"-DLLAMA_AVX=off",
|
||||
@@ -234,7 +242,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
sign
|
||||
compress
|
||||
@@ -253,6 +261,7 @@ if ($null -ne $env:HIP_PATH) {
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||
"-DLLAMA_HIPBLAS=on",
|
||||
"-DHIP_PLATFORM=amd",
|
||||
"-DLLAMA_AVX=on",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
|
||||
|
||||
@@ -49,7 +49,7 @@ func (llm *ggla) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *ggla) Tensors() []*Tensor {
|
||||
func (llm *ggla) Tensors() Tensors {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
|
||||
104
llm/ggml.go
104
llm/ggml.go
@@ -13,16 +13,6 @@ type GGML struct {
|
||||
model
|
||||
}
|
||||
|
||||
func (ggml *GGML) LayerSize(prefix string) (n int64) {
|
||||
for _, t := range ggml.Tensors() {
|
||||
if strings.HasPrefix(t.Name, prefix) {
|
||||
n += int64(t.size())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
fileTypeF32 uint32 = iota
|
||||
fileTypeF16
|
||||
@@ -101,7 +91,7 @@ func fileType(fileType uint32) string {
|
||||
|
||||
type model interface {
|
||||
KV() KV
|
||||
Tensors() []*Tensor
|
||||
Tensors() Tensors
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
@@ -148,15 +138,15 @@ func (kv KV) HeadCount() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
|
||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
||||
return headCountKV
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
|
||||
return kv.HeadCount() / headCountKV
|
||||
}
|
||||
|
||||
return 0
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
@@ -167,6 +157,36 @@ func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
type Tensors []*Tensor
|
||||
|
||||
func (ts Tensors) Layers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if parts[0] == "blk" {
|
||||
parts = parts[1:]
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
layers[parts[0]] = make(Layer)
|
||||
}
|
||||
|
||||
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||
}
|
||||
|
||||
return layers
|
||||
}
|
||||
|
||||
type Layer map[string]*Tensor
|
||||
|
||||
func (l Layer) size() (size uint64) {
|
||||
for _, t := range l {
|
||||
size += t.size()
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
Kind uint32 `json:"kind"`
|
||||
@@ -303,3 +323,53 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
model: model,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (llm *gguf) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *gguf) Tensors() []*Tensor {
|
||||
func (llm *gguf) Tensors() Tensors {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
|
||||
Submodule llm/llama.cpp updated: 37e7854c10...1b67731e18
71
llm/llm.go
71
llm/llm.go
@@ -6,10 +6,81 @@ package llm
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile, filetype string) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
coutfile := C.CString(outfile)
|
||||
defer C.free(unsafe.Pointer(coutfile))
|
||||
|
||||
params := C.llama_model_quantize_default_params()
|
||||
params.nthread = -1
|
||||
|
||||
switch filetype {
|
||||
case "F32":
|
||||
params.ftype = fileTypeF32
|
||||
case "F16":
|
||||
params.ftype = fileTypeF16
|
||||
case "Q4_0":
|
||||
params.ftype = fileTypeQ4_0
|
||||
case "Q4_1":
|
||||
params.ftype = fileTypeQ4_1
|
||||
case "Q4_1_F16":
|
||||
params.ftype = fileTypeQ4_1_F16
|
||||
case "Q8_0":
|
||||
params.ftype = fileTypeQ8_0
|
||||
case "Q5_0":
|
||||
params.ftype = fileTypeQ5_0
|
||||
case "Q5_1":
|
||||
params.ftype = fileTypeQ5_1
|
||||
case "Q2_K":
|
||||
params.ftype = fileTypeQ2_K
|
||||
case "Q3_K_S":
|
||||
params.ftype = fileTypeQ3_K_S
|
||||
case "Q3_K_M":
|
||||
params.ftype = fileTypeQ3_K_M
|
||||
case "Q3_K_L":
|
||||
params.ftype = fileTypeQ3_K_L
|
||||
case "Q4_K_S":
|
||||
params.ftype = fileTypeQ4_K_S
|
||||
case "Q4_K_M":
|
||||
params.ftype = fileTypeQ4_K_M
|
||||
case "Q5_K_S":
|
||||
params.ftype = fileTypeQ5_K_S
|
||||
case "Q5_K_M":
|
||||
params.ftype = fileTypeQ5_K_M
|
||||
case "Q6_K":
|
||||
params.ftype = fileTypeQ6_K
|
||||
case "IQ2_XXS":
|
||||
params.ftype = fileTypeIQ2_XXS
|
||||
case "IQ2_XS":
|
||||
params.ftype = fileTypeIQ2_XS
|
||||
case "Q2_K_S":
|
||||
params.ftype = fileTypeQ2_K_S
|
||||
case "Q3_K_XS":
|
||||
params.ftype = fileTypeQ3_K_XS
|
||||
case "IQ3_XXS":
|
||||
params.ftype = fileTypeIQ3_XXS
|
||||
default:
|
||||
return fmt.Errorf("unknown filetype: %s", filetype)
|
||||
}
|
||||
|
||||
if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", retval)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
114
llm/server.go
114
llm/server.go
@@ -41,10 +41,6 @@ var cpuOnlyFamilies = []string{
|
||||
}
|
||||
|
||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,66 +61,79 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
availableMemory, _ := gpu.CheckVRAM()
|
||||
memoryAvailable, _ := gpu.CheckVRAM()
|
||||
info := gpu.GetGPUInfo()
|
||||
|
||||
usedMemory := info.MinimumMemory
|
||||
memoryMinimum := info.MinimumMemory
|
||||
for _, projector := range projectors {
|
||||
usedMemory += projectorMemoryRequirements(projector)
|
||||
memoryMinimum += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
|
||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.KV().GQA()) * kv / 6
|
||||
usedMemory += graph
|
||||
|
||||
if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {
|
||||
info.Library = "cpu"
|
||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
|
||||
requiredMemory := usedMemory
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
var layers int
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
layerMemory := ggml.LayerSize(fmt.Sprintf("blk.%d.", i)) + kv/int64(ggml.KV().BlockCount())
|
||||
requiredMemory += layerMemory
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
|
||||
if availableMemory > usedMemory+layerMemory && (opts.NumGPU < 0 || layers < opts.NumGPU) {
|
||||
usedMemory += layerMemory
|
||||
layers++
|
||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
|
||||
if info.Library != "metal" {
|
||||
if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
}
|
||||
|
||||
memOutputLayer := ggml.LayerSize("output.")
|
||||
requiredMemory += memOutputLayer
|
||||
var layerCount int
|
||||
layers := ggml.Tensors().Layers()
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
memoryLayer := layers[fmt.Sprintf("%d", i)].size()
|
||||
|
||||
// only offload output layer if all repeating layers are offloaded
|
||||
if layers >= int(ggml.KV().BlockCount()) && availableMemory > usedMemory+memOutputLayer {
|
||||
usedMemory += memOutputLayer
|
||||
layers++
|
||||
// KV is proportional to the number of layers
|
||||
memoryLayer += kv / ggml.KV().BlockCount()
|
||||
|
||||
memoryRequiredTotal += memoryLayer
|
||||
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
||||
memoryRequiredPartial += memoryLayer
|
||||
layerCount++
|
||||
}
|
||||
}
|
||||
|
||||
memoryLayerOutput := layers["output"].size()
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
if memoryAvailable > memoryRequiredTotal {
|
||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||
memoryRequiredPartial = memoryRequiredTotal
|
||||
}
|
||||
|
||||
if opts.NumGPU < 0 {
|
||||
opts.NumGPU = layerCount
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"offload to gpu",
|
||||
"layers", layers,
|
||||
"required", format.HumanBytes2(requiredMemory),
|
||||
"used", format.HumanBytes2(usedMemory),
|
||||
"available", format.HumanBytes2(availableMemory),
|
||||
"reallayers", opts.NumGPU,
|
||||
"layers", layerCount,
|
||||
"required", format.HumanBytes2(memoryRequiredTotal),
|
||||
"used", format.HumanBytes2(memoryRequiredPartial),
|
||||
"available", format.HumanBytes2(memoryAvailable),
|
||||
"kv", format.HumanBytes2(kv),
|
||||
"graph", format.HumanBytes2(graph),
|
||||
"fulloffload", format.HumanBytes2(graphFullOffload),
|
||||
"partialoffload", format.HumanBytes2(graphPartialOffload),
|
||||
)
|
||||
|
||||
if opts.NumGPU < 0 && info.Library != "cpu" {
|
||||
opts.NumGPU = layers
|
||||
}
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
@@ -171,14 +180,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyBase > 0 {
|
||||
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyScale > 0 {
|
||||
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
|
||||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
@@ -289,7 +290,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
return nil, finalErr
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) int64 {
|
||||
func projectorMemoryRequirements(filename string) uint64 {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
@@ -301,18 +302,12 @@ func projectorMemoryRequirements(filename string) int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
prefixes := make(map[string]struct{})
|
||||
for _, layer := range ggml.Tensors() {
|
||||
parts := strings.Split(layer.Name, ".")
|
||||
prefixes[strings.Join(parts[:2], ".")] = struct{}{}
|
||||
var mem uint64
|
||||
for _, layer := range ggml.Tensors().Layers() {
|
||||
mem += layer.size()
|
||||
}
|
||||
|
||||
var ask int64
|
||||
for prefix := range prefixes {
|
||||
ask += ggml.LayerSize(prefix)
|
||||
}
|
||||
|
||||
return ask
|
||||
return mem
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
@@ -390,7 +385,8 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||
|
||||
func (s *LlamaServer) waitUntilRunning() error {
|
||||
start := time.Now()
|
||||
expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ go build .
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```
|
||||
cd app
|
||||
cd macapp
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
@@ -247,7 +247,8 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
|
||||
@@ -284,7 +284,7 @@ func realpath(mfDir, from string) string {
|
||||
return abspath
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
@@ -337,8 +337,27 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
if ggufName != "" {
|
||||
pathName = ggufName
|
||||
slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
|
||||
defer os.RemoveAll(ggufName)
|
||||
|
||||
if quantization != "" {
|
||||
quantization = strings.ToUpper(quantization)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
|
||||
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(tempfile.Name())
|
||||
|
||||
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pathName = tempfile.Name()
|
||||
}
|
||||
}
|
||||
|
||||
bin, err := os.Open(pathName)
|
||||
|
||||
@@ -647,7 +647,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -913,6 +913,24 @@ func HeadBlobHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func CreateBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
// noop
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
default:
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -61,7 +61,7 @@ func Test_Routes(t *testing.T) {
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
t.Logf("Status: %s", resp.Status)
|
||||
}
|
||||
err = CreateModel(context.TODO(), name, "", commands, fn)
|
||||
err = CreateModel(context.TODO(), name, "", "", commands, fn)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -38,52 +38,15 @@ func (d Digest) String() string { return d.s }
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
|
||||
}
|
||||
*d = ParseDigest(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Digest{}
|
||||
_ sql.Scanner = (*Digest)(nil)
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (d *Digest) Scan(src any) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal Scan on valid Digest")
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*d = ParseDigest(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*d = ParseDigest(string(v))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (d Digest) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
@@ -94,20 +57,6 @@ func ParseDigest(s string) Digest {
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
@@ -2,18 +2,6 @@ package model
|
||||
|
||||
import "testing"
|
||||
|
||||
// - test scan
|
||||
// - test marshal text
|
||||
// - test unmarshal text
|
||||
// - test log value
|
||||
// - test string
|
||||
// - test type
|
||||
// - test digest
|
||||
// - test valid
|
||||
// - test driver valuer
|
||||
// - test sql scanner
|
||||
// - test parse digest
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {s: "sha256-1234"},
|
||||
@@ -56,28 +44,3 @@ func TestDigestString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestUnmarshalText(t *testing.T) {
|
||||
const testDigest = "sha256-1234"
|
||||
t.Run("UnmarshalText (into Valid)", func(t *testing.T) {
|
||||
d := ParseDigest(testDigest)
|
||||
if !d.IsValid() {
|
||||
panic("invalid test")
|
||||
}
|
||||
if err := d.UnmarshalText(nil); err == nil {
|
||||
t.Errorf("UnmarshalText on valid Digest did not return error")
|
||||
}
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText on valid Digest changed Digest: %q", d.String())
|
||||
}
|
||||
})
|
||||
t.Run("UnmarshalText make safe copy", func(t *testing.T) {
|
||||
data := []byte(testDigest)
|
||||
var d Digest
|
||||
d.UnmarshalText(data)
|
||||
data[0] = 'x'
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText did not make a safe copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,31 +1,37 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
"github.com/ollama/ollama/types/structs"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrInvalidName is not used by this package, but is exported so that
|
||||
// other packages do not need to invent their own error type when they
|
||||
// need to return an error for an invalid name.
|
||||
// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
|
||||
// used by this package, but are exported so that other packages can
|
||||
// use them, instead of defining their own errors for them.
|
||||
ErrInvalidName = errors.New("invalid model name")
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
// DefaultMask is the default mask used by [Name.DisplayShortest].
|
||||
DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
|
||||
// DefaultFill is the default fill used by [ParseName].
|
||||
DefaultFill = "registry.ollama.ai/library/_:latest"
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
|
||||
type PartKind int
|
||||
@@ -64,7 +70,7 @@ func (k PartKind) String() string {
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
// with the case preserved, but is not directly comparable with other Names
|
||||
// since model names can be represented with different caseing depending on
|
||||
// since model names can be represented with different casing depending on
|
||||
// the use case. For instance, "Mistral" and "mistral" are the same model
|
||||
// but each version may have come from different sources (e.g. copied from a
|
||||
// Web page, or from a file path).
|
||||
@@ -94,16 +100,17 @@ func (k PartKind) String() string {
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [6]string // host, namespace, model, tag, build
|
||||
parts [6]string // host, namespace, model, tag, build, digest
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offests all into a single uint64 since the first
|
||||
// could pack the offsets all into a single uint64 since the first
|
||||
// parts take less bits since their max offset is less than the max
|
||||
// offset of the next part. This would save a ton of bytes per Name
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name. The input string must be a valid string
|
||||
// ParseNameFill parses s into a Name, and returns the result of filling it with
|
||||
// defaults. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
|
||||
@@ -121,7 +128,7 @@ type Name struct {
|
||||
// "mistral:7b+x"
|
||||
// "example.com/mike/mistral:latest+Q4_0"
|
||||
// "example.com/bruce/mistral:latest"
|
||||
// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
|
||||
// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
|
||||
//
|
||||
// Examples of invalid paths:
|
||||
//
|
||||
@@ -135,25 +142,38 @@ type Name struct {
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
func ParseName(s string) Name {
|
||||
//
|
||||
// For more about filling in missing parts, see [Fill].
|
||||
func ParseNameFill(s, defaults string) Name {
|
||||
var r Name
|
||||
for kind, part := range Parts(s) {
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if kind == PartInvalid {
|
||||
return Name{}
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
return Name{}
|
||||
r = Name{}
|
||||
return false
|
||||
}
|
||||
r.parts[kind] = part
|
||||
}
|
||||
return true
|
||||
})
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
return r
|
||||
if defaults == "" {
|
||||
return r
|
||||
}
|
||||
return Fill(r, ParseNameFill(defaults, ""))
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
func MustParseName(s string) Name {
|
||||
r := ParseName(s)
|
||||
// ParseName is equal to ParseNameFill(s, DefaultFill).
|
||||
func ParseName(s string) Name {
|
||||
return ParseNameFill(s, DefaultFill)
|
||||
}
|
||||
|
||||
func MustParseNameFill(s, defaults string) Name {
|
||||
r := ParseNameFill(s, "")
|
||||
if !r.IsValid() {
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
}
|
||||
@@ -185,7 +205,9 @@ func (r Name) WithDigest(digest Digest) Name {
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convienent way to compare names, use [Name.EqualFold].
|
||||
// checks. For a convenient way to compare names, use [Name.EqualFold].
|
||||
//
|
||||
//nolint:errcheck
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
@@ -209,41 +231,34 @@ func (r Name) slice(from, to PartKind) Name {
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayModel returns the a display string composed of the model only.
|
||||
func (r Name) DisplayModel() string {
|
||||
return r.parts[PartModel]
|
||||
}
|
||||
|
||||
// DisplayFullest returns the fullest possible display string in form:
|
||||
// DisplayShortest returns the shortest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>
|
||||
// [host/][<namespace>/]<model>[:<tag>]
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// It does not include the build part. For the fullest possible display
|
||||
// string with the build, use [Name.String].
|
||||
func (r Name) DisplayFullest() string {
|
||||
// The host is omitted if it is the mask host is the same as r.
|
||||
// The namespace is omitted if the host and the namespace are the same as r.
|
||||
// The tag is omitted if it is the mask tag is the same as r.
|
||||
func (r Name) DisplayShortest(mask string) string {
|
||||
mask = cmp.Or(mask, DefaultMask)
|
||||
d := ParseName(mask)
|
||||
if !d.IsValid() {
|
||||
panic("mask is an invalid Name")
|
||||
}
|
||||
equalSlice := func(form, to PartKind) bool {
|
||||
return r.slice(form, to).EqualFold(d.slice(form, to))
|
||||
}
|
||||
if equalSlice(PartHost, PartNamespace) {
|
||||
r.parts[PartNamespace] = ""
|
||||
}
|
||||
if equalSlice(PartHost, PartHost) {
|
||||
r.parts[PartHost] = ""
|
||||
}
|
||||
if equalSlice(PartTag, PartTag) {
|
||||
r.parts[PartTag] = ""
|
||||
}
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayShort returns the fullest possible display string in form:
|
||||
//
|
||||
// <model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayShort() string {
|
||||
return r.slice(PartModel, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
//
|
||||
// <namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
return r.slice(PartNamespace, PartTag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
PartHost: "/",
|
||||
PartNamespace: "/",
|
||||
@@ -258,23 +273,28 @@ var seps = [...]string{
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
|
||||
//
|
||||
// Missing parts and their seperators are not written.
|
||||
// Missing parts and their separators are not written.
|
||||
//
|
||||
// The full digest is always prefixed with "@". That is if [Name.IsValid]
|
||||
// reports false and [Name.IsResolved] reports true, then the string is
|
||||
// returned as "@<digest-type>-<digest>".
|
||||
func (r Name) writeTo(w io.StringWriter) {
|
||||
func (r Name) writeTo(w io.StringWriter) error {
|
||||
var partsWritten int
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if partsWritten > 0 || i == int(PartDigest) {
|
||||
w.WriteString(seps[i-1])
|
||||
if _, err := w.WriteString(seps[i-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := w.WriteString(r.parts[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
w.WriteString(r.parts[i])
|
||||
partsWritten++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
@@ -296,7 +316,7 @@ func (r Name) String() string {
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
r.writeTo(b)
|
||||
_ = r.writeTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -316,71 +336,6 @@ func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
b := bufPool.Get().(*bytes.Buffer)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
defer bufPool.Put(b)
|
||||
r.writeTo(b)
|
||||
// TODO: We can remove this alloc if/when
|
||||
// https://github.com/golang/go/issues/62384 lands.
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
//
|
||||
// It is an error to call UnmarshalText on a valid Name.
|
||||
func (r *Name) UnmarshalText(text []byte) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of UnmarshalText is that it should only be
|
||||
// called on an invalid/zero Name. If we allow UnmarshalText
|
||||
// on a valid Name, then the Name will be mutated, breaking
|
||||
// the immutability of the Name.
|
||||
return errors.New("model.Name: illegal UnmarshalText on valid Name")
|
||||
}
|
||||
|
||||
// The contract of UnmarshalText is that we copy to keep the text.
|
||||
*r = ParseName(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Name{}
|
||||
_ sql.Scanner = (*Name)(nil)
|
||||
)
|
||||
|
||||
// Scan implements [database/sql.Scanner].
|
||||
func (r *Name) Scan(src any) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of Scan is that it should only be called on an
|
||||
// invalid/zero Name. If we allow Scan on a valid Name, then the
|
||||
// Name will be mutated, breaking the immutability of the Name.
|
||||
return errors.New("model.Name: illegal Scan on valid Name")
|
||||
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*r = ParseName(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*r = ParseName(string(v))
|
||||
return nil
|
||||
}
|
||||
return errors.New("model.Name: invalid Scan source")
|
||||
}
|
||||
|
||||
// Value implements [database/sql/driver.Valuer].
|
||||
func (r Name) Value() (driver.Value, error) {
|
||||
return r.String(), nil
|
||||
}
|
||||
|
||||
// IsComplete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) IsComplete() bool {
|
||||
@@ -447,16 +402,27 @@ func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
|
||||
// restrictions in the go1.22 iter package requiring the
|
||||
// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
|
||||
// which we are not yet ready to support.
|
||||
//
|
||||
// Once we are ready to support rangefunc, this can be removed and replaced
|
||||
// with the iter.Seq2 type.
|
||||
type iter_Seq2[A, B any] func(func(A, B) bool)
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
// to least specific.
|
||||
//
|
||||
// It normalizes the input string by removing "http://" and "https://" only.
|
||||
// No other normalization is done.
|
||||
func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
// No other normalizations are performed.
|
||||
func parts(s string) iter_Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
//nolint:gosimple
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
@@ -473,6 +439,7 @@ func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
numConsecutiveDots := 0
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
@@ -543,7 +510,15 @@ func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !isValidByte(state, s[i]) {
|
||||
if s[i] == '.' {
|
||||
if numConsecutiveDots++; numConsecutiveDots > 1 {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
numConsecutiveDots = 0
|
||||
}
|
||||
if !isValidByteFor(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
@@ -558,28 +533,32 @@ func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the Name hPartas a valid nick. To know if a Name is
|
||||
// "complete", use [Name.IsComplete].
|
||||
func (r Name) IsZero() bool {
|
||||
return r.parts == [6]string{}
|
||||
}
|
||||
|
||||
// IsValid reports if a model has at minimum a valid model part.
|
||||
func (r Name) IsValid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// isValidPart returns Parttrue if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
// isValidPart reports if s contains all valid characters for the given
|
||||
// part kind.
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range []byte(s) {
|
||||
if !isValidByte(kind, c) {
|
||||
if !isValidByteFor(kind, c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidByte(kind PartKind, c byte) bool {
|
||||
func isValidByteFor(kind PartKind, c byte) bool {
|
||||
if kind == PartNamespace && c == '.' {
|
||||
return false
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package model
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
@@ -89,10 +88,40 @@ var testNames = map[string]fields{
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
":x": {},
|
||||
"+x": {},
|
||||
"x+": {},
|
||||
|
||||
// Disallow ("\.+") in any part to prevent path traversal anywhere
|
||||
// we convert the name to a path.
|
||||
"../etc/passwd": {},
|
||||
".../etc/passwd": {},
|
||||
"./../passwd": {},
|
||||
"./0+..": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
// TestConsecutiveDots tests that consecutive dots are not allowed in any
|
||||
// part, to avoid path traversal. There also are some tests in testNames, but
|
||||
// this test is more exhaustive and exists to emphasize the importance of
|
||||
// preventing path traversal.
|
||||
func TestNameConsecutiveDots(t *testing.T) {
|
||||
for i := 1; i < 10; i++ {
|
||||
s := strings.Repeat(".", i)
|
||||
if i > 1 {
|
||||
if g := ParseNameFill(s, "").String(); g != "" {
|
||||
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
|
||||
}
|
||||
} else {
|
||||
if g := ParseNameFill(s, "").String(); g != s {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
@@ -119,32 +148,16 @@ func TestParseName(t *testing.T) {
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
for kind, part := range Parts(s) {
|
||||
t.Logf("Part: %s: %q", kind, part)
|
||||
}
|
||||
|
||||
name := ParseName(s)
|
||||
name := ParseNameFill(s, "")
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(name.String()).EqualFold(name) {
|
||||
if !ParseNameFill(name.String(), "").EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
}
|
||||
|
||||
if name.IsValid() && name.DisplayModel() == "" {
|
||||
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
|
||||
} else if !name.IsValid() && name.DisplayModel() != "" {
|
||||
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
|
||||
}
|
||||
|
||||
if name.IsResolved() && !name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
|
||||
} else if !name.IsResolved() && name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -166,7 +179,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
p := ParseNameFill(tt.in, "")
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
@@ -181,7 +194,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0").IsComplete())
|
||||
keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
@@ -198,7 +211,7 @@ func TestNameLogValue(t *testing.T) {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseName(s)
|
||||
name := ParseNameFill(s, "")
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
@@ -209,83 +222,43 @@ func TestNameLogValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameDisplay(t *testing.T) {
|
||||
func TestNameGoString(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantShort string
|
||||
wantLong string
|
||||
wantComplete string
|
||||
wantString string
|
||||
wantModel string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "example.com/library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "mistral:latest",
|
||||
wantComplete: "mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
if g := p.DisplayShort(); g != tt.wantShort {
|
||||
t.Errorf("DisplayShort = %q; want %q", g, tt.wantShort)
|
||||
}
|
||||
if g := p.DisplayLong(); g != tt.wantLong {
|
||||
t.Errorf("DisplayLong = %q; want %q", g, tt.wantLong)
|
||||
}
|
||||
if g := p.DisplayFullest(); g != tt.wantComplete {
|
||||
t.Errorf("DisplayFullest = %q; want %q", g, tt.wantComplete)
|
||||
}
|
||||
if g := p.String(); g != tt.in {
|
||||
t.Errorf("String(%q) = %q; want %q", tt.in, g, tt.in)
|
||||
}
|
||||
if g := p.DisplayModel(); g != tt.wantModel {
|
||||
t.Errorf("Model = %q; want %q", g, tt.wantModel)
|
||||
}
|
||||
|
||||
p := ParseNameFill(tt.in, "")
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
@@ -294,9 +267,60 @@ func TestNameDisplay(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayShortest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
mask string
|
||||
want string
|
||||
wantPanic bool
|
||||
}{
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
|
||||
// case-insensitive
|
||||
{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
|
||||
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
|
||||
// invalid mask
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
|
||||
|
||||
// DefaultMask
|
||||
{"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
|
||||
|
||||
// Auto-Fill
|
||||
{"x", "example.com/library/_:latest", "x", false},
|
||||
{"x", "example.com/library/_:latest+Q4_0", "x", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
defer func() {
|
||||
if tt.wantPanic {
|
||||
if recover() == nil {
|
||||
t.Errorf("expected panic")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
p := ParseNameFill(tt.in, "")
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.DisplayShortest(tt.mask); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
@@ -307,29 +331,28 @@ func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNameDisplay(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
r := ParseName("example.com/mistral:7b+Q4_0")
|
||||
b.Run("Short", func(b *testing.B) {
|
||||
for range b.N {
|
||||
keep(r.DisplayShort())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Add(":x")
|
||||
f.Add("@sha256-123456")
|
||||
f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
|
||||
f.Add(":@!@")
|
||||
f.Add("...")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s)
|
||||
if !r0.IsValid() {
|
||||
r0 := ParseNameFill(s, "")
|
||||
|
||||
if strings.Contains(s, "..") && !r0.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
}
|
||||
|
||||
if !r0.IsValid() && !r0.IsResolved() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
@@ -346,7 +369,7 @@ func FuzzParseName(f *testing.F) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.String())
|
||||
r1 := ParseNameFill(r0.String(), "")
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
@@ -366,7 +389,7 @@ func TestFill(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseName(tt.dst), ParseName(tt.src))
|
||||
r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
@@ -374,118 +397,8 @@ func TestFill(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTextMarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||
{"mistral:latest", "mistral:latest", nil},
|
||||
{"mistral", "mistral", nil},
|
||||
{"mistral:7b", "mistral:7b", nil},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
got, err := p.MarshalText()
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||
}
|
||||
|
||||
var r Name
|
||||
if err := r.UnmarshalText(got); err != nil {
|
||||
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||
}
|
||||
if !r.EqualFold(p) {
|
||||
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("UnmarshalText into valid Name", func(t *testing.T) {
|
||||
// UnmarshalText should not be called on a valid Name.
|
||||
p := MustParseName("x")
|
||||
if err := p.UnmarshalText([]byte("mistral:latest+Q4_0")); err == nil {
|
||||
t.Error("UnmarshalText() = nil; want error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TextMarshal allocs", func(t *testing.T) {
|
||||
var data []byte
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
if !name.IsComplete() {
|
||||
// sanity check
|
||||
panic("sanity check failed")
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var err error
|
||||
data, err = name.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 0 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
// Currently, the best we can do is 1 alloc.
|
||||
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnmarshalTest makes safe copy", func(t *testing.T) {
|
||||
// UnmarshalText should make a copy of the data.
|
||||
data := []byte("mistral:latest+Q4_0")
|
||||
p := Name{}
|
||||
if err := p.UnmarshalText(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 'x'
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("UnmarshalText() did not make a copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
t.Run("Scan for already valid Name", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err == nil {
|
||||
t.Error("Scan() = nil; want error")
|
||||
}
|
||||
})
|
||||
t.Run("Scan for invalid Name", func(t *testing.T) {
|
||||
p := Name{}
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err != nil {
|
||||
t.Errorf("Scan() = %v; want nil", err)
|
||||
}
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
|
||||
}
|
||||
})
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if g, err := p.Value(); err != nil {
|
||||
t.Errorf("Value() error = %v; want nil", err)
|
||||
} else if g != "x" {
|
||||
t.Errorf("Value() = %q; want %q", g, "x")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
@@ -495,8 +408,8 @@ func TestNameStringAllocs(t *testing.T) {
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
|
||||
r := Fill(ParseNameFill("mistral", ""), defaults)
|
||||
fmt.Println(r)
|
||||
|
||||
// Output:
|
||||
@@ -507,12 +420,12 @@ func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseName("mistral:latest+q4").MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4").MapHash()] = true
|
||||
m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
|
||||
m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
|
||||
m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseName("mistral:LATest").MapHash()] = true
|
||||
m[ParseNameFill("mistral:LATest", "").MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
@@ -521,9 +434,9 @@ func ExampleName_MapHash() {
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseName("mistral:latest"),
|
||||
ParseName("mistRal:7b+q4"),
|
||||
ParseName("MIstral:7b"),
|
||||
ParseNameFill("mistral:latest", ""),
|
||||
ParseNameFill("mistRal:7b+q4", ""),
|
||||
ParseNameFill("MIstral:7b", ""),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
@@ -544,8 +457,8 @@ func ExampleName_completeAndResolved() {
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
p := ParseName(s)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", p.IsComplete(), p.IsResolved(), p.Digest())
|
||||
name := ParseNameFill(s, "")
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
@@ -554,19 +467,24 @@ func ExampleName_completeAndResolved() {
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayFullest() {
|
||||
for _, s := range []string{
|
||||
"example.com/jmorganca/mistral:latest+Q4_0",
|
||||
"mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
} {
|
||||
fmt.Println(ParseName(s).DisplayFullest())
|
||||
}
|
||||
func ExampleName_DisplayShortest() {
|
||||
name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
|
||||
|
||||
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:_"))
|
||||
fmt.Println(name.DisplayShortest("_/_/_:_"))
|
||||
|
||||
// Default
|
||||
name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
|
||||
fmt.Println(name.DisplayShortest(""))
|
||||
|
||||
// Output:
|
||||
// mistral
|
||||
// jmorganca/mistral
|
||||
// jmorganca/mistral:latest
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral:latest
|
||||
// mistral:latest
|
||||
// mistral
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
||||
113
x/api/api.go
113
x/api/api.go
@@ -1,113 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
regtype "github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
// Common API Errors
|
||||
var (
|
||||
errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
|
||||
errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Build *build.Server
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
oweb.Serve(s.serveHTTP, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func want(r *http.Request, method, path string) bool {
|
||||
return r.Method == method && r.URL.Path == path
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return oweb.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Name == "" {
|
||||
return oweb.Missing("name")
|
||||
}
|
||||
|
||||
const registryURLTODO = "http://localhost:8888"
|
||||
|
||||
man, err := s.Build.ManifestData(params.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, build.ErrNotFound) {
|
||||
return errRefNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c := registry.Client{BaseURL: registryURLTODO}
|
||||
requirements, err := c.Push(r.Context(), params.Name, man, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploads []regtype.CompletePart
|
||||
for _, rq := range requirements {
|
||||
l, err := s.Build.LayerFile(rq.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = func() error {
|
||||
f, err := os.Open(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
cp, err := registry.PushLayer(r.Context(), f, rq.URL, rq.Offset, rq.Size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = append(uploads, cp)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// commit the manifest to the registry
|
||||
requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{
|
||||
CompleteParts: uploads,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range requirements {
|
||||
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
209
x/build/build.go
209
x/build/build.go
@@ -1,209 +0,0 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrIncompleteRef = errors.New("unqualified ref")
|
||||
ErrBuildPresentInRef = errors.New("build present in ref")
|
||||
ErrUnsupportedModelFormat = errors.New("unsupported model format")
|
||||
ErrMissingFileType = errors.New("missing 'general.file_type' key")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
type mediaType string
|
||||
|
||||
// Known media types
|
||||
const (
|
||||
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
st *blobstore.Store
|
||||
}
|
||||
|
||||
// Open starts a new build server that uses dir as the base directory for all
|
||||
// build artifacts. If dir is empty, DefaultDir is used.
|
||||
//
|
||||
// It returns an error if the provided or default dir cannot be initialized.
|
||||
func Open(dir string) (*Server, error) {
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = DefaultDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
st, err := blobstore.Open(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{st: st}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Build(ref string, f model.File) error {
|
||||
mp := model.ParseName(ref)
|
||||
if !mp.IsCompleteNoBuild() {
|
||||
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
|
||||
}
|
||||
|
||||
// 1. Resolve FROM
|
||||
// a. If it's a local file (gguf), hash it and add it to the store.
|
||||
// c. If it's a remote file (http), refuse.
|
||||
// 2. Turn other pragmas into layers, and add them to the store.
|
||||
// 3. Create a manifest from the layers.
|
||||
// 4. Store the manifest in the manifest cache
|
||||
// 5. Done.
|
||||
|
||||
if f.From == "" {
|
||||
return &model.FileError{Pragma: "FROM", Message: "missing"}
|
||||
}
|
||||
|
||||
var layers []layerJSON
|
||||
|
||||
id, info, size, err := s.importModel(f.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: mediaTypeModel,
|
||||
Size: size,
|
||||
})
|
||||
|
||||
id, size, err = blobstore.PutString(s.st, f.License)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: "text/plain",
|
||||
Size: size,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(manifestJSON{Layers: layers})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setManifestData(
|
||||
mp.WithBuild(info.FileType.String()),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) LayerFile(digest string) (string, error) {
|
||||
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
|
||||
_, err := os.Stat(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
|
||||
}
|
||||
return fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) ManifestData(ref string) ([]byte, error) {
|
||||
data, _, err := s.resolve(model.ParseName(ref))
|
||||
return data, err
|
||||
}
|
||||
|
||||
// WeightFile returns the absolute path to the weights file for the given model ref.
|
||||
func (s *Server) WeightsFile(ref string) (string, error) {
|
||||
m, err := s.getManifest(model.ParseName(ref))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if l.MediaType == mediaTypeModel {
|
||||
return s.st.OutputFilename(l.ID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing weights layer for %q", ref)
|
||||
}
|
||||
|
||||
// resolve returns the data for the given ref, if any.
|
||||
//
|
||||
// TODO: This should ideally return an ID, but the current on
|
||||
// disk layout is that the actual manifest is stored in the "ref" instead of
|
||||
// a pointer to a content-addressed blob. I (bmizerany) think we should
|
||||
// change the on-disk layout to store the manifest in a content-addressed
|
||||
// blob, and then have the ref point to that blob. This would simplify the
|
||||
// code, allow us to have integrity checks on the manifest, and clean up
|
||||
// this interface.
|
||||
func (s *Server) resolve(ref model.Name) (data []byte, fileName string, err error) {
|
||||
fileName, err = s.refFileName(ref)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err = os.ReadFile(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
|
||||
}
|
||||
if err != nil {
|
||||
// do not wrap the error here, as it is likely an I/O error
|
||||
// and we want to preserve the absraction since we may not
|
||||
// be on disk later.
|
||||
return nil, "", fmt.Errorf("manifest read error: %v", err)
|
||||
}
|
||||
return data, fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) SetManifestData(ref string, data []byte) error {
|
||||
return s.setManifestData(model.ParseName(ref), data)
|
||||
}
|
||||
|
||||
// Set sets the data for the given ref.
|
||||
func (s *Server) setManifestData(mp model.Name, data []byte) error {
|
||||
path, err := s.refFileName(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refFileName(mp model.Name) (string, error) {
|
||||
if !mp.IsComplete() {
|
||||
return "", fmt.Errorf("ref not fully qualified: %q", mp)
|
||||
}
|
||||
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(mp.Parts()...)), nil
|
||||
}
|
||||
|
||||
type manifestJSON struct {
|
||||
// Layers is the list of layers in the manifest.
|
||||
Layers []layerJSON `json:"layers"`
|
||||
}
|
||||
|
||||
// Layer is a layer in a model manifest.
|
||||
type layerJSON struct {
|
||||
// ID is the ID of the layer.
|
||||
ID blobstore.ID `json:"digest"`
|
||||
MediaType mediaType `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (s *Server) getManifest(ref model.Name) (manifestJSON, error) {
|
||||
data, path, err := s.resolve(ref)
|
||||
if err != nil {
|
||||
return manifestJSON{}, err
|
||||
}
|
||||
var m manifestJSON
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
const qualifiedRef = "x/y/z:latest+Q4_0"
|
||||
|
||||
func TestServerBuildErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("unqualified ref", func(t *testing.T) {
|
||||
err := s.Build("x", model.File{})
|
||||
if !errors.Is(err, ErrIncompleteRef) {
|
||||
t.Fatalf("Build() err = %v; want unqualified ref", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM pragma missing", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{})
|
||||
var e *model.FileError
|
||||
if !errors.As(err, &e) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if e.Pragma != "FROM" {
|
||||
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
|
||||
}
|
||||
if e.Message != "missing" {
|
||||
t.Errorf("e.Message = %s; want missing", e.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM file not found", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{From: "bar"})
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Build() err = %v; want file not found", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM gguf", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
// Write a gguf file without general.file_type metadata.
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"",
|
||||
)
|
||||
|
||||
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
|
||||
if !errors.Is(err, ErrMissingFileType) {
|
||||
t.Fatalf("Build() err = %#v; want missing file type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM obscure dir", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.mkdirAll("unknown")
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM unsupported model type", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
from := w.write("unknown", "unknown content")
|
||||
err := s.Build(qualifiedRef, model.File{From: from})
|
||||
if !errors.Is(err, ErrUnsupportedModelFormat) {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBasicGGUF(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
|
||||
"general.file_type"+ // key
|
||||
"\x04\x00\x00\x00"+ // type (uint32)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
|
||||
"",
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
t.Logf("file: %s", p)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("WeightsFile() err = %v; want not found", err)
|
||||
}
|
||||
|
||||
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info, err := gguf.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.FileType != gguf.TypeQ4_0 {
|
||||
t.Errorf("info.FileType = %d; want 1", info.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
type work struct {
|
||||
t testing.TB
|
||||
dir string
|
||||
}
|
||||
|
||||
func newWorkDir(t *testing.T) work {
|
||||
return work{t: t, dir: t.TempDir()}
|
||||
}
|
||||
|
||||
func (w work) write(name, content string) (path string) {
|
||||
w.t.Helper()
|
||||
path = w.fileName(name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (w work) fileName(name string) string {
|
||||
w.t.Helper()
|
||||
return filepath.Join(w.dir, name)
|
||||
}
|
||||
|
||||
func (w work) mkdirAll(path string) {
|
||||
w.t.Helper()
|
||||
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package build
|
||||
|
||||
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
|
||||
// TODO: decine on hueristic for converting safetensor to gguf and
|
||||
// the errors that can be returned. For now, we just say
|
||||
// "unsupported", however it may be intended to be a valid safe
|
||||
// tensor but we hit an error in the conversion.
|
||||
//
|
||||
// I (bmizernay) think this will naturally evolve as we implement
|
||||
// the conversion.
|
||||
return "", ErrUnsupportedModelFormat
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDir = sync.OnceValues(func() (string, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return dir, nil
|
||||
})
|
||||
)
|
||||
|
||||
// DefaultDir returns the default directory for models. It returns the value
|
||||
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
|
||||
// "$HOME/.ollama/models".
|
||||
func DefaultDir() (string, error) {
|
||||
return defaultDir()
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
|
||||
return blobstore.ID{}, gguf.Info{}, 0, err
|
||||
}
|
||||
|
||||
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return s.importSafeTensor(path)
|
||||
} else {
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := gguf.StatReader(f)
|
||||
if errors.Is(err, gguf.ErrBadMagic) {
|
||||
return importError(ErrUnsupportedModelFormat)
|
||||
}
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
|
||||
if info.FileType == 0 {
|
||||
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
|
||||
}
|
||||
id, size, err := s.st.Put(f)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return id, info, size, nil
|
||||
}
|
||||
|
||||
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
path, err := convertSafeTensorToGGUF(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
// Package blobstore implements a blob store.
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidID = errors.New("invalid ID")
|
||||
)
|
||||
|
||||
const HashSize = 32
|
||||
|
||||
// An ID is a blob output key, the hash of an output of a computation.
|
||||
type ID struct {
|
||||
a [HashSize]byte
|
||||
}
|
||||
|
||||
func (id ID) MarshalText() ([]byte, error) {
|
||||
return []byte(id.String()), nil
|
||||
}
|
||||
|
||||
func (id *ID) UnmarshalText(text []byte) error {
|
||||
*id = ParseID(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseID(s string) ID {
|
||||
const prefix = "sha256-"
|
||||
h, ok := strings.CutPrefix(s, prefix)
|
||||
if !ok {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
if len(h) != HashSize*2 {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var b []byte
|
||||
_, err := fmt.Sscanf(h, "%x", &b)
|
||||
if err != nil {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var id ID
|
||||
copy(id.a[:], b)
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
if !id.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", id.a[:])
|
||||
}
|
||||
|
||||
func (id ID) Valid() bool {
|
||||
return id != ID{}
|
||||
}
|
||||
|
||||
func (id ID) Match(h [HashSize]byte) bool {
|
||||
return id.a == h
|
||||
}
|
||||
|
||||
// A Store is a blob store, backed by a file system directory tree.
|
||||
type Store struct {
|
||||
dir string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// Open opens and returns the store in the given directory.
|
||||
//
|
||||
// It is safe for multiple processes on a single machine to use the
|
||||
// same store directory in a local file system simultaneously.
|
||||
// They will coordinate using operating system file locks and may
|
||||
// duplicate effort but will not corrupt the store.
|
||||
//
|
||||
// However, it is NOT safe for multiple processes on different machines
|
||||
// to share a store directory (for example, if the directory were stored
|
||||
// in a network file system). File locking is notoriously unreliable in
|
||||
// network file systems and may not suffice to protect the store.
|
||||
func Open(dir string) (*Store, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Store{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Store) Dir() string {
|
||||
return s.dir
|
||||
}
|
||||
|
||||
// fileName returns the name of the blob file corresponding to the given id.
|
||||
func (s *Store) fileName(id ID) string {
|
||||
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
|
||||
}
|
||||
|
||||
// An entryNotFoundError indicates that a store entry was not found, with an
|
||||
// optional underlying reason.
|
||||
type entryNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Error() string {
|
||||
if e.Err == nil {
|
||||
return "store entry not found"
|
||||
}
|
||||
return fmt.Sprintf("store entry not found: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
ID ID
|
||||
Size int64
|
||||
Time time.Time // when added to store
|
||||
}
|
||||
|
||||
// GetFile looks up the blob ID in the store and returns
|
||||
// the name of the corresponding data file.
|
||||
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
|
||||
entry, err = s.Get(id)
|
||||
if err != nil {
|
||||
return "", Entry{}, err
|
||||
}
|
||||
file = s.OutputFilename(entry.ID)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
if info.Size() != entry.Size {
|
||||
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
|
||||
}
|
||||
return file, entry, nil
|
||||
}
|
||||
|
||||
// GetBytes looks up the blob ID in the store and returns
|
||||
// the corresponding output bytes.
|
||||
// GetBytes should only be used for data that can be expected to fit in memory.
|
||||
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
|
||||
entry, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, entry, err
|
||||
}
|
||||
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
|
||||
if entry.ID.Match(sha256.Sum256(data)) {
|
||||
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
|
||||
}
|
||||
return data, entry, nil
|
||||
}
|
||||
|
||||
// OutputFilename returns the name of the blob file for the given ID.
|
||||
func (s *Store) OutputFilename(id ID) string {
|
||||
file := s.fileName(id)
|
||||
// TODO(bmizerany): touch as "used" for cache trimming. (see
|
||||
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
|
||||
return file
|
||||
}
|
||||
|
||||
// Get looks up the blob ID in the store,
|
||||
// returning the corresponding output ID and file size, if any.
|
||||
// Note that finding an output ID does not guarantee that the
|
||||
// saved file for that output ID is still available.
|
||||
func (s *Store) Get(id ID) (Entry, error) {
|
||||
file := s.fileName(id)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
return Entry{
|
||||
ID: id,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
// TODO(bmizerany): return c.Trim()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores the data read from the given file into the store as ID.
|
||||
//
|
||||
// It may read file twice. The content of file must not change between the
|
||||
// two passes.
|
||||
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
|
||||
return s.put(file)
|
||||
}
|
||||
|
||||
func PutBytes(s *Store, data []byte) (ID, int64, error) {
|
||||
return s.Put(bytes.NewReader(data))
|
||||
}
|
||||
|
||||
func PutString(s *Store, data string) (ID, int64, error) {
|
||||
return s.Put(strings.NewReader(data))
|
||||
}
|
||||
|
||||
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
|
||||
// Compute output ID.
|
||||
h := sha256.New()
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
size, err := io.Copy(h, file)
|
||||
if err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
var out ID
|
||||
h.Sum(out.a[:0])
|
||||
|
||||
// Copy to blob file (if not already present).
|
||||
if err := s.copyFile(file, out, size); err != nil {
|
||||
return out, size, err
|
||||
}
|
||||
|
||||
// TODO: Add to manifest index.
|
||||
return out, size, nil
|
||||
}
|
||||
|
||||
// copyFile copies file into the store, expecting it to have the given
|
||||
// output ID and size, if that file is not present already.
|
||||
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
|
||||
name := s.fileName(out)
|
||||
println("name", name)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// Check hash.
|
||||
if f, err := os.Open(name); err == nil {
|
||||
h := sha256.New()
|
||||
io.Copy(h, f)
|
||||
f.Close()
|
||||
var out2 ID
|
||||
h.Sum(out2.a[:0])
|
||||
if out == out2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Hash did not match. Fall through and rewrite file.
|
||||
}
|
||||
|
||||
// Copy file to blobs directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h := sha256.New()
|
||||
w := io.MultiWriter(f, h)
|
||||
if _, err := io.CopyN(w, file, size-1); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
// Check last byte before writing it; writing it will make the size match
|
||||
// what other processes expect to find and might cause them to start
|
||||
// using the file.
|
||||
buf := make([]byte, 1)
|
||||
if _, err := file.Read(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h.Write(buf)
|
||||
sum := h.Sum(nil)
|
||||
if !bytes.Equal(sum, out.a[:]) {
|
||||
f.Truncate(0)
|
||||
return fmt.Errorf("file content changed underfoot")
|
||||
}
|
||||
|
||||
// Commit manifest entry.
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove stored file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, s.now(), s.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
var invalid = strings.Repeat("\x00", HashSize*2)
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", invalid},
|
||||
{"sha256-", invalid},
|
||||
{"sha256-" + valid, valid},
|
||||
|
||||
{"" + valid, invalid}, // no prefix
|
||||
{"sha123-" + valid, invalid}, // invalid prefix
|
||||
{"sha256-" + valid[1:], invalid}, // too short
|
||||
{"sha256-" + valid + "a", invalid}, // too long
|
||||
{"sha256-!" + valid[1:], invalid}, // invalid hex
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
// sanity check
|
||||
if len(tt.want) > HashSize*2 {
|
||||
panic("invalid test")
|
||||
}
|
||||
|
||||
got := ParseID(tt.in)
|
||||
|
||||
wantValid := tt.want != invalid
|
||||
if wantValid {
|
||||
if !got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
|
||||
}
|
||||
if got.String() != "sha256-"+tt.want {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
|
||||
}
|
||||
if got.String() != "" {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const (
|
||||
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
func TestStoreBasicBlob(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
checkDir(t, dir, nil)
|
||||
|
||||
st, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
st.now = func() time.Time { return now }
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
})
|
||||
|
||||
id, size, err := PutBytes(st, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ParseID(blobNameHello) {
|
||||
t.Errorf("unexpected ID: %s", id)
|
||||
}
|
||||
if size != 5 {
|
||||
t.Errorf("unexpected size: %d", size)
|
||||
}
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
"blobs/" + blobNameHello,
|
||||
})
|
||||
|
||||
got, err := st.Get(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, Entry{
|
||||
ID: id,
|
||||
Size: 5,
|
||||
Time: now,
|
||||
})
|
||||
|
||||
file := st.OutputFilename(id)
|
||||
wantFile := filepath.Join(dir, "blobs", blobNameHello)
|
||||
if file != wantFile {
|
||||
t.Errorf("unexpected file: %s", file)
|
||||
}
|
||||
|
||||
// Check tags
|
||||
name := model.ParseName("registry.ollama.ai/library/test:latest+KQED")
|
||||
|
||||
t.Logf("RESOLVING: %q", name.Parts())
|
||||
|
||||
}
|
||||
|
||||
// checkDir checks that the directory at dir contains the files in want. The
|
||||
// files in want must be relative to dir.
|
||||
//
|
||||
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
|
||||
//
|
||||
// want must be in lexicographic order.
|
||||
func checkDir(t testing.TB, dir string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
var matches []string
|
||||
for path, err := range walkDir(dir) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("found %s", path)
|
||||
if path == "./" {
|
||||
continue
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
matches = append(matches, path)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, matches, want)
|
||||
}
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
func walkDir(dir string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err = filepath.Rel(dir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
if info.IsDir() {
|
||||
path += "/"
|
||||
}
|
||||
if !yield(path, nil) {
|
||||
return errStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, errStop) && err != nil {
|
||||
yield("", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package apitype
|
||||
|
||||
import "time"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Ref string `json:"ref"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt int64 `json:"modified"`
|
||||
}
|
||||
|
||||
func (m Model) Modifed() time.Time {
|
||||
return time.Unix(0, m.ModifiedAt)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
|
||||
Insecure bool `json:"insecure"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type PushStatus struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/types/empty"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): PROGRESS INDICATORS!!!!
|
||||
|
||||
const DefaultBaseURL = "http://localhost:11434"
|
||||
|
||||
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
|
||||
|
||||
// Default returns a new client with the default base URL.
|
||||
func Default() *Client {
|
||||
return &Client{BaseURL: envBaseURL}
|
||||
}
|
||||
|
||||
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
|
||||
// true for any instance of Client to work.
|
||||
var I_Acknowledge_This_API_Is_Under_Development bool
|
||||
|
||||
// Client is a client for the Ollama API.
|
||||
type Client struct {
|
||||
// BaseURL is the base URL of the Ollama API.
|
||||
BaseURL string
|
||||
|
||||
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
|
||||
}
|
||||
|
||||
// Build requests the remote Ollama service to build a model. It uploads any
|
||||
// source files the server needs.
|
||||
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Push requests the remote Ollama service to push a model to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string) error {
|
||||
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Remove(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
// Status is the HTTP status code returned by the server.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Code specifies a machine readable code indicating the class of
|
||||
// error this error is. See http://docs.ollama.com/errors for a full
|
||||
// list of error codes.
|
||||
Code string `json:"code"`
|
||||
|
||||
// Message is a humage readable message that describes the error. It
|
||||
// may change across versions of the API, so it should not be used for
|
||||
// programmatic decisions.
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Field is the field in the request that caused the error, if any.
|
||||
Field string `json:"field,omitempty"`
|
||||
|
||||
// Value is the value of the field that caused the error, if any.
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("ollama: ")
|
||||
b.WriteString(e.Code)
|
||||
if e.Field != "" {
|
||||
b.WriteString(" ")
|
||||
b.WriteString(e.Field)
|
||||
}
|
||||
if e.Value != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Value)
|
||||
}
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Do encodes in and sends it in a request to the Ollama server and decodes
|
||||
// the response into Res, or an error response (non-2xx) into an *Error, or
|
||||
// any error encounted decoding the response.
|
||||
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
|
||||
var body bytes.Buffer
|
||||
// TODO(bmizerany): pool and reuse this buffer AND the encoder
|
||||
if err := encodeJSON(&body, in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr := c.BaseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
var buf bytes.Buffer
|
||||
body := io.TeeReader(res.Body, &buf)
|
||||
e, err := decodeJSON[Error](body)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
|
||||
return nil, err
|
||||
}
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return decodeJSON[Res](res.Body)
|
||||
}
|
||||
|
||||
// decodeJSON decodes JSON from r into a new value of type T.
|
||||
//
|
||||
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
|
||||
// do not try and consolidate so we can keep ollama/client free from
|
||||
// dependencies which are moving targets and not pulling enough weight to
|
||||
// justify their inclusion.
|
||||
func decodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// NOTE: see NOT above decodeJSON
|
||||
func encodeJSON(w io.Writer, v any) error {
|
||||
// TODO(bmizerany): pool and reuse encoder
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
// Bllamo is a (new) tool for managing Ollama models.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// bllamo <command> [arguments]
|
||||
//
|
||||
// The commands are:
|
||||
//
|
||||
// build build a model from a Modelfile
|
||||
// list list all models
|
||||
// push push a model from an ollama registry
|
||||
// pull pull a model from an ollama registry
|
||||
// delete delete a model from an ollama registry
|
||||
// help display help for a command
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/api"
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := Main(flag.Args()...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var TODOUsage = fmt.Errorf("TODO: usage")
|
||||
|
||||
var commands = map[string]func(ctx context.Context, args ...string) error{
|
||||
"build": cmdBuild,
|
||||
"push": cmdPush,
|
||||
"serve": cmdServe,
|
||||
"registry": cmdRegistry,
|
||||
}
|
||||
|
||||
// Main is the entry point for the blammo command.
|
||||
func Main(args ...string) error {
|
||||
cmd := args[0]
|
||||
args = args[1:]
|
||||
if f, ok := commands[cmd]; ok {
|
||||
ctx := context.TODO()
|
||||
return f(ctx, args...)
|
||||
}
|
||||
return fmt.Errorf("blammo: unknown command %q", cmd)
|
||||
}
|
||||
|
||||
func cmdBuild(ctx context.Context, args ...string) error {
|
||||
var v struct {
|
||||
Modelfile string `flag:"f,the Modelfile to use"`
|
||||
}
|
||||
|
||||
fs := readFlags("build", args, &v)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
|
||||
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
|
||||
}
|
||||
|
||||
func cmdRegistry(_ context.Context, _ ...string) error {
|
||||
var s registry.Server
|
||||
return http.ListenAndServe(":8888", &s)
|
||||
}
|
||||
|
||||
func cmdServe(ctx context.Context, args ...string) error {
|
||||
bs, err := build.Open("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return http.ListenAndServe(":11434", &api.Server{Build: bs})
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, args ...string) error {
|
||||
fs := readFlags("push", args, nil)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
return ollama.Default().Push(ctx, fs.Arg(0))
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseArgs parses the provided args using a flag.FlagSet that is
|
||||
// dynamically build using reflection for the provided type. The type fields
|
||||
// that have a "flag" tag are used to build the flags. The flag tag should
|
||||
// include either a ('-'). Example usage:
|
||||
//
|
||||
// func main() {
|
||||
// var flags struct {
|
||||
// Modelfile string `flag:"f,path to the Modelfile"`
|
||||
// }
|
||||
//
|
||||
// fs := readFlags(os.Args[1:], &flags)
|
||||
// fs.Parse(os.Args[1:])
|
||||
// }
|
||||
func readFlags(name string, args []string, v any) *flag.FlagSet {
|
||||
fs := flag.NewFlagSet(name, flag.ExitOnError)
|
||||
defer fs.Parse(args)
|
||||
if v == nil {
|
||||
return fs
|
||||
}
|
||||
|
||||
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
|
||||
f := reflect.ValueOf(v).Field(i)
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Type().Field(i).Tag.Get("flag")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
var name, usage string
|
||||
if i := strings.Index(tag, ","); i != -1 {
|
||||
name = tag[:i]
|
||||
usage = tag[i+1:]
|
||||
} else {
|
||||
name = tag
|
||||
}
|
||||
|
||||
// TODO(bmizerany): add more types as needed
|
||||
switch f.Kind() {
|
||||
case reflect.String:
|
||||
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
|
||||
case reflect.Bool:
|
||||
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
|
||||
}
|
||||
}
|
||||
return fs
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
// Gguf is a tool for learning about GGUF files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// gguf [flags] <file>
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Main(stdout io.Writer, args ...string) error {
|
||||
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
|
||||
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
|
||||
|
||||
fs.Usage = func() {
|
||||
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "Usage:\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "\tgguf [flags] <file>\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
var numFlags int
|
||||
fs.VisitAll(func(*flag.Flag) { numFlags++ })
|
||||
if numFlags > 0 {
|
||||
io.WriteString(stdout, "Flags:\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
fs.Parse(args)
|
||||
|
||||
if fs.NArg() != 1 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
file := fs.Arg(0)
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, err := gguf.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
|
||||
defer tw.Flush()
|
||||
|
||||
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
|
||||
|
||||
for m, err := range g.Metadata {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(m.Values) > 5 {
|
||||
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
|
||||
} else {
|
||||
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
|
||||
}
|
||||
}
|
||||
|
||||
var i int
|
||||
var totalLayerBytes uint64
|
||||
var offGPU bool
|
||||
for t, err := range g.Tensors {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
totalLayerBytes += t.Size
|
||||
if totalLayerBytes > *flagGPU {
|
||||
offGPU = true
|
||||
}
|
||||
|
||||
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
|
||||
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
|
||||
|
||||
i++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,376 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
|
||||
|
||||
// MaxDimensions is the maximum number of dimensions a tensor can have.
|
||||
const MaxDimensions uint32 = 1e6
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrBadMagic is returned when the magic bytes at the start of the
|
||||
// file. This is useful for detecting if the file is not a gguf
|
||||
// file.
|
||||
ErrBadMagic = errors.New("gguf: bad magic")
|
||||
|
||||
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
|
||||
ErrMangled = errors.New("gguf: mangled data")
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeF32 Type = 0
|
||||
TypeF16 Type = 1
|
||||
TypeQ4_0 Type = 2
|
||||
TypeQ4_1 Type = 3
|
||||
TypeQ5_0 Type = 6
|
||||
TypeQ5_1 Type = 7
|
||||
TypeQ8_0 Type = 8
|
||||
TypeQ8_1 Type = 9
|
||||
TypeQ2_K Type = 10
|
||||
TypeQ3_K Type = 11
|
||||
TypeQ4_K Type = 12
|
||||
TypeQ5_K Type = 13
|
||||
TypeQ6_K Type = 14
|
||||
TypeQ8_K Type = 15
|
||||
TypeI8 Type = 16
|
||||
TypeI16 Type = 17
|
||||
TypeI32 Type = 18
|
||||
TypeCount Type = 19
|
||||
)
|
||||
|
||||
var typeNames = map[Type]string{
|
||||
TypeF32: "F32",
|
||||
TypeF16: "F16",
|
||||
TypeQ4_0: "Q4_0",
|
||||
TypeQ4_1: "Q4_1",
|
||||
TypeQ5_0: "Q5_0",
|
||||
TypeQ5_1: "Q5_1",
|
||||
TypeQ8_0: "Q8_0",
|
||||
TypeQ8_1: "Q8_1",
|
||||
TypeQ2_K: "Q2_K",
|
||||
TypeQ3_K: "Q3_K",
|
||||
TypeQ4_K: "Q4_K",
|
||||
TypeQ5_K: "Q5_K",
|
||||
TypeQ6_K: "Q6_K",
|
||||
TypeQ8_K: "Q8_K",
|
||||
TypeI8: "I8",
|
||||
TypeI16: "I16",
|
||||
TypeI32: "I32",
|
||||
TypeCount: "COUNT",
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
if name := typeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_type %d!)", t)
|
||||
}
|
||||
|
||||
// ValueType is the type of a metadata value.
|
||||
type ValueType uint32
|
||||
|
||||
func (t ValueType) String() string {
|
||||
if name := metaTypeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_value_type %d!)", t)
|
||||
}
|
||||
|
||||
const (
|
||||
ValueTypeUint8 ValueType = 0
|
||||
ValueTypeInt8 ValueType = 1
|
||||
ValueTypeUint16 ValueType = 2
|
||||
ValueTypeInt16 ValueType = 3
|
||||
ValueTypeUint32 ValueType = 4
|
||||
ValueTypeInt32 ValueType = 5
|
||||
ValueTypeFloat32 ValueType = 6
|
||||
ValueTypeBool ValueType = 7
|
||||
ValueTypeString ValueType = 8
|
||||
ValueTypeArray ValueType = 9
|
||||
ValueTypeUint64 ValueType = 10
|
||||
ValueTypeInt64 ValueType = 11
|
||||
ValueTypeFloat64 ValueType = 12
|
||||
)
|
||||
|
||||
var metaTypeNames = map[ValueType]string{
|
||||
ValueTypeUint8: "uint8",
|
||||
ValueTypeInt8: "int8",
|
||||
ValueTypeUint16: "uint16",
|
||||
ValueTypeInt16: "int16",
|
||||
ValueTypeUint32: "uint32",
|
||||
ValueTypeInt32: "int32",
|
||||
ValueTypeFloat32: "float32",
|
||||
ValueTypeBool: "bool",
|
||||
ValueTypeString: "string",
|
||||
ValueTypeArray: "array",
|
||||
ValueTypeUint64: "uint64",
|
||||
ValueTypeInt64: "int64",
|
||||
ValueTypeFloat64: "float64",
|
||||
}
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Dimensions []uint64
|
||||
Type Type
|
||||
Offset uint64
|
||||
Size uint64
|
||||
}
|
||||
|
||||
type MetaValue struct {
|
||||
Type ValueType
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (v MetaValue) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(v.Type.String())
|
||||
b.WriteString("(")
|
||||
switch v.Type {
|
||||
case ValueTypeArray:
|
||||
b.WriteString("[...]")
|
||||
case ValueTypeString:
|
||||
b.WriteString(strconv.Quote(string(v.Value)))
|
||||
case ValueTypeBool:
|
||||
if len(v.Value) == 0 {
|
||||
b.WriteString("(!invalid bool)")
|
||||
}
|
||||
switch v.Value[0] {
|
||||
case 0:
|
||||
b.WriteString("false")
|
||||
case 1:
|
||||
b.WriteString("true")
|
||||
default:
|
||||
b.WriteString("!invalid bool")
|
||||
}
|
||||
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
|
||||
var buf [8]byte
|
||||
if len(v.Value) < 8 {
|
||||
copy(buf[:], v.Value)
|
||||
}
|
||||
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
|
||||
default:
|
||||
fmt.Fprintf(&b, "%v", v.Value)
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type MetaEntry struct {
|
||||
Key string
|
||||
Type ValueType
|
||||
Values []MetaValue
|
||||
}
|
||||
|
||||
func (e MetaEntry) String() string {
|
||||
if len(e.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) Uint32() uint32 {
|
||||
if len(e.Values) == 0 {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) FileType() Type {
|
||||
if len(e.Values) == 0 {
|
||||
return TypeCount
|
||||
}
|
||||
return Type(e.Uint32())
|
||||
}
|
||||
|
||||
func (e MetaEntry) GoString() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Key)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Type.String())
|
||||
b.WriteString("(")
|
||||
for i, v := range e.Values {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(v.String())
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
|
||||
|
||||
Version int
|
||||
FileType Type
|
||||
}
|
||||
|
||||
func Stat(path string) (Info, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
return StatReader(f)
|
||||
}
|
||||
|
||||
// StatReader reads the header information from r and returns an Info
|
||||
// struct with the version and file type.
|
||||
//
|
||||
// It returns an error if any.
|
||||
//
|
||||
// As a special case, it returns ErrBadMagic if the file does not start with
|
||||
// the magic bytes. This can be used to detect if the file is not a GGUF
|
||||
// file.
|
||||
func StatReader(r io.ReadSeeker) (Info, error) {
|
||||
if _, err := r.Seek(0, 0); err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
f, err := ReadFile(r)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
info := Info{Version: f.Version()}
|
||||
for m, err := range f.Metadata {
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
if m.Key == "general.file_type" {
|
||||
if m.Type != ValueTypeUint32 {
|
||||
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
|
||||
}
|
||||
info.FileType = m.FileType()
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
version uint32
|
||||
numMetaValues uint64
|
||||
numTensors uint64
|
||||
|
||||
gr *ggufReader
|
||||
}
|
||||
|
||||
// ReadFile reads header information from r and returns a File, ready for
|
||||
// iteration over Metadata and Tensors.
|
||||
func ReadFile(r io.Reader) (*File, error) {
|
||||
f, err := readFile(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) Version() int {
|
||||
return int(f.version)
|
||||
}
|
||||
|
||||
// Metadata iterates over the metadata in the file. It must be exhausted
|
||||
// before calling Tensors.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
|
||||
var n int
|
||||
for range f.numMetaValues {
|
||||
meta, err := f.gr.readMetaEntry()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
|
||||
yield(MetaEntry{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(meta, nil) {
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Tensors iterates over the tensors in the file. It must only be called
|
||||
// after exhausting the metadata iterator.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
|
||||
var last TensorInfo
|
||||
for range f.numTensors {
|
||||
info, err := f.gr.readTensorInfo()
|
||||
|
||||
// If the last tensor had a valid offset, yield it.
|
||||
//
|
||||
// NOTE: No tensor should have an offset of 0 because the
|
||||
// offset is the start of the tensor data which is always
|
||||
// afer the magic bytes, version, numMetaValues, and
|
||||
// numTensors, which MUST all be non-zero bytes as per the
|
||||
// GGUF spec.
|
||||
if last.Offset > 0 {
|
||||
if !yield(last, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
yield(TensorInfo{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Tensor data does not include size, so we need to
|
||||
// calculate it based on the offset of the previous tensor
|
||||
// offset to the current.
|
||||
offset0 := last.Offset
|
||||
last = info
|
||||
last.Size = info.Offset - offset0
|
||||
}
|
||||
if last.Offset > 0 {
|
||||
yield(last, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
|
||||
|
||||
func readFile(r io.Reader) (*File, error) {
|
||||
gr := &ggufReader{r: &reader{r: r}}
|
||||
magic, err := gr.next(4)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, ErrBadMagic)
|
||||
}
|
||||
if !bytes.Equal(magic, magicBytes) {
|
||||
return nil, ErrBadMagic
|
||||
}
|
||||
version, err := gr.readUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 3 {
|
||||
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
|
||||
}
|
||||
numTensors, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numMetaValues, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &File{
|
||||
version: version,
|
||||
|
||||
numMetaValues: numMetaValues,
|
||||
numTensors: numTensors,
|
||||
gr: gr,
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
wantInfo Info
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "valid general.file_type",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"general.file_type" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
||||
"",
|
||||
wantInfo: Info{
|
||||
Version: 3,
|
||||
FileType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info, err := StatReader(strings.NewReader(tt.data))
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
|
||||
wantMeta []MetaEntry
|
||||
wantTensor []TensorInfo
|
||||
wantReadErr error
|
||||
wantMetaErr error
|
||||
wantTensorErr error
|
||||
wantInfo Info
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantReadErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantReadErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantReadErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "no metadata or tensors",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"",
|
||||
wantReadErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good metadata",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "good metadata with multiple values",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// MetaEntry 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"x" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"XX" + // string value
|
||||
|
||||
// MetaEntry 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"y" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x99\x88\x77\x66" + // uint32 value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
||||
Type: ValueTypeString,
|
||||
Value: []byte("XX"),
|
||||
}}},
|
||||
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
||||
Type: ValueTypeUint32,
|
||||
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
||||
}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative string length in meta key",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMetaErr: ErrMangled,
|
||||
},
|
||||
|
||||
// Tensor tests
|
||||
{
|
||||
name: "good tensor",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
// dimensions
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many dimensions",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
"\x00\x00\x00\x01" + // dimensions length
|
||||
"",
|
||||
wantTensorErr: ErrMangled,
|
||||
},
|
||||
{
|
||||
name: "size computed",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
|
||||
// Tensor 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 768,
|
||||
Size: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ReadFile(strings.NewReader(tt.data))
|
||||
if err != nil {
|
||||
if !errors.Is(err, tt.wantReadErr) {
|
||||
t.Fatalf("unexpected ReadFile error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var got []MetaEntry
|
||||
for meta, err := range f.Metadata {
|
||||
if !errors.Is(err, tt.wantMetaErr) {
|
||||
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got = append(got, meta)
|
||||
}
|
||||
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
||||
|
||||
var gotT []TensorInfo
|
||||
for tinfo, err := range f.Tensors {
|
||||
if !errors.Is(err, tt.wantTensorErr) {
|
||||
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gotT = append(gotT, tinfo)
|
||||
}
|
||||
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzReadInfo(f *testing.F) {
|
||||
f.Add(string(magicBytes))
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"")
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
gf, err := ReadFile(strings.NewReader(data))
|
||||
if err != nil {
|
||||
t.Logf("ReadFile error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
for _, err := range gf.Metadata {
|
||||
if err != nil {
|
||||
t.Logf("metadata error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
for tinfo, err := range gf.Tensors {
|
||||
if err != nil {
|
||||
t.Logf("tensor error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Offset <= 0 {
|
||||
t.Logf("invalid tensor offset: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Size <= 0 {
|
||||
t.Logf("invalid tensor size: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type ggufReader struct {
|
||||
r *reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
|
||||
key, err := r.readString()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
typ, err := r.readValueType()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
var values []MetaValue
|
||||
for v, err := range r.readMetaValues(typ) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return MetaEntry{
|
||||
Key: string(key),
|
||||
Type: typ,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
|
||||
var value []byte
|
||||
var err error
|
||||
switch typ {
|
||||
case ValueTypeUint8, ValueTypeInt8:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeUint16, ValueTypeInt16:
|
||||
value, err = r.next(2)
|
||||
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
|
||||
value, err = r.next(4)
|
||||
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
|
||||
value, err = r.next(8)
|
||||
case ValueTypeBool:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeString:
|
||||
value, err = r.readString()
|
||||
case ValueTypeArray:
|
||||
err = fmt.Errorf("nested arrays are not supported")
|
||||
default:
|
||||
err = fmt.Errorf("unsupported metadata type: %d", typ)
|
||||
}
|
||||
if err != nil {
|
||||
return MetaValue{}, err
|
||||
}
|
||||
return MetaValue{
|
||||
Type: typ,
|
||||
Value: bytes.Clone(value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
|
||||
return func(yield func(MetaValue, error) bool) {
|
||||
if typ == ValueTypeArray {
|
||||
atyp, err := r.readValueType()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid type: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid length: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
v, err := r.readMetaValue(atyp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(v, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v, err := r.readMetaValue(typ)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata value: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
yield(v, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ggufReader) readValueType() (ValueType, error) {
|
||||
typ, err := r.readUint32()
|
||||
return ValueType(typ), err
|
||||
}
|
||||
|
||||
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
|
||||
name, err := r.readString()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
numDimensions, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
if numDimensions > MaxDimensions {
|
||||
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
|
||||
}
|
||||
|
||||
dims := make([]uint64, numDimensions)
|
||||
for i := range dims {
|
||||
d, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
dims[i] = d
|
||||
}
|
||||
typ, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
offset, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): check offset is multiple of ALIGNMENT
|
||||
return TensorInfo{
|
||||
Name: string(name),
|
||||
Dimensions: dims,
|
||||
Type: Type(typ),
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) next(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
|
||||
}
|
||||
w := r.r.window()
|
||||
for len(w) < n {
|
||||
if r.r.extend() == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
w = r.r.window()
|
||||
}
|
||||
r.r.release(n)
|
||||
r.n += n
|
||||
return w[:n], nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readString() ([]byte, error) {
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bmizerany): limit max string length
|
||||
return r.next(int(n))
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint32() (uint32, error) {
|
||||
b, err := r.next(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint32(b)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint64() (uint64, error) {
|
||||
b, err := r.next(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint64(b)
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import "io"
|
||||
|
||||
// A reader implements a sliding window over an io.Reader.
|
||||
type reader struct {
|
||||
data []byte
|
||||
offset int
|
||||
r io.Reader
|
||||
err error
|
||||
}
|
||||
|
||||
// release discards n bytes from the front of the window.
|
||||
func (b *reader) release(n int) {
|
||||
b.offset += n
|
||||
}
|
||||
|
||||
// window returns the current window.
|
||||
// The window is invalidated by calls to release or extend.
|
||||
func (b *reader) window() []byte {
|
||||
return b.data[b.offset:]
|
||||
}
|
||||
|
||||
// tuning constants for byteReader.extend.
|
||||
const (
|
||||
newBufferSize = 8 << 10
|
||||
minReadSize = newBufferSize >> 2
|
||||
)
|
||||
|
||||
// extend extends the window with data from the underlying reader.
|
||||
func (b *reader) extend() int {
|
||||
if b.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := len(b.data) - b.offset
|
||||
if remaining == 0 {
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
if cap(b.data)-len(b.data) >= minReadSize {
|
||||
// nothing to do, enough space exists between len and cap.
|
||||
} else if cap(b.data)-remaining >= minReadSize {
|
||||
// buffer has enough space if we move the data to the front.
|
||||
b.compact()
|
||||
} else {
|
||||
// otherwise, we must allocate/extend a new buffer
|
||||
b.grow()
|
||||
}
|
||||
remaining += b.offset
|
||||
n, err := b.r.Read(b.data[remaining:cap(b.data)])
|
||||
// reduce length to the existing plus the data we read.
|
||||
b.data = b.data[:remaining+n]
|
||||
b.err = err
|
||||
return n
|
||||
}
|
||||
|
||||
// grow grows the buffer, moving the active data to the front.
|
||||
func (b *reader) grow() {
|
||||
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
|
||||
copy(buf, b.data[b.offset:])
|
||||
b.data = buf
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves the active data to the front of the buffer.
|
||||
func (b *reader) compact() {
|
||||
copy(b.data, b.data[b.offset:])
|
||||
b.offset = 0
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")
|
||||
@@ -1,2 +0,0 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")
|
||||
132
x/model/file.go
132
x/model/file.go
@@ -1,132 +0,0 @@
|
||||
// Package model implements the File and Name types for working with and
|
||||
// representing Modelfiles and model Names.
|
||||
//
|
||||
// The Name type should be used when working with model names, and the File
|
||||
// type should be used when working with Modelfiles.
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ParamPragma struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type MessagePragma struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
// From is a required pragma that specifies the source of the model,
|
||||
// either on disk, or by reference (see model.ParseName).
|
||||
From string
|
||||
|
||||
// Optional
|
||||
Params []ParamPragma
|
||||
Template string
|
||||
System string
|
||||
Adapter string
|
||||
Messages []MessagePragma
|
||||
|
||||
License string
|
||||
}
|
||||
|
||||
type FileError struct {
|
||||
Pragma string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *FileError) Error() string {
|
||||
return e.Pragma + ": " + e.Message
|
||||
}
|
||||
|
||||
// Pragma represents a single pragma in a Modelfile.
|
||||
type Pragma struct {
|
||||
// The pragma name
|
||||
Name string
|
||||
|
||||
// Args contains the user-defined arguments for the pragma. If no
|
||||
// arguments were provided, it is nil.
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (p Pragma) Arg(i int) string {
|
||||
if i >= len(p.Args) {
|
||||
return ""
|
||||
}
|
||||
return p.Args[i]
|
||||
}
|
||||
|
||||
func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
|
||||
return func(yield func(Pragma, error) bool) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
// TODO(bmizerany): set a max num fields/args to
|
||||
// prevent mem bloat
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
p := Pragma{
|
||||
Name: strings.ToUpper(args[0]),
|
||||
}
|
||||
if p.Name == "MESSAGE" {
|
||||
// handle special case where message content
|
||||
// is space separated on the _rest_ of the
|
||||
// line like: `MESSAGE user Is Ontario in
|
||||
// Canada?`
|
||||
panic("TODO")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
p.Args = args[1:]
|
||||
}
|
||||
if !yield(p, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
yield(Pragma{}, sc.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(r io.Reader) (File, error) {
|
||||
var f File
|
||||
for p, err := range FilePragmas(r) {
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
switch p.Name {
|
||||
case "FROM":
|
||||
f.From = p.Arg(0)
|
||||
case "PARAMETER":
|
||||
f.Params = append(f.Params, ParamPragma{
|
||||
Key: strings.ToLower(p.Arg(0)),
|
||||
Value: p.Arg(1),
|
||||
})
|
||||
case "TEMPLATE":
|
||||
f.Template = p.Arg(0)
|
||||
case "SYSTEM":
|
||||
f.System = p.Arg(0)
|
||||
case "ADAPTER":
|
||||
f.Adapter = p.Arg(0)
|
||||
case "MESSAGE":
|
||||
f.Messages = append(f.Messages, MessagePragma{
|
||||
Role: p.Arg(0),
|
||||
Content: p.Arg(1),
|
||||
})
|
||||
case "LICENSE":
|
||||
f.License = p.Arg(0)
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package oweb
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
)
|
||||
|
||||
func Missing(field string) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "missing",
|
||||
Field: field,
|
||||
Message: fmt.Sprintf("%s is required", field),
|
||||
}
|
||||
}
|
||||
|
||||
func Invalid(field, value, format string, args ...any) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "invalid",
|
||||
Field: field,
|
||||
Value: value,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience errors
|
||||
var (
|
||||
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
|
||||
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
|
||||
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
|
||||
)
|
||||
|
||||
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if err := h(w, r); err != nil {
|
||||
// TODO: take a slog.Logger
|
||||
log.Printf("error: %v", err)
|
||||
var oe *ollama.Error
|
||||
if !errors.As(err, &oe) {
|
||||
oe = ErrInternal
|
||||
}
|
||||
oe.Status = cmp.Or(oe.Status, 400)
|
||||
w.WriteHeader(oe.Status)
|
||||
if err := EncodeJSON(w, oe); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
|
||||
v, err := DecodeJSON[T](r)
|
||||
|
||||
// Handle common JSON syntax errors
|
||||
var e *json.SyntaxError
|
||||
if errors.As(err, &e) {
|
||||
return nil, Invalid(field, "", e.Error())
|
||||
}
|
||||
|
||||
// Handle type errors
|
||||
var se *json.UnmarshalTypeError
|
||||
if errors.As(err, &se) {
|
||||
return nil, Invalid(field, se.Value, "expected %s", se.Type)
|
||||
}
|
||||
|
||||
// Return v and err as they were.
|
||||
return v, err
|
||||
}
|
||||
|
||||
func DecodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v *T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
var zero T
|
||||
return &zero, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func EncodeJSON(w io.Writer, v any) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package apitype
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Manifest struct {
|
||||
Layers []Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type CompletePart struct {
|
||||
URL string `json:"url"` // contains partNumber and uploadId from server
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"ref"`
|
||||
Manifest json.RawMessage `json:"manifest"`
|
||||
|
||||
// Parts is a list of upload parts that the client upload in the previous
|
||||
// push.
|
||||
CompleteParts []CompletePart `json:"part_uploads"`
|
||||
}
|
||||
|
||||
type Requirement struct {
|
||||
Digest string `json:"digest"`
|
||||
Offset int64 `json:"offset"`
|
||||
Size int64 `json:"Size"`
|
||||
|
||||
// URL is the url to PUT the layer to.
|
||||
//
|
||||
// Clients must include it as the URL, alond with the ETag in the
|
||||
// response headers from the PUT request, in the next push request
|
||||
// in the Uploaded field.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PushResponse struct {
|
||||
// Requirements is a list of digests that the client needs to push before
|
||||
// repushing the manifest.
|
||||
Requirements []Requirement `json:"requirements,omitempty"`
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func (c *Client) oclient() *ollama.Client {
|
||||
return (*ollama.Client)(c)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
CompleteParts []apitype.CompletePart
|
||||
}
|
||||
|
||||
// Push pushes a manifest to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
|
||||
p = cmp.Or(p, &PushParams{})
|
||||
// TODO(bmizerany): backoff
|
||||
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
||||
Name: ref,
|
||||
Manifest: manifest,
|
||||
CompleteParts: p.CompleteParts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.Requirements, nil
|
||||
}
|
||||
|
||||
func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
|
||||
var zero apitype.CompletePart
|
||||
if off < 0 {
|
||||
return zero, errors.New("off must be >0")
|
||||
}
|
||||
|
||||
file := io.NewSectionReader(body, off, n)
|
||||
req, err := http.NewRequest("PUT", url, file)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
req.ContentLength = n
|
||||
|
||||
// TODO(bmizerany): take content type param
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
if n >= 0 {
|
||||
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
e := parseS3Error(res)
|
||||
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
|
||||
}
|
||||
etag := strings.Trim(res.Header.Get("ETag"), `"`)
|
||||
cp := apitype.CompletePart{
|
||||
URL: url,
|
||||
ETag: etag,
|
||||
// TODO(bmizerany): checksum
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
type s3Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Resource string `xml:"Resource"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
func (e *s3Error) Error() string {
|
||||
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// parseS3Error parses an XML error response from S3.
|
||||
func parseS3Error(res *http.Response) error {
|
||||
var se *s3Error
|
||||
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
||||
@@ -1,256 +0,0 @@
|
||||
// Package implements an Ollama registry client and server package registry
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
DefaultUploadChunkSize = 50 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
UploadChunkSize int64 // default is DefaultUploadChunkSize
|
||||
S3Client *minio.Client
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.serveHTTP(w, r); err != nil {
|
||||
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
|
||||
var e *ollama.Error
|
||||
if !errors.As(err, &e) {
|
||||
e = oweb.ErrInternal
|
||||
}
|
||||
w.WriteHeader(cmp.Or(e.Status, 400))
|
||||
if err := oweb.EncodeJSON(w, e); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
case "/v1/pull":
|
||||
return s.handlePull(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) uploadChunkSize() int64 {
|
||||
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
|
||||
const bucketTODO = "test"
|
||||
const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec
|
||||
|
||||
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp := model.ParseName(pr.Name)
|
||||
if !mp.IsComplete() {
|
||||
return oweb.Invalid("name", pr.Name, "must be complete")
|
||||
}
|
||||
|
||||
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcc := &minio.Core{Client: s.s3()}
|
||||
// TODO(bmizerany): complete uploads before stats for any with ETag
|
||||
|
||||
type completeParts struct {
|
||||
key string
|
||||
parts []minio.CompletePart
|
||||
}
|
||||
|
||||
completePartsByUploadID := make(map[string]completeParts)
|
||||
for _, mcp := range pr.CompleteParts {
|
||||
// parse the URL
|
||||
u, err := url.Parse(mcp.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
|
||||
// Check if this is a part upload, if not, skip
|
||||
uploadID := q.Get("uploadId")
|
||||
if uploadID == "" {
|
||||
// not a part upload
|
||||
continue
|
||||
}
|
||||
|
||||
// PartNumber is required
|
||||
queryPartNumber := q.Get("partNumber")
|
||||
partNumber, err := strconv.Atoi(queryPartNumber)
|
||||
if err != nil {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "")
|
||||
}
|
||||
if partNumber < 1 {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "must be >= 1")
|
||||
}
|
||||
|
||||
// ETag is required
|
||||
if mcp.ETag == "" {
|
||||
return oweb.Missing("etag")
|
||||
}
|
||||
|
||||
cp := completePartsByUploadID[uploadID]
|
||||
cp.key = u.Path
|
||||
cp.parts = append(cp.parts, minio.CompletePart{
|
||||
PartNumber: partNumber,
|
||||
ETag: mcp.ETag,
|
||||
})
|
||||
completePartsByUploadID[uploadID] = cp
|
||||
}
|
||||
|
||||
for uploadID, cp := range completePartsByUploadID {
|
||||
var zeroOpts minio.PutObjectOptions
|
||||
|
||||
// TODO: gross fix!!!!!!!!!!!!!!!
|
||||
key := strings.TrimPrefix(cp.key, "/"+bucketTODO+"/")
|
||||
|
||||
fmt.Printf("Completing multipart upload %s %s %v\n", bucketTODO, key, cp.parts)
|
||||
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, key, uploadID, cp.parts, zeroOpts)
|
||||
if err != nil {
|
||||
var e minio.ErrorResponse
|
||||
if errors.As(err, &e) && e.Code == "NoSuchUpload" {
|
||||
return oweb.Invalid("uploadId", uploadID, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var requirements []apitype.Requirement
|
||||
for _, l := range m.Layers {
|
||||
// TODO(bmizerany): do in parallel
|
||||
if l.Size == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(bmizerany): "global" throttle of rate of transfer
|
||||
pushed, err := s.statObject(r.Context(), l.Digest)
|
||||
if err != nil {
|
||||
println("ERROR:", "statObject", err)
|
||||
return err
|
||||
}
|
||||
if !pushed {
|
||||
key := path.Join("blobs", l.Digest)
|
||||
if l.Size < minimumMultipartSize {
|
||||
// single part upload
|
||||
fmt.Printf("Presigning single %s %s\n", bucketTODO, key)
|
||||
signedURL, err := s.s3().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute)
|
||||
if err != nil {
|
||||
println("ERROR:", "presign single", err)
|
||||
return err
|
||||
}
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
} else {
|
||||
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Presigning multi %s %s %s\n", bucketTODO, key, uploadID)
|
||||
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
|
||||
const timeToStartUpload = 15 * time.Minute
|
||||
|
||||
signedURL, err := s.s3().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
|
||||
"partNumber": []string{strconv.Itoa(partNumber)},
|
||||
"uploadId": []string{uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
println("ERROR:", "presign multi", err)
|
||||
return err
|
||||
}
|
||||
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Offset: c.Offset,
|
||||
Size: c.N,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(requirements) == 0 {
|
||||
// Commit the manifest
|
||||
body := bytes.NewReader(pr.Manifest)
|
||||
path := path.Join("manifests", path.Join(mp.Parts()...))
|
||||
_, err := s.s3().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
// lookup manifest
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
|
||||
// HEAD the object
|
||||
path := path.Join("blobs", digest)
|
||||
_, err = s.s3().StatObject(ctx, "test", path, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if isNoSuchKey(err) {
|
||||
err = nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isNoSuchKey(err error) bool {
|
||||
var e minio.ErrorResponse
|
||||
return errors.As(err, &e) && e.Code == "NoSuchKey"
|
||||
}
|
||||
|
||||
func (s *Server) s3() *minio.Client {
|
||||
if s.S3Client != nil {
|
||||
return s.S3Client
|
||||
}
|
||||
s3, err := minio.New("localhost:9000", &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s3
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/backoff"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
// const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
// const manifest = `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`
|
||||
|
||||
// ts := newTestServer(t)
|
||||
// ts.pushNotOK(ref, `{}`, &ollama.Error{
|
||||
// Status: 400,
|
||||
// Code: "invalid",
|
||||
// Message: "name must be fully qualified",
|
||||
// })
|
||||
|
||||
// ts.push(ref, `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`)
|
||||
|
||||
type tWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (w tWriter) Write(p []byte) (n int, err error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestPushBasic(t *testing.T) {
|
||||
const MB = 1024 * 1024
|
||||
|
||||
mc := startMinio(t, true)
|
||||
|
||||
defer func() {
|
||||
mcc := &minio.Core{Client: mc}
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
|
||||
// Upload two small layers and one large layer that will
|
||||
// trigger a multipart upload.
|
||||
manifest := []byte(`{
|
||||
"layers": [
|
||||
{"digest": "sha256-1", "size": 1},
|
||||
{"digest": "sha256-2", "size": 2},
|
||||
{"digest": "sha256-3", "size": 11000000}
|
||||
]
|
||||
}`)
|
||||
|
||||
hs := httptest.NewServer(&Server{
|
||||
S3Client: mc,
|
||||
UploadChunkSize: 5 * MB,
|
||||
})
|
||||
t.Cleanup(hs.Close)
|
||||
c := &Client{BaseURL: hs.URL}
|
||||
|
||||
requirements, err := c.Push(context.Background(), ref, manifest, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(requirements) < 3 {
|
||||
t.Errorf("expected at least 3 requirements; got %d", len(requirements))
|
||||
t.Logf("requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var uploaded []apitype.CompletePart
|
||||
for i, r := range requirements {
|
||||
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
|
||||
|
||||
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
uploaded = append(uploaded, cp)
|
||||
}
|
||||
|
||||
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
|
||||
CompleteParts: uploaded,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(requirements) != 0 {
|
||||
t.Errorf("unexpected requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
|
||||
Recursive: true,
|
||||
})
|
||||
for k := range keys {
|
||||
paths = append(paths, k.Key)
|
||||
}
|
||||
|
||||
t.Logf("paths: %v", paths)
|
||||
|
||||
diff.Test(t, t.Errorf, paths, []string{
|
||||
"blobs/sha256-1",
|
||||
"blobs/sha256-2",
|
||||
"blobs/sha256-3",
|
||||
"manifests/registry.ollama.ai/x/y/latest/Z",
|
||||
})
|
||||
|
||||
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
var gotM apitype.Manifest
|
||||
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
|
||||
Layers: []apitype.Layer{
|
||||
{Digest: "sha256-1", Size: 1},
|
||||
{Digest: "sha256-2", Size: 2},
|
||||
{Digest: "sha256-3", Size: 11000000},
|
||||
},
|
||||
})
|
||||
|
||||
// checksum the blobs
|
||||
for i, l := range gotM.Layers {
|
||||
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
info, err := obj.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
|
||||
|
||||
if msg := checkABCs(obj, int(l.Size)); msg != "" {
|
||||
t.Errorf("[%d] %s", i, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
|
||||
// presigning a multipart upload, uploading the parts, and completing the
|
||||
// upload. It is for future reference and should not be deleted. This flow
|
||||
// is tricky and if we get it wrong in our server, we can refer back to this
|
||||
// as a "back to basics" test/reference.
|
||||
func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
|
||||
t.Skip("skipping reference test; unskip when needed")
|
||||
|
||||
mc := startMinio(t, true)
|
||||
mcc := &minio.Core{Client: mc}
|
||||
|
||||
uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var completed []minio.CompletePart
|
||||
const size int64 = 10 * 1024 * 1024
|
||||
const chunkSize = 5 * 1024 * 1024
|
||||
|
||||
for partNumber, c := range upload.Chunks(size, chunkSize) {
|
||||
u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
|
||||
"partNumber": {strconv.Itoa(partNumber)},
|
||||
"uploadId": {uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("[partNumber=%d]: %v", partNumber, u)
|
||||
|
||||
var body abcReader
|
||||
cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("completed part: %v", cp)
|
||||
|
||||
// behave like server here (don't cheat and use partNumber)
|
||||
// instead get partNumber from the URL
|
||||
retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
|
||||
completed = append(completed, minio.CompletePart{
|
||||
PartNumber: retPartNumber,
|
||||
ETag: cp.ETag,
|
||||
})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("completed: %v", info)
|
||||
|
||||
// Check key in bucket
|
||||
obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, obj); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotSum := h.Sum(nil)
|
||||
|
||||
h.Reset()
|
||||
var body abcReader
|
||||
if _, err := io.CopyN(h, &body, size); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantSum := h.Sum(nil)
|
||||
|
||||
if !bytes.Equal(gotSum, wantSum) {
|
||||
t.Errorf("got sum = %x; want %x", gotSum, wantSum)
|
||||
}
|
||||
}
|
||||
|
||||
func availableAddr() string {
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
// tracing is "experimental" and may be removed in the future, I can't get it to
|
||||
// work consistently, but I'm leaving it in for now.
|
||||
func startMinio(t *testing.T, trace bool) *minio.Client {
|
||||
t.Helper()
|
||||
|
||||
// Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
|
||||
// explicitly setting trace to true.
|
||||
trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
t.Cleanup(func() {
|
||||
// TODO(bmizerany): trim temp dir based on dates so that
|
||||
// future runs may be able to inspect results for some time.
|
||||
})
|
||||
|
||||
waitAndMaybeLogError := func(cmd *exec.Cmd) {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) {
|
||||
if e.Exited() {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
|
||||
t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
} else {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel must be called first so do wait to add to Cleanup
|
||||
// stack as last cleanup.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
deadline, ok := t.Deadline()
|
||||
if ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
|
||||
}
|
||||
|
||||
t.Logf(">> minio: minio server %s", dir)
|
||||
|
||||
addr := availableAddr()
|
||||
cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
})
|
||||
|
||||
mc, err := minio.New(addr, &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
// wait for server to start with exponential backoff
|
||||
for _, err := range backoff.Upto(ctx, 1*time.Second) {
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
// try list buckets to see if server is up
|
||||
if _, err := mc.ListBuckets(ctx); err == nil {
|
||||
break
|
||||
}
|
||||
t.Logf("startMinio: server is offline; retrying")
|
||||
}
|
||||
|
||||
if trace {
|
||||
cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
|
||||
cmd.Env = append(os.Environ(),
|
||||
"MC_HOST_test=http://minioadmin:minioadmin@"+addr,
|
||||
)
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
doneLogging := make(chan struct{})
|
||||
sc := bufio.NewScanner(stdout)
|
||||
go func() {
|
||||
defer close(doneLogging)
|
||||
|
||||
// Scan lines until the process exits.
|
||||
for sc.Scan() {
|
||||
t.Logf("startMinio: mc trace: %s", sc.Text())
|
||||
}
|
||||
_ = sc.Err() // ignore (not important)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
|
||||
// Make sure we do not log after test exists to
|
||||
// avoid panic.
|
||||
<-doneLogging
|
||||
})
|
||||
}
|
||||
|
||||
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
// contextForTest returns a context that is canceled when the test deadline,
|
||||
// if any, is reached. The returned doneLogging function should be called
|
||||
// after all Log/Error/Fatalf calls are done before the test returns.
|
||||
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
|
||||
done := make(chan struct{})
|
||||
deadline, ok := t.Deadline()
|
||||
if !ok {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
<-done
|
||||
})
|
||||
return ctx, func() { close(done) }
|
||||
}
|
||||
|
||||
// abcReader repeats the string s infinitely.
|
||||
type abcReader struct {
|
||||
pos int
|
||||
}
|
||||
|
||||
const theABCs = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
func (r *abcReader) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[r.pos]
|
||||
r.pos++
|
||||
if r.pos == len(theABCs) {
|
||||
r.pos = 0
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func checkABCs(r io.Reader, size int) (reason string) {
|
||||
h := sha256.New()
|
||||
n, err := io.CopyN(h, &abcReader{}, int64(size))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
panic("short read; should not happen")
|
||||
}
|
||||
want := h.Sum(nil)
|
||||
h = sha256.New()
|
||||
n, err = io.Copy(h, r)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
return fmt.Sprintf("got len(r) = %d; want %d", n, size)
|
||||
}
|
||||
got := h.Sum(nil)
|
||||
if !bytes.Equal(got, want) {
|
||||
return fmt.Sprintf("got sum = %x; want %x", got, want)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package empty
|
||||
|
||||
// Message is a placeholder type used when encoding json messages.
|
||||
type Message struct{}
|
||||
@@ -1,12 +0,0 @@
|
||||
package they
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Want returns true if the request method is method and the request path
|
||||
// starts with pathPrefix.
|
||||
func Want(r *http.Request, method string, pathPrefix string) bool {
|
||||
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrMaxAttempts is not used by backoff but is available for use by
|
||||
// callers that want to signal that a maximum number of retries has
|
||||
// been exceeded. This should eliminate the need for callers to invent
|
||||
// their own error.
|
||||
ErrMaxAttempts = errors.New("max retries exceeded")
|
||||
)
|
||||
|
||||
// Upto implements a backoff strategy that yields nil errors until the
|
||||
// context is canceled, the maxRetries is exceeded, or yield returns false.
|
||||
//
|
||||
// The backoff strategy is a simple exponential backoff with a maximum
|
||||
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
|
||||
// the current backoff, in order to prevent accidental "thundering herd"
|
||||
// problems.
|
||||
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Chunk[I constraints.Integer] struct {
|
||||
Offset I
|
||||
N I
|
||||
}
|
||||
|
||||
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
||||
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
||||
// not a multiple of chunkSize.
|
||||
//
|
||||
// The first part number is 1 and increases monotonically.
|
||||
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
|
||||
return func(yield func(int, Chunk[I]) bool) {
|
||||
var n int
|
||||
for off := I(0); off < size; off += chunkSize {
|
||||
n++
|
||||
if !yield(n, Chunk[I]{off, min(chunkSize, size-off)}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestChunks(t *testing.T) {
|
||||
const size = 101
|
||||
const chunkSize = 10
|
||||
var got []Chunk[int]
|
||||
var lastN int
|
||||
for n, c := range Chunks(size, chunkSize) {
|
||||
if n != lastN+1 {
|
||||
t.Errorf("n = %d; want %d", n, lastN+1)
|
||||
}
|
||||
got = append(got, c)
|
||||
lastN = n
|
||||
}
|
||||
|
||||
want := []Chunk[int]{
|
||||
{0, 10},
|
||||
{10, 10},
|
||||
{20, 10},
|
||||
{30, 10},
|
||||
{40, 10},
|
||||
{50, 10},
|
||||
{60, 10},
|
||||
{70, 10},
|
||||
{80, 10},
|
||||
{90, 10},
|
||||
{100, 1},
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, want)
|
||||
}
|
||||
|
||||
func TestChunksBreak(t *testing.T) {
|
||||
for _, _ = range Chunks(1, 1) {
|
||||
return
|
||||
}
|
||||
t.Fatal("expected break")
|
||||
}
|
||||
Reference in New Issue
Block a user