Compare commits

...

33 Commits

Author SHA1 Message Date
Blake Mizerany
889c5b1a75 ... 2024-05-01 15:56:59 -07:00
Blake Mizerany
844217bcf1 weaving in 2024-05-01 15:32:24 -07:00
Blake Mizerany
8afe873f17 ... 2024-04-30 16:53:47 -07:00
Blake Mizerany
7ba71c3989 ... 2024-04-30 16:45:36 -07:00
Blake Mizerany
fdef9a0eb2 ... 2024-04-30 16:45:36 -07:00
Blake Mizerany
b9f74ff3d6 types/model: reintroduce Digest (#4065) 2024-04-30 16:38:03 -07:00
jmorganca
fcf4d60eee llm: add back check for empty token cache 2024-04-30 17:38:44 -04:00
jmorganca
e33d5c2dbc update llama.cpp commit to 952d03d 2024-04-30 17:31:20 -04:00
Jeffrey Morgan
18d9a7e1f1 update llama.cpp submodule to f364eb6 (#4060) 2024-04-30 17:25:39 -04:00
Michael
8488388cbd Update README.md 2024-04-30 15:45:56 -04:00
Blake Mizerany
588901f449 types/model: reduce Name.Filepath allocs from 5 to 2 (#4039) 2024-04-30 11:09:19 -07:00
Bruce MacDonald
0a7fdbe533 prompt to display and add local ollama keys to account (#3717)
- return descriptive error messages when unauthorized to create blob or push a model
- display the local public key associated with the request that was denied
2024-04-30 11:02:08 -07:00
Christian Frantzen
5950c176ca Update langchainpy.md (#4037)
Updated the code a bit
2024-04-29 23:19:06 -04:00
Daniel Hiltgen
23d23409a0 Update llama.cpp (#4036)
* Bump llama.cpp to b2761

* Adjust types for bump
2024-04-29 23:18:48 -04:00
Patrick Devine
9009bedf13 better checking for OLLAMA_HOST variable (#3661) 2024-04-29 19:14:07 -04:00
Daniel Hiltgen
d4ac57e240 Merge pull request #4035 from dhiltgen/fix_relative_paths
Fix relative path lookup
2024-04-29 16:08:06 -07:00
Daniel Hiltgen
7b59d1770f Fix relative path lookup 2024-04-29 16:00:08 -07:00
Jeffrey Morgan
95ead8ffba Restart server on failure when running Windows app (#3985)
* app: restart server on failure

* fix linter

* address comments

* refactor log directory creation to be where logs are written

* check all log dir creation errors
2024-04-29 10:07:52 -04:00
Jeffrey Morgan
7aa08a77ca llm: dont cap context window limit to training context window (#3988) 2024-04-29 10:07:30 -04:00
Blake Mizerany
7e432cdfac types/model: remove old comment (#4020) 2024-04-28 20:52:26 -07:00
Jeffrey Morgan
586672f490 fix copying model to itself (#4019) 2024-04-28 23:47:49 -04:00
Daniel Hiltgen
b03408de74 Merge pull request #3972 from hmartinez82/win_arm64
Add support for building on Windows ARM64
2024-04-28 14:52:58 -07:00
Daniel Hiltgen
1e6a28bf5b Merge pull request #4009 from dhiltgen/cpu_concurrency
Fix concurrency for CPU mode
2024-04-28 14:20:27 -07:00
Daniel Hiltgen
d6e3b64582 Fix concurrency for CPU mode
Prior refactoring passes accidentally removed the logic to bypass VRAM
checks for CPU loads.  This adds that back, along with test coverage.

This also fixes loaded map access in the unit test to be behind the mutex which was
likely the cause of various flakes in the tests.
2024-04-28 13:42:39 -07:00
Blake Mizerany
114c932a8e types/model: allow _ as starter character in Name parts (#3991) 2024-04-27 21:24:52 -07:00
Jeffrey Morgan
7f7103de06 mac: update setup command to llama3 (#3986) 2024-04-27 22:52:10 -04:00
Blake Mizerany
c631a9c726 types/model: relax name length constraint from 2 to 1 (#3984) 2024-04-27 17:58:41 -07:00
Blake Mizerany
8fd9e56804 types/structs: drop unused structs package (#3981) 2024-04-27 14:06:11 -07:00
Hernan Martinez
8a65717f55 Do not build AVX runners on ARM64 2024-04-26 23:55:32 -06:00
Hernan Martinez
6d3152a98a Use architecture specific folders in installer script 2024-04-26 23:35:16 -06:00
Hernan Martinez
b438d485f1 Use architecture specific folders in the generate script 2024-04-26 23:34:12 -06:00
Hernan Martinez
204349b17b Use architecture specific folders in the build script 2024-04-26 23:26:03 -06:00
Hernan Martinez
86e67fc4a9 Add import declaration for windows,arm64 to llm.go 2024-04-26 23:23:53 -06:00
28 changed files with 1178 additions and 140 deletions

View File

@@ -51,7 +51,7 @@ Here are some example models that can be downloaded:
| ------------------ | ---------- | ----- | ------------------------------ |
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
| Starling | 7B | 4.1GB | `ollama run starling-lm` |

View File

@@ -18,6 +18,7 @@ import (
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
return &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434"
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
}
}
return &Client{
base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http: http.DefaultClient,
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme,
Host: host,
Port: port,
}, nil
}

View File

@@ -1,6 +1,12 @@
package api
import "testing"
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
}
})
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View File

@@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
}
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct

View File

@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
return command
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
logDir := filepath.Dir(ServerLogFile)
_, err := os.Stat(logDir)
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
func start(ctx context.Context, command string) (*exec.Cmd, error) {
cmd := getCmd(ctx, getCLIFullPath(command))
// send stdout and stderr to a file
stdout, err := cmd.StdoutPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
}
// TODO - rotation
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
if err != nil {
return done, fmt.Errorf("failed to create server log %w", err)
return nil, fmt.Errorf("failed to create server log: %w", err)
}
logDir := filepath.Dir(ServerLogFile)
_, err = os.Stat(logDir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
}
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
go func() {
defer logFile.Close()
io.Copy(logFile, stdout) //nolint:errcheck
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
// run the command and wait for it to finish
if err := cmd.Start(); err != nil {
return done, fmt.Errorf("failed to start server %w", err)
return nil, fmt.Errorf("failed to start server %w", err)
}
if cmd.Process != nil {
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
}
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
return cmd, nil
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
go func() {
// Keep the server running unless we're shuttind down the app
crashCount := 0
for {
slog.Info("starting server...")
cmd, err := start(ctx, command)
if err != nil {
crashCount++
slog.Error(fmt.Sprintf("failed to start server %s", err))
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
continue
}
cmd.Wait() //nolint:errcheck
stdin.Close()
var code int
if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode()
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
default:
crashCount++
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
time.Sleep(500 * time.Millisecond)
if err := cmd.Start(); err != nil {
slog.Error(fmt.Sprintf("failed to restart server %s", err))
// Keep trying, but back off if we keep failing
time.Sleep(time.Duration(crashCount) * time.Second)
}
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
break
}
}
}()
return done, nil
}

View File

@@ -88,8 +88,8 @@ DialogFontSize=12
[Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
#if DirExists("..\dist\windows-amd64\rocm")

View File

@@ -10,12 +10,44 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
}
func NewNonce(r io.Reader, length int) (string, error) {
nonce := make([]byte, length)
if _, err := io.ReadFull(r, nonce); err != nil {
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
home, err := os.UserHomeDir()
keyPath, err := keyPath()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -0,0 +1,95 @@
package apitype
import (
"cmp"
"encoding/json"
"log/slog"
"net/url"
"slices"
)
type Manifest struct {
Layers []*Layer `json:"layers"`
}
type CompletePart struct {
URL string `json:"url"` // contains partNumber and uploadId from server
ETag string `json:"etag"`
}
func queryFromString(s string) url.Values {
u, err := url.Parse(s)
if err != nil {
return nil
}
return u.Query()
}
func (cp *CompletePart) Compare(o *CompletePart) int {
qa := queryFromString(cp.URL)
qb := queryFromString(o.URL)
return cmp.Or(
cmp.Compare(qa.Get("partNumber"), qb.Get("partNumber")),
cmp.Compare(qa.Get("uploadId"), qb.Get("uploadId")),
cmp.Compare(cp.ETag, o.ETag),
)
}
func SortCompleteParts(a []*CompletePart) {
slices.SortFunc(a, (*CompletePart).Compare)
}
type Layer struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
// If present, URL is a remote location of the layer for fetching.
URL string `json:"url,omitempty"`
}
func (l *Layer) LogValue() slog.Value {
return slog.GroupValue(
slog.String("digest", l.Digest),
slog.String("mediaType", l.MediaType),
slog.Int64("size", l.Size),
slog.String("url", l.URL),
)
}
type PushRequest struct {
Name string `json:"ref"`
Manifest json.RawMessage `json:"manifest,omitempty"`
// Parts is a list of upload parts that the client upload in the previous
// push.
CompleteParts []*CompletePart `json:"part_uploads"`
}
type Need struct {
Digest string `json:"digest"`
Start int64 `json:"start"`
End int64 `json:"end"`
// URL is the url to PUT the layer to.
//
// Clients must include it as the URL, along with the ETag in the
// response headers from the PUT request, in the next push request
// in the Uploaded field.
URL string `json:"url"`
}
type PushResponse struct {
// Needs is a list of digests that the client needs to push before
// repushing the manifest.
Needs []*Need `json:"requirements,omitempty"`
}
type PullResponse struct {
// Name is the name of the model being pulled.
Name string `json:"name"`
// Manifest is the manifest of the model being pulled.
Manifest *Manifest `json:"manifest"`
}

421
client/registry/registry.go Normal file
View File

@@ -0,0 +1,421 @@
package registry
import (
"cmp"
"context"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"iter"
"log/slog"
"net/http"
"net/url"
"os"
"sync"
"github.com/ollama/ollama/client/ollama"
"github.com/ollama/ollama/client/registry/apitype"
"github.com/ollama/ollama/types/model"
"golang.org/x/exp/constraints"
"golang.org/x/sync/errgroup"
)
// Errors
var (
ErrLayerNotFound = errors.New("layer not found")
)
type Client struct {
BaseURL string
Logger *slog.Logger
// NameFill is a string that is used to fill in the missing parts of
// a name when it is not fully qualified. It is used to make a name
// fully qualified before pushing or pulling it. The default is
// "registry.ollama.ai/library/_:latest".
//
// Most users can ignore this field. It is intended for use by
// clients that need to push or pull names to registries other than
// registry.ollama.ai, and for testing.
NameFill string
}
func (c *Client) log() *slog.Logger {
return cmp.Or(c.Logger, slog.Default())
}
func (c *Client) oclient() *ollama.Client {
return &ollama.Client{
BaseURL: c.BaseURL,
}
}
type ReadAtSeekCloser interface {
io.ReaderAt
io.Seeker
io.Closer
}
type Cache interface {
// LayerFile returns the absolute file path to the layer file for
// the given model digest.
//
// If the digest is invalid, or the layer does not exist, the empty
// string is returned.
LayerFile(model.Digest) string
// OpenLayer opens the layer file for the given model digest and
// returns it, or an if any. The caller is responsible for closing
// the returned file.
OpenLayer(model.Digest) (ReadAtSeekCloser, error)
// PutLayerFile moves the layer file at fromPath to the cache for
// the given model digest. It is a hack intended to short circuit a
// file copy operation.
//
// The file returned is expected to exist for the lifetime of the
// cache.
//
// TODO(bmizerany): remove this; find a better way. Once we move
// this into a build package, we should be able to get rid of this.
PutLayerFile(_ model.Digest, fromPath string) error
// SetManifestData sets the provided manifest data for the given
// model name. If the manifest data is empty, the manifest is
// removed. If the manifeest exists, it is overwritten.
//
// It is an error to call SetManifestData with a name that is not
// complete.
SetManifestData(model.Name, []byte) error
// ManifestData returns the manifest data for the given model name.
//
// If the name incomplete, or the manifest does not exist, the empty
// string is returned.
ManifestData(name model.Name) []byte
}
// Pull pulls the manifest for name, and downloads any of its required
// layers that are not already in the cache. It returns an error if any part
// of the process fails, specifically:
func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: pull: invalid name: %s", name)
}
log := c.log().With("name", name)
pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
if err != nil {
return fmt.Errorf("ollama: pull: %w: %s", err, name)
}
if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
}
// download required layers we do not already have
for _, l := range pr.Manifest.Layers {
d, err := model.ParseDigest(l.Digest)
if err != nil {
return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
}
if cache.LayerFile(d) != "" {
continue
}
err = func() error {
log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
log.Debug("starting download")
// TODO(bmizerany): stop using temp which might not
// be on same device as cache.... instead let cache
// give us a place to store parts...
tmpFile, err := os.CreateTemp("", "ollama-download-")
if err != nil {
return err
}
defer func() {
tmpFile.Close()
os.Remove(tmpFile.Name()) // in case we fail before committing
}()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(8) // TODO(bmizerany): make this configurable
// TODO(bmizerany): make chunk size configurable
const chunkSize = 50 * 1024 * 1024 // 50MB
chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
g.Go(func() (err error) {
defer func() {
if err == nil {
return
}
safeURL := redactAmzSignature(l.URL)
err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
}()
log.Debug("downloading", "range", rng)
// TODO(bmizerany): retry
// TODO(bmizerany): use real http client
// TODO(bmizerany): resumable
// TODO(bmizerany): multipart download
req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", "bytes="+rng.String())
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
}
if res.ContentLength != rng.Size() {
return fmt.Errorf("unexpected content length: %d", res.ContentLength)
}
w := io.NewOffsetWriter(tmpFile, rng.Start)
_, err = io.Copy(w, res.Body)
return err
})
return true
})
if err := g.Wait(); err != nil {
return err
}
tmpFile.Close() // release our hold on the file before moving it
return cache.PutLayerFile(d, tmpFile.Name())
}()
if err != nil {
return fmt.Errorf("ollama: pull: %w", err)
}
}
// do not store the presigned URLs in the cache
for i := range pr.Manifest.Layers {
pr.Manifest.Layers[i].URL = ""
}
data, err := json.Marshal(pr.Manifest)
if err != nil {
return err
}
// TODO(bmizerany): remove dep on model.Name
return cache.SetManifestData(mn, data)
}
type nopSeeker struct {
io.Reader
}
func (nopSeeker) Seek(int64, int) (int64, error) {
return 0, nil
}
func parseNameFill(name, fill string) model.Name {
fill = cmp.Or(fill, "bllamo.com/library/_:latest")
f := model.ParseNameBare(fill)
if !f.IsFullyQualified() {
panic(fmt.Errorf("invalid fill: %q", fill))
}
return model.Merge(model.ParseNameBare(name), f)
}
// Push pushes a manifest to the server and responds to the server's
// requests for layer uploads, if any, and finally commits the manifest for
// name. It returns an error if any part of the process fails, specifically:
//
// If the server requests layers not found in the cache, ErrLayerNotFound is
// returned.
func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: push: invalid name: %s", name)
}
manifest := cache.ManifestData(mn)
if len(manifest) == 0 {
return fmt.Errorf("manifest not found: %s", name)
}
var mu sync.Mutex
var completed []*apitype.CompletePart
push := func() (*apitype.PushResponse, error) {
v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Name: name,
Manifest: manifest,
CompleteParts: completed,
})
if err != nil {
return nil, fmt.Errorf("Do: %w", err)
}
return v, nil
}
pr, err := push()
if err != nil {
return err
}
var g errgroup.Group
for _, need := range pr.Needs {
g.Go(func() error {
nd, err := model.ParseDigest(need.Digest)
if err != nil {
return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
}
f, err := cache.OpenLayer(nd)
if err != nil {
return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
}
defer f.Close()
c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End)
cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
if err != nil {
return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
}
mu.Lock()
completed = append(completed, cp)
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return fmt.Errorf("Push: Required: %w", err)
}
if len(completed) > 0 {
pr, err := push()
if err != nil {
return err
}
if len(pr.Needs) > 0 {
var errs []error
for _, r := range pr.Needs {
errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
}
return errors.Join(errs...)
}
}
return cache.SetManifestData(mn, manifest)
}
func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
if start < 0 || end < start {
return nil, errors.New("start must satisfy 0 <= start <= end")
}
file := io.NewSectionReader(body, start, end-start+1)
req, err := http.NewRequest("PUT", url, file)
if err != nil {
return nil, err
}
req.ContentLength = end - start + 1
// TODO(bmizerany): take content type param
req.Header.Set("Content-Type", "text/plain")
if start != 0 || end != 0 {
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
e := parseS3Error(res)
return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
}
cp := &apitype.CompletePart{
URL: url,
ETag: res.Header.Get("ETag"),
// TODO(bmizerany): checksum
}
return cp, nil
}
type s3Error struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestId string `xml:"RequestId"`
}
func (e *s3Error) Error() string {
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
}
// parseS3Error parses an XML error response from S3.
func parseS3Error(res *http.Response) error {
var se *s3Error
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
return err
}
return se
}
// TODO: replace below by using upload pkg after we have rangefunc; until
// then, we need to keep this free of rangefunc for now.
type chunkRange[I constraints.Integer] struct {
// Start is the byte offset of the chunk.
Start I
// End is the byte offset of the last byte in the chunk.
End I
}
func (c chunkRange[I]) Size() I {
return c.End - c.Start + 1
}
func (c chunkRange[I]) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}
func (c chunkRange[I]) LogValue() slog.Value {
return slog.StringValue(c.String())
}
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
// not a multiple of chunkSize.
//
// The first part number is 1 and increases monotonically.
func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
return func(yield func(int, chunkRange[I]) bool) {
var n int
for off := I(0); off < size; off += chunkSize {
n++
if !yield(n, chunkRange[I]{
Start: off,
End: off + min(chunkSize, size-off) - 1,
}) {
return
}
}
}
}
func redactAmzSignature(s string) string {
u, err := url.Parse(s)
if err != nil {
return ""
}
q := u.Query()
q.Set("X-Amz-Signature", "REDACTED")
u.RawQuery = q.Encode()
return u.String()
}

