Compare commits

..

21 Commits

Author SHA1 Message Date
Jeffrey Morgan
8c13cfa4dd ml/backend/ggml: fix crash on windows paths with wide characters (#9305) 2025-02-23 19:13:53 -08:00
Jeffrey Morgan
7cfd4aee4d docs: add additional ROCm docs for building (#9066) 2025-02-22 11:22:59 -08:00
Blake Mizerany
68bac1e0a6 server: group routes by category and purpose (#9270)
The route assembly in Handler lacked clear organization making it
difficult scan for routes and their relationships to each other. This
commit aims to fix that by reordering the assembly of routes to group
them by category and purpose.

Also, be more specific about what "config" refers to (it is about CORS
if you were wondering... I was.)
2025-02-21 21:02:26 -08:00
Jesse Gross
f53f4198c3 ml: Abstract attention out of model definitions
There are two benefits to doing this:
 - Provide a library function that models can use, reducing code for
   each model implementation
 - Enables a single place to drop in optimized implementations of
   attention based on the backend or other factors. One is provided for
   GGML.

On CUDA this improves token generation rate by about 3%. It does not
have a significant effect on Metal.

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2025-02-21 13:16:21 -08:00
Michael Yang
2192a28eed ml/backend/ggml: fix rms norm 2025-02-21 18:34:19 +00:00
Junyan Qin (Chin)
5d81c1a184 docs: add RockChinQ/LangBot to integrations list (#9272) 2025-02-21 09:36:55 -08:00
Jesse Gross
5c5535c064 models: Prune unused outputs earlier in the forward pass
Currently Rows is called as the last step in a model computation
to get the values for the output tokens. However, if we move it
earlier in the process then we can trim out computations that
never get used. This is similar to how models are defined in
llama.cpp.

Changing the model definition in this way improves token generation
performance by approximately 8%.
2025-02-20 14:49:47 -08:00
Jesse Gross
e5bcc51ae1 ggml-backend: Don't recreate the scheduler for each context
We don't need to create and destroy the GGML scheduler for every
context. This introduces extra CPU overhead for every forward
pass and extra memory for contexts that don't actually get scheduled
(for example, KV caches). We can instead just have one scheduler
for the backend and reset it each time we call Compute.

This improves token generation performance by 1-2% and removes
scheduler create/destroy from profile traces.
2025-02-20 14:49:47 -08:00
Jesse Gross
bd6a7d5e64 ollamarunner: Pass runner performance parameters to backends
Currently the following parameters are in the runner but not used:
 - numGPULayers
 - mainGPU
 - threads
 - tensorSplit

This passes them through to the backend, which is where they would
actually get used. However, the GGML backend does not yet do anything
with them.
2025-02-20 13:27:57 -08:00
Bruce MacDonald
14b5a9a150 api: document client stream behavior with a test (#8996)
Added unit tests to verify error handling behavior in the Client.stream and Client.do methods.
Tests cover various error scenarios including:
- Error responses with status codes >= 400
- Error messages with successful status codes
- Empty error messages
- Successful responses
2025-02-20 13:19:58 -08:00
Michael Yang
ba9ec3d05e ci: use clang for windows cpu builds
clang outputs are faster. we were previously building with clang via gcc
wrapper in cgo but this was missed during the build updates so there was
a drop in performance
2025-02-20 20:22:36 +00:00
frob
7c168b08c9 server: add missing function parens to debug log (#9255) 2025-02-20 12:10:15 -08:00
danielekp
3d4cc7833c docs: Add yla to community integrations 2025-02-20 11:34:24 -08:00
Lucas Hahn
351a85d9ea openai: add 'timeout' to allowable x-stainless headers (#9237) 2025-02-19 21:56:18 -08:00
Michael Yang
bda4ef6c56 reorder patches 2025-02-20 03:49:24 +00:00
Michael Yang
1e438b237c Merge pull request #9203 from ollama/mxyng/sapphirerapids
build: remove backend build for sapphirerapids
2025-02-19 21:42:00 +00:00
yuiseki
d721a02e7d test: add test cases for ListHandler (#9146) 2025-02-19 13:24:27 -08:00
zyxucp
778603a818 docs: Add AntSK to Community Integrations (#9214) 2025-02-19 13:22:48 -08:00
maninhill
3c874df46e docs: Add MaxKB to Community Integrations (#9212) 2025-02-19 13:20:09 -08:00
Jeffrey Morgan
d2eb226c91 llama: add patch to fix ggml backend reg on Linux with utf-8 characters in the path (#9159) 2025-02-18 22:46:17 -05:00
Michael Yang
5f8c03189e build: remove backend build for sapphirerapids
sapphire rapids has amx support but it ends up having a negative
performance impact.

emerald rapids also has amx support with a positive performance impact
however there's no reasonable way in ggml to differentiate between the
two. the impact is small (~6%) so disable amx entirely for simplicity
2025-02-18 14:47:58 -08:00
25 changed files with 1038 additions and 254 deletions

View File

@@ -160,6 +160,10 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:

View File

@@ -382,6 +382,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
### Cloud

View File

@@ -126,14 +126,13 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return err
}
}
return ctx.Err()
return nil
}
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer
var buf io.Reader
if data != nil {
bts, err := json.Marshal(data)
if err != nil {
@@ -190,7 +189,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
}
return ctx.Err()
return nil
}
// GenerateResponseFunc is a function that [Client.Generate] invokes every time

View File

@@ -1,6 +1,13 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
@@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -15,11 +15,13 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/containerd/console"
@@ -328,7 +330,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
@@ -857,6 +858,17 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
@@ -891,7 +903,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
return nil, err
}
@@ -931,6 +946,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
fn := func(response api.GenerateResponse) error {
@@ -966,7 +992,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
@@ -988,7 +1017,8 @@ func generate(cmd *cobra.Command, opts runOptions) error {
latest.Summary()
}
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
@@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) {
}
}
func TestListHandler(t *testing.T) {
tests := []struct {
name string
args []string
serverResponse []api.ListModelResponse
expectedError string
expectedOutput string
}{
{
name: "list all models",
args: []string{},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
"model2 sha256:def45 2.0 KB 2 days ago \n",
},
{
name: "filter models by prefix",
args: []string{"model1"},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
},
{
name: "server error",
args: []string{},
expectedError: "server error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
http.Error(w, "not found", http.StatusNotFound)
return
}
if tt.expectedError != "" {
http.Error(w, tt.expectedError, http.StatusInternalServerError)
return
}
response := api.ListResponse{Models: tt.serverResponse}
if err := json.NewEncoder(w).Encode(response); err != nil {
t.Fatal(err)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := ListHandler(cmd, tt.args)
// Restore stdout and get output
w.Close()
os.Stdout = oldStdout
output, _ := io.ReadAll(r)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string

View File

@@ -46,15 +46,6 @@ Install prerequisites:
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
> [!IMPORTANT]
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
> [!IMPORTANT]
> CUDA is only compatible with Visual Studio CMake generators.
Then, configure and build the project:
```shell
@@ -62,6 +53,14 @@ cmake -B build
cmake --build build --config Release
```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
> cmake --build build --config Release
> ```
Lastly, run Ollama:
```shell

View File

@@ -53,8 +53,8 @@ func Host() *url.URL {
}
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
@@ -249,7 +249,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},

View File

@@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})

View File

@@ -0,0 +1,315 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring
---
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
1 file changed, 58 insertions(+), 86 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 84b21dd8..e35a6936 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -66,26 +66,6 @@
#include "ggml-kompute.h"
#endif
-// disable C++17 deprecation warning for std::codecvt_utf8
-#if defined(__clang__)
-# pragma clang diagnostic push
-# pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
-static std::wstring utf8_to_utf16(const std::string & str) {
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
- return converter.from_bytes(str);
-}
-
-static std::string utf16_to_utf8(const std::wstring & str) {
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
- return converter.to_bytes(str);
-}
-
-#if defined(__clang__)
-# pragma clang diagnostic pop
-#endif
-
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -96,7 +76,7 @@ struct dl_handle_deleter {
}
};
-static dl_handle * dl_load_library(const std::wstring & path) {
+static dl_handle * dl_load_library(const std::filesystem::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
@@ -129,8 +109,8 @@ struct dl_handle_deleter {
}
};
-static void * dl_load_library(const std::wstring & path) {
- dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
+static void * dl_load_library(const std::filesystem::path & path) {
+ dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
#endif
+static std::string path_to_string(const std::filesystem::path & path)
+{
+#ifdef _WIN32
+ const std::wstring wstr = path.wstring();
+ const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
+ if (size_needed <= 0) {
+ return std::string();
+ }
+
+ // size_needed includes the null terminator
+ std::string str(size_needed - 1, '\0');
+ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
+ return str;
+#else
+ return path.string();
+#endif
+}
+
+
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
@@ -222,11 +221,11 @@ struct ggml_backend_registry {
);
}
- ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
+ ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
if (!silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -234,7 +233,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) {
if (!silent) {
- GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -242,7 +241,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) {
if (!silent) {
- GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -251,16 +250,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) {
if (!reg) {
- GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
} else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
- __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+ __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
}
}
return nullptr;
}
- GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) {
// Dynamic loading
ggml_backend_reg_t ggml_backend_load(const char * path) {
- return get_reg().load_backend(utf8_to_utf16(path), false);
+ return get_reg().load_backend(path, false);
}
void ggml_backend_unload(ggml_backend_reg_t reg) {
get_reg().unload_backend(reg, true);
}
-static std::wstring get_executable_path() {
+static std::filesystem::path get_executable_path() {
#if defined(__APPLE__)
// get executable path
std::vector<char> path;
@@ -415,15 +414,9 @@ static std::wstring get_executable_path() {
}
path.resize(size);
}
- std::string base_path(path.data(), size);
- // remove executable name
- auto last_slash = base_path.find_last_of('/');
- if (last_slash != std::string::npos) {
- base_path = base_path.substr(0, last_slash);
- }
- return utf8_to_utf16(base_path + "/");
+
+ return std::filesystem::path(path.data()).parent_path();
#elif defined(__linux__) || defined(__FreeBSD__)
- std::string base_path = ".";
std::vector<char> path(1024);
while (true) {
// get executable path
@@ -436,76 +429,55 @@ static std::wstring get_executable_path() {
break;
}
if (len < (ssize_t) path.size()) {
- base_path = std::string(path.data(), len);
- // remove executable name
- auto last_slash = base_path.find_last_of('/');
- if (last_slash != std::string::npos) {
- base_path = base_path.substr(0, last_slash);
- }
- break;
+ return std::filesystem::path(path.data()).parent_path();
}
path.resize(path.size() * 2);
}
-
- return utf8_to_utf16(base_path + "/");
#elif defined(_WIN32)
std::vector<wchar_t> path(MAX_PATH);
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
if (len == 0) {
return {};
}
- std::wstring base_path(path.data(), len);
- // remove executable name
- auto last_slash = base_path.find_last_of('\\');
- if (last_slash != std::string::npos) {
- base_path = base_path.substr(0, last_slash);
- }
- return base_path + L"\\";
-#else
- return {};
-#endif
-}
-static std::wstring backend_filename_prefix() {
-#ifdef _WIN32
- return L"ggml-";
-#else
- return L"libggml-";
+ return std::filesystem::path(path.data()).parent_path();
#endif
+ return {};
}
-static std::wstring backend_filename_suffix() {
+static std::string backend_filename_prefix() {
#ifdef _WIN32
- return L".dll";
+ return "ggml-";
#else
- return L".so";
+ return "libggml-";
#endif
}
-static std::wstring path_separator() {
+static std::string backend_filename_suffix() {
#ifdef _WIN32
- return L"\\";
+ return ".dll";
#else
- return L"/";
+ return ".so";
#endif
}
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
// TODO: search system paths
- std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
- std::vector<std::wstring> search_paths;
+ namespace fs = std::filesystem;
+ std::string file_prefix = backend_filename_prefix() + name + "-";
+ std::vector<fs::path> search_paths;
+
if (user_search_path == nullptr) {
- search_paths.push_back(L"." + path_separator());
+ search_paths.push_back(fs::current_path());
search_paths.push_back(get_executable_path());
} else {
- search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
+ search_paths.push_back(fs::u8path(user_search_path));
}
int best_score = 0;
- std::wstring best_path;
+ fs::path best_path;
- namespace fs = std::filesystem;
for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) {
continue;
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
for (const auto & entry : dir_it) {
try {
if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
+ std::string filename = entry.path().filename().string();
+ std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
- best_path = entry.path().wstring();
+ best_path = entry.path();
}
}
}
} catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) {
// try to load the base backend
for (const auto & search_path : search_paths) {
- std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
+ fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
if (fs::exists(path)) {
return get_reg().load_backend(path, silent);
}

View File

@@ -0,0 +1,24 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 18 Feb 2025 14:47:21 -0800
Subject: [PATCH] remove amx
---
ggml/src/CMakeLists.txt | 4 ----
1 file changed, 4 deletions(-)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 72b488dd..50828717 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
- if (NOT MSVC)
- # MSVC doesn't support AMX
- ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
- endif()
else ()
ggml_add_cpu_backend_variant_impl("")
endif()

14
main.go
View File

@@ -2,8 +2,6 @@ package main
import (
"context"
"os"
"os/signal"
"github.com/spf13/cobra"
@@ -11,15 +9,5 @@ import (
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
go func() {
<-sigChan
cancel()
}()
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
}

View File

@@ -26,9 +26,24 @@ type Backend interface {
SystemInfo() string
}
var backends = make(map[string]func(*os.File) (Backend, error))
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
// MainGPU is the index of the primary GPU to use
MainGPU int
// NumGPULayers is the number of layers to offload to GPUs
NumGPULayers int
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -36,9 +51,9 @@ func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
backends[name] = f
}
func NewBackend(f *os.File) (Backend, error) {
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(f)
return backend(f, params)
}
return nil, fmt.Errorf("unsupported backend")
@@ -96,6 +111,26 @@ type Tensor interface {
Copy(ctx Context, t2 Tensor) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
}
type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@@ -82,9 +82,11 @@ type Backend struct {
meta *fs.GGML
cpus, gpus []Context
tensors map[string]*Context
sched *C.struct_ggml_backend_sched
}
func New(r *os.File) (ml.Backend, error) {
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@@ -182,10 +184,24 @@ func New(r *os.File) (ml.Backend, error) {
return nil, err
}
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
for i, c := range append(gpus, cpus...) {
backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
}
return &Backend{
meta: meta,
cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
true,
),
}, nil
}
@@ -219,31 +235,23 @@ func (b *Backend) NewContext() ml.Context {
})
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus))
for i, c := range append(b.gpus, b.cpus...) {
backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
}
return &Context{
b: b,
ctx: c,
backend: backends[0],
nodes: nodes,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(nodes),
true,
),
}
}
type Context struct {
b *Backend
ctx *C.struct_ggml_context
backend *C.struct_ggml_backend
sched *C.struct_ggml_backend_sched
graph *C.struct_ggml_cgraph
nodes int
}
@@ -257,12 +265,13 @@ func (c *Context) Forward(t ml.Tensor) {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
needSync := true
sync := func() {
if needSync {
C.ggml_backend_sched_synchronize(c.sched)
C.ggml_backend_sched_synchronize(c.b.sched)
needSync = false
}
}
@@ -350,7 +359,6 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) Close() {
if c != nil {
C.ggml_backend_sched_free(c.sched)
C.ggml_free(c.ctx)
}
}
@@ -477,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -643,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
}
kq := key.MulmatFullPrec(ctx, t)
kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
func (b *Backend) SystemInfo() string {
var compiler string
switch C.get_compiler() {

View File

@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
else ()
ggml_add_cpu_backend_variant_impl("")
endif()

View File

@@ -66,26 +66,6 @@
#include "ggml-kompute.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
static std::wstring utf8_to_utf16(const std::string & str) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(str);
}
static std::string utf16_to_utf8(const std::wstring & str) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.to_bytes(str);
}
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -96,7 +76,7 @@ struct dl_handle_deleter {
}
};
static dl_handle * dl_load_library(const std::wstring & path) {
static dl_handle * dl_load_library(const std::filesystem::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
@@ -129,8 +109,8 @@ struct dl_handle_deleter {
}
};
static void * dl_load_library(const std::wstring & path) {
dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
static void * dl_load_library(const std::filesystem::path & path) {
dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
#endif
static std::string path_to_string(const std::filesystem::path & path)
{
#ifdef _WIN32
const std::wstring wstr = path.wstring();
const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
if (size_needed <= 0) {
return std::string();
}
// size_needed includes the null terminator
std::string str(size_needed - 1, '\0');
WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
return str;
#else
return path.string();
#endif
}
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
@@ -222,11 +221,11 @@ struct ggml_backend_registry {
);
}
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
if (!silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -234,7 +233,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) {
if (!silent) {
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -242,7 +241,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) {
if (!silent) {
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -251,16 +250,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) {
if (!reg) {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
} else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
__func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
__func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
}
}
return nullptr;
}
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) {
// Dynamic loading
ggml_backend_reg_t ggml_backend_load(const char * path) {
return get_reg().load_backend(utf8_to_utf16(path), false);
return get_reg().load_backend(path, false);
}
void ggml_backend_unload(ggml_backend_reg_t reg) {
get_reg().unload_backend(reg, true);
}
static std::wstring get_executable_path() {
static std::filesystem::path get_executable_path() {
#if defined(__APPLE__)
// get executable path
std::vector<char> path;
@@ -415,15 +414,9 @@ static std::wstring get_executable_path() {
}
path.resize(size);
}
std::string base_path(path.data(), size);
// remove executable name
auto last_slash = base_path.find_last_of('/');
if (last_slash != std::string::npos) {
base_path = base_path.substr(0, last_slash);
}
return utf8_to_utf16(base_path + "/");
return std::filesystem::path(path.data()).parent_path();
#elif defined(__linux__) || defined(__FreeBSD__)
std::string base_path = ".";
std::vector<char> path(1024);
while (true) {
// get executable path
@@ -436,76 +429,55 @@ static std::wstring get_executable_path() {
break;
}
if (len < (ssize_t) path.size()) {
base_path = std::string(path.data(), len);
// remove executable name
auto last_slash = base_path.find_last_of('/');
if (last_slash != std::string::npos) {
base_path = base_path.substr(0, last_slash);
}
break;
return std::filesystem::path(path.data()).parent_path();
}
path.resize(path.size() * 2);
}
return utf8_to_utf16(base_path + "/");
#elif defined(_WIN32)
std::vector<wchar_t> path(MAX_PATH);
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
if (len == 0) {
return {};
}
std::wstring base_path(path.data(), len);
// remove executable name
auto last_slash = base_path.find_last_of('\\');
if (last_slash != std::string::npos) {
base_path = base_path.substr(0, last_slash);
}
return base_path + L"\\";
#else
return std::filesystem::path(path.data()).parent_path();
#endif
return {};
}
static std::string backend_filename_prefix() {
#ifdef _WIN32
return "ggml-";
#else
return "libggml-";
#endif
}
static std::wstring backend_filename_prefix() {
static std::string backend_filename_suffix() {
#ifdef _WIN32
return L"ggml-";
return ".dll";
#else
return L"libggml-";
#endif
}
static std::wstring backend_filename_suffix() {
#ifdef _WIN32
return L".dll";
#else
return L".so";
#endif
}
static std::wstring path_separator() {
#ifdef _WIN32
return L"\\";
#else
return L"/";
return ".so";
#endif
}
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
// TODO: search system paths
std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
std::vector<std::wstring> search_paths;
namespace fs = std::filesystem;
std::string file_prefix = backend_filename_prefix() + name + "-";
std::vector<fs::path> search_paths;
if (user_search_path == nullptr) {
search_paths.push_back(L"." + path_separator());
search_paths.push_back(fs::current_path());
search_paths.push_back(get_executable_path());
} else {
search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
search_paths.push_back(fs::u8path(user_search_path));
}
int best_score = 0;
std::wstring best_path;
fs::path best_path;
namespace fs = std::filesystem;
for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) {
continue;
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
for (const auto & entry : dir_it) {
try {
if (entry.is_regular_file()) {
std::wstring filename = entry.path().filename().wstring();
std::wstring ext = entry.path().extension().wstring();
std::string filename = entry.path().filename().string();
std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
best_path = entry.path().wstring();
best_path = entry.path();
}
}
}
} catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) {
// try to load the base backend
for (const auto & search_path : search_paths) {
std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
if (fs::exists(path)) {
return get_reg().load_backend(path, silent);
}

59
ml/nn/attention.go Normal file
View File

@@ -0,0 +1,59 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}

View File

@@ -70,14 +70,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string) (Model, error) {
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r)
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
}

