Compare commits

..

46 Commits

Author SHA1 Message Date
Roy Han
907b038ff0 reduce error footprint 2024-07-15 10:57:01 -07:00
royjhan
1f73889f34 Merge branch 'royh-batchembed' into royh-embed-parallel 2024-07-12 16:44:12 -07:00
Roy Han
7e313e5964 remove redundant error check 2024-07-12 16:37:29 -07:00
Roy Han
5a8f8e96e0 clean up 2024-07-12 16:35:25 -07:00
Roy Han
7cddd6d741 parallelized 2024-07-12 16:08:12 -07:00
Roy Han
1f3aefd323 remove function closure 2024-07-12 14:45:16 -07:00
Roy Han
2d7048f410 Revert "remove function closure"
This reverts commit 55d48c6ed1.
2024-07-12 14:40:40 -07:00
Roy Han
55d48c6ed1 remove function closure 2024-07-12 14:35:43 -07:00
Roy Han
c0b5bf0a36 testing clean up 2024-07-12 11:45:45 -07:00
Roy Han
53e9576f46 testing clean up 2024-07-11 20:20:14 -07:00
Roy Han
dbe9527305 clean up 2024-07-11 17:28:55 -07:00
Roy Han
694388db90 set context length 2024-07-10 15:21:46 -07:00
Roy Han
cdb9fe9b06 test values 2024-07-10 09:57:36 -07:00
Roy Han
8f6d0242b6 refactoring 2024-07-09 16:19:02 -07:00
Roy Han
c697eb2a9b fix hanging on single string 2024-07-09 15:51:55 -07:00
Roy Han
b686ac144c merge conflicts 2024-07-09 14:00:13 -07:00
royjhan
786848dfd3 Merge branch 'main' into royh-batchembed 2024-07-09 13:48:06 -07:00
Roy Han
fb390b8902 embedding type 64 2024-07-09 13:41:48 -07:00
Roy Han
bcb63e6e0e touches 2024-07-09 13:37:00 -07:00
Roy Han
3342e5f035 merge conflicts 2024-07-08 15:15:09 -07:00
royjhan
b7c622dd32 Merge branch 'main' into royh-batchembed 2024-07-08 15:10:52 -07:00
Roy Han
6caac01494 clear comments 2024-07-03 14:05:34 -07:00
Roy Han
17de2b4405 Refactoring of legacy and new 2024-07-03 14:02:25 -07:00
Roy Han
922b8f2584 input handling and handler testing 2024-07-03 12:48:54 -07:00
Roy Han
c0fa2236cf integration float32 2024-07-03 12:47:57 -07:00
Roy Han
a413014aaf refactoring 2024-07-03 11:21:06 -07:00
royjhan
a5f23d766e Merge branch 'main' into royh-batchembed 2024-07-03 11:20:24 -07:00
Roy Han
95e46eeedf move normalize test 2024-07-03 09:45:42 -07:00
Roy Han
3d060e0ae9 move normalize 2024-07-02 10:35:02 -07:00
Roy Han
00a4cb26ca use float32 2024-07-02 10:30:29 -07:00
Roy Han
512e0a7bde Clean up 2024-07-01 16:29:54 -07:00
Roy Han
1a0c8b363c Truncation Integration Tests 2024-07-01 16:26:30 -07:00
Roy Han
e068e7f698 Integration Test Template 2024-07-01 15:24:26 -07:00
Roy Han
aee25acb5b move normalization to go 2024-07-01 14:10:58 -07:00
Roy Han
9c32b6b9ed Truncation 2024-07-01 11:59:44 -07:00
Roy Han
1daac52651 Truncation 2024-07-01 11:55:16 -07:00
Roy Han
80c1a3f812 playing around with truncate stuff 2024-06-28 18:17:09 -07:00
Roy Han
c111d8bb51 normalization 2024-06-28 17:19:04 -07:00
Roy Han
5213c12354 clean up 2024-06-28 15:26:58 -07:00
Roy Han
b9c74df37b check normalization 2024-06-28 15:10:58 -07:00
Roy Han
49e341147d add server function 2024-06-28 15:03:53 -07:00
Roy Han
c406fa7a4c api/embed draft 2024-06-28 14:54:21 -07:00
Roy Han
22458c573a mock up notes 2024-06-28 14:21:45 -07:00
Roy Han
ff191d7cba Initial Draft 2024-06-25 13:29:47 -07:00
Roy Han
0f87628b6d Revert "Initial Batch Embedding"
This reverts commit c22d54895a.
2024-06-24 15:26:05 -07:00
Roy Han
c22d54895a Initial Batch Embedding 2024-06-18 17:34:36 -07:00
100 changed files with 1104 additions and 3892 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -126,7 +126,7 @@ jobs:
strategy:
matrix:
rocm-version:
- '6.1.2'
- '6.1.1'
runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
@@ -169,7 +169,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=6.1.2
ARG ROCM_VERSION=6.1.1
# Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code

View File

@@ -293,9 +293,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
### Terminal

View File

@@ -17,7 +17,6 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -25,10 +24,7 @@ import (
"net/http"
"net/url"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@@ -360,7 +356,7 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse,
return &resp, nil
}
// Embeddings generates an embedding from a model.
// Embeddings generates embeddings from a model. (Legacy)
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@@ -387,16 +383,3 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@@ -47,9 +47,6 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt.
System string `json:"system"`
@@ -100,80 +97,17 @@ type ChatRequest struct {
// followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
*m = Message(a)
m.Role = strings.ToLower(m.Role)
return nil
}
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
Role string `json:"role"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
@@ -260,7 +194,7 @@ type EmbedRequest struct {
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -309,10 +243,8 @@ type DeleteRequest struct {
// ShowRequest is the request passed to [Client.Show].
type ShowRequest struct {
Model string `json:"model"`
System string `json:"system"`
// Template is deprecated
Model string `json:"model"`
System string `json:"system"`
Template string `json:"template"`
Verbose bool `json:"verbose"`

View File

@@ -208,26 +208,3 @@ func TestUseMmapFormatParams(t *testing.T) {
})
}
}
func TestMessage_UnmarshalJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"role": "USER", "content": "Hello!"}`, "user"},
{`{"role": "System", "content": "Initialization complete."}`, "system"},
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
}
for _, test := range tests {
var msg Message
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if msg.Role != test.expected {
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
}
}
}