View File

@@ -32,10 +32,13 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -357,6 +360,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -404,6 +448,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
host := model.ParseName(args[0]).Host
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// re-throw an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err
}
@@ -831,19 +889,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil {
host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
return err
}
if err := initializeKeypair(); err != nil {
return err
}
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
if err != nil {
return err
}

View File

@@ -17,10 +17,12 @@ Let's start by asking a simple question that we can get an answer to from the **
Then we can create a model and ask the question:
```python
from langchain.llms import Ollama
ollama = Ollama(base_url='http://localhost:11434',
model="llama2")
print(ollama("why is the sky blue"))
from langchain_community.llms import Ollama
ollama = Ollama(
base_url='http://localhost:11434',
model="llama3"
)
print(ollama.invoke("why is the sky blue"))
```
Notice that we are defining the model and the base URL for Ollama.

View File

@@ -40,7 +40,7 @@ func PayloadsDir() (string, error) {
}
var paths []string
for _, root := range []string{appExe, cwd} {
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),

View File

@@ -1032,7 +1032,7 @@ struct llama_server_context
slot.has_next_token = false;
}
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
{
slot.stopped_eos = true;
slot.has_next_token = false;
@@ -1144,12 +1144,15 @@ struct llama_server_context
res.result_json = json
{
{"content", tkn.text_to_send},
{"stop", false},
{"slot_id", slot.id},
{"multimodal", multimodal}
};
if (!llama_token_is_eog(model, tkn.tok)) {
res.result_json["content"] = tkn.text_to_send;
}
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
@@ -2644,18 +2647,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
kvo.val_i64 = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
kvo.val_f64 = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
kvo.val_bool = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
kvo.val_bool = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;

View File

@@ -42,7 +42,7 @@ function init_vars {
"-DLLAMA_NATIVE=off"
)
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = "amd64" # arm not yet supported.
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
md "$script:DIST_BASE" -ea 0 > $null
if ($env:CGO_CFLAGS -contains "-g") {
@@ -213,11 +213,11 @@ function build_static() {
}
}
function build_cpu() {
function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
@@ -349,11 +349,15 @@ if ($($args.count) -eq 0) {
git_module_setup
apply_patches
build_static
build_cpu
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
if ($script:ARCH -eq "arm64") {
build_cpu("ARM64")
} else { # amd64
build_cpu("x64")
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
}
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"

View File

@@ -4,6 +4,7 @@ package llm
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h>

View File

@@ -73,8 +73,7 @@ func LoadModel(model string) (*GGML, error) {
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {

View File

@@ -19,7 +19,7 @@ export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama2'
const command = 'ollama run llama3'
return (
<div className='drag'>

View File

@@ -7,6 +7,8 @@
$ErrorActionPreference = "Stop"
function checkEnv() {
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
Write-host "Building for ${script:TARGET_ARCH}"
write-host "Locating required tools and paths"
$script:SRC_DIR=$PWD
if (!$env:VCToolsRedistDir) {
@@ -30,7 +32,7 @@ function checkEnv() {
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
echo "Checking version"
if (!$env:VERSION) {
@@ -81,8 +83,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
cp .\ollama.exe .\dist\windows-amd64\
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
}
function buildApp() {
@@ -127,16 +129,16 @@ function buildInstaller() {
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
} else {
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
}
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
function distZip() {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
}
try {

75
server/cache.go Normal file
View File

@@ -0,0 +1,75 @@
package server
import (
"cmp"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/types/model"
)
// cache is a simple demo disk cache. it does not validate anything
type cache struct {
dir string
}
func defaultCache() registry.Cache {
homeDir, _ := os.UserHomeDir()
if homeDir == "" {
panic("could not determine home directory")
}
modelsDir := cmp.Or(
os.Getenv("OLLAMA_MODELS"),
filepath.Join(homeDir, ".ollama", "models"),
)
return &cache{modelsDir}
}
func invalidDigest(digest string) error {
return fmt.Errorf("invalid digest: %s", digest)
}
func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) {
return os.Open(c.LayerFile(d))
}
func (c *cache) LayerFile(d model.Digest) string {
return filepath.Join(c.dir, "blobs", d.String())
}
func (c *cache) PutLayerFile(d model.Digest, fromPath string) error {
if !d.IsValid() {
return invalidDigest(d.String())
}
bfile := c.LayerFile(d)
dir, _ := filepath.Split(bfile)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.Rename(fromPath, bfile)
}
func (c *cache) ManifestData(name model.Name) []byte {
if !name.IsFullyQualified() {
return nil
}
data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath()))
if err != nil {
return nil
}
return data
}
func (c *cache) SetManifestData(name model.Name, data []byte) error {
if !name.IsFullyQualified() {
return fmt.Errorf("invalid name: %s", name)
}
filep := filepath.Join(c.dir, "manifests", name.Filepath())
dir, _ := filepath.Split(filep)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.WriteFile(filep, data, 0644)
}

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -25,10 +26,12 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -710,6 +713,10 @@ func CopyModel(src, dst model.Name) error {
return model.Unqualified(src)
}
if src.Filepath() == dst.Filepath() {
return nil
}
manifests, err := GetManifestPath()
if err != nil {
return err
@@ -976,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
}
return err
}
}
@@ -1141,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = errors.New("unauthorized")
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string {
parts := strings.Split(token, ".")
if len(parts) != 3 {
slog.Error("jwt token does not contain 3 parts")
return ""
}
payload := parts[1]
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
return ""
}
var payloadMap map[string]interface{}
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
return ""
}
sub, ok := payloadMap["sub"]
if !ok {
slog.Error("jwt does not contain 'sub' field")
return ""
}
return fmt.Sprintf("%s", sub)
}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
@@ -1162,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, err
}
anonymous = getTokenSubject(token) == "anonymous"
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
@@ -1182,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
}
}
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized
}

