Compare commits

..

46 Commits

Author SHA1 Message Date
Roy Han
907b038ff0 reduce error footprint 2024-07-15 10:57:01 -07:00
royjhan
1f73889f34 Merge branch 'royh-batchembed' into royh-embed-parallel 2024-07-12 16:44:12 -07:00
Roy Han
7e313e5964 remove redundant error check 2024-07-12 16:37:29 -07:00
Roy Han
5a8f8e96e0 clean up 2024-07-12 16:35:25 -07:00
Roy Han
7cddd6d741 parallelized 2024-07-12 16:08:12 -07:00
Roy Han
1f3aefd323 remove function closure 2024-07-12 14:45:16 -07:00
Roy Han
2d7048f410 Revert "remove function closure"
This reverts commit 55d48c6ed1.
2024-07-12 14:40:40 -07:00
Roy Han
55d48c6ed1 remove function closure 2024-07-12 14:35:43 -07:00
Roy Han
c0b5bf0a36 testing clean up 2024-07-12 11:45:45 -07:00
Roy Han
53e9576f46 testing clean up 2024-07-11 20:20:14 -07:00
Roy Han
dbe9527305 clean up 2024-07-11 17:28:55 -07:00
Roy Han
694388db90 set context length 2024-07-10 15:21:46 -07:00
Roy Han
cdb9fe9b06 test values 2024-07-10 09:57:36 -07:00
Roy Han
8f6d0242b6 refactoring 2024-07-09 16:19:02 -07:00
Roy Han
c697eb2a9b fix hanging on single string 2024-07-09 15:51:55 -07:00
Roy Han
b686ac144c merge conflicts 2024-07-09 14:00:13 -07:00
royjhan
786848dfd3 Merge branch 'main' into royh-batchembed 2024-07-09 13:48:06 -07:00
Roy Han
fb390b8902 embedding type 64 2024-07-09 13:41:48 -07:00
Roy Han
bcb63e6e0e touches 2024-07-09 13:37:00 -07:00
Roy Han
3342e5f035 merge conflicts 2024-07-08 15:15:09 -07:00
royjhan
b7c622dd32 Merge branch 'main' into royh-batchembed 2024-07-08 15:10:52 -07:00
Roy Han
6caac01494 clear comments 2024-07-03 14:05:34 -07:00
Roy Han
17de2b4405 Refactoring of legacy and new 2024-07-03 14:02:25 -07:00
Roy Han
922b8f2584 input handling and handler testing 2024-07-03 12:48:54 -07:00
Roy Han
c0fa2236cf integration float32 2024-07-03 12:47:57 -07:00
Roy Han
a413014aaf refactoring 2024-07-03 11:21:06 -07:00
royjhan
a5f23d766e Merge branch 'main' into royh-batchembed 2024-07-03 11:20:24 -07:00
Roy Han
95e46eeedf move normalize test 2024-07-03 09:45:42 -07:00
Roy Han
3d060e0ae9 move normalize 2024-07-02 10:35:02 -07:00
Roy Han
00a4cb26ca use float32 2024-07-02 10:30:29 -07:00
Roy Han
512e0a7bde Clean up 2024-07-01 16:29:54 -07:00
Roy Han
1a0c8b363c Truncation Integration Tests 2024-07-01 16:26:30 -07:00
Roy Han
e068e7f698 Integration Test Template 2024-07-01 15:24:26 -07:00
Roy Han
aee25acb5b move normalization to go 2024-07-01 14:10:58 -07:00
Roy Han
9c32b6b9ed Truncation 2024-07-01 11:59:44 -07:00
Roy Han
1daac52651 Truncation 2024-07-01 11:55:16 -07:00
Roy Han
80c1a3f812 playing around with truncate stuff 2024-06-28 18:17:09 -07:00
Roy Han
c111d8bb51 normalization 2024-06-28 17:19:04 -07:00
Roy Han
5213c12354 clean up 2024-06-28 15:26:58 -07:00
Roy Han
b9c74df37b check normalization 2024-06-28 15:10:58 -07:00
Roy Han
49e341147d add server function 2024-06-28 15:03:53 -07:00
Roy Han
c406fa7a4c api/embed draft 2024-06-28 14:54:21 -07:00
Roy Han
22458c573a mock up notes 2024-06-28 14:21:45 -07:00
Roy Han
ff191d7cba Initial Draft 2024-06-25 13:29:47 -07:00
Roy Han
0f87628b6d Revert "Initial Batch Embedding"
This reverts commit c22d54895a.
2024-06-24 15:26:05 -07:00
Roy Han
c22d54895a Initial Batch Embedding 2024-06-18 17:34:36 -07:00
60 changed files with 1034 additions and 529 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -169,7 +169,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil
}
// Embeddings generates embeddings from a model.
// Embed generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates embeddings from a model. (Legacy)
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

