mirror of
https://github.com/ollama/ollama.git
synced 2025-12-24 08:10:54 -05:00
Compare commits
68 Commits
native
...
mxyng/spli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc474f9b83 | ||
|
|
41ae232e10 | ||
|
|
122b35c784 | ||
|
|
3244a25c79 | ||
|
|
b535afe35c | ||
|
|
fd071eab8b | ||
|
|
da0bb5d772 | ||
|
|
1909e624ce | ||
|
|
1d8c850f38 | ||
|
|
e9ae607ece | ||
|
|
93707fa3f2 | ||
|
|
94c369095f | ||
|
|
9164b0161b | ||
|
|
bf4fc25f7b | ||
|
|
5b806d8d24 | ||
|
|
cb1e072643 | ||
|
|
45b6a12e45 | ||
|
|
68755f1f5e | ||
|
|
997a455039 | ||
|
|
88775e1ff9 | ||
|
|
8867e744ff | ||
|
|
4fd064bea6 | ||
|
|
59fbceedcc | ||
|
|
321d57e1a0 | ||
|
|
ba26c7aa00 | ||
|
|
63c763685f | ||
|
|
34a4a94f13 | ||
|
|
f4a73d57a4 | ||
|
|
948114e3e3 | ||
|
|
a3e60d9058 | ||
|
|
5ea844964e | ||
|
|
bd8eed57fc | ||
|
|
9cf0f2e973 | ||
|
|
176ad3aa6e | ||
|
|
4d08363580 | ||
|
|
8907bf51d2 | ||
|
|
abe614c705 | ||
|
|
238715037d | ||
|
|
c0a00f68ae | ||
|
|
f0c454ab57 | ||
|
|
b9f74ff3d6 | ||
|
|
fcf4d60eee | ||
|
|
e33d5c2dbc | ||
|
|
18d9a7e1f1 | ||
|
|
8488388cbd | ||
|
|
588901f449 | ||
|
|
0a7fdbe533 | ||
|
|
5950c176ca | ||
|
|
23d23409a0 | ||
|
|
9009bedf13 | ||
|
|
d4ac57e240 | ||
|
|
7b59d1770f | ||
|
|
95ead8ffba | ||
|
|
7aa08a77ca | ||
|
|
7e432cdfac | ||
|
|
586672f490 | ||
|
|
b03408de74 | ||
|
|
1e6a28bf5b | ||
|
|
d6e3b64582 | ||
|
|
114c932a8e | ||
|
|
7f7103de06 | ||
|
|
c631a9c726 | ||
|
|
8fd9e56804 | ||
|
|
8a65717f55 | ||
|
|
6d3152a98a | ||
|
|
b438d485f1 | ||
|
|
204349b17b | ||
|
|
86e67fc4a9 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ ggml-metal.metal
|
||||
.idea
|
||||
test_data
|
||||
*.crt
|
||||
llm/build
|
||||
llm/build
|
||||
__debug_bin*
|
||||
10
README.md
10
README.md
@@ -1,5 +1,5 @@
|
||||
<div align="center">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</div>
|
||||
|
||||
# Ollama
|
||||
@@ -51,7 +51,7 @@ Here are some example models that can be downloaded:
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
|
||||
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
@@ -173,7 +173,7 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||
```
|
||||
|
||||
### Pass in prompt as arguments
|
||||
### Pass the prompt as an argument
|
||||
|
||||
```
|
||||
$ ollama run llama3 "Summarize this file: $(cat README.md)"
|
||||
@@ -294,7 +294,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
- [chat: chat web app for teams](https://github.com/swuecho/chat)
|
||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
||||
- [Ollama RAG Chatbot: Local Chat with multiples PDFs using Ollama and RAG.](https://github.com/datvodinh/rag-chatbot.git)
|
||||
- [Ollama RAG Chatbot: Local Chat with multiple PDFs using Ollama and RAG.](https://github.com/datvodinh/rag-chatbot.git)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -384,4 +384,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
|
||||
### Supported backends
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
ollamaHost, err := GetOllamaHost()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OllamaHost struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
}
|
||||
|
||||
func GetOllamaHost() (OllamaHost, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||
hostVar := os.Getenv("OLLAMA_HOST")
|
||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||
|
||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||
scheme, hostport = "http", hostVar
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
@@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(host, port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||
return OllamaHost{}, ErrInvalidHostPort
|
||||
}
|
||||
|
||||
return OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package api
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
type testCase struct {
|
||||
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
hostTestCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: ":1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||
"zero port": {value: ":0", expect: ":0"},
|
||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||
}
|
||||
|
||||
for k, v := range hostTestCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
oh, err := GetOllamaHost()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = errors.New("invalid options")
|
||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
|
||||
@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
|
||||
return command
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err := os.Stat(logDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||
cmd := getCmd(ctx, getCLIFullPath(command))
|
||||
// send stdout and stderr to a file
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// TODO - rotation
|
||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to create server log %w", err)
|
||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err = os.Stat(logDir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stdout) //nolint:errcheck
|
||||
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
return nil, fmt.Errorf("failed to start server %w", err)
|
||||
}
|
||||
if cmd.Process != nil {
|
||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
go func() {
|
||||
// Keep the server running unless we're shuttind down the app
|
||||
crashCount := 0
|
||||
for {
|
||||
slog.Info("starting server...")
|
||||
cmd, err := start(ctx, command)
|
||||
if err != nil {
|
||||
crashCount++
|
||||
slog.Error(fmt.Sprintf("failed to start server %s", err))
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Wait() //nolint:errcheck
|
||||
stdin.Close()
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
default:
|
||||
crashCount++
|
||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to restart server %s", err))
|
||||
// Keep trying, but back off if we keep failing
|
||||
time.Sleep(time.Duration(crashCount) * time.Second)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return done, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -88,8 +88,8 @@ DialogFontSize=12
|
||||
[Files]
|
||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||
#if DirExists("..\dist\windows-amd64\rocm")
|
||||
|
||||
36
auth/auth.go
36
auth/auth.go
@@ -10,12 +10,44 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
nonce := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
||||
93
cmd/cmd.go
93
cmd/cmd.go
@@ -32,10 +32,13 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -54,12 +57,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
modelfile, err := os.ReadFile(filename)
|
||||
modelfile, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer modelfile.Close()
|
||||
|
||||
commands, err := parser.Parse(bytes.NewReader(modelfile))
|
||||
commands, err := parser.Parse(modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -73,10 +77,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
for _, c := range commands {
|
||||
switch c.Name {
|
||||
for i := range commands {
|
||||
switch commands[i].Name {
|
||||
case "model", "adapter":
|
||||
path := c.Args
|
||||
path := commands[i].Args
|
||||
if path == "~" {
|
||||
path = home
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
@@ -88,7 +92,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
|
||||
if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -111,13 +115,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
name := c.Name
|
||||
if c.Name == "model" {
|
||||
name = "from"
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
|
||||
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
|
||||
commands[i].Args = "@"+digest
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +145,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -357,6 +355,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -404,6 +443,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
host := model.ParseName(args[0]).Host
|
||||
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// re-throw an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -831,19 +884,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
|
||||
// retrieve the OLLAMA_HOST environment variable
|
||||
ollamaHost, err := api.GetOllamaHost()
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", "11434"
|
||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializeKeypair(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1069,7 +1120,7 @@ Environment Variables:
|
||||
RunE: ListHandler,
|
||||
}
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE TARGET",
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
|
||||
@@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
@@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Println("Cleared session context")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -47,7 +48,7 @@ type ByteOrder interface {
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
WriteGGUF(io.WriteSeeker) error
|
||||
}
|
||||
|
||||
type ModelFormat interface {
|
||||
|
||||
@@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "gemma",
|
||||
"general.name": m.Name,
|
||||
@@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (m *LlamaModel) LoadVocab() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||
func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
@@ -161,16 +161,9 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
|
||||
|
||||
return f.Name(), nil
|
||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(f, kv, m.Tensors)
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
@@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"os"
|
||||
"io"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
@@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MixtralModel) WriteGGUF() (string, error) {
|
||||
func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
@@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) {
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
@@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD
|
||||
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
|
||||
- [Strawberry Perl](https://strawberryperl.com/)
|
||||
|
||||
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
|
||||
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
|
||||
|
||||
@@ -17,10 +17,12 @@ Let's start by asking a simple question that we can get an answer to from the **
|
||||
Then we can create a model and ask the question:
|
||||
|
||||
```python
|
||||
from langchain.llms import Ollama
|
||||
ollama = Ollama(base_url='http://localhost:11434',
|
||||
model="llama2")
|
||||
print(ollama("why is the sky blue"))
|
||||
from langchain_community.llms import Ollama
|
||||
ollama = Ollama(
|
||||
base_url='http://localhost:11434',
|
||||
model="llama3"
|
||||
)
|
||||
print(ollama.invoke("why is the sky blue"))
|
||||
```
|
||||
|
||||
Notice that we are defining the model and the base URL for Ollama.
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
# Ollama Windows Preview
|
||||
|
||||
Welcome to the Ollama Windows preview.
|
||||
|
||||
No more WSL required!
|
||||
|
||||
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
|
||||
After installing Ollama Windows Preview, Ollama will run in the background and
|
||||
the `ollama` command line is available in `cmd`, `powershell` or your favorite
|
||||
terminal application. As usual the Ollama [api](./api.md) will be served on
|
||||
`http://localhost:11434`.
|
||||
|
||||
As this is a preview release, you should expect a few bugs here and there. If
|
||||
you run into a problem you can reach out on
|
||||
[Discord](https://discord.gg/ollama), or file an
|
||||
[issue](https://github.com/ollama/ollama/issues).
|
||||
Logs will often be helpful in diagnosing the problem (see
|
||||
[Troubleshooting](#troubleshooting) below)
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Windows 10 or newer, Home or Pro
|
||||
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
||||
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
```powershell
|
||||
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
|
||||
a "view logs" menu item to the app, and increses logging for the GUI app and
|
||||
server.
|
||||
|
||||
Ollama on Windows stores files in a few different locations. You can view them in
|
||||
the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||
- *app.log* contains logs from the GUI application
|
||||
- *server.log* contains the server logs
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
# Ollama Windows Preview
|
||||
|
||||
Welcome to the Ollama Windows preview.
|
||||
|
||||
No more WSL required!
|
||||
|
||||
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
|
||||
After installing Ollama Windows Preview, Ollama will run in the background and
|
||||
the `ollama` command line is available in `cmd`, `powershell` or your favorite
|
||||
terminal application. As usual the Ollama [api](./api.md) will be served on
|
||||
`http://localhost:11434`.
|
||||
|
||||
As this is a preview release, you should expect a few bugs here and there. If
|
||||
you run into a problem you can reach out on
|
||||
[Discord](https://discord.gg/ollama), or file an
|
||||
[issue](https://github.com/ollama/ollama/issues).
|
||||
Logs will often be helpful in diagnosing the problem (see
|
||||
[Troubleshooting](#troubleshooting) below)
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Windows 10 or newer, Home or Pro
|
||||
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
||||
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
```powershell
|
||||
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
|
||||
a "view logs" menu item to the app, and increses logging for the GUI app and
|
||||
server.
|
||||
|
||||
Ollama on Windows stores files in a few different locations. You can view them in
|
||||
the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||
- *app.log* contains logs from the GUI application
|
||||
- *server.log* contains the server logs
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
|
||||
@@ -40,7 +40,7 @@ func PayloadsDir() (string, error) {
|
||||
}
|
||||
|
||||
var paths []string
|
||||
for _, root := range []string{appExe, cwd} {
|
||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||
paths = append(paths,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
|
||||
@@ -10,6 +10,12 @@ package gpu
|
||||
import "C"
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
metalMinimumMemory = 512 * format.MebiByte
|
||||
)
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
@@ -32,7 +38,7 @@ func GetGPUInfo() GpuInfoList {
|
||||
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||
info.FreeMemory = info.TotalMemory
|
||||
|
||||
info.MinimumMemory = 0
|
||||
info.MinimumMemory = metalMinimumMemory
|
||||
return []GpuInfo{info}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error {
|
||||
|
||||
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
|
||||
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
|
||||
os.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
t.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
}
|
||||
|
||||
slog.Info("starting server", "url", ollamaHost)
|
||||
|
||||
15
llm/ext_server/server.cpp
vendored
15
llm/ext_server/server.cpp
vendored
@@ -1032,7 +1032,7 @@ struct llama_server_context
|
||||
slot.has_next_token = false;
|
||||
}
|
||||
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
|
||||
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
|
||||
{
|
||||
slot.stopped_eos = true;
|
||||
slot.has_next_token = false;
|
||||
@@ -1144,12 +1144,15 @@ struct llama_server_context
|
||||
|
||||
res.result_json = json
|
||||
{
|
||||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"slot_id", slot.id},
|
||||
{"multimodal", multimodal}
|
||||
};
|
||||
|
||||
if (!llama_token_is_eog(model, tkn.tok)) {
|
||||
res.result_json["content"] = tkn.text_to_send;
|
||||
}
|
||||
|
||||
if (slot.sparams.n_probs > 0)
|
||||
{
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
@@ -2644,18 +2647,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
if (strncmp(sep, "int:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||
kvo.int_value = std::atol(sep);
|
||||
kvo.val_i64 = std::atol(sep);
|
||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||
sep += 6;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||
kvo.float_value = std::atof(sep);
|
||||
kvo.val_f64 = std::atof(sep);
|
||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||
sep += 5;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||
if (std::strcmp(sep, "true") == 0) {
|
||||
kvo.bool_value = true;
|
||||
kvo.val_bool = true;
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.bool_value = false;
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
||||
invalid_param = true;
|
||||
|
||||
140
llm/filetype.go
Normal file
140
llm/filetype.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package llm
|
||||
|
||||
import "fmt"
|
||||
|
||||
type fileType uint32
|
||||
|
||||
const (
|
||||
fileTypeF32 fileType = iota
|
||||
fileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16
|
||||
fileTypeQ4_2 // unused
|
||||
fileTypeQ4_3 // unused
|
||||
fileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
fileTypeQ2_K
|
||||
fileTypeQ3_K_S
|
||||
fileTypeQ3_K_M
|
||||
fileTypeQ3_K_L
|
||||
fileTypeQ4_K_S
|
||||
fileTypeQ4_K_M
|
||||
fileTypeQ5_K_S
|
||||
fileTypeQ5_K_M
|
||||
fileTypeQ6_K
|
||||
fileTypeIQ2_XXS
|
||||
fileTypeIQ2_XS
|
||||
fileTypeQ2_K_S
|
||||
fileTypeQ3_K_XS
|
||||
fileTypeIQ3_XXS
|
||||
|
||||
fileTypeUnknown
|
||||
)
|
||||
|
||||
func ParseFileType(s string) (fileType, error) {
|
||||
switch s {
|
||||
case "F32":
|
||||
return fileTypeF32, nil
|
||||
case "F16":
|
||||
return fileTypeF16, nil
|
||||
case "Q4_0":
|
||||
return fileTypeQ4_0, nil
|
||||
case "Q4_1":
|
||||
return fileTypeQ4_1, nil
|
||||
case "Q4_1_F16":
|
||||
return fileTypeQ4_1_F16, nil
|
||||
case "Q8_0":
|
||||
return fileTypeQ8_0, nil
|
||||
case "Q5_0":
|
||||
return fileTypeQ5_0, nil
|
||||
case "Q5_1":
|
||||
return fileTypeQ5_1, nil
|
||||
case "Q2_K":
|
||||
return fileTypeQ2_K, nil
|
||||
case "Q3_K_S":
|
||||
return fileTypeQ3_K_S, nil
|
||||
case "Q3_K_M":
|
||||
return fileTypeQ3_K_M, nil
|
||||
case "Q3_K_L":
|
||||
return fileTypeQ3_K_L, nil
|
||||
case "Q4_K_S":
|
||||
return fileTypeQ4_K_S, nil
|
||||
case "Q4_K_M":
|
||||
return fileTypeQ4_K_M, nil
|
||||
case "Q5_K_S":
|
||||
return fileTypeQ5_K_S, nil
|
||||
case "Q5_K_M":
|
||||
return fileTypeQ5_K_M, nil
|
||||
case "Q6_K":
|
||||
return fileTypeQ6_K, nil
|
||||
case "IQ2_XXS":
|
||||
return fileTypeIQ2_XXS, nil
|
||||
case "IQ2_XS":
|
||||
return fileTypeIQ2_XS, nil
|
||||
case "Q2_K_S":
|
||||
return fileTypeQ2_K_S, nil
|
||||
case "Q3_K_XS":
|
||||
return fileTypeQ3_K_XS, nil
|
||||
case "IQ3_XXS":
|
||||
return fileTypeIQ3_XXS, nil
|
||||
default:
|
||||
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (t fileType) String() string {
|
||||
switch t {
|
||||
case fileTypeF32:
|
||||
return "F32"
|
||||
case fileTypeF16:
|
||||
return "F16"
|
||||
case fileTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeQ4_1_F16:
|
||||
return "Q4_1_F16"
|
||||
case fileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
return "Q5_0"
|
||||
case fileTypeQ5_1:
|
||||
return "Q5_1"
|
||||
case fileTypeQ2_K:
|
||||
return "Q2_K"
|
||||
case fileTypeQ3_K_S:
|
||||
return "Q3_K_S"
|
||||
case fileTypeQ3_K_M:
|
||||
return "Q3_K_M"
|
||||
case fileTypeQ3_K_L:
|
||||
return "Q3_K_L"
|
||||
case fileTypeQ4_K_S:
|
||||
return "Q4_K_S"
|
||||
case fileTypeQ4_K_M:
|
||||
return "Q4_K_M"
|
||||
case fileTypeQ5_K_S:
|
||||
return "Q5_K_S"
|
||||
case fileTypeQ5_K_M:
|
||||
return "Q5_K_M"
|
||||
case fileTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case fileTypeIQ2_XXS:
|
||||
return "IQ2_XXS"
|
||||
case fileTypeIQ2_XS:
|
||||
return "IQ2_XS"
|
||||
case fileTypeQ2_K_S:
|
||||
return "Q2_K_S"
|
||||
case fileTypeQ3_K_XS:
|
||||
return "Q3_K_XS"
|
||||
case fileTypeIQ3_XXS:
|
||||
return "IQ3_XXS"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t fileType) Value() uint32 {
|
||||
return uint32(t)
|
||||
}
|
||||
@@ -42,7 +42,7 @@ function init_vars {
|
||||
"-DLLAMA_NATIVE=off"
|
||||
)
|
||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
||||
md "$script:DIST_BASE" -ea 0 > $null
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
@@ -213,11 +213,11 @@ function build_static() {
|
||||
}
|
||||
}
|
||||
|
||||
function build_cpu() {
|
||||
function build_cpu($gen_arch) {
|
||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||
$script:distDir="$script:DIST_BASE\cpu"
|
||||
write-host "Building LCD CPU"
|
||||
@@ -349,11 +349,15 @@ if ($($args.count) -eq 0) {
|
||||
git_module_setup
|
||||
apply_patches
|
||||
build_static
|
||||
build_cpu
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
if ($script:ARCH -eq "arm64") {
|
||||
build_cpu("ARM64")
|
||||
} else { # amd64
|
||||
build_cpu("x64")
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
}
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
|
||||
|
||||
12
llm/ggla.go
12
llm/ggla.go
@@ -33,6 +33,7 @@ func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
|
||||
type ggla struct {
|
||||
*containerGGLA
|
||||
offset int64
|
||||
|
||||
kv KV
|
||||
tensors []*Tensor
|
||||
@@ -53,6 +54,10 @@ func (llm *ggla) Tensors() Tensors {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
func (llm *ggla) Offset() int64 {
|
||||
return llm.offset
|
||||
}
|
||||
|
||||
func (llm *ggla) decode(rs io.ReadSeeker) error {
|
||||
var r uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
||||
@@ -66,6 +71,13 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
|
||||
}
|
||||
llm.kv["alpha"] = alpha
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
llm.offset = offset
|
||||
|
||||
for {
|
||||
var dims uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
|
||||
|
||||
96
llm/ggml.go
96
llm/ggml.go
@@ -13,85 +13,10 @@ type GGML struct {
|
||||
model
|
||||
}
|
||||
|
||||
const (
|
||||
fileTypeF32 uint32 = iota
|
||||
fileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16
|
||||
fileTypeQ8_0 uint32 = iota + 2
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
fileTypeQ2_K
|
||||
fileTypeQ3_K_S
|
||||
fileTypeQ3_K_M
|
||||
fileTypeQ3_K_L
|
||||
fileTypeQ4_K_S
|
||||
fileTypeQ4_K_M
|
||||
fileTypeQ5_K_S
|
||||
fileTypeQ5_K_M
|
||||
fileTypeQ6_K
|
||||
fileTypeIQ2_XXS
|
||||
fileTypeIQ2_XS
|
||||
fileTypeQ2_K_S
|
||||
fileTypeQ3_K_XS
|
||||
fileTypeIQ3_XXS
|
||||
)
|
||||
|
||||
func fileType(fileType uint32) string {
|
||||
switch fileType {
|
||||
case fileTypeF32:
|
||||
return "F32"
|
||||
case fileTypeF16:
|
||||
return "F16"
|
||||
case fileTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeQ4_1_F16:
|
||||
return "Q4_1_F16"
|
||||
case fileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
return "Q5_0"
|
||||
case fileTypeQ5_1:
|
||||
return "Q5_1"
|
||||
case fileTypeQ2_K:
|
||||
return "Q2_K"
|
||||
case fileTypeQ3_K_S:
|
||||
return "Q3_K_S"
|
||||
case fileTypeQ3_K_M:
|
||||
return "Q3_K_M"
|
||||
case fileTypeQ3_K_L:
|
||||
return "Q3_K_L"
|
||||
case fileTypeQ4_K_S:
|
||||
return "Q4_K_S"
|
||||
case fileTypeQ4_K_M:
|
||||
return "Q4_K_M"
|
||||
case fileTypeQ5_K_S:
|
||||
return "Q5_K_S"
|
||||
case fileTypeQ5_K_M:
|
||||
return "Q5_K_M"
|
||||
case fileTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case fileTypeIQ2_XXS:
|
||||
return "IQ2_XXS"
|
||||
case fileTypeIQ2_XS:
|
||||
return "IQ2_XS"
|
||||
case fileTypeQ2_K_S:
|
||||
return "Q2_K_S"
|
||||
case fileTypeQ3_K_XS:
|
||||
return "Q3_K_XS"
|
||||
case fileTypeIQ3_XXS:
|
||||
return "IQ3_XXS"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type model interface {
|
||||
KV() KV
|
||||
Tensors() Tensors
|
||||
Offset() int64
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
@@ -123,7 +48,7 @@ func (kv KV) ParameterCount() uint64 {
|
||||
|
||||
func (kv KV) FileType() string {
|
||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
||||
return fileType(uint32(u64))
|
||||
return fileType(uint32(u64)).String()
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
@@ -286,6 +211,23 @@ const (
|
||||
|
||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||
|
||||
func DetectGGMLType(b []byte) string {
|
||||
switch binary.LittleEndian.Uint32(b[:4]) {
|
||||
case FILE_MAGIC_GGML:
|
||||
return "ggml"
|
||||
case FILE_MAGIC_GGMF:
|
||||
return "ggmf"
|
||||
case FILE_MAGIC_GGJT:
|
||||
return "ggjt"
|
||||
case FILE_MAGIC_GGLA:
|
||||
return "ggla"
|
||||
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
|
||||
return "gguf"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
var magic uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||
|
||||
11
llm/gguf.go
11
llm/gguf.go
@@ -55,7 +55,7 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
|
||||
|
||||
model := newGGUF(c)
|
||||
slog.Debug(fmt.Sprintf("model = %#v", model))
|
||||
if err := model.Decode(rs); err != nil {
|
||||
if err := model.decode(rs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -90,6 +90,7 @@ const (
|
||||
|
||||
type gguf struct {
|
||||
*containerGGUF
|
||||
offset int64
|
||||
|
||||
kv KV
|
||||
tensors []*Tensor
|
||||
@@ -116,6 +117,10 @@ func (llm *gguf) Tensors() Tensors {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
func (llm *gguf) Offset() int64 {
|
||||
return llm.offset
|
||||
}
|
||||
|
||||
func (llm *gguf) numTensor() uint64 {
|
||||
switch llm.Version {
|
||||
case 1:
|
||||
@@ -138,7 +143,7 @@ func (llm *gguf) numKV() uint64 {
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
func (llm *gguf) decode(rs io.ReadSeeker) error {
|
||||
// decode key-values
|
||||
for i := 0; uint64(i) < llm.numKV(); i++ {
|
||||
k, err := readGGUFString(llm, rs)
|
||||
@@ -250,6 +255,8 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
return err
|
||||
}
|
||||
|
||||
llm.offset = offset + padding
|
||||
|
||||
for _, tensor := range llm.tensors {
|
||||
if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil {
|
||||
return err
|
||||
|
||||
Submodule llm/llama.cpp updated: 46e12c4692...952d03dbea
57
llm/llm.go
57
llm/llm.go
@@ -4,6 +4,7 @@ package llm
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
@@ -19,7 +20,7 @@ func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile, filetype string) error {
|
||||
func Quantize(infile, outfile string, ftype fileType) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
@@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
|
||||
|
||||
params := C.llama_model_quantize_default_params()
|
||||
params.nthread = -1
|
||||
params.ftype = ftype.Value()
|
||||
|
||||
switch filetype {
|
||||
case "F32":
|
||||
params.ftype = fileTypeF32
|
||||
case "F16":
|
||||
params.ftype = fileTypeF16
|
||||
case "Q4_0":
|
||||
params.ftype = fileTypeQ4_0
|
||||
case "Q4_1":
|
||||
params.ftype = fileTypeQ4_1
|
||||
case "Q4_1_F16":
|
||||
params.ftype = fileTypeQ4_1_F16
|
||||
case "Q8_0":
|
||||
params.ftype = fileTypeQ8_0
|
||||
case "Q5_0":
|
||||
params.ftype = fileTypeQ5_0
|
||||
case "Q5_1":
|
||||
params.ftype = fileTypeQ5_1
|
||||
case "Q2_K":
|
||||
params.ftype = fileTypeQ2_K
|
||||
case "Q3_K_S":
|
||||
params.ftype = fileTypeQ3_K_S
|
||||
case "Q3_K_M":
|
||||
params.ftype = fileTypeQ3_K_M
|
||||
case "Q3_K_L":
|
||||
params.ftype = fileTypeQ3_K_L
|
||||
case "Q4_K_S":
|
||||
params.ftype = fileTypeQ4_K_S
|
||||
case "Q4_K_M":
|
||||
params.ftype = fileTypeQ4_K_M
|
||||
case "Q5_K_S":
|
||||
params.ftype = fileTypeQ5_K_S
|
||||
case "Q5_K_M":
|
||||
params.ftype = fileTypeQ5_K_M
|
||||
case "Q6_K":
|
||||
params.ftype = fileTypeQ6_K
|
||||
case "IQ2_XXS":
|
||||
params.ftype = fileTypeIQ2_XXS
|
||||
case "IQ2_XS":
|
||||
params.ftype = fileTypeIQ2_XS
|
||||
case "Q2_K_S":
|
||||
params.ftype = fileTypeQ2_K_S
|
||||
case "Q3_K_XS":
|
||||
params.ftype = fileTypeQ3_K_XS
|
||||
case "IQ3_XXS":
|
||||
params.ftype = fileTypeIQ3_XXS
|
||||
default:
|
||||
return fmt.Errorf("unknown filetype: %s", filetype)
|
||||
}
|
||||
|
||||
if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", retval)
|
||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -88,6 +88,11 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
graphFullOffload *= uint64(len(gpus))
|
||||
graphPartialOffload *= uint64(len(gpus))
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
}
|
||||
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
|
||||
|
||||
@@ -73,8 +73,7 @@ func LoadModel(model string) (*GGML, error) {
|
||||
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||
var err error
|
||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
@@ -301,12 +300,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
continue
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
// Exit status managed via getServerStatus
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
// TODO - make sure this is all wired up correctly
|
||||
// if err = s.WaitUntilRunning(); err != nil {
|
||||
// slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
@@ -900,7 +893,13 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
||||
func (s *llmServer) Close() error {
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
return s.cmd.Process.Kill()
|
||||
if err := s.cmd.Process.Kill(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = s.cmd.Wait()
|
||||
|
||||
slog.Debug("llama server stopped")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,7 +19,7 @@ export default function () {
|
||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
||||
|
||||
const command = 'ollama run llama2'
|
||||
const command = 'ollama run llama3'
|
||||
|
||||
return (
|
||||
<div className='drag'>
|
||||
|
||||
363
parser/parser.go
363
parser/parser.go
@@ -6,8 +6,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
@@ -15,118 +15,283 @@ type Command struct {
|
||||
Args string
|
||||
}
|
||||
|
||||
func (c *Command) Reset() {
|
||||
c.Name = ""
|
||||
c.Args = ""
|
||||
}
|
||||
type state int
|
||||
|
||||
func Parse(reader io.Reader) ([]Command, error) {
|
||||
var commands []Command
|
||||
var command, modelCommand Command
|
||||
const (
|
||||
stateNil state = iota
|
||||
stateName
|
||||
stateValue
|
||||
stateParameter
|
||||
stateMessage
|
||||
stateComment
|
||||
)
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
||||
scanner.Split(scanModelfile)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
|
||||
)
|
||||
|
||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
func Format(cmds []Command) string {
|
||||
var sb strings.Builder
|
||||
for _, cmd := range cmds {
|
||||
name := cmd.Name
|
||||
args := cmd.Args
|
||||
|
||||
switch string(bytes.ToUpper(fields[0])) {
|
||||
case "FROM":
|
||||
command.Name = "model"
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
// copy command for validation
|
||||
modelCommand = command
|
||||
case "ADAPTER":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(fields[1])
|
||||
case "PARAMETER":
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("missing value for %s", fields)
|
||||
}
|
||||
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
case "MESSAGE":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
||||
}
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
name = "from"
|
||||
args = cmd.Args
|
||||
case "license", "template", "system", "adapter":
|
||||
args = quote(args)
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(cmd.Args, ": ")
|
||||
args = role + " " + quote(message)
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
||||
}
|
||||
continue
|
||||
name = "parameter"
|
||||
args = cmd.Name + " " + quote(cmd.Args)
|
||||
}
|
||||
|
||||
commands = append(commands, command)
|
||||
command.Reset()
|
||||
fmt.Fprintln(&sb, strings.ToUpper(name), args)
|
||||
}
|
||||
|
||||
if modelCommand.Args == "" {
|
||||
return nil, errors.New("no FROM line for the model was specified")
|
||||
}
|
||||
|
||||
return commands, scanner.Err()
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
func Parse(r io.Reader) (cmds []Command, err error) {
|
||||
var cmd Command
|
||||
var curr state
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
return bufio.ScanLines(data, atEOF)
|
||||
}
|
||||
|
||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
newline := bytes.IndexByte(data, '\n')
|
||||
|
||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
||||
if end < 0 {
|
||||
if atEOF {
|
||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
||||
} else {
|
||||
return 0, nil, nil
|
||||
}
|
||||
br := bufio.NewReader(r)
|
||||
for {
|
||||
r, _, err := br.ReadRune()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := start + len(openBytes) + end + len(closeBytes)
|
||||
next, r, err := parseRuneForState(r, curr)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newData := data[:start]
|
||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
||||
return n, newData, nil
|
||||
// process the state transition, some transitions need to be intercepted and redirected
|
||||
if next != curr {
|
||||
switch curr {
|
||||
case stateName:
|
||||
if !isValidCommand(b.String()) {
|
||||
return nil, errInvalidCommand
|
||||
}
|
||||
|
||||
// next state sometimes depends on the current buffer value
|
||||
switch s := strings.ToLower(b.String()); s {
|
||||
case "from":
|
||||
cmd.Name = "model"
|
||||
case "parameter":
|
||||
// transition to stateParameter which sets command name
|
||||
next = stateParameter
|
||||
case "message":
|
||||
// transition to stateMessage which validates the message role
|
||||
next = stateMessage
|
||||
fallthrough
|
||||
default:
|
||||
cmd.Name = s
|
||||
}
|
||||
case stateParameter:
|
||||
cmd.Name = b.String()
|
||||
case stateMessage:
|
||||
if !isValidMessageRole(b.String()) {
|
||||
return nil, errInvalidMessageRole
|
||||
}
|
||||
|
||||
role = b.String()
|
||||
case stateComment, stateNil:
|
||||
// pass
|
||||
case stateValue:
|
||||
s, ok := unquote(b.String())
|
||||
if !ok || isSpace(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
role = ""
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
curr = next
|
||||
}
|
||||
|
||||
if strconv.IsPrint(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
// flush the buffer
|
||||
switch curr {
|
||||
case stateComment, stateNil:
|
||||
// pass; nothing to flush
|
||||
case stateValue:
|
||||
s, ok := unquote(b.String())
|
||||
if !ok {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
cmds = append(cmds, cmd)
|
||||
default:
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == "model" {
|
||||
return cmds, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMissingFrom
|
||||
}
|
||||
|
||||
func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||
switch cs {
|
||||
case stateNil:
|
||||
switch {
|
||||
case r == '#':
|
||||
return stateComment, 0, nil
|
||||
case isSpace(r), isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateName, r, nil
|
||||
}
|
||||
case stateName:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateName, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, errInvalidCommand
|
||||
}
|
||||
case stateValue:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, r, nil
|
||||
case isSpace(r):
|
||||
return stateNil, r, nil
|
||||
default:
|
||||
return stateValue, r, nil
|
||||
}
|
||||
case stateParameter:
|
||||
switch {
|
||||
case isAlpha(r), isNumber(r), r == '_':
|
||||
return stateParameter, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateMessage:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateMessage, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateComment:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateComment, 0, nil
|
||||
}
|
||||
default:
|
||||
return stateNil, 0, errors.New("")
|
||||
}
|
||||
}
|
||||
|
||||
func quote(s string) string {
|
||||
if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
|
||||
if strings.Contains(s, "\"") {
|
||||
return `"""` + s + `"""`
|
||||
}
|
||||
|
||||
return `"` + s + `"`
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func unquote(s string) (string, bool) {
|
||||
if len(s) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// TODO: single quotes
|
||||
if len(s) >= 3 && s[:3] == `"""` {
|
||||
if len(s) >= 6 && s[len(s)-3:] == `"""` {
|
||||
return s[3 : len(s)-3], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
if len(s) >= 1 && s[0] == '"' {
|
||||
if len(s) >= 2 && s[len(s)-1] == '"' {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s, true
|
||||
}
|
||||
|
||||
func isAlpha(r rune) bool {
|
||||
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
|
||||
}
|
||||
|
||||
func isNumber(r rune) bool {
|
||||
return r >= '0' && r <= '9'
|
||||
}
|
||||
|
||||
func isSpace(r rune) bool {
|
||||
return r == ' ' || r == '\t'
|
||||
}
|
||||
|
||||
func isNewline(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
}
|
||||
|
||||
func isValidMessageRole(role string) bool {
|
||||
return role == "system" || role == "user" || role == "assistant"
|
||||
}
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "parameter", "message":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Parser(t *testing.T) {
|
||||
|
||||
func TestParser(t *testing.T) {
|
||||
input := `
|
||||
FROM model1
|
||||
ADAPTER adapter1
|
||||
@@ -35,21 +37,62 @@ TEMPLATE template1
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_NoFromLine(t *testing.T) {
|
||||
func TestParserFrom(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"FROM foo",
|
||||
[]Command{{Name: "model", Args: "foo"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM /path/to/model",
|
||||
[]Command{{Name: "model", Args: "/path/to/model"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM /path/to/model/fp16.bin",
|
||||
[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM llama3:latest",
|
||||
[]Command{{Name: "model", Args: "llama3:latest"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM llama3:7b-instruct-q4_K_M",
|
||||
[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"", nil, errMissingFrom,
|
||||
},
|
||||
{
|
||||
"PARAMETER param1 value1",
|
||||
nil,
|
||||
errMissingFrom,
|
||||
},
|
||||
{
|
||||
"PARAMETER param1 value1\nFROM foo",
|
||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
input := `
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "no FROM line")
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.input))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Parser_MissingValue(t *testing.T) {
|
||||
|
||||
func TestParserParametersMissingValue(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
PARAMETER param1
|
||||
@@ -58,41 +101,401 @@ PARAMETER param1
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestParserBadCommand(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
BADCOMMAND param1 value1
|
||||
`
|
||||
_, err := Parse(strings.NewReader(input))
|
||||
assert.ErrorIs(t, err, errInvalidCommand)
|
||||
|
||||
}
|
||||
|
||||
func Test_Parser_Messages(t *testing.T) {
|
||||
|
||||
input := `
|
||||
func TestParserMessages(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
||||
|
||||
input := `
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system """
|
||||
You are a multiline Parser. Always Parse things.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`
|
||||
`,
|
||||
nil,
|
||||
errInvalidMessageRole,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.input))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserQuoted(t *testing.T) {
|
||||
var cases = []struct {
|
||||
multiline string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a
|
||||
multiline system.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a\nmultiline system.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a
|
||||
multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a\nmultiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a
|
||||
multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "This is a\nmultiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "This is a multiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a multiline system.""
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM "
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a multiline system with "quotes".
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM ""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM "'"
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "'"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """''"'""'""'"'''''""'""'"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: `''"'""'""'"'''''""'""'`},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """
|
||||
{{ .Prompt }}
|
||||
"""`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "\n{{ .Prompt }}\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.multiline))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserParameters(t *testing.T) {
|
||||
var cases = map[string]struct {
|
||||
name, value string
|
||||
}{
|
||||
"numa true": {"numa", "true"},
|
||||
"num_ctx 1": {"num_ctx", "1"},
|
||||
"num_batch 1": {"num_batch", "1"},
|
||||
"num_gqa 1": {"num_gqa", "1"},
|
||||
"num_gpu 1": {"num_gpu", "1"},
|
||||
"main_gpu 1": {"main_gpu", "1"},
|
||||
"low_vram true": {"low_vram", "true"},
|
||||
"f16_kv true": {"f16_kv", "true"},
|
||||
"logits_all true": {"logits_all", "true"},
|
||||
"vocab_only true": {"vocab_only", "true"},
|
||||
"use_mmap true": {"use_mmap", "true"},
|
||||
"use_mlock true": {"use_mlock", "true"},
|
||||
"num_thread 1": {"num_thread", "1"},
|
||||
"num_keep 1": {"num_keep", "1"},
|
||||
"seed 1": {"seed", "1"},
|
||||
"num_predict 1": {"num_predict", "1"},
|
||||
"top_k 1": {"top_k", "1"},
|
||||
"top_p 1.0": {"top_p", "1.0"},
|
||||
"tfs_z 1.0": {"tfs_z", "1.0"},
|
||||
"typical_p 1.0": {"typical_p", "1.0"},
|
||||
"repeat_last_n 1": {"repeat_last_n", "1"},
|
||||
"temperature 1.0": {"temperature", "1.0"},
|
||||
"repeat_penalty 1.0": {"repeat_penalty", "1.0"},
|
||||
"presence_penalty 1.0": {"presence_penalty", "1.0"},
|
||||
"frequency_penalty 1.0": {"frequency_penalty", "1.0"},
|
||||
"mirostat 1": {"mirostat", "1"},
|
||||
"mirostat_tau 1.0": {"mirostat_tau", "1.0"},
|
||||
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
||||
"penalize_newline true": {"penalize_newline", "true"},
|
||||
"stop ### User:": {"stop", "### User:"},
|
||||
"stop ### User: ": {"stop", "### User: "},
|
||||
"stop \"### User:\"": {"stop", "### User:"},
|
||||
"stop \"### User: \"": {"stop", "### User: "},
|
||||
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
||||
"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
|
||||
"stop <|endoftext|>": {"stop", "<|endoftext|>"},
|
||||
"stop <|eot_id|>": {"stop", "<|eot_id|>"},
|
||||
"stop </s>": {"stop", "</s>"},
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintln(&b, "FROM foo")
|
||||
fmt.Fprintln(&b, "PARAMETER", k)
|
||||
commands, err := Parse(&b)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: v.name, Args: v.value},
|
||||
}, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserComments(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
}{
|
||||
{
|
||||
`
|
||||
# comment
|
||||
FROM foo
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.input))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFormatParse(t *testing.T) {
|
||||
var cases = []string{
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system """
|
||||
You are a store greeter. Always responsed with "Hello!".
|
||||
"""
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE """
|
||||
Very long and boring legal text.
|
||||
Blah blah blah.
|
||||
"Oh look, a quote!"
|
||||
"""
|
||||
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system """
|
||||
You are a store greeter. Always responsed with "Hello!".
|
||||
"""
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c))
|
||||
assert.NoError(t, err)
|
||||
|
||||
commands2, err := Parse(strings.NewReader(Format(commands)))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, commands, commands2)
|
||||
})
|
||||
}
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function checkEnv() {
|
||||
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
Write-host "Building for ${script:TARGET_ARCH}"
|
||||
write-host "Locating required tools and paths"
|
||||
$script:SRC_DIR=$PWD
|
||||
if (!$env:VCToolsRedistDir) {
|
||||
@@ -30,7 +32,7 @@ function checkEnv() {
|
||||
|
||||
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
|
||||
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
echo "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
@@ -81,8 +83,8 @@ function buildOllama() {
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
|
||||
cp .\ollama.exe .\dist\windows-amd64\
|
||||
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
|
||||
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
|
||||
}
|
||||
|
||||
function buildApp() {
|
||||
@@ -127,16 +129,16 @@ function buildInstaller() {
|
||||
cd "${script:SRC_DIR}\app"
|
||||
$env:PKG_VERSION=$script:PKG_VERSION
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
} else {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
|
||||
}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function distZip() {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
839
server/images.go
839
server/images.go
File diff suppressed because it is too large
Load Diff
@@ -5,39 +5,18 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type Layers struct {
|
||||
items []*Layer
|
||||
}
|
||||
|
||||
func (ls *Layers) Add(layer *Layer) {
|
||||
if layer.Size > 0 {
|
||||
ls.items = append(ls.items, layer)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *Layers) Replace(layer *Layer) {
|
||||
if layer.Size > 0 {
|
||||
mediatype := layer.MediaType
|
||||
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
|
||||
return l.MediaType == mediatype
|
||||
})
|
||||
|
||||
ls.items = append(layers, layer)
|
||||
}
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
From string `json:"from,omitempty"`
|
||||
|
||||
tempFileName string
|
||||
Intermediate bool `json:"intermediate,omitempty"`
|
||||
MergeBase string `json:"merge_base,omitempty"`
|
||||
|
||||
message string
|
||||
}
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
@@ -46,14 +25,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const delimiter = "-"
|
||||
|
||||
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
|
||||
temp, err := os.CreateTemp(blobs, pattern)
|
||||
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
sha256sum := sha256.New()
|
||||
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
||||
@@ -61,11 +38,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := temp.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := "using existing layer"
|
||||
if _, err := os.Stat(blob); err != nil {
|
||||
status = "creating new layer"
|
||||
if err := os.Rename(temp.Name(), blob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Layer{
|
||||
MediaType: mediatype,
|
||||
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
|
||||
Size: n,
|
||||
tempFileName: temp.Name(),
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
message: fmt.Sprintf("%s %s", status, digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -85,21 +80,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
From: from,
|
||||
message: fmt.Sprintf("using existing layer %s", digest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Layer) Commit() (bool, error) {
|
||||
// always remove temp
|
||||
defer os.Remove(l.tempFileName)
|
||||
|
||||
func (l *Layer) Open() (*os.File, error) {
|
||||
blob, err := GetBlobsPath(l.Digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(blob); err != nil {
|
||||
return true, os.Rename(l.tempFileName, blob)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return os.Open(blob)
|
||||
}
|
||||
259
server/model.go
Normal file
259
server/model.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type layerWithGGML struct {
|
||||
*Layer
|
||||
*llm.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
modelpath := ParseModelPath(name.String())
|
||||
manifest, _, err := GetManifest(modelpath)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelpath = ParseModelPath(name.String())
|
||||
manifest, _, err = GetManifest(modelpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model",
|
||||
"application/vnd.ollama.image.projector",
|
||||
"application/vnd.ollama.image.adapter":
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := os.Open(blobpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
ggml, _, err := llm.DecodeGGML(blob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
default:
|
||||
layers = append(layers, &layerWithGGML{layer, nil})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := zip.NewReader(file, stat.Size())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tempdir)
|
||||
|
||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
||||
for _, f := range r.File {
|
||||
// TODO(mxyng): this should not write out all files to disk
|
||||
outfile, err := os.Create(filepath.Join(tempdir, f.Name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
infile, err := f.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = io.Copy(outfile, infile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := outfile.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := infile.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := convert.GetModelFormat(tempdir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params, err := mf.GetParams(tempdir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mArch, err := mf.GetModelArch("", tempdir, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "processing tensors"})
|
||||
if err := mArch.GetTensors(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mArch.LoadVocab(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "converting model"})
|
||||
|
||||
// TODO(mxyng): this should write directly into a layer
|
||||
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
|
||||
temp, err := os.CreateTemp(tempdir, "fp16")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if err = mArch.WriteGGUF(temp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aaa: %w", err)
|
||||
}
|
||||
|
||||
blobpath, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bin, err := os.Open(blobpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
ggml, _, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
sr := io.NewSectionReader(file, 0, 512)
|
||||
contentType, err := detectContentType(sr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "gguf", "ggla":
|
||||
// noop
|
||||
case "application/zip":
|
||||
return parseFromZipFile(ctx, file, fn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported content type: %s", contentType)
|
||||
}
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < stat.Size() {
|
||||
ggml, n, err := llm.DecodeGGML(file)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if ggml.Name() == "ggla" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if ggml.KV().Architecture() == "clip" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
offset = n
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func detectContentType(r io.Reader) (string, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
return "unknown", nil
|
||||
}
|
||||
@@ -580,7 +580,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), strings.ToUpper(req.Quantization), commands, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := ShowModelfile(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Modelfile = mf
|
||||
var sb strings.Builder
|
||||
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
|
||||
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
|
||||
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
|
||||
fmt.Fprint(&sb, parser.Format(model.Commands()))
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
@@ -810,16 +810,13 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
|
||||
|
||||
src := model.ParseName(r.Source)
|
||||
if !src.IsValid() {
|
||||
_ = c.Error(fmt.Errorf("source %q is invalid", r.Source))
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
||||
return
|
||||
}
|
||||
|
||||
dst := model.ParseName(r.Destination)
|
||||
if !dst.IsValid() {
|
||||
_ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination))
|
||||
}
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Source)})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -875,11 +872,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := layer.Commit(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
|
||||
@@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/create",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
f, err := os.CreateTemp(t.TempDir(), "ollama-model")
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
fname := createTestFile(t, "ollama-model")
|
||||
|
||||
stream := false
|
||||
createReq := api.CreateRequest{
|
||||
Name: "t-bone",
|
||||
Modelfile: fmt.Sprintf("FROM %s", f.Name()),
|
||||
Modelfile: fmt.Sprintf("FROM %s", fname),
|
||||
Stream: &stream,
|
||||
}
|
||||
jsonData, err := json.Marshal(createReq)
|
||||
@@ -216,28 +214,25 @@ func Test_Routes(t *testing.T) {
|
||||
httpSrv := httptest.NewServer(router)
|
||||
t.Cleanup(httpSrv.Close)
|
||||
|
||||
workDir, err := os.MkdirTemp("", "ollama-test")
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(workDir)
|
||||
os.Setenv("OLLAMA_MODELS", workDir)
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Logf("Running Test: [%s]", tc.Name)
|
||||
u := httpSrv.URL + tc.Path
|
||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||
assert.Nil(t, err)
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
u := httpSrv.URL + tc.Path
|
||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if tc.Setup != nil {
|
||||
tc.Setup(t, req)
|
||||
}
|
||||
if tc.Setup != nil {
|
||||
tc.Setup(t, req)
|
||||
}
|
||||
|
||||
resp, err := httpSrv.Client().Do(req)
|
||||
assert.Nil(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.Expected != nil {
|
||||
tc.Expected(t, resp)
|
||||
}
|
||||
resp, err := httpSrv.Client().Do(req)
|
||||
assert.Nil(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.Expected != nil {
|
||||
tc.Expected(t, resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,6 +149,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
// If we're CPU only mode, just limit by loadedMax above
|
||||
// TODO handle system memory exhaustion
|
||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
|
||||
slog.Debug("cpu mode with existing models, loading")
|
||||
s.loadFn(pending, ggml, gpus)
|
||||
break
|
||||
}
|
||||
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
@@ -242,6 +250,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
defer runner.refMu.Unlock()
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
@@ -288,6 +297,10 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
runner.refCount++
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
runner.sessionDuration = pending.sessionDuration
|
||||
pending.successCh <- runner
|
||||
go func() {
|
||||
@@ -418,6 +431,10 @@ type runnerRef struct {
|
||||
|
||||
// The refMu must already be held when calling unload
|
||||
func (runner *runnerRef) unload() {
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
if runner.llama != nil {
|
||||
runner.llama.Close()
|
||||
}
|
||||
|
||||
@@ -28,19 +28,33 @@ func TestInitScheduler(t *testing.T) {
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
defer done()
|
||||
initialMax := loadedMax
|
||||
initialParallel := numParallel
|
||||
s := InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, 0, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
os.Setenv("OLLAMA_NUM_PARALLEL", "blue")
|
||||
_ = InitScheduler(ctx)
|
||||
require.Equal(t, initialParallel, numParallel)
|
||||
os.Setenv("OLLAMA_NUM_PARALLEL", "10")
|
||||
_ = InitScheduler(ctx)
|
||||
require.Equal(t, 10, numParallel)
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
@@ -51,6 +65,7 @@ func TestLoad(t *testing.T) {
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
model: &Model{ModelPath: "foo"},
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
sessionDuration: 2,
|
||||
@@ -63,7 +78,9 @@ func TestLoad(t *testing.T) {
|
||||
s.load(req, ggml, gpus)
|
||||
require.Len(t, req.successCh, 0)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
err := <-req.errCh
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
@@ -78,7 +95,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
require.Equal(t, uint64(10), resp.estimatedVRAM)
|
||||
require.Equal(t, uint(1), resp.refCount)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
@@ -90,7 +109,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
t.Errorf("unexpected success %v", resp)
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded["dummy_model_path"]
|
||||
s.loadedMu.Unlock()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, uint(0), runner.refCount)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
@@ -143,6 +164,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||
scenario.req = &LlmRequest{
|
||||
ctx: scenario.ctx,
|
||||
model: model,
|
||||
opts: api.DefaultOptions(),
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
@@ -171,7 +193,9 @@ func TestRequests(t *testing.T) {
|
||||
// Multiple loaded models
|
||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
@@ -240,7 +264,9 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
loadedMax = 0
|
||||
s.newServerFn = scenario3b.newServer
|
||||
@@ -254,19 +280,14 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
// This is a CPU load with NumGPU = 0 so it should load
|
||||
s.newServerFn = scenario3c.newServer
|
||||
slog.Info("scenario3c")
|
||||
require.Len(t, s.loaded, 2)
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3c.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
require.Len(t, s.loaded, 1)
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3c.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3c.srv)
|
||||
@@ -275,7 +296,36 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
s.newServerFn = scenario3d.newServer
|
||||
slog.Info("scenario3d")
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3d.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3d.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3d.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, scenario3d.req.errCh, 0)
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestGetRunner(t *testing.T) {
|
||||
@@ -318,7 +368,9 @@ func TestGetRunner(t *testing.T) {
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
scenario1a.ctxDone()
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
scenario1c.req.model.ModelPath = "bad path"
|
||||
slog.Info("scenario1c")
|
||||
@@ -328,7 +380,9 @@ func TestGetRunner(t *testing.T) {
|
||||
require.Len(t, errCh1c, 0)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
require.Len(t, errCh1c, 1)
|
||||
err = <-errCh1c
|
||||
require.Contains(t, err.Error(), "bad path")
|
||||
@@ -358,7 +412,9 @@ func TestPrematureExpired(t *testing.T) {
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, errCh1a, 0)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Info("sending premature expired event now")
|
||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||
case <-ctx.Done():
|
||||
@@ -383,6 +439,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
sessionDuration: 2,
|
||||
}
|
||||
@@ -426,8 +483,10 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpus}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
s.updateFreeSpace(gpus)
|
||||
require.Equal(t, uint64(850), gpus[0].FreeMemory)
|
||||
@@ -437,13 +496,18 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer done()
|
||||
req := &LlmRequest{ctx: ctx}
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
}
|
||||
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
|
||||
r2 := &runnerRef{sessionDuration: 2}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
resp := s.findRunnerToUnload(req)
|
||||
require.Equal(t, r2, resp)
|
||||
@@ -458,10 +522,11 @@ func TestNeedsReload(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{}
|
||||
do := api.DefaultOptions()
|
||||
runner := &runnerRef{
|
||||
adapters: []string{"adapter1"},
|
||||
projectors: []string{"projector1"},
|
||||
Options: &api.Options{},
|
||||
Options: &do,
|
||||
llama: llm,
|
||||
}
|
||||
req := &LlmRequest{
|
||||
@@ -469,7 +534,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
AdapterPaths: []string{"adapter2"},
|
||||
ProjectorPaths: []string{"projector2"},
|
||||
},
|
||||
opts: api.Options{},
|
||||
opts: api.DefaultOptions(),
|
||||
}
|
||||
resp := runner.needsReload(ctx, req)
|
||||
require.True(t, resp)
|
||||
@@ -508,8 +573,10 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
r1 := &runnerRef{llama: llm1}
|
||||
r2 := &runnerRef{llama: llm2}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
s.unloadAllRunners()
|
||||
|
||||
require.True(t, llm1.closeCalled)
|
||||
|
||||
18
types/errtypes/errtypes.go
Normal file
18
types/errtypes/errtypes.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -80,9 +81,6 @@ func (k partKind) String() string {
|
||||
//
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
//
|
||||
// It is not directly comparable with other Names. Use [Name.Equal] and
|
||||
// [Name.MapHash] for determining equality and using as a map key.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
@@ -109,20 +107,20 @@ type Name struct {
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: alphanum { alphanum | "-" | "_" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// model:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: alphanum { alphanum | "-" | ":" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// Most users should use [ParseName] instead, unless need to support
|
||||
// different defaults than DefaultName.
|
||||
@@ -145,18 +143,28 @@ func ParseNameBare(s string) Name {
|
||||
n.RawDigest = MissingPart
|
||||
}
|
||||
|
||||
s, n.Tag, _ = cutPromised(s, ":")
|
||||
// "/" is an illegal tag character, so we can use it to split the host
|
||||
if strings.LastIndex(s, ":") > strings.LastIndex(s, "/") {
|
||||
s, n.Tag, _ = cutPromised(s, ":")
|
||||
}
|
||||
|
||||
s, n.Model, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Model = s
|
||||
return n
|
||||
}
|
||||
|
||||
s, n.Namespace, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Namespace = s
|
||||
return n
|
||||
}
|
||||
n.Host = s
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if ! ok {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
|
||||
return n
|
||||
}
|
||||
@@ -234,12 +242,12 @@ func (n Name) Filepath() string {
|
||||
if !n.IsFullyQualified() {
|
||||
panic("illegal attempt to get filepath of invalid name")
|
||||
}
|
||||
return filepath.Join(
|
||||
strings.ToLower(n.Host),
|
||||
strings.ToLower(n.Namespace),
|
||||
strings.ToLower(n.Model),
|
||||
strings.ToLower(n.Tag),
|
||||
)
|
||||
return strings.ToLower(filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
))
|
||||
}
|
||||
|
||||
// LogValue returns a slog.Value that represents the name as a string.
|
||||
@@ -254,7 +262,7 @@ func isValidLen(kind partKind, s string) bool {
|
||||
case kindTag:
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
default:
|
||||
return len(s) >= 2 && len(s) <= 80
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +272,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
}
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
@@ -280,7 +288,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -288,8 +296,8 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlphanumeric(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func cutLast(s, sep string) (before, after string, ok bool) {
|
||||
@@ -311,3 +319,57 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
|
||||
}
|
||||
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
||||
}
|
||||
|
||||
type DigestType byte
|
||||
|
||||
const (
|
||||
DigestTypeInvalid DigestType = iota
|
||||
DigestTypeSHA256
|
||||
)
|
||||
|
||||
func (t DigestType) String() string {
|
||||
switch t {
|
||||
case DigestTypeSHA256:
|
||||
return "sha256"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
type Digest struct {
|
||||
Type DigestType
|
||||
Sum [32]byte
|
||||
}
|
||||
|
||||
func ParseDigest(s string) (Digest, error) {
|
||||
i := strings.IndexAny(s, "-:")
|
||||
if i < 0 {
|
||||
return Digest{}, fmt.Errorf("invalid digest %q", s)
|
||||
}
|
||||
typ, encSum := s[:i], s[i+1:]
|
||||
if typ != "sha256" {
|
||||
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
|
||||
}
|
||||
d := Digest{
|
||||
Type: DigestTypeSHA256,
|
||||
}
|
||||
n, err := hex.Decode(d.Sum[:], []byte(encSum))
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
if n != 32 {
|
||||
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d Digest) String() string {
|
||||
if d.Type == DigestTypeInvalid {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", d.Sum)
|
||||
}
|
||||
|
||||
func (d Digest) IsValid() bool {
|
||||
return d.Type != DigestTypeInvalid
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -14,8 +16,19 @@ func TestParseNameParts(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want Name
|
||||
wantFilepath string
|
||||
wantValidDigest bool
|
||||
}{
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host/namespace/model:tag",
|
||||
want: Name{
|
||||
@@ -24,6 +37,17 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host/namespace/model",
|
||||
@@ -32,6 +56,16 @@ func TestParseNameParts(t *testing.T) {
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("host", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "host:port/namespace/model",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "namespace/model",
|
||||
@@ -39,12 +73,14 @@ func TestParseNameParts(t *testing.T) {
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("registry.ollama.ai", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "model",
|
||||
want: Name{
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("registry.ollama.ai", "library", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "h/nn/mm:t",
|
||||
@@ -54,6 +90,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: "mm",
|
||||
Tag: "t",
|
||||
},
|
||||
wantFilepath: filepath.Join("h", "nn", "mm", "t"),
|
||||
},
|
||||
{
|
||||
in: part80 + "/" + part80 + "/" + part80 + ":" + part80,
|
||||
@@ -63,6 +100,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: part80,
|
||||
Tag: part80,
|
||||
},
|
||||
wantFilepath: filepath.Join(part80, part80, part80, part80),
|
||||
},
|
||||
{
|
||||
in: part350 + "/" + part80 + "/" + part80 + ":" + part80,
|
||||
@@ -72,6 +110,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: part80,
|
||||
Tag: part80,
|
||||
},
|
||||
wantFilepath: filepath.Join(part350, part80, part80, part80),
|
||||
},
|
||||
{
|
||||
in: "@digest",
|
||||
@@ -96,11 +135,23 @@ func TestParseNameParts(t *testing.T) {
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
|
||||
}
|
||||
|
||||
got = ParseName(tt.in)
|
||||
if tt.wantFilepath != "" && got.Filepath() != tt.wantFilepath {
|
||||
t.Errorf("parseName(%q).Filepath() = %q; want %q", tt.in, got.Filepath(), tt.wantFilepath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t@d": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": false,
|
||||
"namespace/model": false,
|
||||
@@ -116,11 +167,12 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
|
||||
"m": false, // model too short
|
||||
"n/mm:": false, // namespace too short
|
||||
"h/n/mm:t": false, // namespace too short
|
||||
"@t": false, // digest too short
|
||||
"mm@d": false, // digest too short
|
||||
// unqualified
|
||||
"m": false,
|
||||
"n/m:": false,
|
||||
"h/n/m": false,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
@@ -140,8 +192,6 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"hh/nn/mm:-tt@dd": false,
|
||||
"hh/nn/mm:tt@-dd": false,
|
||||
|
||||
"": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
@@ -163,7 +213,6 @@ func TestNameIsValid(t *testing.T) {
|
||||
var numStringTests int
|
||||
for s, want := range testCases {
|
||||
n := ParseNameBare(s)
|
||||
t.Logf("n: %#v", n)
|
||||
got := n.IsValid()
|
||||
if got != want {
|
||||
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
|
||||
@@ -212,6 +261,54 @@ func TestNameIsValidPart(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestFilepathAllocs(t *testing.T) {
|
||||
n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
n.Filepath()
|
||||
})
|
||||
allowedAllocs := 2.0
|
||||
if runtime.GOOS == "windows" {
|
||||
allowedAllocs = 4
|
||||
}
|
||||
if allocs > allowedAllocs {
|
||||
t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
|
||||
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
|
||||
)
|
||||
|
||||
func TestParseDigest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""}, // empty
|
||||
{"sha123-12", ""}, // invalid type
|
||||
{"sha256-", ""}, // invalid sum
|
||||
{"sha256-123", ""}, // invalid odd length sum
|
||||
|
||||
{validSha256, validSha256},
|
||||
{validSha256Old, validSha256},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got, err := ParseDigest(tt.in)
|
||||
if err != nil {
|
||||
if tt.want != "" {
|
||||
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got.String() != tt.want {
|
||||
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzName(f *testing.F) {
|
||||
for s := range testCases {
|
||||
f.Add(s)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
||||
Reference in New Issue
Block a user