View File

@@ -127,10 +127,6 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages]
WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models.

View File

@@ -3,67 +3,49 @@ package auth
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func privateKey() (ssh.Signer, error) {
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if os.IsNotExist(err) {
err := initializeKeypair()
if err != nil {
return nil, err
}
return privateKey()
} else if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return nil, err
}
return ssh.ParsePrivateKey(privateKeyFile)
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (ssh.PublicKey, error) {
// try to read pubkey first
home, err := os.UserHomeDir()
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return nil, err
return "", err
}
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
pubKeyFile, err := os.ReadFile(pubkeyPath)
if os.IsNotExist(err) {
// try from privateKey
privateKey, err := privateKey()
if err != nil {
return nil, fmt.Errorf("failed to read public key: %w", err)
}
return privateKey.PublicKey(), nil
} else if err != nil {
return nil, fmt.Errorf("failed to read public key: %w", err)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
return pubKey, err
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
}
func NewNonce(r io.Reader, length int) (string, error) {
@@ -76,20 +58,25 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
privateKey, err := privateKey()
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
// get the pubkey, but remove the type
publicKey, err := GetPublicKey()
if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
parts := bytes.Split(publicKey, []byte(" "))
if len(parts) < 2 {
return "", fmt.Errorf("malformed public key")
}
@@ -102,49 +89,3 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
_, err = os.Stat(privKeyPath)
if os.IsNotExist(err) {
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("could not create directory %w", err)
}
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
return err
}
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
if err != nil {
return err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
return err
}
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
}
return nil
}

View File

@@ -4,8 +4,10 @@ import (
"archive/zip"
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
@@ -13,7 +15,6 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -262,8 +263,6 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil
}
var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
@@ -281,120 +280,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
// We check if we can find the models directory locally
// If we can, we return the path to the directory
// If we can't, we return an error
// If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
// Successfully found the model directory
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err
}
return digest, nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
data, err := json.Marshal(digest)
if err != nil {
return "", err
}
reqBody := bytes.NewReader(data)
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return "", err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
resp, err := client.Do(request)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect {
dest := resp.Header.Get("LocalLocation")
return dest, nil
}
return "", ErrBlobExists
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@@ -488,12 +379,11 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 {
serverPubKey := matches[0]
publicKey, err := auth.GetPublicKey()
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
@@ -953,6 +843,7 @@ type runOptions struct {
WordWrap bool
Format string
System string
Template string
Images []api.ImageData
Options map[string]interface{}
MultiModal bool
@@ -1146,6 +1037,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
}
@@ -1182,7 +1074,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
if _, err := auth.GetPublicKey(); err != nil {
if err := initializeKeypair(); err != nil {
return err
}
@@ -1199,6 +1091,52 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return err
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
_, err = os.Stat(privKeyPath)
if os.IsNotExist(err) {
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("could not create directory %w", err)
}
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
return err
}
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
if err != nil {
return err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
return err
}
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
}
return nil
}
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -1408,6 +1346,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
})
default:
appendEnvDocs(cmd, envs)

View File

@@ -1,23 +0,0 @@
package cmd
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

View File

@@ -1,7 +0,0 @@
package cmd
import "errors"
func localCopy(src, target string) error {
return errors.New("no local copy implementation for linux")
}

View File