View File

@@ -1,6 +1,7 @@
package server
import (
"cmp"
"context"
"encoding/json"
"errors"
@@ -17,6 +18,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
@@ -25,6 +27,8 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/client/ollama"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@@ -33,6 +37,23 @@ import (
"github.com/ollama/ollama/version"
)
// envs
var (
envRegistryBaseURL = cmp.Or(os.Getenv("OLLAMA_REGISTRY_BASE_URL"), "https://bllamo.com")
)
func init() {
ollama.I_Acknowledge_This_API_Is_Unstable = true
}
var experiments = sync.OnceValue(func() []string {
return strings.Split(strings.ToLower(os.Getenv("OLLAMA_EXPERIMENT")), ",")
})
func useExperiment(flag string) bool {
return slices.Contains(experiments(), flag)
}
var mode string = gin.DebugMode
type Server struct {
@@ -444,6 +465,25 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
if useExperiment("pull") {
rc := &registry.Client{
BaseURL: envRegistryBaseURL,
}
modelsDir, err := modelsDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
cache := &cache{dir: modelsDir}
println("DIR: ", modelsDir)
// TODO(bmizerany): progress updates
if err := rc.Pull(c.Request.Context(), cache, model); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
ch := make(chan any)
go func() {
defer close(ch)

View File

@@ -149,6 +149,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
// If we're CPU only mode, just limit by loadedMax above
// TODO handle system memory exhaustion
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
slog.Debug("cpu mode with existing models, loading")
s.loadFn(pending, ggml, gpus)
break
}
// No models loaded. Load the model but prefer the best fit.
if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)

View File

@@ -28,19 +28,33 @@ func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
initialParallel := numParallel
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_NUM_PARALLEL", "blue")
_ = InitScheduler(ctx)
require.Equal(t, initialParallel, numParallel)
os.Setenv("OLLAMA_NUM_PARALLEL", "10")
_ = InitScheduler(ctx)
require.Equal(t, 10, numParallel)
}
func TestLoad(t *testing.T) {
@@ -51,6 +65,7 @@ func TestLoad(t *testing.T) {
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
@@ -63,7 +78,9 @@ func TestLoad(t *testing.T) {
s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
@@ -78,7 +95,9 @@ func TestLoad(t *testing.T) {
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
}
req.model.ModelPath = "dummy_model_path"
@@ -90,7 +109,9 @@ func TestLoad(t *testing.T) {
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"]
s.loadedMu.Unlock()
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
time.Sleep(1 * time.Millisecond)
@@ -143,6 +164,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
@@ -171,7 +193,9 @@ func TestRequests(t *testing.T) {
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@@ -240,7 +264,9 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
loadedMax = 0
s.newServerFn = scenario3b.newServer
@@ -254,19 +280,14 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
// Try to load a model that wont fit
// This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
@@ -275,7 +296,36 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
// Try to load a model that wont fit
s.newServerFn = scenario3d.newServer
slog.Info("scenario3d")
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
scenario3b.ctxDone()
select {
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3d.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
}
func TestGetRunner(t *testing.T) {
@@ -318,7 +368,9 @@ func TestGetRunner(t *testing.T) {
t.Errorf("timeout")
}
scenario1a.ctxDone()
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
@@ -328,7 +380,9 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 0)
time.Sleep(5 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
@@ -358,7 +412,9 @@ func TestPrematureExpired(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
@@ -383,6 +439,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
@@ -426,8 +483,10 @@ func TestUpdateFreeSpace(t *testing.T) {
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
@@ -437,13 +496,18 @@ func TestUpdateFreeSpace(t *testing.T) {
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{ctx: ctx}
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
@@ -458,10 +522,11 @@ func TestNeedsReload(t *testing.T) {
defer done()
llm := &mockLlm{}
do := api.DefaultOptions()
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &api.Options{},
Options: &do,
llama: llm,
}
req := &LlmRequest{
@@ -469,7 +534,7 @@ func TestNeedsReload(t *testing.T) {
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.Options{},
opts: api.DefaultOptions(),
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
@@ -508,8 +573,10 @@ func TestUnloadAllRunners(t *testing.T) {
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.unloadAllRunners()
require.True(t, llm1.closeCalled)

View File

@@ -0,0 +1,18 @@
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const UnknownOllamaKeyErrMsg = "unknown ollama key"
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}

View File

@@ -4,6 +4,7 @@ package model
import (
"cmp"
"encoding/hex"
"errors"
"fmt"
"log/slog"
@@ -80,9 +81,6 @@ func (k partKind) String() string {
//
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
//
// It is not directly comparable with other Names. Use [Name.Equal] and
// [Name.MapHash] for determining equality and using as a map key.
type Name struct {
Host string
Namespace string
@@ -109,20 +107,20 @@ type Name struct {
// { model }
// "@" { digest }
// host:
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: alphanum { alphanum | "-" | "_" }*
// length: [2, 80]
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// model:
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// length: [2, 80]
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: alphanum { alphanum | "-" | ":" }*
// length: [2, 80]
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// Most users should use [ParseName] instead, unless need to support
// different defaults than DefaultName.
@@ -234,12 +232,12 @@ func (n Name) Filepath() string {
if !n.IsFullyQualified() {
panic("illegal attempt to get filepath of invalid name")
}
return filepath.Join(
strings.ToLower(n.Host),
strings.ToLower(n.Namespace),
strings.ToLower(n.Model),
strings.ToLower(n.Tag),
)
return strings.ToLower(filepath.Join(
n.Host,
n.Namespace,
n.Model,
n.Tag,
))
}
// LogValue returns a slog.Value that represents the name as a string.
@@ -254,7 +252,7 @@ func isValidLen(kind partKind, s string) bool {
case kindTag:
return len(s) >= 1 && len(s) <= 80
default:
return len(s) >= 2 && len(s) <= 80
return len(s) >= 1 && len(s) <= 80
}
}
@@ -264,7 +262,7 @@ func isValidPart(kind partKind, s string) bool {
}
for i := range s {
if i == 0 {
if !isAlphanumeric(s[i]) {
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
continue
@@ -280,7 +278,7 @@ func isValidPart(kind partKind, s string) bool {
return false
}
default:
if !isAlphanumeric(s[i]) {
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
}
@@ -288,8 +286,8 @@ func isValidPart(kind partKind, s string) bool {
return true
}
func isAlphanumeric(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
func isAlphanumericOrUnderscore(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
}
func cutLast(s, sep string) (before, after string, ok bool) {
@@ -311,3 +309,57 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
}
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
}
type DigestType byte
const (
DigestTypeInvalid DigestType = iota
DigestTypeSHA256
)
func (t DigestType) String() string {
switch t {
case DigestTypeSHA256:
return "sha256"
default:
return "invalid"
}
}
type Digest struct {
Type DigestType
Sum [32]byte
}
func ParseDigest(s string) (Digest, error) {
i := strings.IndexAny(s, "-:")
if i < 0 {
return Digest{}, fmt.Errorf("invalid digest %q", s)
}
typ, encSum := s[:i], s[i+1:]
if typ != "sha256" {
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
}
d := Digest{
Type: DigestTypeSHA256,
}
n, err := hex.Decode(d.Sum[:], []byte(encSum))
if err != nil {
return Digest{}, err
}
if n != 32 {
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
}
return d, nil
}
func (d Digest) String() string {
if d.Type == DigestTypeInvalid {
return ""
}
return fmt.Sprintf("sha256-%x", d.Sum)
}
func (d Digest) IsValid() bool {
return d.Type != DigestTypeInvalid
}

View File

@@ -2,6 +2,7 @@ package model
import (
"reflect"
"runtime"
"testing"
)
@@ -101,6 +102,13 @@ func TestParseNameParts(t *testing.T) {
}
var testCases = map[string]bool{ // name -> valid
"": false,
"_why/_the/_lucky:_stiff": true,
// minimal
"h/n/m:t@d": true,
"host/namespace/model:tag": true,
"host/namespace/model": false,
"namespace/model": false,
@@ -116,11 +124,12 @@ var testCases = map[string]bool{ // name -> valid
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
"m": false, // model too short
"n/mm:": false, // namespace too short
"h/n/mm:t": false, // namespace too short
"@t": false, // digest too short
"mm@d": false, // digest too short
// unqualified
"m": false,
"n/m:": false,
"h/n/m": false,
"@t": false,
"m@d": false,
// invalids
"^": false,
@@ -140,8 +149,6 @@ var testCases = map[string]bool{ // name -> valid
"hh/nn/mm:-tt@dd": false,
"hh/nn/mm:tt@-dd": false,
"": false,
// hosts
"host:https/namespace/model:tag": true,
@@ -163,7 +170,6 @@ func TestNameIsValid(t *testing.T) {
var numStringTests int
for s, want := range testCases {
n := ParseNameBare(s)
t.Logf("n: %#v", n)
got := n.IsValid()
if got != want {
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
@@ -212,6 +218,54 @@ func TestNameIsValidPart(t *testing.T) {
}
func TestFilepathAllocs(t *testing.T) {
n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
allocs := testing.AllocsPerRun(1000, func() {
n.Filepath()
})
allowedAllocs := 2.0
if runtime.GOOS == "windows" {
allowedAllocs = 4
}
if allocs > allowedAllocs {
t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
}
}
const (
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""}, // empty
{"sha123-12", ""}, // invalid type
{"sha256-", ""}, // invalid sum
{"sha256-123", ""}, // invalid odd length sum
{validSha256, validSha256},
{validSha256Old, validSha256},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := ParseDigest(tt.in)
if err != nil {
if tt.want != "" {
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
}
return
}
if got.String() != tt.want {
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
}
})
}
}
func FuzzName(f *testing.F) {
for s := range testCases {
f.Add(s)

View File

@@ -1,15 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package structs contains the Incomparable type.
package structs
// Incomparable is a zero-width incomparable type. If added as the
// first field in a struct, it marks that struct as not comparable
// (can't do == or be a map key) and usually doesn't add any width to
// the struct (unless the struct has only small fields).
//
// By making a struct incomparable, you can prevent misuse (prevent
// people from using ==), but also you can shrink generated binaries,
// as the compiler can omit equality funcs from the binary.
type Incomparable [0]func()