View File

@@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
@@ -120,11 +115,19 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@@ -144,22 +147,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {

View File

@@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err
}
// TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@@ -74,11 +69,19 @@ type TextSelfAttentionDecoderLayer struct {
MLP *TextMLP
}
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@@ -104,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value ml.Tensor
var key, value, mask ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -117,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
cache.Put(ctx, key, value)
} else {
key, value, _ = cache.Get(ctx)
key, value, mask = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)
@@ -145,7 +144,7 @@ type TextCrossAttentionDecoderLayer struct {
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
}
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -161,14 +160,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
}
type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
}
type TextDecoder struct {
Layers []TextDecoderLayer
}
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers {
layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@@ -179,7 +178,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
cache.SetLayerType(layerType)
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
var lastLayerOutputs ml.Tensor
if i == len(d.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
}
}
@@ -205,9 +209,9 @@ type TextModel struct {
*TextModelOptions
}
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -49,29 +49,29 @@ func (p *Progress) stop() bool {
func (p *Progress) Stop() bool {
stopped := p.stop()
if stopped {
fmt.Fprintln(p.w)
fmt.Fprint(p.w, "\n")
p.w.Flush()
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
func (p *Progress) StopAndClear() bool {
defer p.w.Flush()
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
stopped := p.stop()
if stopped {
// clear all progress lines
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
}
fmt.Fprint(p.w, "\033[2K", "\033[1G")
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
@@ -86,13 +86,19 @@ func (p *Progress) render() {
p.mu.Lock()
defer p.mu.Unlock()
defer p.w.Flush()
// eliminate flickering on terminals that support synchronized output
fmt.Fprint(p.w, "\033[?2026h")
defer fmt.Fprint(p.w, "\033[?2026l")
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
// move the cursor back to the beginning
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[1G")
// render progress lines
@@ -104,13 +110,10 @@ func (p *Progress) render() {
}
p.pos = len(p.states)
p.w.Flush()
}
func (p *Progress) start() {
p.ticker = time.NewTicker(100 * time.Millisecond)
// hide cursor
fmt.Fprint(p.w, "\033[?25l")
for range p.ticker.C {
p.render()
}

View File

@@ -25,6 +25,7 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -801,6 +802,7 @@ func (m *multiLPath) String() string {
func (s *Server) loadModel(
mpath string,
params ml.BackendParams,
lpath multiLPath,
parallel int,
kvCacheType string,
@@ -808,12 +810,12 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */)
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
// TODO(jessegross): LoRA loading
if lpath.String() != "" {
@@ -843,17 +845,17 @@ func Execute(args []string) error {
mpath := fs.String("model", "", "Path to model binary file")
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := fs.Int("batch-size", 512, "Batch size")
_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
_ = fs.Int("main-gpu", 0, "Main GPU")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
_ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
_ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
var lpaths multiLPath
@@ -890,15 +892,11 @@ func Execute(args []string) error {
}
// TODO(jessegross): Parameters that need to be implemented:
// n-gpu-layers
// main-gpu
// flash-attn
// threads
// no-mmap
// mlock
// tensor-split
/*var tensorSplitFloats []float32
var tensorSplitFloats []float32
if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
@@ -907,10 +905,17 @@ func Execute(args []string) error {
f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f))
}
}*/
}
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
}
server.ready.Add(1)
go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)

View File

@@ -1127,54 +1127,72 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
corsConfig.AllowHeaders = []string{
"Authorization",
"Content-Type",
"User-Agent",
"Accept",
"X-Requested-With",
// OpenAI compatibility headers
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-retry-count",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-async",
"x-stainless-helper-method",
"x-stainless-poll-helper",
"x-stainless-custom-poll-interval",
"x-stainless-timeout",
}
config.AllowOrigins = envconfig.Origins()
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r.Use(
cors.New(config),
cors.New(corsConfig),
allowedHostsMiddleware(s.addr),
)
// General
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/show", s.ShowHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
// Compatibility endpoints
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", s.ListHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
return r
}

View File

@@ -179,7 +179,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
if allReliable {
// HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus))
} else {
// HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))