mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
2 Commits
v0.5.12
...
progress-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcfbb06f1b | ||
|
|
e8d35d0de0 |
4
.github/workflows/release.yaml
vendored
4
.github/workflows/release.yaml
vendored
@@ -160,10 +160,6 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
|
||||
@@ -382,10 +382,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -126,13 +126,14 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf io.Reader
|
||||
var buf *bytes.Buffer
|
||||
if data != nil {
|
||||
bts, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -50,206 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testError represents an internal error type with status code and message
|
||||
// this is used since the error response from the server is not a standard error struct
|
||||
type testError struct {
|
||||
message string
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (e testError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func TestClientStream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
responses []any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
responses: []any{
|
||||
testError{
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
},
|
||||
wantErr: "test error message",
|
||||
},
|
||||
{
|
||||
name: "error after successful chunks, ok response",
|
||||
responses: []any{
|
||||
ChatResponse{Message: Message{Content: "partial response 1"}},
|
||||
ChatResponse{Message: Message{Content: "partial response 2"}},
|
||||
testError{
|
||||
message: "mid-stream error",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
wantErr: "mid-stream error",
|
||||
},
|
||||
{
|
||||
name: "successful stream completion",
|
||||
responses: []any{
|
||||
ChatResponse{Message: Message{Content: "chunk 1"}},
|
||||
ChatResponse{Message: Message{Content: "chunk 2"}},
|
||||
ChatResponse{
|
||||
Message: Message{Content: "final chunk"},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatal("expected http.Flusher")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
|
||||
for _, resp := range tc.responses {
|
||||
if errResp, ok := resp.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||
|
||||
var receivedChunks []ChatResponse
|
||||
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||
}
|
||||
receivedChunks = append(receivedChunks, resp)
|
||||
return nil
|
||||
})
|
||||
|
||||
if tc.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.wantErr) {
|
||||
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
response: testError{
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantErr: "test error message",
|
||||
},
|
||||
{
|
||||
name: "server error response",
|
||||
response: testError{
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
wantErr: "internal error",
|
||||
},
|
||||
{
|
||||
name: "successful response",
|
||||
response: struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
ID: "msg_123",
|
||||
Success: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if errResp, ok := tc.response.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||
|
||||
var resp struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||
|
||||
if tc.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("got nil, want error %q", tc.wantErr)
|
||||
}
|
||||
if err.Error() != tc.wantErr {
|
||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("got error %q, want nil", err)
|
||||
}
|
||||
|
||||
if expectedResp, ok := tc.response.(struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
}); ok {
|
||||
if resp.ID != expectedResp.ID {
|
||||
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
|
||||
}
|
||||
if resp.Success != expectedResp.Success {
|
||||
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
38
cmd/cmd.go
38
cmd/cmd.go
@@ -15,13 +15,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/console"
|
||||
@@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
}
|
||||
return info, err
|
||||
@@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
@@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
if err := client.Chat(cmd.Context(), req, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
@@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -491,96 +490,6 @@ func TestPushHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
serverResponse []api.ListModelResponse
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
name: "list all models",
|
||||
args: []string{},
|
||||
serverResponse: []api.ListModelResponse{
|
||||
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
|
||||
},
|
||||
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
|
||||
"model2 sha256:def45 2.0 KB 2 days ago \n",
|
||||
},
|
||||
{
|
||||
name: "filter models by prefix",
|
||||
args: []string{"model1"},
|
||||
serverResponse: []api.ListModelResponse{
|
||||
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||
},
|
||||
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
args: []string{},
|
||||
expectedError: "server error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
|
||||
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.expectedError != "" {
|
||||
http.Error(w, tt.expectedError, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := api.ListResponse{Models: tt.serverResponse}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
err := ListHandler(cmd, tt.args)
|
||||
|
||||
// Restore stdout and get output
|
||||
w.Close()
|
||||
os.Stdout = oldStdout
|
||||
output, _ := io.ReadAll(r)
|
||||
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if got := string(output); got != tt.expectedOutput {
|
||||
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
||||
}
|
||||
} else {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -46,6 +46,15 @@ Install prerequisites:
|
||||
- (Optional) NVIDIA GPU support
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> CUDA is only compatible with Visual Studio CMake generators.
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
```shell
|
||||
@@ -53,14 +62,6 @@ cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Building for ROCm requires additional flags:
|
||||
> ```
|
||||
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||
> cmake --build build --config Release
|
||||
> ```
|
||||
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```shell
|
||||
|
||||
@@ -53,8 +53,8 @@ func Host() *url.URL {
|
||||
}
|
||||
}
|
||||
|
||||
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||
func AllowedOrigins() (origins []string) {
|
||||
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||
func Origins() (origins []string) {
|
||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||
origins = strings.Split(s, ",")
|
||||
}
|
||||
@@ -249,7 +249,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) {
|
||||
t.Run(tt.value, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
||||
|
||||
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
|
||||
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
|
||||
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 16 Feb 2025 20:00:22 -0500
|
||||
Subject: [PATCH] use std::filesystem::path instead of wstring
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
|
||||
1 file changed, 58 insertions(+), 86 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 84b21dd8..e35a6936 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -66,26 +66,6 @@
|
||||
#include "ggml-kompute.h"
|
||||
#endif
|
||||
|
||||
-// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
-#if defined(__clang__)
|
||||
-# pragma clang diagnostic push
|
||||
-# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
-#endif
|
||||
-
|
||||
-static std::wstring utf8_to_utf16(const std::string & str) {
|
||||
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
- return converter.from_bytes(str);
|
||||
-}
|
||||
-
|
||||
-static std::string utf16_to_utf8(const std::wstring & str) {
|
||||
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
- return converter.to_bytes(str);
|
||||
-}
|
||||
-
|
||||
-#if defined(__clang__)
|
||||
-# pragma clang diagnostic pop
|
||||
-#endif
|
||||
-
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
@@ -96,7 +76,7 @@ struct dl_handle_deleter {
|
||||
}
|
||||
};
|
||||
|
||||
-static dl_handle * dl_load_library(const std::wstring & path) {
|
||||
+static dl_handle * dl_load_library(const std::filesystem::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
@@ -129,8 +109,8 @@ struct dl_handle_deleter {
|
||||
}
|
||||
};
|
||||
|
||||
-static void * dl_load_library(const std::wstring & path) {
|
||||
- dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
+static void * dl_load_library(const std::filesystem::path & path) {
|
||||
+ dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
|
||||
#endif
|
||||
|
||||
+static std::string path_to_string(const std::filesystem::path & path)
|
||||
+{
|
||||
+#ifdef _WIN32
|
||||
+ const std::wstring wstr = path.wstring();
|
||||
+ const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
|
||||
+ if (size_needed <= 0) {
|
||||
+ return std::string();
|
||||
+ }
|
||||
+
|
||||
+ // size_needed includes the null terminator
|
||||
+ std::string str(size_needed - 1, '\0');
|
||||
+ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
|
||||
+ return str;
|
||||
+#else
|
||||
+ return path.string();
|
||||
+#endif
|
||||
+}
|
||||
+
|
||||
+
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
struct ggml_backend_reg_entry {
|
||||
@@ -222,11 +221,11 @@ struct ggml_backend_registry {
|
||||
);
|
||||
}
|
||||
|
||||
- ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
||||
+ ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
|
||||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
if (!silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -234,7 +233,7 @@ struct ggml_backend_registry {
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (score_fn && score_fn() == 0) {
|
||||
if (!silent) {
|
||||
- GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
|
||||
+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -242,7 +241,7 @@ struct ggml_backend_registry {
|
||||
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
|
||||
if (!backend_init_fn) {
|
||||
if (!silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -251,16 +250,16 @@ struct ggml_backend_registry {
|
||||
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
|
||||
if (!silent) {
|
||||
if (!reg) {
|
||||
- GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
|
||||
- __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
+ __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
- GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
|
||||
+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
|
||||
|
||||
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
|
||||
|
||||
@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) {
|
||||
|
||||
// Dynamic loading
|
||||
ggml_backend_reg_t ggml_backend_load(const char * path) {
|
||||
- return get_reg().load_backend(utf8_to_utf16(path), false);
|
||||
+ return get_reg().load_backend(path, false);
|
||||
}
|
||||
|
||||
void ggml_backend_unload(ggml_backend_reg_t reg) {
|
||||
get_reg().unload_backend(reg, true);
|
||||
}
|
||||
|
||||
-static std::wstring get_executable_path() {
|
||||
+static std::filesystem::path get_executable_path() {
|
||||
#if defined(__APPLE__)
|
||||
// get executable path
|
||||
std::vector<char> path;
|
||||
@@ -415,15 +414,9 @@ static std::wstring get_executable_path() {
|
||||
}
|
||||
path.resize(size);
|
||||
}
|
||||
- std::string base_path(path.data(), size);
|
||||
- // remove executable name
|
||||
- auto last_slash = base_path.find_last_of('/');
|
||||
- if (last_slash != std::string::npos) {
|
||||
- base_path = base_path.substr(0, last_slash);
|
||||
- }
|
||||
- return utf8_to_utf16(base_path + "/");
|
||||
+
|
||||
+ return std::filesystem::path(path.data()).parent_path();
|
||||
#elif defined(__linux__) || defined(__FreeBSD__)
|
||||
- std::string base_path = ".";
|
||||
std::vector<char> path(1024);
|
||||
while (true) {
|
||||
// get executable path
|
||||
@@ -436,76 +429,55 @@ static std::wstring get_executable_path() {
|
||||
break;
|
||||
}
|
||||
if (len < (ssize_t) path.size()) {
|
||||
- base_path = std::string(path.data(), len);
|
||||
- // remove executable name
|
||||
- auto last_slash = base_path.find_last_of('/');
|
||||
- if (last_slash != std::string::npos) {
|
||||
- base_path = base_path.substr(0, last_slash);
|
||||
- }
|
||||
- break;
|
||||
+ return std::filesystem::path(path.data()).parent_path();
|
||||
}
|
||||
path.resize(path.size() * 2);
|
||||
}
|
||||
-
|
||||
- return utf8_to_utf16(base_path + "/");
|
||||
#elif defined(_WIN32)
|
||||
std::vector<wchar_t> path(MAX_PATH);
|
||||
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
|
||||
if (len == 0) {
|
||||
return {};
|
||||
}
|
||||
- std::wstring base_path(path.data(), len);
|
||||
- // remove executable name
|
||||
- auto last_slash = base_path.find_last_of('\\');
|
||||
- if (last_slash != std::string::npos) {
|
||||
- base_path = base_path.substr(0, last_slash);
|
||||
- }
|
||||
- return base_path + L"\\";
|
||||
-#else
|
||||
- return {};
|
||||
-#endif
|
||||
-}
|
||||
|
||||
-static std::wstring backend_filename_prefix() {
|
||||
-#ifdef _WIN32
|
||||
- return L"ggml-";
|
||||
-#else
|
||||
- return L"libggml-";
|
||||
+ return std::filesystem::path(path.data()).parent_path();
|
||||
#endif
|
||||
+ return {};
|
||||
}
|
||||
|
||||
-static std::wstring backend_filename_suffix() {
|
||||
+static std::string backend_filename_prefix() {
|
||||
#ifdef _WIN32
|
||||
- return L".dll";
|
||||
+ return "ggml-";
|
||||
#else
|
||||
- return L".so";
|
||||
+ return "libggml-";
|
||||
#endif
|
||||
}
|
||||
|
||||
-static std::wstring path_separator() {
|
||||
+static std::string backend_filename_suffix() {
|
||||
#ifdef _WIN32
|
||||
- return L"\\";
|
||||
+ return ".dll";
|
||||
#else
|
||||
- return L"/";
|
||||
+ return ".so";
|
||||
#endif
|
||||
}
|
||||
|
||||
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
|
||||
// TODO: search system paths
|
||||
- std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
|
||||
- std::vector<std::wstring> search_paths;
|
||||
+ namespace fs = std::filesystem;
|
||||
+ std::string file_prefix = backend_filename_prefix() + name + "-";
|
||||
+ std::vector<fs::path> search_paths;
|
||||
+
|
||||
if (user_search_path == nullptr) {
|
||||
- search_paths.push_back(L"." + path_separator());
|
||||
+ search_paths.push_back(fs::current_path());
|
||||
search_paths.push_back(get_executable_path());
|
||||
} else {
|
||||
- search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
|
||||
+ search_paths.push_back(fs::u8path(user_search_path));
|
||||
}
|
||||
|
||||
int best_score = 0;
|
||||
- std::wstring best_path;
|
||||
+ fs::path best_path;
|
||||
|
||||
- namespace fs = std::filesystem;
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
continue;
|
||||
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
+ std::string filename = entry.path().filename().string();
|
||||
+ std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
+ best_path = entry.path();
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
if (best_score == 0) {
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
- std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
|
||||
+ fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Tue, 18 Feb 2025 14:47:21 -0800
|
||||
Subject: [PATCH] remove amx
|
||||
|
||||
---
|
||||
ggml/src/CMakeLists.txt | 4 ----
|
||||
1 file changed, 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||
index 72b488dd..50828717 100644
|
||||
--- a/ggml/src/CMakeLists.txt
|
||||
+++ b/ggml/src/CMakeLists.txt
|
||||
@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
- if (NOT MSVC)
|
||||
- # MSVC doesn't support AMX
|
||||
- ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
- endif()
|
||||
else ()
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
14
main.go
14
main.go
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -9,5 +11,15 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
|
||||
}
|
||||
|
||||
@@ -26,24 +26,9 @@ type Backend interface {
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
var backends = make(map[string]func(*os.File) (Backend, error))
|
||||
|
||||
// MainGPU is the index of the primary GPU to use
|
||||
MainGPU int
|
||||
|
||||
// NumGPULayers is the number of layers to offload to GPUs
|
||||
NumGPULayers int
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
||||
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
@@ -51,9 +36,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
func NewBackend(f *os.File) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
return backend(f, params)
|
||||
return backend(f)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
@@ -111,26 +96,6 @@ type Tensor interface {
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
|
||||
@@ -82,11 +82,9 @@ type Backend struct {
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
}
|
||||
|
||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
func New(r *os.File) (ml.Backend, error) {
|
||||
meta, n, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -184,24 +182,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
|
||||
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
|
||||
for i, c := range append(gpus, cpus...) {
|
||||
backends[i] = c.backend
|
||||
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
C.int(len(backends)),
|
||||
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
|
||||
true,
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -235,23 +219,31 @@ func (b *Backend) NewContext() ml.Context {
|
||||
})
|
||||
|
||||
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
|
||||
bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus))
|
||||
for i, c := range append(b.gpus, b.cpus...) {
|
||||
backends[i] = c.backend
|
||||
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
|
||||
}
|
||||
|
||||
return &Context{
|
||||
b: b,
|
||||
ctx: c,
|
||||
backend: backends[0],
|
||||
nodes: nodes,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
C.int(len(backends)),
|
||||
C.size_t(nodes),
|
||||
true,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
b *Backend
|
||||
ctx *C.struct_ggml_context
|
||||
backend *C.struct_ggml_backend
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
graph *C.struct_ggml_cgraph
|
||||
nodes int
|
||||
}
|
||||
@@ -265,13 +257,12 @@ func (c *Context) Forward(t ml.Tensor) {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
||||
|
||||
needSync := true
|
||||
sync := func() {
|
||||
if needSync {
|
||||
C.ggml_backend_sched_synchronize(c.b.sched)
|
||||
C.ggml_backend_sched_synchronize(c.sched)
|
||||
needSync = false
|
||||
}
|
||||
}
|
||||
@@ -359,6 +350,7 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
|
||||
func (c *Context) Close() {
|
||||
if c != nil {
|
||||
C.ggml_backend_sched_free(c.sched)
|
||||
C.ggml_free(c.ctx)
|
||||
}
|
||||
}
|
||||
@@ -485,7 +477,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
@@ -651,21 +643,6 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, t)
|
||||
kq = &Tensor{
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
|
||||
4
ml/backend/ggml/ggml/src/CMakeLists.txt
vendored
4
ml/backend/ggml/ggml/src/CMakeLists.txt
vendored
@@ -293,6 +293,10 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
if (NOT MSVC)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
else ()
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
154
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
154
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -66,6 +66,26 @@
|
||||
#include "ggml-kompute.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
static std::wstring utf8_to_utf16(const std::string & str) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.from_bytes(str);
|
||||
}
|
||||
|
||||
static std::string utf16_to_utf8(const std::wstring & str) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.to_bytes(str);
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
@@ -76,7 +96,7 @@ struct dl_handle_deleter {
|
||||
}
|
||||
};
|
||||
|
||||
static dl_handle * dl_load_library(const std::filesystem::path & path) {
|
||||
static dl_handle * dl_load_library(const std::wstring & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
@@ -109,8 +129,8 @@ struct dl_handle_deleter {
|
||||
}
|
||||
};
|
||||
|
||||
static void * dl_load_library(const std::filesystem::path & path) {
|
||||
dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
static void * dl_load_library(const std::wstring & path) {
|
||||
dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
@@ -121,25 +141,6 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
|
||||
#endif
|
||||
|
||||
static std::string path_to_string(const std::filesystem::path & path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
const std::wstring wstr = path.wstring();
|
||||
const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
|
||||
if (size_needed <= 0) {
|
||||
return std::string();
|
||||
}
|
||||
|
||||
// size_needed includes the null terminator
|
||||
std::string str(size_needed - 1, '\0');
|
||||
WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
|
||||
return str;
|
||||
#else
|
||||
return path.string();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
struct ggml_backend_reg_entry {
|
||||
@@ -221,11 +222,11 @@ struct ggml_backend_registry {
|
||||
);
|
||||
}
|
||||
|
||||
ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
|
||||
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
||||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
if (!silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -233,7 +234,7 @@ struct ggml_backend_registry {
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (score_fn && score_fn() == 0) {
|
||||
if (!silent) {
|
||||
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
|
||||
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -241,7 +242,7 @@ struct ggml_backend_registry {
|
||||
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
|
||||
if (!backend_init_fn) {
|
||||
if (!silent) {
|
||||
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -250,16 +251,16 @@ struct ggml_backend_registry {
|
||||
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
|
||||
if (!silent) {
|
||||
if (!reg) {
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
|
||||
__func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
__func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
|
||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
|
||||
|
||||
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
|
||||
|
||||
@@ -395,14 +396,14 @@ ggml_backend_t ggml_backend_init_best(void) {
|
||||
|
||||
// Dynamic loading
|
||||
ggml_backend_reg_t ggml_backend_load(const char * path) {
|
||||
return get_reg().load_backend(path, false);
|
||||
return get_reg().load_backend(utf8_to_utf16(path), false);
|
||||
}
|
||||
|
||||
void ggml_backend_unload(ggml_backend_reg_t reg) {
|
||||
get_reg().unload_backend(reg, true);
|
||||
}
|
||||
|
||||
static std::filesystem::path get_executable_path() {
|
||||
static std::wstring get_executable_path() {
|
||||
#if defined(__APPLE__)
|
||||
// get executable path
|
||||
std::vector<char> path;
|
||||
@@ -414,9 +415,15 @@ static std::filesystem::path get_executable_path() {
|
||||
}
|
||||
path.resize(size);
|
||||
}
|
||||
|
||||
return std::filesystem::path(path.data()).parent_path();
|
||||
std::string base_path(path.data(), size);
|
||||
// remove executable name
|
||||
auto last_slash = base_path.find_last_of('/');
|
||||
if (last_slash != std::string::npos) {
|
||||
base_path = base_path.substr(0, last_slash);
|
||||
}
|
||||
return utf8_to_utf16(base_path + "/");
|
||||
#elif defined(__linux__) || defined(__FreeBSD__)
|
||||
std::string base_path = ".";
|
||||
std::vector<char> path(1024);
|
||||
while (true) {
|
||||
// get executable path
|
||||
@@ -429,55 +436,76 @@ static std::filesystem::path get_executable_path() {
|
||||
break;
|
||||
}
|
||||
if (len < (ssize_t) path.size()) {
|
||||
return std::filesystem::path(path.data()).parent_path();
|
||||
base_path = std::string(path.data(), len);
|
||||
// remove executable name
|
||||
auto last_slash = base_path.find_last_of('/');
|
||||
if (last_slash != std::string::npos) {
|
||||
base_path = base_path.substr(0, last_slash);
|
||||
}
|
||||
break;
|
||||
}
|
||||
path.resize(path.size() * 2);
|
||||
}
|
||||
|
||||
return utf8_to_utf16(base_path + "/");
|
||||
#elif defined(_WIN32)
|
||||
std::vector<wchar_t> path(MAX_PATH);
|
||||
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
|
||||
if (len == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
return std::filesystem::path(path.data()).parent_path();
|
||||
#endif
|
||||
std::wstring base_path(path.data(), len);
|
||||
// remove executable name
|
||||
auto last_slash = base_path.find_last_of('\\');
|
||||
if (last_slash != std::string::npos) {
|
||||
base_path = base_path.substr(0, last_slash);
|
||||
}
|
||||
return base_path + L"\\";
|
||||
#else
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string backend_filename_prefix() {
|
||||
#ifdef _WIN32
|
||||
return "ggml-";
|
||||
#else
|
||||
return "libggml-";
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string backend_filename_suffix() {
|
||||
static std::wstring backend_filename_prefix() {
|
||||
#ifdef _WIN32
|
||||
return ".dll";
|
||||
return L"ggml-";
|
||||
#else
|
||||
return ".so";
|
||||
return L"libggml-";
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::wstring backend_filename_suffix() {
|
||||
#ifdef _WIN32
|
||||
return L".dll";
|
||||
#else
|
||||
return L".so";
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::wstring path_separator() {
|
||||
#ifdef _WIN32
|
||||
return L"\\";
|
||||
#else
|
||||
return L"/";
|
||||
#endif
|
||||
}
|
||||
|
||||
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
|
||||
// TODO: search system paths
|
||||
namespace fs = std::filesystem;
|
||||
std::string file_prefix = backend_filename_prefix() + name + "-";
|
||||
std::vector<fs::path> search_paths;
|
||||
|
||||
std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
|
||||
std::vector<std::wstring> search_paths;
|
||||
if (user_search_path == nullptr) {
|
||||
search_paths.push_back(fs::current_path());
|
||||
search_paths.push_back(L"." + path_separator());
|
||||
search_paths.push_back(get_executable_path());
|
||||
} else {
|
||||
search_paths.push_back(fs::u8path(user_search_path));
|
||||
search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
|
||||
}
|
||||
|
||||
int best_score = 0;
|
||||
fs::path best_path;
|
||||
std::wstring best_path;
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
continue;
|
||||
@@ -486,31 +514,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
std::wstring filename = entry.path().filename().wstring();
|
||||
std::wstring ext = entry.path().extension().wstring();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path();
|
||||
best_path = entry.path().wstring();
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -518,7 +546,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
if (best_score == 0) {
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
fs::path path = fs::path(search_path) / (backend_filename_prefix() + name + backend_filename_suffix());
|
||||
std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||
// - mask: Optional attention mask that is added to the attention score. If
|
||||
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||
}
|
||||
|
||||
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
@@ -70,14 +70,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
}
|
||||
|
||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
func New(modelPath string) (Model, error) {
|
||||
r, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
b, err := ml.NewBackend(r, params)
|
||||
b, err := ml.NewBackend(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -86,8 +86,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
@@ -115,19 +120,11 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
@@ -147,26 +144,22 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -93,13 +93,15 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -38,8 +38,13 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
@@ -69,19 +74,11 @@ type TextSelfAttentionDecoderLayer struct {
|
||||
MLP *TextMLP
|
||||
}
|
||||
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
@@ -107,7 +104,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value, mask ml.Tensor
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
@@ -120,15 +117,19 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
key, value, _ = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
@@ -144,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct {
|
||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||
}
|
||||
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -160,14 +161,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
||||
}
|
||||
|
||||
type TextDecoderLayer interface {
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextDecoder struct {
|
||||
Layers []TextDecoderLayer
|
||||
}
|
||||
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
for i, layer := range d.Layers {
|
||||
layerType := selfAttentionLayer
|
||||
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
||||
@@ -178,12 +179,7 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
||||
cache.SetLayerType(layerType)
|
||||
|
||||
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(d.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,9 +205,9 @@ type TextModel struct {
|
||||
*TextModelOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -49,29 +49,29 @@ func (p *Progress) stop() bool {
|
||||
func (p *Progress) Stop() bool {
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
p.w.Flush()
|
||||
fmt.Fprintln(p.w)
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
func (p *Progress) StopAndClear() bool {
|
||||
defer p.w.Flush()
|
||||
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
// clear all progress lines
|
||||
for i := range p.pos {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
fmt.Fprint(p.w, "\033[2K\033[1G")
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[2K", "\033[1G")
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
@@ -86,19 +86,13 @@ func (p *Progress) render() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
defer p.w.Flush()
|
||||
|
||||
// eliminate flickering on terminals that support synchronized output
|
||||
fmt.Fprint(p.w, "\033[?2026h")
|
||||
defer fmt.Fprint(p.w, "\033[?2026l")
|
||||
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
||||
// move the cursor back to the beginning
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[1G")
|
||||
|
||||
// render progress lines
|
||||
@@ -110,10 +104,13 @@ func (p *Progress) render() {
|
||||
}
|
||||
|
||||
p.pos = len(p.states)
|
||||
p.w.Flush()
|
||||
}
|
||||
|
||||
func (p *Progress) start() {
|
||||
p.ticker = time.NewTicker(100 * time.Millisecond)
|
||||
// hide cursor
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
for range p.ticker.C {
|
||||
p.render()
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
@@ -802,7 +801,6 @@ func (m *multiLPath) String() string {
|
||||
|
||||
func (s *Server) loadModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
@@ -810,12 +808,12 @@ func (s *Server) loadModel(
|
||||
multiUserCache bool,
|
||||
) {
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
s.model, err = model.New(mpath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
|
||||
slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
@@ -845,17 +843,17 @@ func Execute(args []string) error {
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
_ = fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
_ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
@@ -892,11 +890,15 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// n-gpu-layers
|
||||
// main-gpu
|
||||
// flash-attn
|
||||
// threads
|
||||
// no-mmap
|
||||
// mlock
|
||||
// tensor-split
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
/*var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
@@ -905,17 +907,10 @@ func Execute(args []string) error {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
}
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
}
|
||||
}*/
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
|
||||
@@ -1127,72 +1127,54 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
corsConfig.AllowHeaders = []string{
|
||||
"Authorization",
|
||||
"Content-Type",
|
||||
"User-Agent",
|
||||
"Accept",
|
||||
"X-Requested-With",
|
||||
|
||||
// OpenAI compatibility headers
|
||||
"x-stainless-lang",
|
||||
"x-stainless-package-version",
|
||||
"x-stainless-os",
|
||||
"x-stainless-arch",
|
||||
"x-stainless-retry-count",
|
||||
"x-stainless-runtime",
|
||||
"x-stainless-runtime-version",
|
||||
"x-stainless-async",
|
||||
"x-stainless-helper-method",
|
||||
"x-stainless-poll-helper",
|
||||
"x-stainless-custom-poll-interval",
|
||||
"x-stainless-timeout",
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
|
||||
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
|
||||
for _, prop := range openAIProperties {
|
||||
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||
config.AllowOrigins = envconfig.Origins()
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(
|
||||
cors.New(corsConfig),
|
||||
cors.New(config),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
// General
|
||||
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
|
||||
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
r.GET("/api/tags", s.ListHandler)
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
|
||||
// Create
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
// Compatibility endpoints
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||
r.Handle(method, "/", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "Ollama is running")
|
||||
})
|
||||
|
||||
r.Handle(method, "/api/tags", s.ListHandler)
|
||||
r.Handle(method, "/api/version", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"version": version.Version})
|
||||
})
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
if allReliable {
|
||||
// HACK
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus))
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
|
||||
} else {
|
||||
// HACK
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))
|
||||
|
||||
Reference in New Issue
Block a user