mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
33 Commits
native
...
bmizerany/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
889c5b1a75 | ||
|
|
844217bcf1 | ||
|
|
8afe873f17 | ||
|
|
7ba71c3989 | ||
|
|
fdef9a0eb2 | ||
|
|
b9f74ff3d6 | ||
|
|
fcf4d60eee | ||
|
|
e33d5c2dbc | ||
|
|
18d9a7e1f1 | ||
|
|
8488388cbd | ||
|
|
588901f449 | ||
|
|
0a7fdbe533 | ||
|
|
5950c176ca | ||
|
|
23d23409a0 | ||
|
|
9009bedf13 | ||
|
|
d4ac57e240 | ||
|
|
7b59d1770f | ||
|
|
95ead8ffba | ||
|
|
7aa08a77ca | ||
|
|
7e432cdfac | ||
|
|
586672f490 | ||
|
|
b03408de74 | ||
|
|
1e6a28bf5b | ||
|
|
d6e3b64582 | ||
|
|
114c932a8e | ||
|
|
7f7103de06 | ||
|
|
c631a9c726 | ||
|
|
8fd9e56804 | ||
|
|
8a65717f55 | ||
|
|
6d3152a98a | ||
|
|
b438d485f1 | ||
|
|
204349b17b | ||
|
|
86e67fc4a9 |
@@ -51,7 +51,7 @@ Here are some example models that can be downloaded:
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
|
||||
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
ollamaHost, err := GetOllamaHost()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OllamaHost struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
}
|
||||
|
||||
func GetOllamaHost() (OllamaHost, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||
hostVar := os.Getenv("OLLAMA_HOST")
|
||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||
|
||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||
scheme, hostport = "http", hostVar
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
@@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(host, port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||
return OllamaHost{}, ErrInvalidHostPort
|
||||
}
|
||||
|
||||
return OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package api
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
type testCase struct {
|
||||
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
hostTestCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: ":1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||
"zero port": {value: ":0", expect: ":0"},
|
||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||
}
|
||||
|
||||
for k, v := range hostTestCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
oh, err := GetOllamaHost()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = errors.New("invalid options")
|
||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
|
||||
@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
|
||||
return command
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err := os.Stat(logDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||
cmd := getCmd(ctx, getCLIFullPath(command))
|
||||
// send stdout and stderr to a file
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// TODO - rotation
|
||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to create server log %w", err)
|
||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err = os.Stat(logDir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stdout) //nolint:errcheck
|
||||
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
return nil, fmt.Errorf("failed to start server %w", err)
|
||||
}
|
||||
if cmd.Process != nil {
|
||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
go func() {
|
||||
// Keep the server running unless we're shuttind down the app
|
||||
crashCount := 0
|
||||
for {
|
||||
slog.Info("starting server...")
|
||||
cmd, err := start(ctx, command)
|
||||
if err != nil {
|
||||
crashCount++
|
||||
slog.Error(fmt.Sprintf("failed to start server %s", err))
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Wait() //nolint:errcheck
|
||||
stdin.Close()
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
default:
|
||||
crashCount++
|
||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to restart server %s", err))
|
||||
// Keep trying, but back off if we keep failing
|
||||
time.Sleep(time.Duration(crashCount) * time.Second)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return done, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -88,8 +88,8 @@ DialogFontSize=12
|
||||
[Files]
|
||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||
#if DirExists("..\dist\windows-amd64\rocm")
|
||||
|
||||
36
auth/auth.go
36
auth/auth.go
@@ -10,12 +10,44 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
nonce := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
||||
95
client/registry/apitype/apitype.go
Normal file
95
client/registry/apitype/apitype.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package apitype
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Manifest struct {
|
||||
Layers []*Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type CompletePart struct {
|
||||
URL string `json:"url"` // contains partNumber and uploadId from server
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
func queryFromString(s string) url.Values {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return u.Query()
|
||||
}
|
||||
|
||||
func (cp *CompletePart) Compare(o *CompletePart) int {
|
||||
qa := queryFromString(cp.URL)
|
||||
qb := queryFromString(o.URL)
|
||||
return cmp.Or(
|
||||
cmp.Compare(qa.Get("partNumber"), qb.Get("partNumber")),
|
||||
cmp.Compare(qa.Get("uploadId"), qb.Get("uploadId")),
|
||||
cmp.Compare(cp.ETag, o.ETag),
|
||||
)
|
||||
}
|
||||
|
||||
func SortCompleteParts(a []*CompletePart) {
|
||||
slices.SortFunc(a, (*CompletePart).Compare)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// If present, URL is a remote location of the layer for fetching.
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
func (l *Layer) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("digest", l.Digest),
|
||||
slog.String("mediaType", l.MediaType),
|
||||
slog.Int64("size", l.Size),
|
||||
slog.String("url", l.URL),
|
||||
)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"ref"`
|
||||
Manifest json.RawMessage `json:"manifest,omitempty"`
|
||||
|
||||
// Parts is a list of upload parts that the client upload in the previous
|
||||
// push.
|
||||
CompleteParts []*CompletePart `json:"part_uploads"`
|
||||
}
|
||||
|
||||
type Need struct {
|
||||
Digest string `json:"digest"`
|
||||
|
||||
Start int64 `json:"start"`
|
||||
End int64 `json:"end"`
|
||||
|
||||
// URL is the url to PUT the layer to.
|
||||
//
|
||||
// Clients must include it as the URL, along with the ETag in the
|
||||
// response headers from the PUT request, in the next push request
|
||||
// in the Uploaded field.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PushResponse struct {
|
||||
// Needs is a list of digests that the client needs to push before
|
||||
// repushing the manifest.
|
||||
Needs []*Need `json:"requirements,omitempty"`
|
||||
}
|
||||
|
||||
type PullResponse struct {
|
||||
// Name is the name of the model being pulled.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Manifest is the manifest of the model being pulled.
|
||||
Manifest *Manifest `json:"manifest"`
|
||||
}
|
||||
421
client/registry/registry.go
Normal file
421
client/registry/registry.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/client/ollama"
|
||||
"github.com/ollama/ollama/client/registry/apitype"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrLayerNotFound = errors.New("layer not found")
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
|
||||
Logger *slog.Logger
|
||||
|
||||
// NameFill is a string that is used to fill in the missing parts of
|
||||
// a name when it is not fully qualified. It is used to make a name
|
||||
// fully qualified before pushing or pulling it. The default is
|
||||
// "registry.ollama.ai/library/_:latest".
|
||||
//
|
||||
// Most users can ignore this field. It is intended for use by
|
||||
// clients that need to push or pull names to registries other than
|
||||
// registry.ollama.ai, and for testing.
|
||||
NameFill string
|
||||
}
|
||||
|
||||
func (c *Client) log() *slog.Logger {
|
||||
return cmp.Or(c.Logger, slog.Default())
|
||||
}
|
||||
|
||||
func (c *Client) oclient() *ollama.Client {
|
||||
return &ollama.Client{
|
||||
BaseURL: c.BaseURL,
|
||||
}
|
||||
}
|
||||
|
||||
type ReadAtSeekCloser interface {
|
||||
io.ReaderAt
|
||||
io.Seeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
// LayerFile returns the absolute file path to the layer file for
|
||||
// the given model digest.
|
||||
//
|
||||
// If the digest is invalid, or the layer does not exist, the empty
|
||||
// string is returned.
|
||||
LayerFile(model.Digest) string
|
||||
|
||||
// OpenLayer opens the layer file for the given model digest and
|
||||
// returns it, or an if any. The caller is responsible for closing
|
||||
// the returned file.
|
||||
OpenLayer(model.Digest) (ReadAtSeekCloser, error)
|
||||
|
||||
// PutLayerFile moves the layer file at fromPath to the cache for
|
||||
// the given model digest. It is a hack intended to short circuit a
|
||||
// file copy operation.
|
||||
//
|
||||
// The file returned is expected to exist for the lifetime of the
|
||||
// cache.
|
||||
//
|
||||
// TODO(bmizerany): remove this; find a better way. Once we move
|
||||
// this into a build package, we should be able to get rid of this.
|
||||
PutLayerFile(_ model.Digest, fromPath string) error
|
||||
|
||||
// SetManifestData sets the provided manifest data for the given
|
||||
// model name. If the manifest data is empty, the manifest is
|
||||
// removed. If the manifeest exists, it is overwritten.
|
||||
//
|
||||
// It is an error to call SetManifestData with a name that is not
|
||||
// complete.
|
||||
SetManifestData(model.Name, []byte) error
|
||||
|
||||
// ManifestData returns the manifest data for the given model name.
|
||||
//
|
||||
// If the name incomplete, or the manifest does not exist, the empty
|
||||
// string is returned.
|
||||
ManifestData(name model.Name) []byte
|
||||
}
|
||||
|
||||
// Pull pulls the manifest for name, and downloads any of its required
|
||||
// layers that are not already in the cache. It returns an error if any part
|
||||
// of the process fails, specifically:
|
||||
func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
|
||||
mn := parseNameFill(name, c.NameFill)
|
||||
if !mn.IsFullyQualified() {
|
||||
return fmt.Errorf("ollama: pull: invalid name: %s", name)
|
||||
}
|
||||
|
||||
log := c.log().With("name", name)
|
||||
|
||||
pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: pull: %w: %s", err, name)
|
||||
}
|
||||
|
||||
if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
|
||||
return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
|
||||
}
|
||||
|
||||
// download required layers we do not already have
|
||||
for _, l := range pr.Manifest.Layers {
|
||||
d, err := model.ParseDigest(l.Digest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
|
||||
}
|
||||
if cache.LayerFile(d) != "" {
|
||||
continue
|
||||
}
|
||||
err = func() error {
|
||||
log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
|
||||
log.Debug("starting download")
|
||||
|
||||
// TODO(bmizerany): stop using temp which might not
|
||||
// be on same device as cache.... instead let cache
|
||||
// give us a place to store parts...
|
||||
tmpFile, err := os.CreateTemp("", "ollama-download-")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpFile.Name()) // in case we fail before committing
|
||||
}()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(8) // TODO(bmizerany): make this configurable
|
||||
|
||||
// TODO(bmizerany): make chunk size configurable
|
||||
const chunkSize = 50 * 1024 * 1024 // 50MB
|
||||
chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
safeURL := redactAmzSignature(l.URL)
|
||||
err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
|
||||
}()
|
||||
|
||||
log.Debug("downloading", "range", rng)
|
||||
|
||||
// TODO(bmizerany): retry
|
||||
// TODO(bmizerany): use real http client
|
||||
// TODO(bmizerany): resumable
|
||||
// TODO(bmizerany): multipart download
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", "bytes="+rng.String())
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode/100 != 2 {
|
||||
log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
|
||||
return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
|
||||
}
|
||||
if res.ContentLength != rng.Size() {
|
||||
return fmt.Errorf("unexpected content length: %d", res.ContentLength)
|
||||
}
|
||||
w := io.NewOffsetWriter(tmpFile, rng.Start)
|
||||
_, err = io.Copy(w, res.Body)
|
||||
return err
|
||||
})
|
||||
return true
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpFile.Close() // release our hold on the file before moving it
|
||||
return cache.PutLayerFile(d, tmpFile.Name())
|
||||
}()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama: pull: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// do not store the presigned URLs in the cache
|
||||
for i := range pr.Manifest.Layers {
|
||||
pr.Manifest.Layers[i].URL = ""
|
||||
}
|
||||
data, err := json.Marshal(pr.Manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): remove dep on model.Name
|
||||
return cache.SetManifestData(mn, data)
|
||||
}
|
||||
|
||||
type nopSeeker struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (nopSeeker) Seek(int64, int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func parseNameFill(name, fill string) model.Name {
|
||||
fill = cmp.Or(fill, "bllamo.com/library/_:latest")
|
||||
f := model.ParseNameBare(fill)
|
||||
if !f.IsFullyQualified() {
|
||||
panic(fmt.Errorf("invalid fill: %q", fill))
|
||||
}
|
||||
return model.Merge(model.ParseNameBare(name), f)
|
||||
}
|
||||
|
||||
// Push pushes a manifest to the server and responds to the server's
|
||||
// requests for layer uploads, if any, and finally commits the manifest for
|
||||
// name. It returns an error if any part of the process fails, specifically:
|
||||
//
|
||||
// If the server requests layers not found in the cache, ErrLayerNotFound is
|
||||
// returned.
|
||||
func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
|
||||
mn := parseNameFill(name, c.NameFill)
|
||||
if !mn.IsFullyQualified() {
|
||||
return fmt.Errorf("ollama: push: invalid name: %s", name)
|
||||
}
|
||||
manifest := cache.ManifestData(mn)
|
||||
if len(manifest) == 0 {
|
||||
return fmt.Errorf("manifest not found: %s", name)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var completed []*apitype.CompletePart
|
||||
push := func() (*apitype.PushResponse, error) {
|
||||
v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
||||
Name: name,
|
||||
Manifest: manifest,
|
||||
CompleteParts: completed,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Do: %w", err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
pr, err := push()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
for _, need := range pr.Needs {
|
||||
g.Go(func() error {
|
||||
nd, err := model.ParseDigest(need.Digest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
|
||||
}
|
||||
f, err := cache.OpenLayer(nd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End)
|
||||
cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
|
||||
}
|
||||
mu.Lock()
|
||||
completed = append(completed, cp)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return fmt.Errorf("Push: Required: %w", err)
|
||||
}
|
||||
|
||||
if len(completed) > 0 {
|
||||
pr, err := push()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(pr.Needs) > 0 {
|
||||
var errs []error
|
||||
for _, r := range pr.Needs {
|
||||
errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
}
|
||||
|
||||
return cache.SetManifestData(mn, manifest)
|
||||
}
|
||||
|
||||
func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
|
||||
if start < 0 || end < start {
|
||||
return nil, errors.New("start must satisfy 0 <= start <= end")
|
||||
}
|
||||
|
||||
file := io.NewSectionReader(body, start, end-start+1)
|
||||
req, err := http.NewRequest("PUT", url, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.ContentLength = end - start + 1
|
||||
|
||||
// TODO(bmizerany): take content type param
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
if start != 0 || end != 0 {
|
||||
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
e := parseS3Error(res)
|
||||
return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
|
||||
}
|
||||
cp := &apitype.CompletePart{
|
||||
URL: url,
|
||||
ETag: res.Header.Get("ETag"),
|
||||
// TODO(bmizerany): checksum
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
type s3Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Resource string `xml:"Resource"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
func (e *s3Error) Error() string {
|
||||
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// parseS3Error parses an XML error response from S3.
|
||||
func parseS3Error(res *http.Response) error {
|
||||
var se *s3Error
|
||||
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
||||
|
||||
// TODO: replace below by using upload pkg after we have rangefunc; until
|
||||
// then, we need to keep this free of rangefunc for now.
|
||||
type chunkRange[I constraints.Integer] struct {
|
||||
// Start is the byte offset of the chunk.
|
||||
Start I
|
||||
|
||||
// End is the byte offset of the last byte in the chunk.
|
||||
End I
|
||||
}
|
||||
|
||||
func (c chunkRange[I]) Size() I {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
func (c chunkRange[I]) String() string {
|
||||
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||
}
|
||||
|
||||
func (c chunkRange[I]) LogValue() slog.Value {
|
||||
return slog.StringValue(c.String())
|
||||
}
|
||||
|
||||
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
||||
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
||||
// not a multiple of chunkSize.
|
||||
//
|
||||
// The first part number is 1 and increases monotonically.
|
||||
func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
|
||||
return func(yield func(int, chunkRange[I]) bool) {
|
||||
var n int
|
||||
for off := I(0); off < size; off += chunkSize {
|
||||
n++
|
||||
if !yield(n, chunkRange[I]{
|
||||
Start: off,
|
||||
End: off + min(chunkSize, size-off) - 1,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func redactAmzSignature(s string) string {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("X-Amz-Signature", "REDACTED")
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
68
cmd/cmd.go
68
cmd/cmd.go
@@ -32,10 +32,13 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -357,6 +360,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -404,6 +448,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
host := model.ParseName(args[0]).Host
|
||||
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// re-throw an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -831,19 +889,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
|
||||
// retrieve the OLLAMA_HOST environment variable
|
||||
ollamaHost, err := api.GetOllamaHost()
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", "11434"
|
||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializeKeypair(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -17,10 +17,12 @@ Let's start by asking a simple question that we can get an answer to from the **
|
||||
Then we can create a model and ask the question:
|
||||
|
||||
```python
|
||||
from langchain.llms import Ollama
|
||||
ollama = Ollama(base_url='http://localhost:11434',
|
||||
model="llama2")
|
||||
print(ollama("why is the sky blue"))
|
||||
from langchain_community.llms import Ollama
|
||||
ollama = Ollama(
|
||||
base_url='http://localhost:11434',
|
||||
model="llama3"
|
||||
)
|
||||
print(ollama.invoke("why is the sky blue"))
|
||||
```
|
||||
|
||||
Notice that we are defining the model and the base URL for Ollama.
|
||||
|
||||
@@ -40,7 +40,7 @@ func PayloadsDir() (string, error) {
|
||||
}
|
||||
|
||||
var paths []string
|
||||
for _, root := range []string{appExe, cwd} {
|
||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||
paths = append(paths,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
|
||||
15
llm/ext_server/server.cpp
vendored
15
llm/ext_server/server.cpp
vendored
@@ -1032,7 +1032,7 @@ struct llama_server_context
|
||||
slot.has_next_token = false;
|
||||
}
|
||||
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
|
||||
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
|
||||
{
|
||||
slot.stopped_eos = true;
|
||||
slot.has_next_token = false;
|
||||
@@ -1144,12 +1144,15 @@ struct llama_server_context
|
||||
|
||||
res.result_json = json
|
||||
{
|
||||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"slot_id", slot.id},
|
||||
{"multimodal", multimodal}
|
||||
};
|
||||
|
||||
if (!llama_token_is_eog(model, tkn.tok)) {
|
||||
res.result_json["content"] = tkn.text_to_send;
|
||||
}
|
||||
|
||||
if (slot.sparams.n_probs > 0)
|
||||
{
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
@@ -2644,18 +2647,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
if (strncmp(sep, "int:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||
kvo.int_value = std::atol(sep);
|
||||
kvo.val_i64 = std::atol(sep);
|
||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||
sep += 6;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||
kvo.float_value = std::atof(sep);
|
||||
kvo.val_f64 = std::atof(sep);
|
||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||
sep += 5;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||
if (std::strcmp(sep, "true") == 0) {
|
||||
kvo.bool_value = true;
|
||||
kvo.val_bool = true;
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.bool_value = false;
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
||||
invalid_param = true;
|
||||
|
||||
@@ -42,7 +42,7 @@ function init_vars {
|
||||
"-DLLAMA_NATIVE=off"
|
||||
)
|
||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
||||
md "$script:DIST_BASE" -ea 0 > $null
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
@@ -213,11 +213,11 @@ function build_static() {
|
||||
}
|
||||
}
|
||||
|
||||
function build_cpu() {
|
||||
function build_cpu($gen_arch) {
|
||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||
$script:distDir="$script:DIST_BASE\cpu"
|
||||
write-host "Building LCD CPU"
|
||||
@@ -349,11 +349,15 @@ if ($($args.count) -eq 0) {
|
||||
git_module_setup
|
||||
apply_patches
|
||||
build_static
|
||||
build_cpu
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
if ($script:ARCH -eq "arm64") {
|
||||
build_cpu("ARM64")
|
||||
} else { # amd64
|
||||
build_cpu("x64")
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
}
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
|
||||
|
||||
Submodule llm/llama.cpp updated: 46e12c4692...952d03dbea
@@ -4,6 +4,7 @@ package llm
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
|
||||
@@ -73,8 +73,7 @@ func LoadModel(model string) (*GGML, error) {
|
||||
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||
var err error
|
||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
|
||||
@@ -19,7 +19,7 @@ export default function () {
|
||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
||||
|
||||
const command = 'ollama run llama2'
|
||||
const command = 'ollama run llama3'
|
||||
|
||||
return (
|
||||
<div className='drag'>
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function checkEnv() {
|
||||
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
Write-host "Building for ${script:TARGET_ARCH}"
|
||||
write-host "Locating required tools and paths"
|
||||
$script:SRC_DIR=$PWD
|
||||
if (!$env:VCToolsRedistDir) {
|
||||
@@ -30,7 +32,7 @@ function checkEnv() {
|
||||
|
||||
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
|
||||
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
echo "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
@@ -81,8 +83,8 @@ function buildOllama() {
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
|
||||
cp .\ollama.exe .\dist\windows-amd64\
|
||||
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
|
||||
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
|
||||
}
|
||||
|
||||
function buildApp() {
|
||||
@@ -127,16 +129,16 @@ function buildInstaller() {
|
||||
cd "${script:SRC_DIR}\app"
|
||||
$env:PKG_VERSION=$script:PKG_VERSION
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
} else {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
|
||||
}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function distZip() {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
75
server/cache.go
Normal file
75
server/cache.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/client/registry"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// cache is a simple demo disk cache. it does not validate anything
|
||||
type cache struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
func defaultCache() registry.Cache {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
if homeDir == "" {
|
||||
panic("could not determine home directory")
|
||||
}
|
||||
modelsDir := cmp.Or(
|
||||
os.Getenv("OLLAMA_MODELS"),
|
||||
filepath.Join(homeDir, ".ollama", "models"),
|
||||
)
|
||||
return &cache{modelsDir}
|
||||
}
|
||||
|
||||
func invalidDigest(digest string) error {
|
||||
return fmt.Errorf("invalid digest: %s", digest)
|
||||
}
|
||||
|
||||
func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) {
|
||||
return os.Open(c.LayerFile(d))
|
||||
}
|
||||
|
||||
func (c *cache) LayerFile(d model.Digest) string {
|
||||
return filepath.Join(c.dir, "blobs", d.String())
|
||||
}
|
||||
|
||||
func (c *cache) PutLayerFile(d model.Digest, fromPath string) error {
|
||||
if !d.IsValid() {
|
||||
return invalidDigest(d.String())
|
||||
}
|
||||
bfile := c.LayerFile(d)
|
||||
dir, _ := filepath.Split(bfile)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(fromPath, bfile)
|
||||
}
|
||||
|
||||
func (c *cache) ManifestData(name model.Name) []byte {
|
||||
if !name.IsFullyQualified() {
|
||||
return nil
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath()))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (c *cache) SetManifestData(name model.Name, data []byte) error {
|
||||
if !name.IsFullyQualified() {
|
||||
return fmt.Errorf("invalid name: %s", name)
|
||||
}
|
||||
filep := filepath.Join(c.dir, "manifests", name.Filepath())
|
||||
dir, _ := filepath.Split(filep)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filep, data, 0644)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -25,10 +26,12 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -710,6 +713,10 @@ func CopyModel(src, dst model.Name) error {
|
||||
return model.Unqualified(src)
|
||||
}
|
||||
|
||||
if src.Filepath() == dst.Filepath() {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -976,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
if errors.Is(err, errUnauthorized) {
|
||||
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1141,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized")
|
||||
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
||||
|
||||
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
||||
func getTokenSubject(token string) string {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
slog.Error("jwt token does not contain 3 parts")
|
||||
return ""
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
var payloadMap map[string]interface{}
|
||||
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
sub, ok := payloadMap["sub"]
|
||||
if !ok {
|
||||
slog.Error("jwt does not contain 'sub' field")
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s", sub)
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
@@ -1162,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
anonymous = getTokenSubject(token) == "anonymous"
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
_, err = body.Seek(0, io.SeekStart)
|
||||
@@ -1182,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
}
|
||||
}
|
||||
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +27,8 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/client/ollama"
|
||||
"github.com/ollama/ollama/client/registry"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -33,6 +37,23 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// envs
|
||||
var (
|
||||
envRegistryBaseURL = cmp.Or(os.Getenv("OLLAMA_REGISTRY_BASE_URL"), "https://bllamo.com")
|
||||
)
|
||||
|
||||
func init() {
|
||||
ollama.I_Acknowledge_This_API_Is_Unstable = true
|
||||
}
|
||||
|
||||
var experiments = sync.OnceValue(func() []string {
|
||||
return strings.Split(strings.ToLower(os.Getenv("OLLAMA_EXPERIMENT")), ",")
|
||||
})
|
||||
|
||||
func useExperiment(flag string) bool {
|
||||
return slices.Contains(experiments(), flag)
|
||||
}
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
@@ -444,6 +465,25 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if useExperiment("pull") {
|
||||
rc := ®istry.Client{
|
||||
BaseURL: envRegistryBaseURL,
|
||||
}
|
||||
modelsDir, err := modelsDir()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
cache := &cache{dir: modelsDir}
|
||||
println("DIR: ", modelsDir)
|
||||
// TODO(bmizerany): progress updates
|
||||
if err := rc.Pull(c.Request.Context(), cache, model); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
@@ -149,6 +149,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
// If we're CPU only mode, just limit by loadedMax above
|
||||
// TODO handle system memory exhaustion
|
||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
|
||||
slog.Debug("cpu mode with existing models, loading")
|
||||
s.loadFn(pending, ggml, gpus)
|
||||
break
|
||||
}
|
||||
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
|
||||
@@ -28,19 +28,33 @@ func TestInitScheduler(t *testing.T) {
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
defer done()
|
||||
initialMax := loadedMax
|
||||
initialParallel := numParallel
|
||||
s := InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, 0, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_NUM_PARALLEL", "blue")
|
||||
_ = InitScheduler(ctx)
|
||||
require.Equal(t, initialParallel, numParallel)
|
||||
os.Setenv("OLLAMA_NUM_PARALLEL", "10")
|
||||
_ = InitScheduler(ctx)
|
||||
require.Equal(t, 10, numParallel)
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
@@ -51,6 +65,7 @@ func TestLoad(t *testing.T) {
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
model: &Model{ModelPath: "foo"},
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
sessionDuration: 2,
|
||||
@@ -63,7 +78,9 @@ func TestLoad(t *testing.T) {
|
||||
s.load(req, ggml, gpus)
|
||||
require.Len(t, req.successCh, 0)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
err := <-req.errCh
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
@@ -78,7 +95,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
require.Equal(t, uint64(10), resp.estimatedVRAM)
|
||||
require.Equal(t, uint(1), resp.refCount)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
@@ -90,7 +109,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
t.Errorf("unexpected success %v", resp)
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded["dummy_model_path"]
|
||||
s.loadedMu.Unlock()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, uint(0), runner.refCount)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
@@ -143,6 +164,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||
scenario.req = &LlmRequest{
|
||||
ctx: scenario.ctx,
|
||||
model: model,
|
||||
opts: api.DefaultOptions(),
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
@@ -171,7 +193,9 @@ func TestRequests(t *testing.T) {
|
||||
// Multiple loaded models
|
||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
@@ -240,7 +264,9 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
loadedMax = 0
|
||||
s.newServerFn = scenario3b.newServer
|
||||
@@ -254,19 +280,14 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
// This is a CPU load with NumGPU = 0 so it should load
|
||||
s.newServerFn = scenario3c.newServer
|
||||
slog.Info("scenario3c")
|
||||
require.Len(t, s.loaded, 2)
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3c.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
require.Len(t, s.loaded, 1)
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3c.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3c.srv)
|
||||
@@ -275,7 +296,36 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
s.newServerFn = scenario3d.newServer
|
||||
slog.Info("scenario3d")
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3d.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3d.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3d.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, scenario3d.req.errCh, 0)
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestGetRunner(t *testing.T) {
|
||||
@@ -318,7 +368,9 @@ func TestGetRunner(t *testing.T) {
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
scenario1a.ctxDone()
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
scenario1c.req.model.ModelPath = "bad path"
|
||||
slog.Info("scenario1c")
|
||||
@@ -328,7 +380,9 @@ func TestGetRunner(t *testing.T) {
|
||||
require.Len(t, errCh1c, 0)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
require.Len(t, errCh1c, 1)
|
||||
err = <-errCh1c
|
||||
require.Contains(t, err.Error(), "bad path")
|
||||
@@ -358,7 +412,9 @@ func TestPrematureExpired(t *testing.T) {
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, errCh1a, 0)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Info("sending premature expired event now")
|
||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||
case <-ctx.Done():
|
||||
@@ -383,6 +439,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
sessionDuration: 2,
|
||||
}
|
||||
@@ -426,8 +483,10 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpus}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
s.updateFreeSpace(gpus)
|
||||
require.Equal(t, uint64(850), gpus[0].FreeMemory)
|
||||
@@ -437,13 +496,18 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer done()
|
||||
req := &LlmRequest{ctx: ctx}
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
}
|
||||
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
|
||||
r2 := &runnerRef{sessionDuration: 2}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
resp := s.findRunnerToUnload(req)
|
||||
require.Equal(t, r2, resp)
|
||||
@@ -458,10 +522,11 @@ func TestNeedsReload(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{}
|
||||
do := api.DefaultOptions()
|
||||
runner := &runnerRef{
|
||||
adapters: []string{"adapter1"},
|
||||
projectors: []string{"projector1"},
|
||||
Options: &api.Options{},
|
||||
Options: &do,
|
||||
llama: llm,
|
||||
}
|
||||
req := &LlmRequest{
|
||||
@@ -469,7 +534,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
AdapterPaths: []string{"adapter2"},
|
||||
ProjectorPaths: []string{"projector2"},
|
||||
},
|
||||
opts: api.Options{},
|
||||
opts: api.DefaultOptions(),
|
||||
}
|
||||
resp := runner.needsReload(ctx, req)
|
||||
require.True(t, resp)
|
||||
@@ -508,8 +573,10 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
r1 := &runnerRef{llama: llm1}
|
||||
r2 := &runnerRef{llama: llm2}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
s.unloadAllRunners()
|
||||
|
||||
require.True(t, llm1.closeCalled)
|
||||
|
||||
18
types/errtypes/errtypes.go
Normal file
18
types/errtypes/errtypes.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -80,9 +81,6 @@ func (k partKind) String() string {
|
||||
//
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
//
|
||||
// It is not directly comparable with other Names. Use [Name.Equal] and
|
||||
// [Name.MapHash] for determining equality and using as a map key.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
@@ -109,20 +107,20 @@ type Name struct {
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: alphanum { alphanum | "-" | "_" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// model:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: alphanum { alphanum | "-" | ":" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// Most users should use [ParseName] instead, unless need to support
|
||||
// different defaults than DefaultName.
|
||||
@@ -234,12 +232,12 @@ func (n Name) Filepath() string {
|
||||
if !n.IsFullyQualified() {
|
||||
panic("illegal attempt to get filepath of invalid name")
|
||||
}
|
||||
return filepath.Join(
|
||||
strings.ToLower(n.Host),
|
||||
strings.ToLower(n.Namespace),
|
||||
strings.ToLower(n.Model),
|
||||
strings.ToLower(n.Tag),
|
||||
)
|
||||
return strings.ToLower(filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
))
|
||||
}
|
||||
|
||||
// LogValue returns a slog.Value that represents the name as a string.
|
||||
@@ -254,7 +252,7 @@ func isValidLen(kind partKind, s string) bool {
|
||||
case kindTag:
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
default:
|
||||
return len(s) >= 2 && len(s) <= 80
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +262,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
}
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
@@ -280,7 +278,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -288,8 +286,8 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlphanumeric(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func cutLast(s, sep string) (before, after string, ok bool) {
|
||||
@@ -311,3 +309,57 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
|
||||
}
|
||||
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
||||
}
|
||||
|
||||
type DigestType byte
|
||||
|
||||
const (
|
||||
DigestTypeInvalid DigestType = iota
|
||||
DigestTypeSHA256
|
||||
)
|
||||
|
||||
func (t DigestType) String() string {
|
||||
switch t {
|
||||
case DigestTypeSHA256:
|
||||
return "sha256"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
type Digest struct {
|
||||
Type DigestType
|
||||
Sum [32]byte
|
||||
}
|
||||
|
||||
func ParseDigest(s string) (Digest, error) {
|
||||
i := strings.IndexAny(s, "-:")
|
||||
if i < 0 {
|
||||
return Digest{}, fmt.Errorf("invalid digest %q", s)
|
||||
}
|
||||
typ, encSum := s[:i], s[i+1:]
|
||||
if typ != "sha256" {
|
||||
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
|
||||
}
|
||||
d := Digest{
|
||||
Type: DigestTypeSHA256,
|
||||
}
|
||||
n, err := hex.Decode(d.Sum[:], []byte(encSum))
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
if n != 32 {
|
||||
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d Digest) String() string {
|
||||
if d.Type == DigestTypeInvalid {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", d.Sum)
|
||||
}
|
||||
|
||||
func (d Digest) IsValid() bool {
|
||||
return d.Type != DigestTypeInvalid
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -101,6 +102,13 @@ func TestParseNameParts(t *testing.T) {
|
||||
}
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t@d": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": false,
|
||||
"namespace/model": false,
|
||||
@@ -116,11 +124,12 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
|
||||
"m": false, // model too short
|
||||
"n/mm:": false, // namespace too short
|
||||
"h/n/mm:t": false, // namespace too short
|
||||
"@t": false, // digest too short
|
||||
"mm@d": false, // digest too short
|
||||
// unqualified
|
||||
"m": false,
|
||||
"n/m:": false,
|
||||
"h/n/m": false,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
@@ -140,8 +149,6 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"hh/nn/mm:-tt@dd": false,
|
||||
"hh/nn/mm:tt@-dd": false,
|
||||
|
||||
"": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
@@ -163,7 +170,6 @@ func TestNameIsValid(t *testing.T) {
|
||||
var numStringTests int
|
||||
for s, want := range testCases {
|
||||
n := ParseNameBare(s)
|
||||
t.Logf("n: %#v", n)
|
||||
got := n.IsValid()
|
||||
if got != want {
|
||||
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
|
||||
@@ -212,6 +218,54 @@ func TestNameIsValidPart(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestFilepathAllocs(t *testing.T) {
|
||||
n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
n.Filepath()
|
||||
})
|
||||
allowedAllocs := 2.0
|
||||
if runtime.GOOS == "windows" {
|
||||
allowedAllocs = 4
|
||||
}
|
||||
if allocs > allowedAllocs {
|
||||
t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
|
||||
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
|
||||
)
|
||||
|
||||
func TestParseDigest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""}, // empty
|
||||
{"sha123-12", ""}, // invalid type
|
||||
{"sha256-", ""}, // invalid sum
|
||||
{"sha256-123", ""}, // invalid odd length sum
|
||||
|
||||
{validSha256, validSha256},
|
||||
{validSha256Old, validSha256},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got, err := ParseDigest(tt.in)
|
||||
if err != nil {
|
||||
if tt.want != "" {
|
||||
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got.String() != tt.want {
|
||||
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzName(f *testing.F) {
|
||||
for s := range testCases {
|
||||
f.Add(s)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
||||
Reference in New Issue
Block a user