@@ -1,67 +0,0 @@
//go:build windows
// +build windows
package cmd
import (
"os"
"path/filepath"
"syscall"
"unsafe"
)
func localCopy(src, target string) error {
// Create target directory if it doesn't exist
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Open source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create target file
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
// Use CopyFileExW to copy the file
err = copyFileEx(src, target)
if err != nil {
return err
}
return nil
}
func copyFileEx(src, dst string) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
srcPtr, err := syscall.UTF16PtrFromString(src)
if err != nil {
return err
}
dstPtr, err := syscall.UTF16PtrFromString(dst)
if err != nil {
return err
}
r1, _, err := copyFileEx.Call(
uintptr(unsafe.Pointer(srcPtr)),
uintptr(unsafe.Pointer(dstPtr)),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

View File

@@ -27,6 +27,7 @@ const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func loadModel(cmd *cobra.Command, opts *runOptions) error {
@@ -93,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@@ -202,6 +204,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
multiline = MultilineNone
@@ -320,13 +326,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system":
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
multiline = MultilineSystem
if args[1] == "system" {
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
@@ -346,17 +356,23 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue
}
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
if args[1] == "system" {
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
fmt.Println("Set system message.")
sb.Reset()
sb.Reset()
continue
@@ -377,6 +393,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
req := &api.ShowRequest{
Name: opts.Model,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
resp, err := client.Show(cmd.Context(), req)
@@ -420,9 +437,12 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template)
} else {
default:
fmt.Println("No prompt template was specified for this model.")
}
default:
@@ -516,6 +536,10 @@ func buildModelfile(opts runOptions) string {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)

View File

@@ -59,6 +59,7 @@ func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
@@ -74,6 +75,7 @@ func TestModelfileBuilder(t *testing.T) {
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
@@ -95,6 +97,7 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]

View File

@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -46,24 +46,13 @@ sudo modprobe nvidia_uvm`
## AMD Radeon
Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
### Overrides
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
some cases you can force the system to try to use a similar LLVM target that is
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -74,7 +63,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
server. If you have an unsupported AMD GPU you can experiment using the list of
supported types below.
At this time, the known supported GPU types on linux are the following LLVM Targets.
At this time, the known supported GPU types are the following LLVM Targets.
This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|

View File

@@ -103,6 +103,10 @@ curl http://localhost:11434/v1/chat/completions \
- [ ] `user`
- [ ] `n`
#### Notes
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models
Before using a model, pull it locally `ollama pull`:

View File

@@ -43,6 +43,8 @@ var (
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment
@@ -87,6 +89,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
@@ -191,6 +194,16 @@ func LoadConfig() {
TmpDir = clean("OLLAMA_TMPDIR")
userLimit := clean("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
MaxVRAM = avail
}
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {

View File

@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
}
func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -33,10 +33,9 @@ type HipLib struct {
}
func NewHipLib() (*HipLib, error) {
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
h, err := windows.LoadLibrary("amdhip64_6.dll")
h, err := windows.LoadLibrary("amdhip64.dll")
if err != nil {
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
}
hl := &HipLib{}
hl.dll = h
@@ -85,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
}
slog.Debug("hipDriverGetVersion", "version", version)
driverMajor = version / 10000000
driverMinor = (version - (driverMajor * 10000000)) / 100000
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil
}

View File

@@ -22,8 +22,8 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
)
func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
}
defer hl.Release()
driverMajor, driverMinor, err := hl.AMDDriverVersion()
if err != nil {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version", "error", err)
}
// TODO - this reports incorrect version information, so omitting for now
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
// if err != nil {
// // For now this is benign, but we may eventually need to fail compatibility checks
// slog.Debug("error looking up amd driver version", "error", err)
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
@@ -92,8 +93,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue
}
if gfxOverride == "" {
// Strip off Target Features when comparing
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
@@ -132,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
DriverMajor: driverMajor,
DriverMinor: driverMinor,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
},
index: i,
}

View File

@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
@@ -360,17 +338,14 @@ func GetGPUInfo() GpuInfoList {
"before",
"total", format.HumanBytes2(cpus[0].TotalMemory),
"free", format.HumanBytes2(cpus[0].FreeMemory),
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
),
slog.Group(
"now",
"total", format.HumanBytes2(mem.TotalMemory),
"free", format.HumanBytes2(mem.FreeMemory),
"free_swap", format.HumanBytes2(mem.FreeSwap),
),
)
cpus[0].FreeMemory = mem.FreeMemory
cpus[0].FreeSwap = mem.FreeSwap
}
var memInfo C.mem_info_t
@@ -399,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory")
continue
}
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data",
"gpu", gpu.ID,
"name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group(
"before",
"total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -57,7 +57,6 @@ func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()),
// FreeSwap omitted as Darwin uses dynamic paging
}, nil
}

View File

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
func GetCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
var total, available, free, buffers, cached uint64
f, err := os.Open("/proc/meminfo")
if err != nil {
return mem, err
@@ -70,21 +70,20 @@ func GetCPUMem() (memInfo, error) {
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
case strings.HasPrefix(line, "Cached:"):
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
case strings.HasPrefix(line, "SwapFree:"):
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
default:
continue
}
if err != nil {
return mem, err
}
if total > 0 && available > 0 {
mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = available * format.KibiByte
return mem, nil
}
}
mem.TotalMemory = total * format.KibiByte
mem.FreeSwap = freeSwap * format.KibiByte
if available > 0 {
mem.FreeMemory = available * format.KibiByte
} else {
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
}
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
return mem, nil
}

View File

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
if r1 == 0 {
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
}
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
}

View File

@@ -10,7 +10,6 @@ import (
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
FreeSwap uint64 `json:"free_swap,omitempty"`
}
// Beginning of an `ollama info` command
@@ -53,8 +52,7 @@ type CPUInfo struct {
type CudaGPUInfo struct {
GpuInfo
OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
index int //nolint:unused,nolintlint
}
type CudaGPUInfoList []CudaGPUInfo

View File

@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
reqLimit := len(req)
iterLimit := 5
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram != "" {
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}

View File

@@ -12,7 +12,7 @@ import (
func TestContextExhaustion(t *testing.T) {
// Longer needed for small footprint GPUs
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
@@ -25,10 +25,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128,
},
}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
}

View File

@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,9 +6,18 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS
}
# Current supported rocblas list from ROCm v6.1.2 on windows
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
# TODO - load from some common data file for linux + windows build consistency
$GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030"
"gfx1100"
"gfx1101"
@@ -357,7 +366,6 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
@@ -386,6 +394,7 @@ function build_rocm() {
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -424,32 +424,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
)
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
fullOffload = max(
fullOffload,
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
qkvBias.Shape[0]),
)
partialOffload = max(
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
4*qkvBias.Shape[0],
)
}
}
return

View File

@@ -537,7 +537,6 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.add_bos_token",
"tokenizer.ggml.add_eos_token",
"tokenizer.chat_template",
"bert.pooling_type",
},
}

View File

@@ -33,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.ftype = ftype.Value()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
return fmt.Errorf("llama_model_quantize: %d", rc)
}
return nil

View File

@@ -1,8 +1,8 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 8fe51971..7113ba64 100644
index 2b9ace28..172640e2 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -5433,16 +5433,7 @@ static void llm_load_vocab(
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
@@ -20,9 +20,9 @@ index 8fe51971..7113ba64 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
tokenizer_pre == "llama3" ||
@@ -5526,7 +5517,8 @@ static void llm_load_vocab(
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
vocab.tokenizer_clean_spaces = false;
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
} else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);

13
llm/patches/06-qwen2.diff Normal file
View File

@@ -0,0 +1,13 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

View File

@@ -1,360 +0,0 @@
diff --git a/common/common.cpp b/common/common.cpp
index dbb724fb..c26fe6ee 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2087,14 +2087,29 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
+
+ // try to load as gguf
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- llama_free(lctx);
- llama_free_model(model);
- return std::make_tuple(nullptr, nullptr);
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
+
+ // if that fails, try loading as ggla for compatibility
+ int err = llama_model_apply_lora_from_file(model,
+ lora_adapter.c_str(),
+ lora_scale,
+ ((i > 0) || params.lora_base.empty())
+ ? NULL
+ : params.lora_base.c_str(),
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+ } else {
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
}
- llama_lora_adapter_set(lctx, adapter, lora_scale);
}
if (params.ignore_eos) {
diff --git a/include/llama.h b/include/llama.h
index 93fd77ca..b0fb37a6 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -1160,6 +1160,20 @@ extern "C" {
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
+ // Apply a LoRA adapter to a loaded model
+ // path_base_model is the path to a higher quality model to use as a base for
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
+ // will be applied on top of the previous one
+ // Returns 0 on success
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
+ const struct llama_model * model,
+ const char * path_lora,
+ float scale,
+ const char * path_base_model,
+ int32_t n_threads);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/src/llama.cpp b/src/llama.cpp
index 80a0dd0f..9d7b0e17 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
fputs(text, stderr);
fflush(stderr);
}
+
+static int llama_apply_lora_from_file_internal(
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
+) {
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
+
+ const int64_t t_start_lora_us = ggml_time_us();
+
+ llama_file fin(path_lora, "rb");
+
+ // verify magic and version
+ {
+ uint32_t magic = fin.read_u32();
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
+ return 1;
+ }
+
+ uint32_t format_version = fin.read_u32();
+ if (format_version != 1) {
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
+ return 1;
+ }
+ }
+
+ int32_t lora_r = fin.read_u32();
+ int32_t lora_alpha = fin.read_u32();
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
+
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+
+ // load base model
+ std::unique_ptr<llama_model_loader> ml;
+ if (path_base_model) {
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
+ }
+
+ struct tensor_meta {
+ std::string name;
+ ggml_type type;
+ int32_t ne[2];
+ size_t offset;
+ };
+ std::map<std::string, tensor_meta> tensor_meta_map;
+
+ // load all tensor meta
+ while (true) {
+ if (fin.tell() == fin.size) {
+ // eof
+ break;
+ }
+
+ int32_t n_dims;
+ int32_t name_len;
+ int32_t ftype;
+
+ fin.read_raw(&n_dims, sizeof(n_dims));
+ fin.read_raw(&name_len, sizeof(name_len));
+ fin.read_raw(&ftype, sizeof(ftype));
+
+ if (n_dims != 1 && n_dims != 2) {
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
+ return 1;
+ }
+
+ int32_t ne[2] = { 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read_raw(&ne[i], sizeof(ne[i]));
+ }
+
+ std::string name;
+ {
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
+ char buf[GGML_MAX_NAME];
+ fin.read_raw(buf, name_len);
+ name = std::string(buf, name_len);
+ }
+
+ // check for lora suffix
+ std::string lora_suffix;
+ if (name.length() > 6) {
+ lora_suffix = name.substr(name.length() - 6);
+ }
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
+ return 1;
+ }
+
+ // tensor type
+ ggml_type wtype;
+ switch (ftype) {
+ case 0: wtype = GGML_TYPE_F32; break;
+ case 1: wtype = GGML_TYPE_F16; break;
+ default:
+ {
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
+ __func__, ftype);
+ return 1;
+ }
+ }
+
+ // data offset
+ size_t offset = fin.tell();
+ offset = (offset + 31) & -32;
+
+ // skip tensor data
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
+
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
+ }
+
+ bool warned = false;
+ int n_tensors = 0;
+
+ // apply
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+ if (backend_cpu == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
+ return 1;
+ }
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
+
+ std::vector<no_init<uint8_t>> read_buf;
+ for (const auto & it : model.tensors_by_name) {
+ const std::string & base_name = it.first;
+ ggml_tensor * model_t = it.second;
+
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
+ continue;
+ }
+
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
+
+ ggml_init_params lora_init_params = {
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+ /* .mem_buffer */ nullptr,
+ /* .no_alloc */ true,
+ };
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
+ if (lora_ctx == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ // create tensors
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
+ ggml_set_name(loraA, metaA.name.c_str());
+ ggml_set_name(loraB, metaB.name.c_str());
+
+ ggml_tensor * base_t;
+ if (ml) {
+ if (!ml->get_tensor_meta(base_name.c_str())) {
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
+ return 1;
+ }
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
+ } else {
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
+ }
+ ggml_set_name(base_t, base_name.c_str());
+
+ // allocate in backend buffer
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (lora_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
+ return 1;
+ }
+
+ // load tensor data
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
+ read_buf.resize(ggml_nbytes(tensor));
+ fin.seek(tensor_meta.offset, SEEK_SET);
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
+ };
+ load_tensor(metaA, loraA);
+ load_tensor(metaB, loraB);
+
+ // load base model tensor data
+ if (ml) {
+ ml->load_data_for(base_t);
+ } else {
+ ggml_backend_tensor_copy(model_t, base_t);
+ }
+
+ if (ggml_is_quantized(base_t->type) && !warned) {
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
+ "use a f16 or f32 base model with --lora-base\n", __func__);
+ warned = true;
+ }
+
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ auto build_lora_graph = [&]() {
+ // w = w + BA*s
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+ ggml_set_name(BA, "BA");
+
+ if (scaling != 1.0f) {
+ BA = ggml_scale(lora_ctx, BA, scaling);
+ ggml_set_name(BA, "BA_scaled");
+ }
+
+ ggml_tensor * r;
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
+ ggml_set_name(r, "r_add");
+
+ if (base_t->type != model_t->type) {
+ // convert the result to the model type
+ r = ggml_cast(lora_ctx, r, model_t->type);
+ ggml_set_name(r, "r_cast");
+ }
+
+ return r;
+ };
+
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
+ ggml_tensor * r = build_lora_graph();
+ ggml_build_forward_expand(gf, r);
+
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (graph_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ ggml_backend_graph_compute(backend_cpu, gf);
+
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
+
+#if 0
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
+
+ // sched compute
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_init_measure(sched, gf);
+
+ // create the graph again, since the previous one was destroyed by the measure
+ ggml_graph_clear(gf);
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_graph_compute(sched, gf);
+ ggml_backend_sched_free(sched);
+#endif
+
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_buffer_free(graph_buf);
+ ggml_free(lora_ctx);
+
+ n_tensors++;
+ if (n_tensors % 4 == 0) {
+ LLAMA_LOG_INFO(".");
+ }
+ }
+
+ ggml_backend_free(backend_cpu);
+
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
+
+ return 0;
+}
+
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
+ try {
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
+ return 1;
+ }
+}
\ No newline at end of file

View File

@@ -88,7 +88,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var estimate MemoryEstimate
var systemTotalMemory uint64
var systemFreeMemory uint64
var systemSwapFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem()
if err != nil {
@@ -96,8 +95,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else {
systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory
systemSwapFreeMemory = systemMemInfo.FreeSwap
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
}
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
@@ -124,16 +122,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
}
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
}
}
estimate.log()
// Loop through potential servers
@@ -266,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit)
}
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {
@@ -385,10 +377,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") ||
strings.HasPrefix(ev, "GPU_") ||
strings.HasPrefix(ev, "HSA_") ||
strings.HasPrefix(ev, "GGML_") ||
strings.HasPrefix(ev, "PATH=") ||
@@ -417,17 +407,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// reap subprocess when it exits
go func() {
err := s.cmd.Wait()
// Favor a more detailed message over the process exit status
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
slog.Debug("llama runner terminated", "error", err)
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
}
s.done <- fmt.Errorf(s.status.LastErrMsg)
} else {
s.done <- err
}
s.done <- s.cmd.Wait()
}()
return s, nil
@@ -590,7 +570,14 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
return fmt.Errorf("llama runner process has terminated: %w", err)
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
if strings.Contains(msg, "unknown model") {
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default:
}
if time.Now().After(stallTimer) {

View File

@@ -3,14 +3,11 @@ package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -30,9 +27,8 @@ type ErrorResponse struct {
}
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
}
type Choice struct {
@@ -63,11 +59,6 @@ type ResponseFormat struct {
Type string `json:"type"`
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@@ -80,7 +71,6 @@ type ChatCompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
}
type ChatCompletion struct {
@@ -114,7 +104,6 @@ type CompletionRequest struct {
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
@@ -136,15 +125,6 @@ type CompletionChunk struct {
SystemFingerprint string `json:"system_fingerprint"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
@@ -152,23 +132,11 @@ type Model struct {
OwnedBy string `json:"owned_by"`
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
@@ -183,31 +151,7 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return ChatCompletion{
Id: id,
Object: "chat.completion",
@@ -216,7 +160,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
@@ -225,6 +169,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -270,6 +215,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -314,27 +260,6 @@ func toListCompletion(r api.ListResponse) ListCompletion {
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
@@ -344,77 +269,10 @@ func toModel(r api.ShowResponse, m string) Model {
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message
for _, msg := range r.Messages {
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, fmt.Errorf("invalid message format")
}
}
default:
if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
}
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
}
options := make(map[string]interface{})
@@ -465,14 +323,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
format = "json"
}
return &api.ChatRequest{
return api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
}, nil
}
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@@ -481,16 +338,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
}
options["stop"] = stops
}
if r.MaxTokens != nil {
@@ -522,7 +375,6 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
}, nil
}
@@ -551,11 +403,6 @@ type RetrieveWriter struct {
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
@@ -721,33 +568,6 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
@@ -811,47 +631,6 @@ func CompletionsMiddleware() gin.HandlerFunc {
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
@@ -873,14 +652,7 @@ func ChatMiddleware() gin.HandlerFunc {
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}

View File

@@ -2,8 +2,8 @@ package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -16,387 +16,7 @@ import (
"github.com/stretchr/testify/assert"
)
const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.ChatRequest
testCases := []testCase{
{
Name: "chat handler",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
},
},
{
Name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
},
},
{
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
Name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestMiddlewareResponses(t *testing.T) {
func TestMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -409,6 +29,188 @@ func TestMiddlewareResponses(t *testing.T) {
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
@@ -425,6 +227,8 @@ func TestMiddlewareResponses(t *testing.T) {
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
@@ -488,8 +292,6 @@ func TestMiddlewareResponses(t *testing.T) {
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp)
})
}

View File

@@ -107,12 +107,9 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -32,23 +32,13 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
)
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
var errCapabilityCompletion = errors.New("completion")
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
const CapabilityCompletion = Capability("completion")
type registryOptions struct {
Insecure bool
@@ -98,15 +88,6 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
}
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
}
default:
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
@@ -114,7 +95,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
}
return nil
@@ -493,12 +474,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
@@ -1089,12 +1064,11 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized

View File

@@ -4,7 +4,6 @@ import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -12,9 +11,6 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
@@ -293,113 +289,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]any
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
slog.Error("parseToolCalls", "error", err)
return nil, false
} else {
offset += int(decoder.InputOffset())
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
objs = append(objs, collect(obj)...)
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -3,9 +3,7 @@ package server
import (
"archive/zip"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
@@ -13,9 +11,7 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func createZipFile(t *testing.T, name string) *os.File {
@@ -114,123 +110,3 @@ func TestExtractFromZipFile(t *testing.T) {
})
}
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"log/slog"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
@@ -15,21 +16,29 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
@@ -57,7 +66,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}

View File

@@ -3,13 +3,21 @@ package server
import (
"bytes"
"context"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
@@ -152,19 +160,6 @@ func TestChatPrompt(t *testing.T) {
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
@@ -183,13 +178,13 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
if tt.prompt != prompt {
t.Errorf("expected %q, got %q", tt.prompt, prompt)
}
if len(images) != len(tt.images) {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -19,15 +18,14 @@ import (
"path/filepath"
"slices"
"strings"
"sync"
"syscall"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
@@ -59,7 +57,6 @@ func init() {
}
var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
@@ -107,7 +104,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -126,10 +122,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -139,8 +131,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
checkpointLoaded := time.Now()
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@@ -158,6 +148,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt := req.Prompt
if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
@@ -178,26 +181,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s)
}
var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
if err := tmpl.Execute(&b, values); err != nil {
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -209,48 +193,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
}, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
Response: r.Content,
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -298,37 +260,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
var input []string
reqEmbed := []string{}
switch i := req.Input.(type) {
switch embeddings := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
if embeddings == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
reqEmbed = []string{embeddings}
case []any:
for _, v := range i {
if len(embeddings) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
for _, v := range embeddings {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
reqEmbed = append(reqEmbed, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
@@ -341,34 +304,64 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
reqEmbedArray := make([]string, len(reqEmbed))
errCh := make(chan error, 1)
successCh := make(chan bool, 1)
sem := make(chan struct{}, 2)
var wg sync.WaitGroup
var mu sync.Mutex
for i, s := range reqEmbed {
wg.Add(1)
sem <- struct{}{}
go func(i int, s string) {
defer wg.Done()
defer func() { <-sem }()
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
errCh <- err
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
errCh <- err
return
}
} else {
errCh <- err
return
}
}
mu.Lock()
reqEmbedArray[i] = s
mu.Unlock()
}(i, s)
}
go func() {
wg.Wait()
successCh <- true
close(errCh)
}()
select {
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
case success := <-successCh:
if !success {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to process all embeddings"})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil {
slog.Error("embedding generation failed", "error", err)
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
@@ -423,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
@@ -431,14 +424,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding := make([]float64, len(embeddings[0]))
embedding64 := make([]float64, len(embedding[0]))
for i, v := range embeddings[0] {
embedding[i] = float64(v)
for i, v := range embedding[0] {
embedding64[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding,
Embedding: embedding64,
}
c.JSON(http.StatusOK, resp)
}
@@ -613,9 +606,6 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
}
ch <- gin.H{"error": err.Error()}
}
}()
@@ -717,6 +707,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m.System = req.System
}
if req.Template != "" {
m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
}
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
@@ -931,6 +928,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
@@ -942,11 +940,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
return
}
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
}
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
@@ -962,54 +955,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated)
}
func (s *Server) isLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
return false
}
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
slog.Error(fmt.Sprintf("failed to get server public key: %v", err))
return false
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
@@ -1115,7 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)
@@ -1128,7 +1073,6 @@ func (s *Server) GenerateRoutes() http.Handler {
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
@@ -1255,15 +1199,11 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return
}
case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok {
c.JSON(status, gin.H{"error": errorMsg})
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
return
}
default:
@@ -1341,8 +1281,6 @@ func (s *Server) ProcessHandler(c *gin.Context) {
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1353,10 +1291,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1366,8 +1300,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
checkpointLoaded := time.Now()
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
@@ -1379,11 +1311,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if req.Messages[0].Role != "system" && m.System != "" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1400,7 +1328,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
ch <- api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1413,26 +1341,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var r api.ChatResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
resp = t
r = t
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -1447,16 +1368,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
resp.Message.Content = sb.String()
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
}
c.JSON(http.StatusOK, resp)
r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
return
}
@@ -1465,7 +1378,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})

