Compare commits

...

15 Commits

Author SHA1 Message Date
Patrick Devine
3c0d043b79 pass the template to the /api/chat endpoint 2024-07-10 14:17:39 -07:00
Daniel Hiltgen
4cfcbc328f Merge pull request #5124 from dhiltgen/amd_windows
Wire up windows AMD driver reporting
2024-07-10 12:50:23 -07:00
Daniel Hiltgen
79292ff3e0 Merge pull request #5555 from dhiltgen/msvc_deps
Bundle missing CRT libraries
2024-07-10 12:50:02 -07:00
Daniel Hiltgen
8ea500441d Merge pull request #5580 from dhiltgen/cuda_overhead
Detect CUDA OS overhead
2024-07-10 12:47:31 -07:00
Daniel Hiltgen
b50c818623 Merge pull request #5607 from dhiltgen/win_rocm_v6
Bump ROCm on windows to 6.1.2
2024-07-10 12:47:10 -07:00
Daniel Hiltgen
b99e750b62 Merge pull request #5605 from dhiltgen/merge_glitch
Remove duplicate merge glitch
2024-07-10 11:47:08 -07:00
Daniel Hiltgen
1f50356e8e Bump ROCm on windows to 6.1.2
This also adjusts our algorithm to favor our bundled ROCm.
I've confirmed VRAM reporting still doesn't work properly so we
can't yet enable concurrency by default.
2024-07-10 11:01:22 -07:00
Daniel Hiltgen
22c81f62ec Remove duplicate merge glitch 2024-07-10 09:01:33 -07:00
Daniel Hiltgen
2d1e3c3229 Merge pull request #5503 from dhiltgen/dual_rocm
Workaround broken ROCm p2p copy
2024-07-09 15:44:16 -07:00
royjhan
4918fae535 OpenAI v1/completions: allow stop token list (#5551)
* stop token parsing fix

* add stop test
2024-07-09 14:01:26 -07:00
royjhan
0aff67877e separate request tests (#5578) 2024-07-09 13:48:31 -07:00
Daniel Hiltgen
f6f759fc5f Detect CUDA OS Overhead
This adds logic to detect skew between the driver and
management library which can be attributed to OS overhead
and records that so we can adjust subsequent management
library free VRAM updates and avoid OOM scenarios.
2024-07-09 12:21:50 -07:00
Daniel Hiltgen
b44320db13 Bundle missing CRT libraries
Some users are experienging runner startup errors due
to not having these msvc redist libraries on their host
2024-07-08 18:24:21 -07:00
Daniel Hiltgen
0bacb30007 Workaround broken ROCm p2p copy
Enable the build flag for llama.cpp to use CPU copy for multi-GPU scenarios.
2024-07-08 09:40:52 -07:00
Daniel Hiltgen
784bf88b0d Wire up windows AMD driver reporting
This seems to be ROCm version, not actually driver version, but
it may be useful for toggling logic for VRAM reporting in the future
2024-06-18 16:22:47 -07:00
18 changed files with 232 additions and 216 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -169,7 +169,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -84,6 +84,9 @@ type ChatRequest struct {
// Model is the model name, as in [GenerateRequest].
Model string `json:"model"`
// Template overrides the model's default prompt template.
Template string `json:"template"`
// Messages is the messages of the chat - can be used to keep a chat memory.
Messages []Message `json:"messages"`

View File

@@ -947,6 +947,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req := &api.ChatRequest{
Model: opts.Model,
Template: opts.Template,
Messages: opts.Messages,
Format: opts.Format,
Options: opts.Options,

View File

@@ -18,6 +18,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
)
@@ -205,9 +206,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
mTemplate := sb.String()
sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
multiline = MultilineNone
scanner.Prompt.UseAlt = false
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
fmt.Println("Set prompt template.")
}
multiline = MultilineNone
@@ -369,9 +378,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
mTemplate := sb.String()
sb.Reset()
_, err := template.Parse(mTemplate)
if err != nil {
fmt.Println("The template is invalid.")
continue
}
opts.Template = mTemplate
fmt.Println("Set prompt template.")
}
sb.Reset()

View File

@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
}
func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Favor our bundled version
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -84,9 +84,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
}
slog.Debug("hipDriverGetVersion", "version", version)
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
driverMajor = version / 10000000
driverMinor = (version - (driverMajor * 10000000)) / 100000
return driverMajor, driverMinor, nil
}

View File

@@ -22,8 +22,8 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
)
func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
}
defer hl.Release()
// TODO - this reports incorrect version information, so omitting for now
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
// if err != nil {
// // For now this is benign, but we may eventually need to fail compatibility checks
// slog.Debug("error looking up amd driver version", "error", err)
// }
driverMajor, driverMinor, err := hl.AMDDriverVersion()
if err != nil {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version", "error", err)
}
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
@@ -132,10 +131,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
DriverMajor: driverMajor,
DriverMinor: driverMinor,
},
index: i,
}

View File

@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
@@ -374,9 +396,14 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory")
continue
}
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data",
"gpu", gpu.ID,
"name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group(
"before",
"total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -52,7 +52,8 @@ type CPUInfo struct {
type CudaGPUInfo struct {
GpuInfo
index int //nolint:unused,nolintlint
OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
}
type CudaGPUInfoList []CudaGPUInfo

View File

@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,18 +6,9 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS
}
# TODO - load from some common data file for linux + windows build consistency
# Current supported rocblas list from ROCm v6.1.2 on windows
$GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030"
"gfx1100"
"gfx1101"
@@ -366,6 +357,7 @@ function build_rocm() {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
@@ -394,7 +386,6 @@ function build_rocm() {
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -254,10 +254,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit)
}
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {

View File

@@ -338,12 +338,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -16,7 +15,133 @@ import (
"github.com/stretchr/testify/assert"
)
func TestMiddleware(t *testing.T) {
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -30,159 +155,7 @@ func TestMiddleware(t *testing.T) {
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",

View File

@@ -107,9 +107,12 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -71,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
@@ -81,6 +81,13 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return nil, nil, nil, err
}
if mTemplate != "" {
model.Template, err = template.Parse(mTemplate)
if err != nil {
return nil, nil, nil, err
}
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
@@ -120,7 +127,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
@@ -256,7 +263,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -1132,7 +1139,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return