View File

@@ -173,6 +173,30 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"`
}
// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`
// Input is the input to embed.
Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct {
// Model is the model name.

View File

@@ -127,10 +127,6 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages]
WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models.

View File

@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
}
func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -84,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
}
slog.Debug("hipDriverGetVersion", "version", version)
driverMajor = version / 10000000
driverMinor = (version - (driverMajor * 10000000)) / 100000
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil
}

View File

@@ -22,8 +22,8 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
)
func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
}
defer hl.Release()
driverMajor, driverMinor, err := hl.AMDDriverVersion()
if err != nil {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version", "error", err)
}
// TODO - this reports incorrect version information, so omitting for now
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
// if err != nil {
// // For now this is benign, but we may eventually need to fail compatibility checks
// slog.Debug("error looking up amd driver version", "error", err)
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
@@ -131,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
DriverMajor: driverMajor,
DriverMinor: driverMinor,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
},
index: i,
}

View File

@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
@@ -360,17 +338,14 @@ func GetGPUInfo() GpuInfoList {
"before",
"total", format.HumanBytes2(cpus[0].TotalMemory),
"free", format.HumanBytes2(cpus[0].FreeMemory),
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
),
slog.Group(
"now",
"total", format.HumanBytes2(mem.TotalMemory),
"free", format.HumanBytes2(mem.FreeMemory),
"free_swap", format.HumanBytes2(mem.FreeSwap),
),
)
cpus[0].FreeMemory = mem.FreeMemory
cpus[0].FreeSwap = mem.FreeSwap
}
var memInfo C.mem_info_t
@@ -399,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory")
continue
}
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data",
"gpu", gpu.ID,
"name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group(
"before",
"total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -57,7 +57,6 @@ func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()),
// FreeSwap omitted as Darwin uses dynamic paging
}, nil
}

View File

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
func GetCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
var total, available, free, buffers, cached uint64
f, err := os.Open("/proc/meminfo")
if err != nil {
return mem, err
@@ -70,21 +70,20 @@ func GetCPUMem() (memInfo, error) {
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
case strings.HasPrefix(line, "Cached:"):
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
case strings.HasPrefix(line, "SwapFree:"):
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
default:
continue
}
if err != nil {
return mem, err
}
if total > 0 && available > 0 {
mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = available * format.KibiByte
return mem, nil
}
}
mem.TotalMemory = total * format.KibiByte
mem.FreeSwap = freeSwap * format.KibiByte
if available > 0 {
mem.FreeMemory = available * format.KibiByte
} else {
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
}
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
return mem, nil
}

View File

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
if r1 == 0 {
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
}
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
}

View File

@@ -10,7 +10,6 @@ import (
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
FreeSwap uint64 `json:"free_swap,omitempty"`
}
// Beginning of an `ollama info` command
@@ -53,8 +52,7 @@ type CPUInfo struct {
type CudaGPUInfo struct {
GpuInfo
OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
index int //nolint:unused,nolintlint
}
type CudaGPUInfoList []CudaGPUInfo

152
integration/embed_test.go Normal file
View File

@@ -0,0 +1,152 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
}
}
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}
func TestAllMiniLmEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
}
reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
prompt = "";
}
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
if (prompt.size() == 1) {
prompt = prompt[0];
}
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// get the result
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,9 +6,18 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS
}
# Current supported rocblas list from ROCm v6.1.2 on windows
# TODO - load from some common data file for linux + windows build consistency
$GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030"
"gfx1100"
"gfx1101"
@@ -357,7 +366,6 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
@@ -386,6 +394,7 @@ function build_rocm() {
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -424,32 +424,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
)
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
fullOffload = max(
fullOffload,
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
qkvBias.Shape[0]),
)
partialOffload = max(
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
4*qkvBias.Shape[0],
)
}
}
return

View File