View File

@@ -85,8 +85,6 @@ func checkFileExists(t *testing.T, p string, expect []string) {
}
func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -113,8 +111,6 @@ func TestCreateFromBin(t *testing.T) {
}
func TestCreateFromModel(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -156,8 +152,6 @@ func TestCreateFromModel(t *testing.T) {
}
func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -205,8 +199,6 @@ func TestCreateRemovesLayers(t *testing.T) {
}
func TestCreateUnsetsSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -263,8 +255,6 @@ func TestCreateUnsetsSystem(t *testing.T) {
}
func TestCreateMergeParameters(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -368,8 +358,6 @@ func TestCreateMergeParameters(t *testing.T) {
}
func TestCreateReplacesMessages(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -446,8 +434,6 @@ func TestCreateReplacesMessages(t *testing.T) {
}
func TestCreateTemplateSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -491,47 +477,9 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system)
}
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
}
func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -578,8 +526,6 @@ func TestCreateLicenses(t *testing.T) {
}
func TestCreateDetectTemplate(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -600,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
})
})

View File

@@ -8,15 +8,12 @@ import (
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -80,8 +77,6 @@ func TestDelete(t *testing.T) {
}
func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server

View File

@@ -1,714 +0,0 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
)
type mockRunner struct {
llm.LlamaServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
}
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil
}
}
func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.ChatResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if diff := cmp.Diff(actual.Message, api.Message{
Role: "assistant",
Content: content,
}); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("messages with model system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("messages with interleaved system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "assistant", Content: "I can help you with that."},
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
}
func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.GenerateResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if actual.Response != content {
t.Errorf("expected response %s, got %s", content, actual.Response)
}
if actual.Context == nil {
t.Errorf("expected context not nil")
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with model system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("prompt with template", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
System: "You can perform magic tricks.",
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@@ -7,14 +7,11 @@ import (
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()

View File

@@ -10,18 +10,15 @@ import (
"math"
"net/http"
"net/http/httptest"
"net/url"
"os"
"sort"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@@ -309,12 +306,8 @@ func Test_Routes(t *testing.T) {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
if embedResp.Embeddings != nil {
t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings)
}
},
},
@@ -530,62 +523,3 @@ func TestNormalize(t *testing.T) {
})
}
}
func TestIsLocalReal(t *testing.T) {
gin.SetMode(gin.TestMode)
clientPubLoc := t.TempDir()
t.Setenv("HOME", clientPubLoc)
_, err := auth.GetPublicKey()
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = &http.Request{
Header: make(http.Header),
}
requestURL := url.URL{
Scheme: "http",
Host: "localhost:8080",
Path: "/api/blobs",
}
request := &http.Request{
Method: http.MethodPost,
URL: &requestURL,
}
s := &Server{}
authz, err := api.Authorization(ctx, request)
if err != nil {
t.Fatal(err)
}
// Set client authorization header
ctx.Request.Header.Set("Authorization", authz)
if !s.isLocal(ctx) {
t.Fatal("Expected isLocal to return true")
}
t.Run("different server pubkey", func(t *testing.T) {
serverPubLoc := t.TempDir()
t.Setenv("HOME", serverPubLoc)
_, err := auth.GetPublicKey()
if err != nil {
t.Fatal(err)
}
if s.isLocal(ctx) {
t.Fatal("Expected isLocal to return false")
}
})
t.Run("invalid pubkey", func(t *testing.T) {
ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
if s.isLocal(ctx) {
t.Fatal("Expected isLocal to return false")
}
})
}

