mirror of
https://github.com/ollama/ollama.git
synced 2026-01-01 20:18:52 -05:00
Compare commits
57 Commits
pdevine/fi
...
jyan/quant
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a548eb6003 | ||
|
|
f92818d90d | ||
|
|
1ef59057d0 | ||
|
|
106fe6b4ae | ||
|
|
5fd359d117 | ||
|
|
b0e4e8d76c | ||
|
|
e59453982d | ||
|
|
369113970a | ||
|
|
26ed829415 | ||
|
|
542134bf50 | ||
|
|
9e0b8f1fe2 | ||
|
|
c498609ba3 | ||
|
|
c800a67f1b | ||
|
|
dfc62648f3 | ||
|
|
24e8292e94 | ||
|
|
c63b4ecbf7 | ||
|
|
ee2b9b076c | ||
|
|
bec9100f32 | ||
|
|
1344843515 | ||
|
|
e87eafe5cd | ||
|
|
6bab0e2368 | ||
|
|
c4cccaf936 | ||
|
|
9fe5c393e4 | ||
|
|
007c988dba | ||
|
|
91d21e7c7b | ||
|
|
3e64284f69 | ||
|
|
39910f2ab2 | ||
|
|
96d0cd92f2 | ||
|
|
3a724a7c80 | ||
|
|
f520f0056e | ||
|
|
d25f85ede4 | ||
|
|
b48420b74b | ||
|
|
784958a1cb | ||
|
|
ae65cc8dea | ||
|
|
a037528bba | ||
|
|
04bf41deb5 | ||
|
|
c23cec9547 | ||
|
|
8377dc48d0 | ||
|
|
3aee405dfa | ||
|
|
9b3f47b674 | ||
|
|
f5441f01a2 | ||
|
|
ab165df43a | ||
|
|
79cc4c9585 | ||
|
|
bc3f59a6ad | ||
|
|
1a85cb904c | ||
|
|
10ea0987e9 | ||
|
|
413d368a6a | ||
|
|
cabf375059 | ||
|
|
ca0ee1d4fe | ||
|
|
1142999aab | ||
|
|
0d5a72aba9 | ||
|
|
ea837412c2 | ||
|
|
736ad6f438 | ||
|
|
64607d16a5 | ||
|
|
a6cfe7f00b | ||
|
|
c3b411a515 | ||
|
|
928f37e3ae |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -147,7 +147,7 @@ jobs:
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
|
||||
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
@@ -169,7 +169,7 @@ jobs:
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
|
||||
@@ -17,14 +17,20 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -374,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer knownHostsFile.Close()
|
||||
|
||||
token, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// interleave request data into the token
|
||||
key, sig, _ := strings.Cut(token, ":")
|
||||
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||
}
|
||||
|
||||
@@ -84,9 +84,6 @@ type ChatRequest struct {
|
||||
// Model is the model name, as in [GenerateRequest].
|
||||
Model string `json:"model"`
|
||||
|
||||
// Template overrides the model's default prompt template.
|
||||
Template string `json:"template"`
|
||||
|
||||
// Messages is the messages of the chat - can be used to keep a chat memory.
|
||||
Messages []Message `json:"messages"`
|
||||
|
||||
@@ -270,6 +267,7 @@ type PullRequest struct {
|
||||
type ProgressResponse struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
54
auth/auth.go
54
auth/auth.go
@@ -10,42 +10,37 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
func keyPath() (ssh.Signer, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
return ssh.ParsePrivateKey(privateKeyFile)
|
||||
}
|
||||
|
||||
func GetPublicKey() (ssh.PublicKey, error) {
|
||||
privateKey, err := keyPath()
|
||||
// if privateKey, try public key directly
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
return privateKey.PublicKey(), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
privateKey, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// get the pubkey, but remove the type
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
parts := bytes.Split(publicKey, []byte(" "))
|
||||
publicKey, err := GetPublicKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||
|
||||
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("malformed public key")
|
||||
}
|
||||
|
||||
181
cmd/cmd.go
181
cmd/cmd.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
status := "transferring model data"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
defer p.Stop()
|
||||
|
||||
for i := range modelfile.Commands {
|
||||
switch modelfile.Commands[i].Name {
|
||||
@@ -112,16 +115,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
path = tempfile
|
||||
}
|
||||
|
||||
digest, err := createBlob(cmd, client, path)
|
||||
// spinner.Stop()
|
||||
digest, err := createBlob(cmd, client, path, spinner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelfile.Commands[i].Args = "@" + digest
|
||||
}
|
||||
}
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var quantizeSpin *progress.Spinner
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
spinner.Stop()
|
||||
@@ -134,11 +138,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if resp.Quantize != "" {
|
||||
spinner.Stop()
|
||||
|
||||
if quantizeSpin != nil {
|
||||
quantizeSpin.SetMessage(resp.Status)
|
||||
} else {
|
||||
quantizeSpin = progress.NewSpinner(resp.Status)
|
||||
p.Add("quantize", quantizeSpin)
|
||||
}
|
||||
} else if status != resp.Status {
|
||||
spinner.Stop()
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
@@ -263,13 +276,22 @@ func tempZipFiles(path string) (string, error) {
|
||||
return tempfile.Name(), nil
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||
var ErrBlobExists = errors.New("blob exists")
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||
bin, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
// Get file info to retrieve the size
|
||||
fileInfo, err := bin.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
@@ -279,13 +301,157 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
return "", err
|
||||
}
|
||||
|
||||
var pw progressWriter
|
||||
// Create a progress bar and start a goroutine to update it
|
||||
// JK Let's use a percentage
|
||||
|
||||
//bar := progress.NewBar("transferring model data...", fileSize, 0)
|
||||
//p.Add("transferring model data", bar)
|
||||
|
||||
status := "transferring model data 0%"
|
||||
spinner.SetMessage(status)
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
|
||||
case <-done:
|
||||
spinner.SetMessage("transferring model data 100%")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
|
||||
// We check if we can find the models directory locally
|
||||
// If we can, we return the path to the directory
|
||||
// If we can't, we return an error
|
||||
// If the blob exists already, we return the digest
|
||||
dest, err := getLocalPath(cmd.Context(), digest)
|
||||
|
||||
if errors.Is(err, ErrBlobExists) {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Successfully found the model directory
|
||||
if err == nil {
|
||||
// Copy blob in via OS specific copy
|
||||
// Linux errors out to use io.copy
|
||||
err = localCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Default copy using io.copy
|
||||
err = defaultCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
n int64
|
||||
}
|
||||
|
||||
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||
w.n += int64(len(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||
ollamaHost := envconfig.Host
|
||||
|
||||
client := http.DefaultClient
|
||||
base := &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(digest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reqBody := bytes.NewReader(data)
|
||||
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||
requestURL := base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
authz, err := api.Authorization(ctx, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", authz)
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
request.Header.Set("X-Redirect-Create", "1")
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
dest := resp.Header.Get("LocalLocation")
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
return "", ErrBlobExists
|
||||
}
|
||||
|
||||
func defaultCopy(path string, dest string) error {
|
||||
// This function should be called if the server is local
|
||||
// It should find the model directory, copy the blob over, and return the digest
|
||||
dirPath := filepath.Dir(dest)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy blob over
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open source file: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create destination file: %v", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
}
|
||||
|
||||
err = destFile.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error flushing file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -379,11 +545,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
publicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
@@ -947,7 +1115,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Template: opts.Template,
|
||||
Messages: opts.Messages,
|
||||
Format: opts.Format,
|
||||
Options: opts.Options,
|
||||
|
||||
23
cmd/copy_darwin.go
Normal file
23
cmd/copy_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
dirPath := filepath.Dir(target)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := unix.Clonefile(src, target, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
cmd/copy_linux.go
Normal file
7
cmd/copy_linux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import "errors"
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
return errors.New("no local copy implementation for linux")
|
||||
}
|
||||
67
cmd/copy_windows.go
Normal file
67
cmd/copy_windows.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func localCopy(src, target string) error {
|
||||
// Create target directory if it doesn't exist
|
||||
dirPath := filepath.Dir(target)
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open source file
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
// Create target file
|
||||
targetFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer targetFile.Close()
|
||||
|
||||
// Use CopyFileExW to copy the file
|
||||
err = copyFileEx(src, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFileEx(src, dst string) error {
|
||||
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
||||
copyFileEx := kernel32.NewProc("CopyFileExW")
|
||||
|
||||
srcPtr, err := syscall.UTF16PtrFromString(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := copyFileEx.Call(
|
||||
uintptr(unsafe.Pointer(srcPtr)),
|
||||
uintptr(unsafe.Pointer(dstPtr)),
|
||||
0, 0, 0, 0)
|
||||
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
@@ -206,17 +205,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
case MultilineTemplate:
|
||||
mTemplate := sb.String()
|
||||
sb.Reset()
|
||||
_, err := template.Parse(mTemplate)
|
||||
if err != nil {
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
fmt.Println("The template is invalid.")
|
||||
continue
|
||||
}
|
||||
opts.Template = mTemplate
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
@@ -378,15 +369,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
mTemplate := sb.String()
|
||||
sb.Reset()
|
||||
_, err := template.Parse(mTemplate)
|
||||
if err != nil {
|
||||
fmt.Println("The template is invalid.")
|
||||
continue
|
||||
}
|
||||
opts.Template = mTemplate
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
|
||||
@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
}
|
||||
|
||||
func commonAMDValidateLibDir() (string, error) {
|
||||
// Favor our bundled version
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||
// the system version. Only use our bundled version if the system version doesn't work
|
||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
@@ -84,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
}
|
||||
|
||||
slog.Debug("hipDriverGetVersion", "version", version)
|
||||
driverMajor = version / 10000000
|
||||
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
|
||||
driverMajor = version / 1000
|
||||
driverMinor = (version - (driverMajor * 1000)) / 10
|
||||
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ const (
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
if err != nil {
|
||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||
slog.Debug("error looking up amd driver version", "error", err)
|
||||
}
|
||||
// TODO - this reports incorrect version information, so omitting for now
|
||||
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
// if err != nil {
|
||||
// // For now this is benign, but we may eventually need to fail compatibility checks
|
||||
// slog.Debug("error looking up amd driver version", "error", err)
|
||||
// }
|
||||
|
||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||
count := hl.HipGetDeviceCount()
|
||||
@@ -131,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
|
||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
||||
// DriverMajor: driverMajor,
|
||||
// DriverMinor: driverMinor,
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
|
||||
27
gpu/gpu.go
27
gpu/gpu.go
@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
|
||||
// query the management library as well so we can record any skew between the two
|
||||
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||
slog.Info("detected OS VRAM overhead",
|
||||
"id", gpuInfo.ID,
|
||||
"library", gpuInfo.Library,
|
||||
"compute", gpuInfo.Compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||
"name", gpuInfo.Name,
|
||||
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
@@ -396,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
|
||||
slog.Warn("error looking up nvidia GPU memory")
|
||||
continue
|
||||
}
|
||||
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||
// When using the management library update based on recorded overhead
|
||||
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||
}
|
||||
slog.Debug("updating cuda memory data",
|
||||
"gpu", gpu.ID,
|
||||
"name", gpu.Name,
|
||||
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||
|
||||
@@ -52,8 +52,7 @@ type CPUInfo struct {
|
||||
|
||||
type CudaGPUInfo struct {
|
||||
GpuInfo
|
||||
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||
index int //nolint:unused,nolintlint
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type CudaGPUInfoList []CudaGPUInfo
|
||||
|
||||
|
||||
@@ -6,9 +6,18 @@ function amdGPUs {
|
||||
if ($env:AMDGPU_TARGETS) {
|
||||
return $env:AMDGPU_TARGETS
|
||||
}
|
||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||
# TODO - load from some common data file for linux + windows build consistency
|
||||
$GPU_LIST = @(
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx940"
|
||||
"gfx941"
|
||||
"gfx942"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
@@ -386,6 +395,7 @@ function build_rocm() {
|
||||
sign
|
||||
install
|
||||
|
||||
# Assumes v5.7, may need adjustments for v6
|
||||
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
|
||||
43
llm/llm.go
43
llm/llm.go
@@ -10,10 +10,17 @@ package llm
|
||||
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
// bool update_quantize_progress(float progress, void* data) {
|
||||
// *((float*)data) = progress;
|
||||
// return true;
|
||||
// }
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
@@ -21,7 +28,7 @@ func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile string, ftype fileType) error {
|
||||
func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
@@ -32,6 +39,40 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
||||
params.nthread = -1
|
||||
params.ftype = ftype.Value()
|
||||
|
||||
// Initialize "global" to store progress
|
||||
store := C.malloc(C.sizeof_float)
|
||||
defer C.free(unsafe.Pointer(store))
|
||||
|
||||
// Initialize store value, e.g., setting initial progress to 0
|
||||
*(*C.float)(store) = 0.0
|
||||
|
||||
params.quantize_callback_data = store
|
||||
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
|
||||
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("quantizing model %d/%d", int(*((*C.float)(store))), tensorCount),
|
||||
Quantize: "quant",
|
||||
})
|
||||
fmt.Println("Progress: ", *((*C.float)(store)))
|
||||
case <-done:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("quantizing model %d/%d", tensorCount, tensorCount),
|
||||
Quantize: "quant",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||
}
|
||||
|
||||
53
llm/patches/10-quantize-progress.diff
Normal file
53
llm/patches/10-quantize-progress.diff
Normal file
@@ -0,0 +1,53 @@
|
||||
From fa509abf281177eacdc71a2a14432c4e6ed74a47 Mon Sep 17 00:00:00 2001
|
||||
From: Josh Yan <jyan00017@gmail.com>
|
||||
Date: Wed, 10 Jul 2024 12:58:31 -0700
|
||||
Subject: [PATCH] quantize callback
|
||||
|
||||
---
|
||||
llama.cpp | 8 ++++++++
|
||||
llama.h | 3 +++
|
||||
2 files changed, 11 insertions(+)
|
||||
|
||||
diff --git a/llama.cpp b/llama.cpp
|
||||
index 61948751..d3126510 100644
|
||||
--- a/llama.cpp
|
||||
+++ b/llama.cpp
|
||||
@@ -15586,6 +15586,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
new_ofstream(0);
|
||||
for (int i = 0; i < ml.n_tensors; ++i) {
|
||||
+ if (params->quantize_callback){
|
||||
+ if (!params->quantize_callback(i, params->quantize_callback_data)) {
|
||||
+ return;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
auto weight = ml.get_weight(i);
|
||||
struct ggml_tensor * tensor = weight->tensor;
|
||||
if (weight->idx != cur_split && params->keep_split) {
|
||||
@@ -16119,6 +16125,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
/*.keep_split =*/ false,
|
||||
/*.imatrix =*/ nullptr,
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
+ /*.quantize_callback =*/ nullptr,
|
||||
+ /*.quantize_callback_data =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
diff --git a/llama.h b/llama.h
|
||||
index da310ffa..3cbe6023 100644
|
||||
--- a/llama.h
|
||||
+++ b/llama.h
|
||||
@@ -337,6 +337,9 @@ extern "C" {
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
+
|
||||
+ llama_progress_callback quantize_callback; // callback to report quantization progress
|
||||
+ void * quantize_callback_data; // user data for the callback
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// grammar types
|
||||
--
|
||||
2.39.3 (Apple Git-146)
|
||||
|
||||
@@ -254,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
if estimate.TensorSplit != "" {
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
for i := range len(servers) {
|
||||
dir := availableServers[servers[i]]
|
||||
if dir == "" {
|
||||
|
||||
@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Spinner) SetMessage(message string) {
|
||||
s.message = message
|
||||
}
|
||||
|
||||
func (s *Spinner) String() string {
|
||||
var sb strings.Builder
|
||||
if len(s.message) > 0 {
|
||||
|
||||
@@ -107,12 +107,9 @@ function gatherDependencies() {
|
||||
|
||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
|
||||
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
}
|
||||
|
||||
|
||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errCapabilityCompletion = errors.New("completion")
|
||||
@@ -421,13 +422,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tensorCount := len(baseLayer.GGML.Tensors())
|
||||
|
||||
ft := baseLayer.GGML.KV().FileType()
|
||||
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
||||
return errors.New("quantization is only supported for F16 and F32 models")
|
||||
} else if want != ft {
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
|
||||
|
||||
blob, err := GetBlobsPath(baseLayer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -440,7 +440,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
|
||||
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
|
||||
if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -473,6 +473,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
|
||||
layers = append(layers, baseLayer.Layer)
|
||||
}
|
||||
|
||||
case "license", "template", "system":
|
||||
if c.Name != "license" {
|
||||
// replace
|
||||
@@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -22,8 +24,10 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
@@ -71,7 +75,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||
|
||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
if name == "" {
|
||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||
}
|
||||
@@ -81,13 +85,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate stri
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if mTemplate != "" {
|
||||
model.Template, err = template.Parse(mTemplate)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := model.CheckCapabilities(caps...); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
@@ -127,7 +124,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
@@ -263,7 +260,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive)
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -777,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
@@ -790,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
||||
c.Header("LocalLocation", path)
|
||||
c.Status(http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -804,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (s *Server) IsLocal(c *gin.Context) bool {
|
||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||
parts := strings.Split(authz, ":")
|
||||
if len(parts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
partialRequestDataParts := strings.Split(string(requestData), ",")
|
||||
if len(partialRequestDataParts) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
serverPublicKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||
return true
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isLocalIP(ip netip.Addr) bool {
|
||||
if interfaces, err := net.Interfaces(); err == nil {
|
||||
for _, iface := range interfaces {
|
||||
@@ -1139,7 +1189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user