@@ -33,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.ftype = ftype.Value()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
return fmt.Errorf("llama_model_quantize: %d", rc)
}
return nil

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -88,7 +88,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var estimate MemoryEstimate
var systemTotalMemory uint64
var systemFreeMemory uint64
var systemSwapFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem()
if err != nil {
@@ -96,8 +95,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else {
systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory
systemSwapFreeMemory = systemMemInfo.FreeSwap
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
}
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
@@ -124,16 +122,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
}
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory)
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
}
}
estimate.log()
// Loop through potential servers
@@ -266,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit)
}
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {
@@ -867,15 +859,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}
type EmbeddingRequest struct {
Content string `json:"content"`
type EmbedRequest struct {
Content []string `json:"content"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
}
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@@ -890,7 +882,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@@ -917,7 +909,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("%s", body)
}
var embedding EmbeddingResponse
var embedding EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}

View File

@@ -338,16 +338,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
}
options["stop"] = stops
}
if r.MaxTokens != nil {

View File

@@ -3,6 +3,7 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -15,133 +16,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
func TestMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -155,7 +30,159 @@ func TestMiddlewareResponses(t *testing.T) {
testCases := []testCase{
{
Name: "completions handler error forwarding",
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",

View File

@@ -107,12 +107,9 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -161,7 +161,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
@@ -17,6 +18,7 @@ import (
"path/filepath"
"slices"
"strings"
"sync"
"syscall"
"time"
@@ -246,6 +248,152 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func (s *Server) EmbedHandler(c *gin.Context) {
var req api.EmbedRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
reqEmbed := []string{}
switch embeddings := req.Input.(type) {
case string:
if embeddings == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
reqEmbed = []string{embeddings}
case []any:
if len(embeddings) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
for _, v := range embeddings {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
reqEmbed = append(reqEmbed, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray := make([]string, len(reqEmbed))
errCh := make(chan error, 1)
successCh := make(chan bool, 1)
sem := make(chan struct{}, 2)
var wg sync.WaitGroup
var mu sync.Mutex
for i, s := range reqEmbed {
wg.Add(1)
sem <- struct{}{}
go func(i int, s string) {
defer wg.Done()
defer func() { <-sem }()
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
errCh <- err
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
errCh <- err
return
}
} else {
errCh <- err
return
}
}
mu.Lock()
reqEmbedArray[i] = s
mu.Unlock()
}(i, s)
}
go func() {
wg.Wait()
successCh <- true
close(errCh)
}()
select {
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
case success := <-successCh:
if !success {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to process all embeddings"})
return
}
}
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
for i, e := range embeddings {
embeddings[i] = normalize(e)
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -268,14 +416,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
embedding64 := make([]float64, len(embedding[0]))
for i, v := range embedding[0] {
embedding64[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding64,
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -901,7 +1059,8 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)

View File

@@ -546,8 +546,8 @@ func TestCreateDetectTemplate(t *testing.T) {
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
})
})

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
@@ -272,6 +273,73 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy)
},
},
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings != nil {
t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -420,3 +488,38 @@ func TestShow(t *testing.T) {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
}
}
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@@ -135,6 +135,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
@@ -188,6 +193,38 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode

View File

@@ -642,8 +642,8 @@ type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
embedResp [][]float32
embedRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
@@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embedResp, s.embedRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr

View File

@@ -1 +1,8 @@
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
@@ -5,4 +16,4 @@
{{ end }}### Response:
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end }}

View File

@@ -1,6 +1,17 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}

View File

@@ -1,10 +1,19 @@
{{ if .System }}Source: system
{{- if .Messages }}
{{- if .System }}Source: system
{{ .System }} <step> {{ end }}Source: user
{{ .System }} <step> {{ end }}
{{- range .Messages }}Source: {{ .Role }}
{{ .Content }} <step> {{ end }}Source: assistant
Destination: user
{{ else }}
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
{{ .Prompt }} <step> Source: assistant
{{- if not .Response }}
Destination: user
{{- end }}
{{ .Response }} <step>
{{ .Response }}<step>
{{- end }}

View File

@@ -1,5 +1,13 @@
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User:
{{ .Prompt }}
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{ else if eq .Role "assistant" }}Falcon:
{{ end }}{{ .Content }}
{{ end }}Falcon:
{{ .Response }}
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}
{{- end }}

View File

@@ -1,5 +1,16 @@
{{- if .Messages }}
{{- range $index, $_ := .Messages }}<start_of_turn>
{{- if eq .Role "user" }}user
{{- if and $.System (eq $index 0) }}
{{ $.System }}
{{- end }}
{{- else if eq .Role "assistant" }}model
{{- end }}
{{ .Content }}<end_of_turn>
{{ end }}<start_of_turn>model
{{ else }}
<start_of_turn>user
{{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}<end_of_turn>
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>
{{- end }}

View File

@@ -1,4 +1,18 @@
{{ if .System }}System:
{{- if .Messages }}
{{- if .System }}System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}Question:
{{- else if eq .Role "assistant" }}Answer:
{{- end }}
{{ .Content }}
{{ end }}Answer:
{{ else }}
{{ if .System }}
System:
{{ .System }}
{{ end }}{{ if .Prompt }}Question:
@@ -6,4 +20,4 @@
{{ end }}Answer:
{{ .Response }}
{{- end }}

View File

@@ -1,6 +1,16 @@
[INST] <<SYS>>
{{- if .System }}
{{ .System }}
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
{{- if $.System }}
{{ $.System }}
{{ end }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
{{ end }}{{ .Content }}
{{- else }} [/INST] {{ .Content }}</s><s>
{{- end }}
{{- end }} [/INST]
{{- else }}
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1,7 +1,19 @@
{{- if .Messages }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>
{{ .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,3 +1,15 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}@@ Instruction
{{- else if eq .Role "assistant" }}@@ Response
{{- end }}
{{ .Content }}
{{ end }}@@ Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
@@ -5,4 +17,4 @@
{{ end }}@@ Response
{{ .Response }}
{{- end }}

View File

@@ -1,3 +1,9 @@
[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
{{- if .Messages }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
{{ end }}{{ .Content }}
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
{{- end }}
{{- end }}[/INST]
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
{{- end }}

View File

@@ -1 +1,11 @@
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- if .Messages }}
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
{{- end }}
{{- range .Messages }}GPT Correct
{{- if eq .Role "user" }} User:
{{- else if eq .Role "assistant" }} Assistant:
{{- end }} {{ .Content }}<|end_of_turn|>
{{- end }}GPT Correct Assistant:
{{- else }}
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}<|end|>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>
{{- end }}

View File

@@ -1,3 +1,16 @@
{{- if .Messages }}
{{- if .System }}### System:
{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### User:
{{ .Content }}
{{ else if eq .Role "assistant" }}### Assistant:
{{ .Content }}</s>
{{ end }}
{{ end }}### Assistant:
{{ else }}
{{ if .System }}### System:
{{ .System }}
@@ -5,5 +18,5 @@
{{ .Prompt }}
{{ end }}### Assistant:
{{ .Response }}</s>
{{ .Response }}
{{- end }}

View File

@@ -1,8 +1,24 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction
{{ .Content }}
{{ else if eq .Role "assistant" }}### Response
{{ .Content }}<|endoftext|>
{{ end }}
{{- end }}### Response
{{ else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }}
{{ end }}### Response
{{ .Response }}<|endoftext|>
{{- end }}

View File

@@ -143,14 +143,11 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
if slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
@@ -160,46 +157,39 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var b bytes.Buffer
var prompt, response string
for i, m := range collated {
switch m.Role {
case "system":
system = m.Content
case "user":
if m.Role == "user" {
prompt = m.Content
case "assistant":
} else {
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"System": "",
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
system = ""
prompt = ""
response = ""
}
}
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
switch t := n.(type) {
case *parse.ActionNode:
case *parse.FieldNode:
if slices.Contains(t.Ident, "Response") {
cut = true
}
tree := t.Template.Copy()
// for the last message, cut everything after "{{ .Response }}"
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
if slices.Contains(parseNode(n), "Response") {
cut = true
}
return cut
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": "",
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
@@ -209,16 +199,25 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int
type messages []*api.Message
var system []string
var collated []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
@@ -229,10 +228,6 @@ func collate(msgs []api.Message) (string, []*api.Message) {
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
@@ -240,7 +235,7 @@ func collate(msgs []api.Message) (string, []*api.Message) {
}
}
return strings.Join(system, "\n\n"), collated
return
}
func parseNode(n parse.Node) []string {
@@ -291,72 +286,3 @@ func parseNode(n parse.Node) []string {
return nil
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
var walk func(n parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return nil
}
switch t := n.(type) {
case *parse.ListNode:
var nodes []parse.Node
for _, c := range t.Nodes {
if n := walk(c); n != nil {
nodes = append(nodes, n)
}
}
t.Nodes = nodes
return t
case *parse.IfNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.WithNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.RangeNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.BranchNode:
t.List = walk(t.List).(*parse.ListNode)
if t.ElseList != nil {
t.ElseList = walk(t.ElseList).(*parse.ListNode)
}
case *parse.ActionNode:
n := walk(t.Pipe)
if n == nil {
return nil
}
t.Pipe = n.(*parse.PipeNode)
case *parse.PipeNode:
var commands []*parse.CommandNode
for _, c := range t.Cmds {
var args []parse.Node
for _, a := range c.Args {
if n := walk(a); n != nil {
args = append(args, n)
}
}
if len(args) == 0 {
return nil
}
c.Args = args
commands = append(commands, c)
}
if len(commands) == 0 {
return nil
}
t.Cmds = commands
}
return n
}
return walk(n)
}

View File

@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
}
for n, tt := range cases {
var actual bytes.Buffer
t.Run(n, func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
@@ -116,34 +116,7 @@ func TestTemplate(t *testing.T) {
t.Fatal(err)
}
bts := actual.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
t.Log("removing trailing space from output")
bts = bts[:len(bts)-1]
}
if diff := cmp.Diff(bts, expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err)
}
legacyBytes := legacy.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
t.Log("removing trailing space from legacy output")
legacyBytes = legacyBytes[:len(legacyBytes)-1]
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
t.Skip("legacy outputs cannot be compared to messages outputs")
}
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
@@ -162,24 +135,7 @@ func TestParse(t *testing.T) {
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
}
for _, tt := range cases {
@@ -189,8 +145,9 @@ func TestParse(t *testing.T) {
t.Fatal(err)
}
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
vars := tmpl.Vars()
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}
@@ -210,17 +167,12 @@ func TestExecuteWithMessages(t *testing.T) {
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
@@ -235,17 +187,13 @@ func TestExecuteWithMessages(t *testing.T) {
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
{{- end }}
{{- end }}`},
},
Values{
@@ -256,9 +204,9 @@ func TestExecuteWithMessages(t *testing.T) {
{Role: "user", Content: "What is your name?"},
},
},
`[INST] You are a helpful assistant!
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
What is your name?[/INST] `,
},
{
"chatml",
@@ -272,9 +220,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{- range $index, $_ := .Messages }}
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
{{ $.System }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>{{ "\n" }}
{{- end }}<|im_start|>assistant
`},
},
Values{
@@ -285,12 +236,12 @@ Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
`<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
@@ -307,11 +258,9 @@ What is your name?<|im_end|>
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
{{ end }}
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
{{- end }}
{{- end }}Answer: `},
},
Values{
@@ -351,8 +300,8 @@ Answer: `,
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
if b.String() != tt.expected {
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
}
})
}