View File

@@ -135,6 +135,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
@@ -188,6 +193,38 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode

View File

@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
require.Len(t, s.expiredCh, 1)
}
type reqBundle struct {
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
@@ -102,13 +102,13 @@ type reqBundle struct {
ggml *llm.GGML
}
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
b := &reqBundle{}
b.ctx, b.ctxDone = context.WithCancel(ctx)
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
@@ -135,154 +135,124 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err)
if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond}
}
b.req = &LlmRequest{
ctx: b.ctx,
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: duration,
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return scenario
}
func getGpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
func getCpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
func TestRequestsSameModelSameRequest(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
b.req.model = a.req.model
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.getCpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = b.newServer
slog.Info("b")
s.pendingReqCh <- b.req
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, a.srv)
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
tmpModel := *a.req.model
b.req.model = &tmpModel
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
// Trigger a reload
s.newServerFn = b.newServer
b.req.model.AdapterPaths = []string{"new"}
slog.Info("b")
s.pendingReqCh <- b.req
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
a.ctxDone()
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, b.srv)
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
s.Run(ctx)
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -292,15 +262,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock()
envconfig.MaxRunners = 0
s.newServerFn = b.newServer
slog.Info("b")
s.pendingReqCh <- b.req
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, b.srv)
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -310,15 +280,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = c.newServer
slog.Info("c")
s.pendingReqCh <- c.req
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
s.pendingReqCh <- scenario3c.req
select {
case resp := <-c.req.successCh:
require.Equal(t, resp.llama, c.srv)
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, c.req.errCh)
case err := <-c.req.errCh:
require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -328,25 +298,25 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock()
// Try to load a model that wont fit
s.newServerFn = d.newServer
slog.Info("d")
s.newServerFn = scenario3d.newServer
slog.Info("scenario3d")
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
a.ctxDone() // Won't help since this one isn't big enough to make room
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- d.req
s.pendingReqCh <- scenario3d.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
b.ctxDone()
scenario3b.ctxDone()
select {
case resp := <-d.req.successCh:
require.Equal(t, resp.llama, d.srv)
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, d.req.errCh)
require.Empty(t, scenario3d.req.errCh)
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -359,19 +329,26 @@ func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
s.newServerFn = a.newServer
slog.Info("a")
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -380,24 +357,22 @@ func TestGetRunner(t *testing.T) {
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, a.srv)
require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a)
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
a.ctxDone() // Set "a" model to idle so it can unload
scenario1a.ctxDone()
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
c.req.model.ModelPath = "bad path"
slog.Info("c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
time.Sleep(5 * time.Millisecond)
require.Empty(t, successCh1c)
s.loadedMu.Lock()
require.Empty(t, s.loaded)
@@ -405,7 +380,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
b.ctxDone()
scenario1b.ctxDone()
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
@@ -414,7 +389,7 @@ func TestPrematureExpired(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
@@ -436,8 +411,6 @@ func TestPrematureExpired(t *testing.T) {
s.loadedMu.Unlock()
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -473,8 +446,6 @@ func TestUseLoadedRunner(t *testing.T) {
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case err := <-req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -654,7 +625,8 @@ func TestAlreadyCanceled(t *testing.T) {
defer done()
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req

View File

@@ -1,67 +0,0 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,39 +0,0 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the users location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,31 +0,0 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,17 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,43 +0,0 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{ .System }}
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools>
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
{{- range .Messages }}
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ if eq .Role "user" }}{{ .Content }}
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}
{{- else if eq .Role "tool" }}<tool_response>
{{ .Content }}
</tool_response>
{{- end }}<|eot_id|>
{{- end }}
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}
{{- if .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,24 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
<tool_response>
22
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,39 +0,0 @@
[
{
"role": "system",
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

View File

@@ -1,15 +0,0 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

View File

@@ -1,3 +0,0 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

View File

@@ -1,30 +0,0 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@@ -1,45 +0,0 @@
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:

View File

@@ -1,40 +0,0 @@
You are a knowledgable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:

View File

@@ -1 +1,8 @@
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
@@ -5,4 +16,4 @@
{{ end }}### Response:
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end }}

View File

@@ -1,6 +1,17 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}

View File

@@ -1,10 +1,19 @@
{{ if .System }}Source: system
{{- if .Messages }}
{{- if .System }}Source: system
{{ .System }} <step> {{ end }}Source: user
{{ .System }} <step> {{ end }}
{{- range .Messages }}Source: {{ .Role }}
{{ .Content }} <step> {{ end }}Source: assistant
Destination: user
{{ else }}
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
{{ .Prompt }} <step> Source: assistant
{{- if not .Response }}
Destination: user
{{- end }}
{{ .Response }} <step>
{{ .Response }}<step>
{{- end }}

View File

@@ -1,5 +1,13 @@
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User:
{{ .Prompt }}
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{ else if eq .Role "assistant" }}Falcon:
{{ end }}{{ .Content }}
{{ end }}Falcon:
{{ .Response }}
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{- end }}

