mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 12:38:15 -05:00
Compare commits
39 Commits
v0.2.2
...
jyan/progr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c7d01af3 | ||
|
|
9d517cf556 | ||
|
|
6bab0e2368 | ||
|
|
c4cccaf936 | ||
|
|
9fe5c393e4 | ||
|
|
007c988dba | ||
|
|
91d21e7c7b | ||
|
|
3e64284f69 | ||
|
|
39910f2ab2 | ||
|
|
96d0cd92f2 | ||
|
|
3a724a7c80 | ||
|
|
f520f0056e | ||
|
|
d25f85ede4 | ||
|
|
b48420b74b | ||
|
|
784958a1cb | ||
|
|
ae65cc8dea | ||
|
|
a037528bba | ||
|
|
04bf41deb5 | ||
|
|
c23cec9547 | ||
|
|
8377dc48d0 | ||
|
|
3aee405dfa | ||
|
|
9b3f47b674 | ||
|
|
f5441f01a2 | ||
|
|
ab165df43a | ||
|
|
79cc4c9585 | ||
|
|
bc3f59a6ad | ||
|
|
1a85cb904c | ||
|
|
10ea0987e9 | ||
|
|
413d368a6a | ||
|
|
cabf375059 | ||
|
|
ca0ee1d4fe | ||
|
|
1142999aab | ||
|
|
0d5a72aba9 | ||
|
|
ea837412c2 | ||
|
|
736ad6f438 | ||
|
|
64607d16a5 | ||
|
|
a6cfe7f00b | ||
|
|
c3b411a515 | ||
|
|
928f37e3ae |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -147,7 +147,7 @@ jobs:
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
|
||||
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
@@ -169,7 +169,7 @@ jobs:
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
|
||||
@@ -17,14 +17,20 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer knownHostsFile.Close()
|
||||
|
||||
token, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// interleave request data into the token
|
||||
key, sig, _ := strings.Cut(token, ":")
|
||||
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||
}
|
||||
|
||||
@@ -127,10 +127,6 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
|
||||
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
||||
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
||||
|
||||
[InstallDelete]
|
||||
Type: filesandordirs; Name: "{%TEMP}\ollama*"
|
||||
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
|
||||
|
||||
[Messages]
|
||||
WizardReady=Ollama Windows Preview
|
||||
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
||||
|
||||
54
auth/auth.go
54
auth/auth.go
@@ -10,42 +10,37 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
func keyPath() (ssh.Signer, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
return ssh.ParsePrivateKey(privateKeyFile)
|
||||
}
|
||||
|
||||
func GetPublicKey() (ssh.PublicKey, error) {
|
||||
privateKey, err := keyPath()
|
||||
// if privateKey, try public key directly
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
return privateKey.PublicKey(), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
privateKey, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// get the pubkey, but remove the type
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
parts := bytes.Split(publicKey, []byte(" "))
|
||||
publicKey, err := GetPublicKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||
|
||||
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("malformed public key")
|
||||
}
|
||||
|
||||
163
cmd/cmd.go
163
cmd/cmd.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
status := "transferring model data"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
defer p.Stop()
|
||||
|
||||
for i := range modelfile.Commands {
|
||||
switch modelfile.Commands[i].Name {
|
||||
@@ -112,11 +115,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
path = tempfile
|
||||
}
|
||||
|
||||
digest, err := createBlob(cmd, client, path)
|
||||
digest, err := createBlob(cmd, client, path, spinner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelfile.Commands[i].Args = "@" + digest
|
||||
}
|
||||
}
|
||||
@@ -138,7 +140,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
@@ -263,13 +265,22 @@ func tempZipFiles(path string) (string, error) {
|
||||
return tempfile.Name(), nil
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||
var ErrBlobExists = errors.New("blob exists")
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||
bin, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
// Get file info to retrieve the size
|
||||
fileInfo, err := bin.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
@@ -279,13 +290,151 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
return "", err
|
||||
}
|
||||
|
||||
var pw progressWriter
|
||||
status := "transferring model data 0%"
|
||||
spinner.SetMessage(status)
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
|
||||
case <-done:
|
||||
spinner.SetMessage("transferring model data 100%")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
|
||||
// We check if we can find the models directory locally
|
||||
// If we can, we return the path to the directory
|
||||
// If we can't, we return an error
|
||||
// If the blob exists already, we return the digest
|
||||
dest, err := getLocalPath(cmd.Context(), digest)
|
||||
|
||||
if errors.Is(err, ErrBlobExists) {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Successfully found the model directory
|
||||
if err == nil {
|
||||
// Copy blob in via OS specific copy
|
||||
// Linux errors out to use io.copy
|
||||
err = localCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Default copy using io.copy
|
||||
err = defaultCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
n int64
|
||||
}
|
||||
|
||||
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||
w.n += int64(len(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||
ollamaHost := envconfig.Host
|
||||
|
||||
client := http.DefaultClient
|
||||
base := &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(digest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reqBody := bytes.NewReader(data)
|
||||
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||
requestURL := base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
authz, err := api.Authorization(ctx, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", authz)
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
request.Header.Set("X-Redirect-Create", "1")
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
dest := resp.Header.Get("LocalLocation")
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
return "", ErrBlobExists
|
||||
}
|
||||
|
||||
func defaultCopy(path string, dest string) error {
|
||||
// This function should be called if the server is local
|
||||
// It should find the model directory, copy the blob over, and return the digest
|
||||
dirPath := filepath.Dir(dest)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy blob over
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open source file: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create destination file: %v", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
}
|
||||
|
||||
err = destFile.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error flushing file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -379,11 +528,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
publicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
|
||||
23
cmd/copy_darwin.go
Normal file
23
cmd/copy_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
dirPath := filepath.Dir(target)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := unix.Clonefile(src, target, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
cmd/copy_linux.go
Normal file
7
cmd/copy_linux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import "errors"
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
return errors.New("no local copy implementation for linux")
|
||||
}
|
||||
67
cmd/copy_windows.go
Normal file
67
cmd/copy_windows.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
// Create target directory if it doesn't exist
|
||||
dirPath := filepath.Dir(target)
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open source file
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
// Create target file
|
||||
targetFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer targetFile.Close()
|
||||
|
||||
// Use CopyFileExW to copy the file
|
||||
err = copyFileEx(src, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFileEx(src, dst string) error {
|
||||
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
||||
copyFileEx := kernel32.NewProc("CopyFileExW")
|
||||
|
||||
srcPtr, err := syscall.UTF16PtrFromString(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := copyFileEx.Call(
|
||||
uintptr(unsafe.Pointer(srcPtr)),
|
||||
uintptr(unsafe.Pointer(dstPtr)),
|
||||
0, 0, 0, 0)
|
||||
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
}
|
||||
|
||||
func commonAMDValidateLibDir() (string, error) {
|
||||
// Favor our bundled version
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||
// the system version. Only use our bundled version if the system version doesn't work
|
||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
@@ -84,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
}
|
||||
|
||||
slog.Debug("hipDriverGetVersion", "version", version)
|
||||
driverMajor = version / 10000000
|
||||
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
|
||||
driverMajor = version / 1000
|
||||
driverMinor = (version - (driverMajor * 1000)) / 10
|
||||
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ const (
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
if err != nil {
|
||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||
slog.Debug("error looking up amd driver version", "error", err)
|
||||
}
|
||||
// TODO - this reports incorrect version information, so omitting for now
|
||||
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
// if err != nil {
|
||||
// // For now this is benign, but we may eventually need to fail compatibility checks
|
||||
// slog.Debug("error looking up amd driver version", "error", err)
|
||||
// }
|
||||
|
||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||
count := hl.HipGetDeviceCount()
|
||||
@@ -131,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
|
||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
||||
// DriverMajor: driverMajor,
|
||||
// DriverMinor: driverMinor,
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
|
||||
30
gpu/gpu.go
30
gpu/gpu.go
@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
|
||||
// query the management library as well so we can record any skew between the two
|
||||
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||
slog.Info("detected OS VRAM overhead",
|
||||
"id", gpuInfo.ID,
|
||||
"library", gpuInfo.Library,
|
||||
"compute", gpuInfo.Compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||
"name", gpuInfo.Name,
|
||||
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
@@ -360,17 +338,14 @@ func GetGPUInfo() GpuInfoList {
|
||||
"before",
|
||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(mem.TotalMemory),
|
||||
"free", format.HumanBytes2(mem.FreeMemory),
|
||||
"free_swap", format.HumanBytes2(mem.FreeSwap),
|
||||
),
|
||||
)
|
||||
cpus[0].FreeMemory = mem.FreeMemory
|
||||
cpus[0].FreeSwap = mem.FreeSwap
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
@@ -399,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
|
||||
slog.Warn("error looking up nvidia GPU memory")
|
||||
continue
|
||||
}
|
||||
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||
// When using the management library update based on recorded overhead
|
||||
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||
}
|
||||
slog.Debug("updating cuda memory data",
|
||||
"gpu", gpu.ID,
|
||||
"name", gpu.Name,
|
||||
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||
|
||||
@@ -57,7 +57,6 @@ func GetCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
FreeMemory: uint64(C.getFreeMemory()),
|
||||
// FreeSwap omitted as Darwin uses dynamic paging
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached, freeSwap uint64
|
||||
var total, available, free, buffers, cached uint64
|
||||
f, err := os.Open("/proc/meminfo")
|
||||
if err != nil {
|
||||
return mem, err
|
||||
@@ -70,21 +70,20 @@ func GetCPUMem() (memInfo, error) {
|
||||
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
||||
case strings.HasPrefix(line, "Cached:"):
|
||||
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
||||
case strings.HasPrefix(line, "SwapFree:"):
|
||||
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return mem, err
|
||||
}
|
||||
|
||||
if total > 0 && available > 0 {
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeMemory = available * format.KibiByte
|
||||
return mem, nil
|
||||
}
|
||||
}
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeSwap = freeSwap * format.KibiByte
|
||||
if available > 0 {
|
||||
mem.FreeMemory = available * format.KibiByte
|
||||
} else {
|
||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||
}
|
||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||
return mem, nil
|
||||
}
|
||||
|
||||
@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
|
||||
if r1 == 0 {
|
||||
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
||||
}
|
||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
|
||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
type memInfo struct {
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
@@ -53,8 +52,7 @@ type CPUInfo struct {
|
||||
|
||||
type CudaGPUInfo struct {
|
||||
GpuInfo
|
||||
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||
index int //nolint:unused,nolintlint
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type CudaGPUInfoList []CudaGPUInfo
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
||||
echo "Building custom CUDA GPU"
|
||||
else
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
|
||||
fi
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
|
||||
@@ -6,9 +6,18 @@ function amdGPUs {
|
||||
if ($env:AMDGPU_TARGETS) {
|
||||
return $env:AMDGPU_TARGETS
|
||||
}
|
||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||
# TODO - load from some common data file for linux + windows build consistency
|
||||
$GPU_LIST = @(
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx940"
|
||||
"gfx941"
|
||||
"gfx942"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
@@ -386,6 +395,7 @@ function build_rocm() {
|
||||
sign
|
||||
install
|
||||
|
||||
# Assumes v5.7, may need adjustments for v6
|
||||
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
|
||||
26
llm/ggml.go
26
llm/ggml.go
@@ -424,32 +424,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||
)
|
||||
case "chatglm":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
|
||||
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
|
||||
fullOffload = max(
|
||||
fullOffload,
|
||||
4*batch*(2+
|
||||
2*embedding+
|
||||
context+
|
||||
context*heads+
|
||||
embeddingHeadsK*heads+
|
||||
qkvBias.Shape[0]),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
partialOffload,
|
||||
4*batch*(1+
|
||||
2*embedding+
|
||||
embeddingHeadsK*heads+
|
||||
context+
|
||||
context*heads)+
|
||||
4*embeddingHeadsK*context+
|
||||
4*context*embeddingHeadsK+
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@@ -33,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
||||
params.ftype = ftype.Value()
|
||||
|
||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -88,7 +88,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
var estimate MemoryEstimate
|
||||
var systemTotalMemory uint64
|
||||
var systemFreeMemory uint64
|
||||
var systemSwapFreeMemory uint64
|
||||
|
||||
systemMemInfo, err := gpu.GetCPUMem()
|
||||
if err != nil {
|
||||
@@ -96,8 +95,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
} else {
|
||||
systemTotalMemory = systemMemInfo.TotalMemory
|
||||
systemFreeMemory = systemMemInfo.FreeMemory
|
||||
systemSwapFreeMemory = systemMemInfo.FreeSwap
|
||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
|
||||
}
|
||||
|
||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||
@@ -124,16 +122,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
}
|
||||
}
|
||||
|
||||
// On linux, over-allocating CPU memory will almost always result in an error
|
||||
if runtime.GOOS == "linux" {
|
||||
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||
available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory)
|
||||
if systemMemoryRequired > available {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||
}
|
||||
}
|
||||
|
||||
estimate.log()
|
||||
|
||||
// Loop through potential servers
|
||||
@@ -266,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
if estimate.TensorSplit != "" {
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
for i := range len(servers) {
|
||||
dir := availableServers[servers[i]]
|
||||
if dir == "" {
|
||||
|
||||
@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Spinner) SetMessage(message string) {
|
||||
s.message = message
|
||||
}
|
||||
|
||||
func (s *Spinner) String() string {
|
||||
var sb strings.Builder
|
||||
if len(s.message) > 0 {
|
||||
|
||||
@@ -107,12 +107,9 @@ function gatherDependencies() {
|
||||
|
||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
|
||||
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
}
|
||||
|
||||
|
||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errCapabilityCompletion = errors.New("completion")
|
||||
@@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -22,8 +24,10 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
@@ -770,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
@@ -783,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
||||
c.Header("LocalLocation", path)
|
||||
c.Status(http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -797,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (s *Server) IsLocal(c *gin.Context) bool {
|
||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||
parts := strings.Split(authz, ":")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
partialRequestDataParts := strings.Split(string(requestData), ",")
|
||||
if len(partialRequestDataParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
serverPublicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||
return true
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isLocalIP(ip netip.Addr) bool {
|
||||
if interfaces, err := net.Interfaces(); err == nil {
|
||||
for _, iface := range interfaces {
|
||||
|
||||
@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
|
||||
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
|
||||
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -135,6 +135,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
cpus := s.getCpuFn()
|
||||
var systemMem gpu.GpuInfo
|
||||
if len(cpus) > 0 {
|
||||
systemMem = cpus[0]
|
||||
}
|
||||
var runnerToExpire *runnerRef
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
@@ -188,6 +193,38 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
|
||||
maxSize := systemMem.FreeMemory
|
||||
|
||||
// Add available GPU memory to the total pool
|
||||
// macOS hardware has unified memory so don't double count
|
||||
if runtime.GOOS != "darwin" {
|
||||
for _, gpu := range gpus {
|
||||
if gpu.Library == "cpu" {
|
||||
continue
|
||||
}
|
||||
if loadedCount == 0 {
|
||||
// If no other models are loaded, set the limit based on what's available
|
||||
maxSize += gpu.FreeMemory
|
||||
} else {
|
||||
// Other models could be unloaded, favor total memory for limit
|
||||
maxSize += gpu.TotalMemory
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Block attempting to load a model larger than system memory + GPU memory
|
||||
if estimate.TotalSize > maxSize {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
|
||||
|
||||
// Linux will crash if over-allocating memory - return an error to the user.
|
||||
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
|
||||
if runtime.GOOS == "linux" {
|
||||
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
// simplifying assumption of defaultParallel when in CPU mode
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<start_system>{{ .System }}<end_message>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
|
||||
{{- end }}<start_assistant>
|
||||
{{- else }}
|
||||
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,14 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction:
|
||||
{{- else if eq .Role "assistant" }}### Response:
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}### Response:
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}### Instruction:
|
||||
@@ -5,4 +16,4 @@
|
||||
|
||||
{{ end }}### Response:
|
||||
{{ .Response }}
|
||||
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ else }}
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- end }}
|
||||
@@ -1,6 +1,17 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System: {{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}User:
|
||||
{{- else if eq .Role "assistant" }}Assistant:
|
||||
{{- end }} {{ .Content }}
|
||||
|
||||
{{ end }}Assistant:
|
||||
{{- else }}
|
||||
{{ if .System }}System: {{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
|
||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,10 +1,19 @@
|
||||
{{ if .System }}Source: system
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}Source: system
|
||||
|
||||
{{ .System }} <step> {{ end }}Source: user
|
||||
{{ .System }} <step> {{ end }}
|
||||
{{- range .Messages }}Source: {{ .Role }}
|
||||
|
||||
{{ .Content }} <step> {{ end }}Source: assistant
|
||||
Destination: user
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }} Source: system
|
||||
|
||||
{{ .System }} <step>{{ end }} Source: user
|
||||
|
||||
{{ .Prompt }} <step> Source: assistant
|
||||
{{- if not .Response }}
|
||||
Destination: user
|
||||
{{- end }}
|
||||
|
||||
{{ .Response }} <step>
|
||||
{{ .Response }}<step>
|
||||
{{- end }}
|
||||
@@ -1,5 +1,13 @@
|
||||
{{ if .System }}System: {{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User:
|
||||
{{ .Prompt }}
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System: {{ .System }}
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}User:
|
||||
{{ else if eq .Role "assistant" }}Falcon:
|
||||
{{ end }}{{ .Content }}
|
||||
{{ end }}Falcon:
|
||||
{{ .Response }}
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,5 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}<start_of_turn>
|
||||
{{- if eq .Role "user" }}user
|
||||
{{- if and $.System (eq $index 0) }}
|
||||
{{ $.System }}
|
||||
{{- end }}
|
||||
{{- else if eq .Role "assistant" }}model
|
||||
{{- end }}
|
||||
{{ .Content }}<end_of_turn>
|
||||
{{ end }}<start_of_turn>model
|
||||
{{ else }}
|
||||
<start_of_turn>user
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ .Prompt }}<end_of_turn>
|
||||
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
||||
<start_of_turn>model
|
||||
{{ .Response }}<end_of_turn>
|
||||
{{- end }}
|
||||
@@ -1,4 +1,18 @@
|
||||
{{ if .System }}System:
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}Question:
|
||||
{{- else if eq .Role "assistant" }}Answer:
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}Answer:
|
||||
{{ else }}
|
||||
{{ if .System }}
|
||||
System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}Question:
|
||||
@@ -6,4 +20,4 @@
|
||||
|
||||
{{ end }}Answer:
|
||||
{{ .Response }}
|
||||
|
||||
{{- end }}
|
||||
@@ -1,6 +1,16 @@
|
||||
[INST] <<SYS>>
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
|
||||
{{- if $.System }}
|
||||
{{ $.System }}
|
||||
{{ end }}<</SYS>>
|
||||
|
||||
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
|
||||
{{ end }}{{ .Content }}
|
||||
{{- else }} [/INST] {{ .Content }}</s><s>
|
||||
{{- end }}
|
||||
{{- end }} [/INST]
|
||||
{{- else }}
|
||||
[INST] <<SYS>>{{ .System }}<</SYS>>
|
||||
|
||||
{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,7 +1,19 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ .Response }}<|eot_id|>
|
||||
{{ .Response }}<|eot_id|>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}@@ Instruction
|
||||
{{- else if eq .Role "assistant" }}@@ Response
|
||||
{{- end }}
|
||||
{{ .Content }}
|
||||
|
||||
{{ end }}@@ Response
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}@@ Instruction
|
||||
@@ -5,4 +17,4 @@
|
||||
|
||||
{{ end }}@@ Response
|
||||
{{ .Response }}
|
||||
|
||||
{{- end }}
|
||||
@@ -1,3 +1,9 @@
|
||||
[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
|
||||
{{- if .Messages }}
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
|
||||
{{ end }}{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
|
||||
{{- end }}
|
||||
{{- end }}[/INST]
|
||||
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1 +1,11 @@
|
||||
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}GPT Correct
|
||||
{{- if eq .Role "user" }} User:
|
||||
{{- else if eq .Role "assistant" }} Assistant:
|
||||
{{- end }} {{ .Content }}<|end_of_turn|>
|
||||
{{- end }}GPT Correct Assistant:
|
||||
{{- else }}
|
||||
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|system|>
|
||||
{{ .System }}<|end|>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|{{ .Role }}|>
|
||||
{{ .Content }}<|end|>
|
||||
{{ end }}<|assistant|>
|
||||
{{ else }}
|
||||
{{ if .System }}<|system|>
|
||||
{{ .System }}<|end|>
|
||||
{{ end }}{{ if .Prompt }}<|user|>
|
||||
{{ .Prompt }}<|end|>
|
||||
{{ end }}<|assistant|>
|
||||
{{ .Response }}<|end|>
|
||||
{{- end }}
|
||||
@@ -1,3 +1,16 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}### System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### User:
|
||||
{{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}### Assistant:
|
||||
{{ .Content }}</s>
|
||||
{{ end }}
|
||||
{{ end }}### Assistant:
|
||||
{{ else }}
|
||||
{{ if .System }}### System:
|
||||
{{ .System }}
|
||||
|
||||
@@ -5,5 +18,5 @@
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}### Assistant:
|
||||
{{ .Response }}</s>
|
||||
|
||||
{{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,8 +1,24 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction
|
||||
{{ .Content }}
|
||||
|
||||
{{ else if eq .Role "assistant" }}### Response
|
||||
{{ .Content }}<|endoftext|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}### Response
|
||||
{{ else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}### Instruction
|
||||
{{ .Prompt }}
|
||||
|
||||
|
||||
{{ end }}### Response
|
||||
{{ .Response }}<|endoftext|>
|
||||
|
||||
{{- end }}
|
||||
@@ -143,14 +143,11 @@ func (t *Template) Vars() []string {
|
||||
|
||||
type Values struct {
|
||||
Messages []api.Message
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
}
|
||||
|
||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
system, collated := collate(v.Messages)
|
||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
if slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": collated,
|
||||
@@ -160,46 +157,39 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
var b bytes.Buffer
|
||||
var prompt, response string
|
||||
for i, m := range collated {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
system = m.Content
|
||||
case "user":
|
||||
if m.Role == "user" {
|
||||
prompt = m.Content
|
||||
case "assistant":
|
||||
} else {
|
||||
response = m.Content
|
||||
}
|
||||
|
||||
if i != len(collated)-1 && prompt != "" && response != "" {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"System": "",
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
system = ""
|
||||
prompt = ""
|
||||
response = ""
|
||||
}
|
||||
}
|
||||
|
||||
var cut bool
|
||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||
switch t := n.(type) {
|
||||
case *parse.ActionNode:
|
||||
case *parse.FieldNode:
|
||||
if slices.Contains(t.Ident, "Response") {
|
||||
cut = true
|
||||
}
|
||||
tree := t.Template.Copy()
|
||||
// for the last message, cut everything after "{{ .Response }}"
|
||||
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
|
||||
if slices.Contains(parseNode(n), "Response") {
|
||||
cut = true
|
||||
}
|
||||
|
||||
return cut
|
||||
})
|
||||
|
||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||
"System": "",
|
||||
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -209,16 +199,25 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// collate messages based on role. consecutive messages of the same role are merged
|
||||
// into a single message. collate also collects and returns all system messages.
|
||||
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||
var n int
|
||||
type messages []*api.Message
|
||||
|
||||
var system []string
|
||||
var collated []*api.Message
|
||||
// collate messages based on role. consecutive messages of the same role are merged
|
||||
// into a single message. collate also pulls out and merges messages with Role == "system"
|
||||
// which are templated separately. As a side effect, it mangles message content adding image
|
||||
// tags ([img-%d]) as needed
|
||||
func collate(msgs []api.Message) (system string, collated messages) {
|
||||
var n int
|
||||
for i := range msgs {
|
||||
msg := msgs[i]
|
||||
if msg.Role == "system" {
|
||||
if system != "" {
|
||||
system += "\n\n"
|
||||
}
|
||||
|
||||
system += msg.Content
|
||||
continue
|
||||
}
|
||||
|
||||
for range msg.Images {
|
||||
imageTag := fmt.Sprintf("[img-%d]", n)
|
||||
if !strings.Contains(msg.Content, "[img]") {
|
||||
@@ -229,10 +228,6 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
||||
n++
|
||||
}
|
||||
|
||||
if msg.Role == "system" {
|
||||
system = append(system, msg.Content)
|
||||
}
|
||||
|
||||
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
||||
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
||||
} else {
|
||||
@@ -240,7 +235,7 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(system, "\n\n"), collated
|
||||
return
|
||||
}
|
||||
|
||||
func parseNode(n parse.Node) []string {
|
||||
@@ -291,72 +286,3 @@ func parseNode(n parse.Node) []string {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteNode walks the node list and deletes nodes that match the predicate
|
||||
// this is currently to remove the {{ .Response }} node from templates
|
||||
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
|
||||
var walk func(n parse.Node) parse.Node
|
||||
walk = func(n parse.Node) parse.Node {
|
||||
if fn(n) {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch t := n.(type) {
|
||||
case *parse.ListNode:
|
||||
var nodes []parse.Node
|
||||
for _, c := range t.Nodes {
|
||||
if n := walk(c); n != nil {
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
}
|
||||
|
||||
t.Nodes = nodes
|
||||
return t
|
||||
case *parse.IfNode:
|
||||
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||
case *parse.WithNode:
|
||||
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||
case *parse.RangeNode:
|
||||
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||
case *parse.BranchNode:
|
||||
t.List = walk(t.List).(*parse.ListNode)
|
||||
if t.ElseList != nil {
|
||||
t.ElseList = walk(t.ElseList).(*parse.ListNode)
|
||||
}
|
||||
case *parse.ActionNode:
|
||||
n := walk(t.Pipe)
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Pipe = n.(*parse.PipeNode)
|
||||
case *parse.PipeNode:
|
||||
var commands []*parse.CommandNode
|
||||
for _, c := range t.Cmds {
|
||||
var args []parse.Node
|
||||
for _, a := range c.Args {
|
||||
if n := walk(a); n != nil {
|
||||
args = append(args, n)
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Args = args
|
||||
commands = append(commands, c)
|
||||
}
|
||||
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Cmds = commands
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
return walk(n)
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
for n, tt := range cases {
|
||||
var actual bytes.Buffer
|
||||
t.Run(n, func(t *testing.T) {
|
||||
var actual bytes.Buffer
|
||||
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -116,34 +116,7 @@ func TestTemplate(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bts := actual.Bytes()
|
||||
|
||||
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
|
||||
t.Log("removing trailing space from output")
|
||||
bts = bts[:len(bts)-1]
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(bts, expect); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy", func(t *testing.T) {
|
||||
t.Skip("legacy outputs are currently default outputs")
|
||||
var legacy bytes.Buffer
|
||||
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
legacyBytes := legacy.Bytes()
|
||||
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
|
||||
t.Log("removing trailing space from legacy output")
|
||||
legacyBytes = legacyBytes[:len(legacyBytes)-1]
|
||||
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
|
||||
t.Skip("legacy outputs cannot be compared to messages outputs")
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
|
||||
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -162,24 +135,7 @@ func TestParse(t *testing.T) {
|
||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||
{`{{- range .Messages }}
|
||||
{{- if eq .Role "system" }}SYSTEM:
|
||||
{{- else if eq .Role "user" }}USER:
|
||||
{{- else if eq .Role "assistant" }}ASSISTANT:
|
||||
{{- end }} {{ .Content }}
|
||||
{{- end }}`, []string{"content", "messages", "role"}},
|
||||
{`{{- if .Messages }}
|
||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ else -}}
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
||||
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -189,8 +145,9 @@ func TestParse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
vars := tmpl.Vars()
|
||||
if !slices.Equal(tt.vars, vars) {
|
||||
t.Errorf("expected %v, got %v", tt.vars, vars)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -210,17 +167,12 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||
{
|
||||
"mistral",
|
||||
[]template{
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||
{{- end }}
|
||||
{{- end }}`},
|
||||
},
|
||||
Values{
|
||||
@@ -235,17 +187,13 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||
{
|
||||
"mistral system",
|
||||
[]template{
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||
{"messages", `
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||
{{- end }}
|
||||
{{- end }}`},
|
||||
},
|
||||
Values{
|
||||
@@ -256,9 +204,9 @@ func TestExecuteWithMessages(t *testing.T) {
|
||||
{Role: "user", Content: "What is your name?"},
|
||||
},
|
||||
},
|
||||
`[INST] You are a helpful assistant!
|
||||
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
||||
|
||||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||
What is your name?[/INST] `,
|
||||
},
|
||||
{
|
||||
"chatml",
|
||||
@@ -272,9 +220,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||
{{ .Response }}<|im_end|>
|
||||
`},
|
||||
{"messages", `
|
||||
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
|
||||
{{ $.System }}<|im_end|>{{ "\n" }}
|
||||
{{- end }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>{{ "\n" }}
|
||||
{{- end }}<|im_start|>assistant
|
||||
`},
|
||||
},
|
||||
Values{
|
||||
@@ -285,12 +236,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||
{Role: "user", Content: "What is your name?"},
|
||||
},
|
||||
},
|
||||
`<|im_start|>system
|
||||
You are a helpful assistant!<|im_end|>
|
||||
<|im_start|>user
|
||||
`<|im_start|>user
|
||||
Hello friend!<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hello human!<|im_end|>
|
||||
<|im_start|>system
|
||||
You are a helpful assistant!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is your name?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
@@ -307,11 +258,9 @@ What is your name?<|im_end|>
|
||||
`},
|
||||
{"messages", `
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}Question: {{ .Content }}
|
||||
|
||||
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
|
||||
|
||||
{{ end }}
|
||||
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
||||
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
||||
{{- end }}
|
||||
{{- end }}Answer: `},
|
||||
},
|
||||
Values{
|
||||
@@ -351,8 +300,8 @@ Answer: `,
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
if b.String() != tt.expected {
|
||||
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
You are a helpful assistant.
|
||||
|
||||
### Instruction:
|
||||
You are a helpful assistant.### Instruction:
|
||||
Hello, how are you?
|
||||
|
||||
### Response:
|
||||
|
||||
@@ -9,4 +9,3 @@ Source: system
|
||||
I'd like to show off how chat templating works! <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
|
||||
@@ -3,4 +3,3 @@ Source: user
|
||||
Hello, how are you? <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
|
||||
@@ -7,4 +7,3 @@ Source: user
|
||||
I'd like to show off how chat templating works! <step> Source: assistant
|
||||
Destination: user
|
||||
|
||||
|
||||
@@ -2,6 +2,4 @@
|
||||
You are a helpful assistant.
|
||||
<</SYS>>
|
||||
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
|
||||
|
||||
I'd like to show off how chat templating works! [/INST]
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
||||
@@ -1,5 +1,3 @@
|
||||
[INST] <<SYS>><</SYS>>
|
||||
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
|
||||
|
||||
I'd like to show off how chat templating works! [/INST]
|
||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
||||
@@ -1,3 +1,2 @@
|
||||
[INST] You are a helpful assistant.
|
||||
|
||||
Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]
|
||||
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
|
||||
I'd like to show off how chat templating works![/INST]
|
||||
@@ -1 +1 @@
|
||||
GPT4 Correct System: You are a helpful assistant.<|end_of_turn|>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
|
||||
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
||||
2
template/testdata/openchat.gotmpl/user
vendored
2
template/testdata/openchat.gotmpl/user
vendored
@@ -1 +1 @@
|
||||
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant:
|
||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:
|
||||
@@ -1 +1 @@
|
||||
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
|
||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
||||
@@ -1,4 +1,14 @@
|
||||
{{ if .System }}{{ .System }}
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}{{ .System }}
|
||||
|
||||
{{ end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}USER: {{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
|
||||
{{ end }}
|
||||
{{- end }}ASSISTANT:
|
||||
{{- else }}
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
|
||||
{{ end }}ASSISTANT: {{ .Response }}</s>
|
||||
{{ end }}ASSISTANT: {{ .Response }}
|
||||
{{- end }}
|
||||
@@ -1,6 +1,15 @@
|
||||
{{- if .Messages }}
|
||||
{{- if .System }}<|system|>
|
||||
{{ .System }}</s>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|{{ .Role }}|>
|
||||
{{ .Content }}</s>
|
||||
{{ end }}<|assistant|>
|
||||
{{ else }}
|
||||
{{ if .System }}<|system|>
|
||||
{{ .System }}</s>
|
||||
{{ end }}{{ if .Prompt }}<|user|>
|
||||
{{ .Prompt }}</s>
|
||||
{{ end }}<|assistant|>
|
||||
{{ .Response }}</s>
|
||||
{{- end }}
|
||||
Reference in New Issue
Block a user