View File

@@ -1,6 +1,4 @@
You are a helpful assistant.
### Instruction:
You are a helpful assistant.### Instruction:
Hello, how are you?
### Response:

View File

@@ -9,4 +9,3 @@ Source: system
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -3,4 +3,3 @@ Source: user
Hello, how are you? <step> Source: assistant
Destination: user

View File

@@ -7,4 +7,3 @@ Source: user
I'd like to show off how chat templating works! <step> Source: assistant
Destination: user

View File

@@ -2,6 +2,4 @@
You are a helpful assistant.
<</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,5 +1,3 @@
[INST] <<SYS>><</SYS>>
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
I'd like to show off how chat templating works! [/INST]
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]

View File

@@ -1,3 +1,2 @@
[INST] You are a helpful assistant.
Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
I'd like to show off how chat templating works![/INST]

View File

@@ -1 +1 @@
GPT4 Correct System: You are a helpful assistant.<|end_of_turn|>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +1 @@
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1 +1 @@
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:

View File

@@ -1,4 +1,14 @@
{{ if .System }}{{ .System }}
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}USER: {{ .Content }}
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
{{ end }}
{{- end }}ASSISTANT:
{{- else }}
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
{{ end }}ASSISTANT: {{ .Response }}</s>
{{ end }}ASSISTANT: {{ .Response }}
{{- end }}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|system|>
{{ .System }}</s>
{{ end }}
{{- range .Messages }}<|{{ .Role }}|>
{{ .Content }}</s>
{{ end }}<|assistant|>
{{ else }}
{{ if .System }}<|system|>
{{ .System }}</s>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}</s>
{{ end }}<|assistant|>
{{ .Response }}</s>
{{- end }}