View File

@@ -1,5 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}<start_of_turn>
{{- if eq .Role "user" }}user
{{- if and $.System (eq $index 0) }}
{{ $.System }}
{{- end }}
{{- else if eq .Role "assistant" }}model
{{- end }}
{{ .Content }}<end_of_turn>
{{ end }}<start_of_turn>model
{{ else }}
<start_of_turn>user
{{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}<end_of_turn>
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
{{- end }}

View File

@@ -1,4 +1,18 @@
{{ if .System }}System:
{{- if .Messages }}
{{- if .System }}System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question:
{{- else if eq .Role "assistant" }}Answer:
{{- end }}
{{ .Content }}
{{ end }}Answer:
{{ else }}
{{ if .System }}
System:
{{ .System }}
{{ end }}{{ if .Prompt }}Question:
@@ -6,4 +20,4 @@
{{ end }}Answer:
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,16 @@
[INST] <<SYS>>
{{- if .System }}
{{ .System }}
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
{{- if $.System }}
{{ $.System }}
{{ end }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
{{ end }}{{ .Content }}
{{- else }} [/INST] {{ .Content }}</s><s>
{{- end }}
{{- end }} [/INST]
{{- else }}
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1,7 +1,19 @@
{{- if .Messages }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>
{{ .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,3 +1,15 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}@@ Instruction
{{- else if eq .Role "assistant" }}@@ Response
{{- end }}
{{ .Content }}
{{ end }}@@ Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
@@ -5,4 +17,4 @@
{{ end }}@@ Response
{{ .Response }}
{{- end }}

View File

@@ -1,3 +1,9 @@
[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
{{ end }}{{ .Content }}
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
{{- end }}
{{- end }}[/INST]
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1 +1,11 @@
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- if .Messages }}
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
{{- end }}
{{- range .Messages }}GPT Correct
{{- if eq .Role "user" }} User:
{{- else if eq .Role "assistant" }} Assistant:
{{- end }} {{ .Content }}<|end_of_turn|>
{{- end }}GPT Correct Assistant:
{{- else }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}<|end|>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
{{- end }}

View File

@@ -1,3 +1,16 @@
{{- if .Messages }}
{{- if .System }}### System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### User:
{{ .Content }}
{{ else if eq .Role "assistant" }}### Assistant:
{{ .Content }}</s>
{{ end }}
{{ end }}### Assistant:
{{ else }}
{{ if .System }}### System:
{{ .System }}
@@ -5,5 +18,5 @@
{{ .Prompt }}
{{ end }}### Assistant:
{{ .Response }}</s>
{{ .Response }}
{{- end }}

View File

@@ -1,8 +1,24 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction
{{ .Content }}
{{ else if eq .Role "assistant" }}### Response
{{ .Content }}<|endoftext|>
{{ end }}
{{- end }}### Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }}
{{ end }}### Response
{{ .Response }}<|endoftext|>
{{- end }}

View File

@@ -102,15 +102,8 @@ var response = parse.ActionNode{
},
}
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl := template.New("").Option("missingkey=zero")
tmpl, err := tmpl.Parse(s)
if err != nil {
@@ -134,7 +127,7 @@ func (t *Template) Vars() []string {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...)
vars = append(vars, parseNode(n)...)
}
}
@@ -150,131 +143,54 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
api.Tools
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
}
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
system, collated := collate(v.Messages)
if slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
"Messages": collated,
})
}
system = ""
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
execute := func() error {
for i, m := range collated {
if m.Role == "user" {
prompt = m.Content
} else {
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"System": "",
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
system = ""
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
}
}
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
tree := t.Template.Copy()
// for the last message, cut everything after "{{ .Response }}"
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
if slices.Contains(parseNode(n), "Response") {
cut = true
return false
}
return cut
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
@@ -283,16 +199,25 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int
type messages []*api.Message
var system []string
var collated []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
@@ -303,10 +228,6 @@ func collate(msgs []api.Message) (string, []*api.Message) {
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
@@ -314,119 +235,54 @@ func collate(msgs []api.Message) (string, []*api.Message) {
}
}
return strings.Join(system, "\n\n"), collated
return
}
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode:
return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
return parseNode(n.Pipe)
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.IfNode:
return Identifiers(&n.BranchNode)
case *parse.RangeNode:
return Identifiers(&n.BranchNode)
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode:
return Identifiers(&n.BranchNode)
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, Identifiers(a)...)
names = append(names, parseNode(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
case *parse.VariableNode:
return n.Ident
case *parse.TemplateNode:
return parseNode(n.Pipe)
}
return nil
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
var walk func(n parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return nil
}
switch t := n.(type) {
case *parse.ListNode:
var nodes []parse.Node
for _, c := range t.Nodes {
if n := walk(c); n != nil {
nodes = append(nodes, n)
}
}
t.Nodes = nodes
return t
case *parse.IfNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.WithNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.RangeNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.BranchNode:
t.List = walk(t.List).(*parse.ListNode)
if t.ElseList != nil {
t.ElseList = walk(t.ElseList).(*parse.ListNode)
}
case *parse.ActionNode:
n := walk(t.Pipe)
if n == nil {
return nil
}
t.Pipe = n.(*parse.PipeNode)
case *parse.PipeNode:
var commands []*parse.CommandNode
for _, c := range t.Cmds {
var args []parse.Node
for _, a := range c.Args {
if n := walk(a); n != nil {
args = append(args, n)
}
}
if len(args) == 0 {
return nil
}
c.Args = args
commands = append(commands, c)
}
if len(commands) == 0 {
return nil
}
t.Cmds = commands
}
return n
}
return walk(n)
}

View File

@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
}
for n, tt := range cases {
var actual bytes.Buffer
t.Run(n, func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
@@ -116,34 +116,7 @@ func TestTemplate(t *testing.T) {
t.Fatal(err)
}
bts := actual.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
t.Log("removing trailing space from output")
bts = bts[:len(bts)-1]
}
if diff := cmp.Diff(bts, expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err)
}
legacyBytes := legacy.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
t.Log("removing trailing space from legacy output")
legacyBytes = legacyBytes[:len(legacyBytes)-1]
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
t.Skip("legacy outputs cannot be compared to messages outputs")
}
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
@@ -162,24 +135,7 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
}
for _, tt := range cases {
@@ -189,8 +145,9 @@ func TestParse(t *testing.T) {
t.Fatal(err)
}
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
vars := tmpl.Vars()
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}
@@ -210,17 +167,12 @@ func TestExecuteWithMessages(t *testing.T) {
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
@@ -235,17 +187,13 @@ func TestExecuteWithMessages(t *testing.T) {
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
@@ -256,29 +204,9 @@ func TestExecuteWithMessages(t *testing.T) {
{Role: "user", Content: "What is your name?"},
},
},
`[INST] You are a helpful assistant!
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral assistant",
[]template{
{"no response", `[INST] {{ .Prompt }}[/INST] `},
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $i, $m := .Messages }}
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
{Role: "assistant", Content: "My name is Ollama and I"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
What is your name?[/INST] `,
},
{
"chatml",
@@ -292,9 +220,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`},
},
Values{
@@ -305,12 +236,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
@@ -327,11 +258,9 @@ What is your name?<|im_end|>
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
{{ end }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `},
},
Values{
@@ -371,46 +300,11 @@ Answer: `,
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
if b.String() != tt.expected {
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
}
})
}
})
}
}
func TestExecuteWithSuffix(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
values Values
expect string
}{
{
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -1,6 +1,4 @@
You are a helpful assistant.
### Instruction:
You are a helpful assistant.### Instruction:
Hello, how are you?
### Response:

View File

@@ -9,4 +9,3 @@ Source: system
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -3,4 +3,3 @@ Source: user
Hello, how are you? <step> Source: assistant
Destination: user

View File

@@ -7,4 +7,3 @@ Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -2,6 +2,4 @@
You are a helpful assistant.
<</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,5 +1,3 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,3 +1,2 @@
[INST] You are a helpful assistant.
Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
I'd like to show off how chat templating works![/INST]

View File

@@ -1 +1 @@
GPT4 Correct System: You are a helpful assistant.<|end_of_turn|>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +1 @@
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +1 @@
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1,4 +1,14 @@
{{ if .System }}{{ .System }}
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}USER: {{ .Content }}
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
{{ end }}
{{- end }}ASSISTANT:
{{- else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
{{ end }}ASSISTANT: {{ .Response }}</s>
{{ end }}ASSISTANT: {{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}</s>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}</s>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}</s>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}</s>
{{ end }}<|assistant|>
{{ .Response }}</s>
{{- end }}