mirror of
https://github.com/ollama/ollama.git
synced 2026-02-07 06:03:39 -05:00
Compare commits
33 Commits
hoyyeva/de
...
parth/rend
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92af238208 | ||
|
|
7461faf651 | ||
|
|
554172759c | ||
|
|
5b6a8e6001 | ||
|
|
467bbc0dd5 | ||
|
|
6d9f9323c5 | ||
|
|
0c2489605d | ||
|
|
8b1b89a984 | ||
|
|
47e272c35a | ||
|
|
417a81fda3 | ||
|
|
dba62ff3a5 | ||
|
|
d70e935526 | ||
|
|
5c1063df7f | ||
|
|
cb485b2019 | ||
|
|
b2af50960f | ||
|
|
eac5b8bfbd | ||
|
|
604e43b28d | ||
|
|
53985b3c4d | ||
|
|
b6e02cbbd2 | ||
|
|
91935631ac | ||
|
|
8de30b568a | ||
|
|
485da9fd35 | ||
|
|
0796d79d19 | ||
|
|
92981ae3f2 | ||
|
|
8ed1adf3db | ||
|
|
440a3823a6 | ||
|
|
718961de68 | ||
|
|
330f62a7fa | ||
|
|
584e2d646f | ||
|
|
1fd4cb87b2 | ||
|
|
4aba2e8b72 | ||
|
|
2f36d769aa | ||
|
|
399eacf486 |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -15,8 +15,12 @@ ml/backend/**/*.cu linguist-vendored
|
||||
ml/backend/**/*.cuh linguist-vendored
|
||||
ml/backend/**/*.m linguist-vendored
|
||||
ml/backend/**/*.metal linguist-vendored
|
||||
ml/backend/**/*.comp linguist-vendored
|
||||
ml/backend/**/*.glsl linguist-vendored
|
||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||
|
||||
app/webview linguist-vendored
|
||||
|
||||
llama/build-info.cpp linguist-generated
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||
|
||||
|
||||
1
.github/workflows/release.yaml
vendored
1
.github/workflows/release.yaml
vendored
@@ -366,6 +366,7 @@ jobs:
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
9
.github/workflows/test.yaml
vendored
9
.github/workflows/test.yaml
vendored
@@ -226,12 +226,9 @@ jobs:
|
||||
if: always()
|
||||
run: go test -count=1 -benchtime=1x ./...
|
||||
|
||||
# TODO(bmizerany): replace this heavy tool with just the
|
||||
# tools/checks/binaries we want and then make them all run in parallel
|
||||
# across jobs, not on a single tiny vm on Github Actions.
|
||||
- uses: golangci/golangci-lint-action@v6
|
||||
- uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
args: --timeout 10m0s -v
|
||||
only-new-issues: true
|
||||
|
||||
patches:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -240,4 +237,4 @@ jobs:
|
||||
- name: Verify patches apply cleanly and do not change files
|
||||
run: |
|
||||
make -f Makefile.sync clean checkout apply-patches sync
|
||||
git diff --compact-summary --exit-code
|
||||
git diff --compact-summary --exit-code
|
||||
|
||||
@@ -1,41 +1,77 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- asasalint
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- copyloopvar
|
||||
- errcheck
|
||||
- errorlint
|
||||
- exptostd
|
||||
- gocheckcompilerdirectives
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- gosimple
|
||||
- gocritic
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
- makezero
|
||||
- misspell
|
||||
- modernize
|
||||
- nilerr
|
||||
- nilnil
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- perfsprint
|
||||
- prealloc
|
||||
- sloglint
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unused
|
||||
- usestdlibvars
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
disable:
|
||||
- usestdlibvars
|
||||
- errcheck
|
||||
linters-settings:
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -SA1019 # omit Deprecated check
|
||||
severity:
|
||||
default-severity: error
|
||||
rules:
|
||||
- linters:
|
||||
- gofmt
|
||||
- goimports
|
||||
- intrange
|
||||
severity: info
|
||||
settings:
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
perfsprint:
|
||||
strconcat: false
|
||||
concat-loop: false
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- -SA1019
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- -ST1000
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- -ST1003
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- -ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- -ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- -ST1022
|
||||
usestdlibvars:
|
||||
http-method: false
|
||||
http-status-code: false
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gci
|
||||
- gofmt
|
||||
- gofumpt
|
||||
settings:
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
|
||||
14
Dockerfile
14
Dockerfile
@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
|
||||
FROM base-${TARGETARCH} AS base
|
||||
ARG CMAKEVERSION
|
||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ENV LDFLAGS=-s
|
||||
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
@@ -57,6 +57,8 @@ ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
@@ -67,6 +69,8 @@ ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
@@ -78,6 +82,8 @@ ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
@@ -87,6 +93,8 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
@@ -118,6 +126,8 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS vulkan
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
|
||||
@@ -367,6 +367,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
||||
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
|
||||
@@ -226,7 +226,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: string(bts),
|
||||
}
|
||||
}
|
||||
return errors.New(string(bts))
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
|
||||
@@ -55,6 +55,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
type testError struct {
|
||||
message string
|
||||
statusCode int
|
||||
raw bool // if true, write message as-is instead of JSON encoding
|
||||
}
|
||||
|
||||
func (e testError) Error() string {
|
||||
@@ -111,6 +112,20 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
responses: []any{
|
||||
"internal server error",
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
responses: []any{
|
||||
"<html><body>404 Not Found</body></html>",
|
||||
},
|
||||
wantErr: "404 Not Found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -135,6 +150,12 @@ func TestClientStream(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := resp.(string); ok {
|
||||
fmt.Fprintln(w, str)
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
@@ -173,9 +194,10 @@ func TestClientStream(t *testing.T) {
|
||||
|
||||
func TestClientDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
@@ -183,7 +205,8 @@ func TestClientDo(t *testing.T) {
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantErr: "test error message",
|
||||
wantErr: "test error message",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "server error response",
|
||||
@@ -191,7 +214,8 @@ func TestClientDo(t *testing.T) {
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
wantErr: "internal error",
|
||||
wantErr: "internal error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "successful response",
|
||||
@@ -203,6 +227,26 @@ func TestClientDo(t *testing.T) {
|
||||
Success: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
response: testError{
|
||||
message: "internal server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
response: testError{
|
||||
message: "<html><body>404 Not Found</body></html>",
|
||||
statusCode: http.StatusNotFound,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "<html><body>404 Not Found</body></html>",
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -210,11 +254,16 @@ func TestClientDo(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if errResp, ok := tc.response.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
if !errResp.raw {
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
} else {
|
||||
// Write raw message (simulates non-JSON error responses)
|
||||
fmt.Fprint(w, errResp.message)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -241,6 +290,15 @@ func TestClientDo(t *testing.T) {
|
||||
if err.Error() != tc.wantErr {
|
||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||
}
|
||||
if tc.wantStatusCode != 0 {
|
||||
if statusErr, ok := err.(StatusError); ok {
|
||||
if statusErr.StatusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("expected StatusError, got %T", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
||||
// handleConnectURLScheme fetches the connect URL and opens it in the browser
|
||||
func handleConnectURLScheme() {
|
||||
if checkUserLoggedIn(uiServerPort) {
|
||||
slog.Info("user is already logged in, opening settings instead")
|
||||
sendUIRequestMessage("/")
|
||||
slog.Info("user is already logged in, opening app instead")
|
||||
showWindow(wv.webview.Window())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -434,37 +434,30 @@ func openInBrowser(url string) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
|
||||
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
|
||||
// parseURLScheme parses an ollama:// URL and validates it
|
||||
// Supports: ollama:// (open app) and ollama://connect (OAuth)
|
||||
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
|
||||
parsedURL, err := url.Parse(urlSchemeRequest)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
return false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a connect URL
|
||||
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
|
||||
return true, "", nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Extract the UI path
|
||||
path := "/"
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
// For URLs like ollama:///settings, use the path directly
|
||||
path = parsedURL.Path
|
||||
} else if parsedURL.Host != "" {
|
||||
// For URLs like ollama://settings (without triple slash),
|
||||
// the "settings" part is parsed as the host, not the path.
|
||||
// We need to convert it to a path by prepending "/"
|
||||
// This also handles ollama://settings/ where Windows adds a trailing slash
|
||||
path = "/" + parsedURL.Host
|
||||
// Allow bare ollama:// or ollama:/// to open the app
|
||||
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, path, nil
|
||||
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
|
||||
}
|
||||
|
||||
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
|
||||
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
||||
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
|
||||
isConnect, err := parseURLScheme(urlSchemeRequest)
|
||||
if err != nil {
|
||||
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
|
||||
return
|
||||
@@ -473,6 +466,8 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
||||
if isConnect {
|
||||
handleConnectURLScheme()
|
||||
} else {
|
||||
sendUIRequestMessage(uiPath)
|
||||
if wv.webview != nil {
|
||||
showWindow(wv.webview.Window())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,27 +24,14 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
for (NSURL *url in urls) {
|
||||
if ([url.scheme isEqualToString:@"ollama"]) {
|
||||
NSString *path = url.path;
|
||||
if (!path || [path isEqualToString:@""]) {
|
||||
// For URLs like ollama://settings (without triple slash),
|
||||
// the "settings" part is parsed as the host, not the path.
|
||||
// We need to convert it to a path by prepending "/"
|
||||
if (url.host && ![url.host isEqualToString:@""]) {
|
||||
path = [@"/" stringByAppendingString:url.host];
|
||||
} else {
|
||||
path = @"/";
|
||||
}
|
||||
}
|
||||
|
||||
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
|
||||
|
||||
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
|
||||
// Special case: handle connect by opening browser instead of app
|
||||
handleConnectURL();
|
||||
} else {
|
||||
// Set app to be active and visible
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
|
||||
// Open the path with the UI
|
||||
[self uiRequest:path];
|
||||
}
|
||||
|
||||
break;
|
||||
@@ -260,7 +247,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)openHelp:(id)sender {
|
||||
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
|
||||
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
|
||||
[[NSWorkspace sharedWorkspace] openURL:url];
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
|
||||
|
||||
// handleURLSchemeRequest processes URL scheme requests from other instances
|
||||
func handleURLSchemeRequest(urlScheme string) {
|
||||
isConnect, uiPath, err := parseURLScheme(urlScheme)
|
||||
isConnect, err := parseURLScheme(urlScheme)
|
||||
if err != nil {
|
||||
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
|
||||
return
|
||||
@@ -147,7 +147,9 @@ func handleURLSchemeRequest(urlScheme string) {
|
||||
if isConnect {
|
||||
handleConnectURLScheme()
|
||||
} else {
|
||||
sendUIRequestMessage(uiPath)
|
||||
if wv.webview != nil {
|
||||
showWindow(wv.webview.Window())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
625
cmd/chat_template/chat_template.py
Normal file
625
cmd/chat_template/chat_template.py
Normal file
@@ -0,0 +1,625 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "transformers>=4.57.0",
|
||||
# "jinja2",
|
||||
# "fastapi",
|
||||
# "uvicorn",
|
||||
# "pydantic",
|
||||
# "requests",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Chat Template Testing Tool
|
||||
|
||||
Test HuggingFace chat templates against Ollama renderers.
|
||||
|
||||
Usage:
|
||||
# Run predefined test cases against a HuggingFace model
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
|
||||
|
||||
# Compare HuggingFace output with Ollama renderer
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
|
||||
|
||||
# Start server for manual curl testing
|
||||
uv run cmd/chat_template/chat_template.py --serve
|
||||
|
||||
# Show chat template for a model
|
||||
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"name": "basic_user_message",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "with_system_message",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "multi_turn_conversation",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "with_tools",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "tool_call_and_response",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in SF?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the weather.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "parallel_tool_calls",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Get weather in SF and NYC"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "call_2",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "New York"},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
|
||||
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Thinking tests
|
||||
{
|
||||
"name": "assistant_with_thinking",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
|
||||
},
|
||||
{"role": "user", "content": "And 3+3?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
{
|
||||
"name": "thinking_with_tool_call",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll check the weather for you.",
|
||||
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "Paris"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "thinking_only_no_content",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Think about this silently."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "", # HuggingFace requires content field
|
||||
"thinking": "I'm thinking about this but won't respond with visible content.",
|
||||
},
|
||||
{"role": "user", "content": "What did you think?"},
|
||||
],
|
||||
"tools": None,
|
||||
},
|
||||
]
|
||||
|
||||
# Cache for tokenizers
|
||||
_tokenizer_cache: dict[str, Any] = {}
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str):
|
||||
"""Get or create tokenizer for the given model."""
|
||||
if model_name not in _tokenizer_cache:
|
||||
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
|
||||
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
|
||||
def apply_template(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Apply HuggingFace chat template to messages."""
|
||||
tokenizer = get_tokenizer(model)
|
||||
|
||||
if tools:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
def get_ollama_prompt(
|
||||
ollama_model: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "http://localhost:11434",
|
||||
) -> str | None:
|
||||
"""Get rendered prompt from Ollama using debug_render_only."""
|
||||
import requests
|
||||
|
||||
# Convert messages to Ollama format
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
ollama_msg = {"role": msg["role"]}
|
||||
if "content" in msg:
|
||||
ollama_msg["content"] = msg["content"]
|
||||
if "thinking" in msg:
|
||||
ollama_msg["thinking"] = msg["thinking"]
|
||||
if "tool_calls" in msg:
|
||||
# Convert tool_calls to Ollama format
|
||||
tool_calls = []
|
||||
for tc in msg["tool_calls"]:
|
||||
tool_call = {
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
}
|
||||
if "id" in tc:
|
||||
tool_call["id"] = tc["id"]
|
||||
tool_calls.append(tool_call)
|
||||
ollama_msg["tool_calls"] = tool_calls
|
||||
if "tool_call_id" in msg:
|
||||
ollama_msg["tool_call_id"] = msg["tool_call_id"]
|
||||
ollama_messages.append(ollama_msg)
|
||||
|
||||
payload = {
|
||||
"model": ollama_model,
|
||||
"messages": ollama_messages,
|
||||
"stream": False,
|
||||
"_debug_render_only": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Field name is _debug_info with underscore prefix
|
||||
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
|
||||
return data["_debug_info"]["rendered_template"]
|
||||
return None
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
|
||||
"""Compute a unified diff between HuggingFace and Ollama prompts."""
|
||||
import difflib
|
||||
|
||||
hf_lines = hf_prompt.splitlines(keepends=True)
|
||||
ollama_lines = ollama_prompt.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
ollama_lines,
|
||||
hf_lines,
|
||||
fromfile="Ollama",
|
||||
tofile="HuggingFace",
|
||||
lineterm="",
|
||||
)
|
||||
return "".join(diff)
|
||||
|
||||
|
||||
def print_test_output(
|
||||
name: str,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
hf_prompt: str,
|
||||
ollama_prompt: str | None = None,
|
||||
as_repr: bool = False,
|
||||
):
|
||||
"""Print test output in a format suitable for Go test creation and LLM diffing."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: {name}")
|
||||
print("=" * 60)
|
||||
print("\n--- Input Messages ---")
|
||||
print(json.dumps(messages, indent=2))
|
||||
if tools:
|
||||
print("\n--- Tools ---")
|
||||
print(json.dumps(tools, indent=2))
|
||||
|
||||
if ollama_prompt is not None:
|
||||
# Comparison mode
|
||||
if hf_prompt == ollama_prompt:
|
||||
print("\n--- Result: MATCH ---")
|
||||
print("\n--- Prompt (both identical) ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
else:
|
||||
print("\n--- Result: MISMATCH ---")
|
||||
print("\n--- HuggingFace Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
print("\n--- Ollama Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(ollama_prompt))
|
||||
else:
|
||||
print(ollama_prompt)
|
||||
print("\n--- Diff (Ollama -> HuggingFace) ---")
|
||||
diff = compute_diff(hf_prompt, ollama_prompt)
|
||||
if diff:
|
||||
print(diff)
|
||||
else:
|
||||
print("(no line-level diff, check whitespace)")
|
||||
else:
|
||||
# HuggingFace only mode
|
||||
print("\n--- HuggingFace Prompt ---")
|
||||
if as_repr:
|
||||
print(repr(hf_prompt))
|
||||
else:
|
||||
print(hf_prompt)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def run_tests(
|
||||
model: str,
|
||||
as_repr: bool = False,
|
||||
test_filter: str | None = None,
|
||||
ollama_model: str | None = None,
|
||||
ollama_host: str = "http://localhost:11434",
|
||||
):
|
||||
"""Run all predefined test cases against a model."""
|
||||
if ollama_model:
|
||||
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
|
||||
else:
|
||||
print(f"\nRunning tests against: {model}\n")
|
||||
|
||||
matches = 0
|
||||
mismatches = 0
|
||||
errors = 0
|
||||
|
||||
for test_case in TEST_CASES:
|
||||
name = test_case["name"]
|
||||
messages = test_case["messages"]
|
||||
tools = test_case["tools"]
|
||||
|
||||
# Filter tests if specified
|
||||
if test_filter and test_filter.lower() not in name.lower():
|
||||
continue
|
||||
|
||||
try:
|
||||
hf_prompt = apply_template(model, messages, tools)
|
||||
|
||||
ollama_prompt = None
|
||||
if ollama_model:
|
||||
ollama_prompt = get_ollama_prompt(
|
||||
ollama_model, messages, tools, ollama_host
|
||||
)
|
||||
if ollama_prompt is None:
|
||||
errors += 1
|
||||
elif hf_prompt == ollama_prompt:
|
||||
matches += 1
|
||||
else:
|
||||
mismatches += 1
|
||||
|
||||
print_test_output(
|
||||
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
|
||||
)
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Test: {name} - FAILED")
|
||||
print(f"--- Input Messages ---")
|
||||
print(json.dumps(messages, indent=2))
|
||||
if tools:
|
||||
print(f"--- Tools ---")
|
||||
print(json.dumps(tools, indent=2))
|
||||
print(f"--- Error ---")
|
||||
print(f"{e}")
|
||||
print("=" * 60)
|
||||
|
||||
# Print summary if comparing
|
||||
if ollama_model:
|
||||
total = matches + mismatches + errors
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f" Total: {total}")
|
||||
print(f" Matches: {matches}")
|
||||
print(f" Mismatches: {mismatches}")
|
||||
print(f" Errors: {errors}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def show_template(model: str):
|
||||
"""Show the chat template for a model."""
|
||||
tokenizer = get_tokenizer(model)
|
||||
print(f"\nChat template for {model}:\n")
|
||||
print("-" * 60)
|
||||
print(tokenizer.chat_template)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
def start_server(host: str = "0.0.0.0", port: int = 8000):
|
||||
"""Start the FastAPI server for manual testing."""
|
||||
from typing import Optional, List, Dict, Any as TypingAny
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str = "PrimeIntellect/INTELLECT-3"
|
||||
tools: Optional[List[Dict[str, TypingAny]]] = None
|
||||
inject_tools_as_functions: bool = False
|
||||
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
prompt: str
|
||||
model: str
|
||||
|
||||
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
|
||||
|
||||
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||
async def generate_prompt(request: GeneratePromptRequest):
|
||||
try:
|
||||
messages = []
|
||||
for msg in request.messages:
|
||||
message_dict = {"role": msg.role}
|
||||
if msg.content is not None:
|
||||
message_dict["content"] = msg.content
|
||||
if msg.tool_calls is not None:
|
||||
tool_calls = []
|
||||
for tc in msg.tool_calls:
|
||||
tc_copy = tc.copy()
|
||||
if "function" in tc_copy and "arguments" in tc_copy["function"]:
|
||||
args = tc_copy["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tc_copy["function"]["arguments"] = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
tool_calls.append(tc_copy)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
if msg.tool_call_id is not None:
|
||||
message_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages.append(message_dict)
|
||||
|
||||
prompt = apply_template(request.model, messages, request.tools)
|
||||
return GeneratePromptResponse(prompt=prompt, model=request.model)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
print(f"Starting server on http://{host}:{port}")
|
||||
print("Endpoints:")
|
||||
print(" POST /generate-prompt - Generate prompt from messages")
|
||||
print(" GET /health - Health check")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace Prompt Testing Tool",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-model",
|
||||
"-o",
|
||||
type=str,
|
||||
help="Ollama model name to compare against (e.g., qwen3-coder)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Ollama server URL (default: http://localhost:11434)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
"-s",
|
||||
action="store_true",
|
||||
help="Start FastAPI server for manual curl testing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
"-p",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Server port (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-template",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Show the chat template for the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repr",
|
||||
"-r",
|
||||
action="store_true",
|
||||
help="Output prompts as Python repr (shows escape sequences)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter",
|
||||
"-f",
|
||||
type=str,
|
||||
help="Filter tests by name (substring match)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.serve:
|
||||
start_server(port=args.port)
|
||||
elif args.model:
|
||||
if args.show_template:
|
||||
show_template(args.model)
|
||||
else:
|
||||
run_tests(
|
||||
args.model,
|
||||
as_repr=args.repr,
|
||||
test_filter=args.filter,
|
||||
ollama_model=args.ollama_model,
|
||||
ollama_host=args.ollama_host,
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
print("\nExample usage:")
|
||||
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
|
||||
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
|
||||
print(" uv run cmd/chat_template/chat_template.py --serve")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
case "DeepseekOCRForCausalLM":
|
||||
conv = &deepseekocr{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
136
convert/convert_deepseekocr.go
Normal file
136
convert/convert_deepseekocr.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type deepseekocr struct {
|
||||
ModelParameters
|
||||
LanguageConfig struct {
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NumSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||
} `json:"language_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Width struct {
|
||||
Vision struct {
|
||||
Heads uint32 `json:"heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Layers uint32 `json:"layers"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"clip-l-14-224"`
|
||||
Sam struct {
|
||||
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
|
||||
Heads uint32 `json:"heads"`
|
||||
Layers uint32 `json:"layers"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"sam_vit_b"`
|
||||
}
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = m.LanguageConfig.HiddenSize
|
||||
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
|
||||
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
|
||||
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
|
||||
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
|
||||
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
|
||||
|
||||
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
|
||||
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
|
||||
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
|
||||
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
|
||||
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
|
||||
|
||||
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
|
||||
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
|
||||
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
|
||||
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
|
||||
for i := range m.LanguageConfig.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
"model.vision_model", "v",
|
||||
"embeddings.patch_embedding", "patch_embd",
|
||||
"embeddings.class_embedding", "class_embd",
|
||||
"embeddings.position_embedding", "position_embd",
|
||||
"transformer.layers", "blk",
|
||||
|
||||
"model.projector", "mm",
|
||||
"model.image_newline", "mm.image_newline",
|
||||
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||
"model.view_seperator", "mm.view_seperator",
|
||||
|
||||
"model.sam_model.patch_embed.proj", "s.patch_embd",
|
||||
"model.sam_model.pos_embed", "s.position_embd",
|
||||
"model.sam_model.blocks", "s.blk",
|
||||
"model.sam_model.neck", "s.neck",
|
||||
"model.sam_model.net_", "s.net_",
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
t.name == "s.position_embd" ||
|
||||
strings.HasSuffix(t.name, "rel_pos_h") ||
|
||||
strings.HasSuffix(t.name, "rel_pos_w") {
|
||||
// these tensors are always F32
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
@@ -96,7 +96,10 @@ type safetensor struct {
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package discover
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -10,12 +11,21 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
mem, err := getCPUMem()
|
||||
if err != nil {
|
||||
return memInfo{}, err
|
||||
}
|
||||
return getCPUMemByCgroups(mem), nil
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached, freeSwap uint64
|
||||
f, err := os.Open("/proc/meminfo")
|
||||
@@ -56,6 +66,32 @@ func GetCPUMem() (memInfo, error) {
|
||||
return mem, nil
|
||||
}
|
||||
|
||||
func getCPUMemByCgroups(mem memInfo) memInfo {
|
||||
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
|
||||
if err == nil {
|
||||
mem.TotalMemory = total
|
||||
}
|
||||
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
|
||||
if err == nil {
|
||||
mem.FreeMemory = mem.TotalMemory - used
|
||||
}
|
||||
return mem
|
||||
}
|
||||
|
||||
func getUint64ValueFromFile(path string) (uint64, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
return strconv.ParseUint(line, 10, 64)
|
||||
}
|
||||
return 0, errors.New("empty file content")
|
||||
}
|
||||
|
||||
const CpuInfoFilename = "/proc/cpuinfo"
|
||||
|
||||
type linuxCpuInfo struct {
|
||||
@@ -74,7 +110,41 @@ func GetCPUDetails() []CPU {
|
||||
return nil
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
cpus := linuxCPUDetails(file)
|
||||
return overwriteThreadCountByLinuxCgroups(cpus)
|
||||
}
|
||||
|
||||
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
|
||||
file, err := os.Open("/sys/fs/cgroup/cpu.max")
|
||||
if err != nil {
|
||||
return cpus
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if sl := strings.Split(line, " "); len(sl) == 2 {
|
||||
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
|
||||
return cpus
|
||||
}
|
||||
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse CPU unit micro secs", "error", err)
|
||||
return cpus
|
||||
}
|
||||
|
||||
threads := int(max(allowdUs/unitUs, 1))
|
||||
|
||||
cpu := cpus[0]
|
||||
cpu.CoreCount = threads
|
||||
cpu.ThreadCount = threads
|
||||
return []CPU{cpu}
|
||||
}
|
||||
}
|
||||
return cpus
|
||||
}
|
||||
|
||||
func linuxCPUDetails(file io.Reader) []CPU {
|
||||
|
||||
@@ -65,6 +65,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
detectIncompatibleLibraries()
|
||||
|
||||
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||
overrideWarnings()
|
||||
@@ -98,6 +99,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
|
||||
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
|
||||
continue
|
||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||
continue
|
||||
@@ -125,10 +129,20 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
supportedMu := sync.Mutex{}
|
||||
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
|
||||
for i := range devices {
|
||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||
if !devices[i].NeedsInitValidation() {
|
||||
// No need to validate, add to the supported map
|
||||
supportedMu.Lock()
|
||||
if _, ok := supported[devices[i].Library]; !ok {
|
||||
supported[devices[i].Library] = make(map[string]map[string]int)
|
||||
}
|
||||
if _, ok := supported[devices[i].Library][libDir]; !ok {
|
||||
supported[devices[i].Library][libDir] = make(map[string]int)
|
||||
}
|
||||
supported[devices[i].Library][libDir][devices[i].ID] = i
|
||||
supportedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
@@ -474,3 +488,16 @@ func overrideWarnings() {
|
||||
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||
}
|
||||
}
|
||||
|
||||
func detectIncompatibleLibraries() {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
basePath, err := exec.LookPath("ggml-base.dll")
|
||||
if err != nil || basePath == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
||||
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
||||
}
|
||||
}
|
||||
|
||||
11
docs/faq.mdx
11
docs/faq.mdx
@@ -57,8 +57,13 @@ ollama ps
|
||||
```
|
||||
|
||||
<Info>
|
||||
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
|
||||
100% GPU 4 minutes from now ```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
NAME ID SIZE PROCESSOR UNTIL
|
||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -385,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
---
|
||||
title: VS Code
|
||||
title: VS Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [VS Code](https://code.visualstudio.com/download).
|
||||
Install [VS Code](https://code.visualstudio.com/download).
|
||||
|
||||
## Usage with Ollama
|
||||
## Usage with Ollama
|
||||
|
||||
1. Open Copilot side bar found in top right window
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-sidebar.png"
|
||||
alt="VS Code chat Sidebar"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. Select the model drowpdown > **Manage models**
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-models.png"
|
||||
alt="VS Code model picker"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-sidebar.png"
|
||||
alt="VS Code chat Sidebar"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. Select the model dropdown > **Manage models**
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-models.png"
|
||||
alt="VS Code model picker"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-model-options.png"
|
||||
alt="VS Code model options dropdown"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-model-options.png"
|
||||
alt="VS Code model options dropdown"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -149,9 +149,6 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
|
||||
@@ -249,6 +249,9 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
||||
@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
|
||||
|
||||
a.values[i] = e
|
||||
} else {
|
||||
discardGGUFString(llm, r)
|
||||
_ = discardGGUFString(llm, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,7 +568,6 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||
for _, t := range ts {
|
||||
t := t
|
||||
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||
g.Go(func() error {
|
||||
_, err := t.WriteTo(w)
|
||||
|
||||
1
go.mod
1
go.mod
@@ -17,7 +17,6 @@ require (
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.36.0
|
||||
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the handler with tools and optional last message
|
||||
// Init initializes the handler with tools, optional last message, and think value
|
||||
// Implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
// Initialize the harmony parser
|
||||
if h.HarmonyParser == nil {
|
||||
h.HarmonyParser = &HarmonyParser{
|
||||
|
||||
@@ -3,7 +3,6 @@ package kvcache
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
@@ -40,18 +39,18 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// starting location for data storage for this batch
|
||||
curLoc int
|
||||
|
||||
// size of the current batch
|
||||
curBatchSize int
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLoc ml.Tensor
|
||||
|
||||
// mask of the cache as used by this batch
|
||||
curMask ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// locations in the cache that are needed for this batch
|
||||
curCellRange cellRange
|
||||
|
||||
@@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var locs []int32
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
locs, err = c.findLocs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
loc := int(locs[i])
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||
seqRange.min = min(seqRange.min, loc)
|
||||
c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||
seqRange.max = max(seqRange.max, loc)
|
||||
c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
c.curLoc = 0
|
||||
locs = make([]int32, c.curBatchSize)
|
||||
for i := range locs {
|
||||
locs[i] = int32(i)
|
||||
}
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
c.curMask = c.buildMask(ctx)
|
||||
|
||||
return nil
|
||||
@@ -257,22 +258,20 @@ func newRange() cellRange {
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first contiguous block of at least curBatchSize
|
||||
func (c *Causal) findStartLoc() (int, error) {
|
||||
var start, count int
|
||||
// Returns a slice of locations where each token in the batch should be stored
|
||||
func (c *Causal) findLocs() ([]int32, error) {
|
||||
loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
for i := range c.cells {
|
||||
if len(c.cells[i].sequences) == 0 {
|
||||
count++
|
||||
if count >= c.curBatchSize {
|
||||
return start, nil
|
||||
loc = append(loc, int32(i))
|
||||
if len(loc) >= c.curBatchSize {
|
||||
return loc, nil
|
||||
}
|
||||
} else {
|
||||
start = i + 1
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
@@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
return maskTensor
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) defrag() {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
|
||||
// Defrag strategy:
|
||||
// - Search for empty holes at the beginning of the cache,
|
||||
// filling them with active data starting at the end
|
||||
// - If there are contiguous elements that need to be moved,
|
||||
// combine them into a single operation by holding new moves
|
||||
// until we see that the next one is non-contiguous
|
||||
// - Fill up the context with the maximum number of operations it
|
||||
// can hold then compute that and continue with a new context
|
||||
//
|
||||
// We could try to optimize placement by grouping blocks from
|
||||
// the same sequences together but most likely the next forward
|
||||
// pass will disrupt this anyways, so the real world benefit
|
||||
// seems limited as this time.
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v). We also need to refer to the original
|
||||
// k and v cache tensors - once per layer, not per move.
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
src := len(c.cells) - 1
|
||||
|
||||
for dst := 0; dst < src; dst++ {
|
||||
if len(c.cells[dst].sequences) == 0 {
|
||||
for ; src > dst; src-- {
|
||||
if len(c.cells[src].sequences) != 0 {
|
||||
c.cells[dst] = c.cells[src]
|
||||
c.cells[src] = cacheCell{}
|
||||
|
||||
if pendingLen > 0 {
|
||||
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||
pendingSrc = src
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
|
||||
pendingSrc = src
|
||||
pendingDst = dst
|
||||
pendingLen = 1
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if moves >= maxMoves {
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
ctx = c.backend.NewContext()
|
||||
|
||||
moves = 0
|
||||
}
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
if moves > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
// Reset range metadata
|
||||
for seq := range c.cellRanges {
|
||||
seqRange := newRange()
|
||||
|
||||
for i, cell := range c.cells {
|
||||
if slices.Contains(cell.sequences, seq) {
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.updateSlidingWindow()
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
@@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||
keyCache := c.keys[c.curLayer]
|
||||
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 3 +
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++++--
|
||||
ggml/src/mem_hip.cpp | 452 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 +++++++++++++
|
||||
9 files changed, 926 insertions(+), 30 deletions(-)
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++--
|
||||
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 +++++++++++
|
||||
9 files changed, 1003 insertions(+), 30 deletions(-)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
@@ -58,7 +58,7 @@ index f9a6587f1..03f359ae9 100644
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index c9333689f..41b00af83 100644
|
||||
index c9333689f..f1a20e7fe 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -111,7 +111,7 @@ index c9333689f..41b00af83 100644
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
+ return;
|
||||
+ }
|
||||
@@ -243,7 +243,7 @@ index 05ff6a5a6..032dee76d 100644
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 3a6bbe564..d2c278a35 100644
|
||||
index 3a6bbe564..ca02ea079 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -229,6 +229,7 @@ class vk_memory_logger;
|
||||
@@ -337,7 +337,7 @@ index 3a6bbe564..d2c278a35 100644
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
+ return;
|
||||
+ }
|
||||
@@ -548,11 +548,12 @@ index 3a6bbe564..d2c278a35 100644
|
||||
}
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 000000000..5a7f5d465
|
||||
index 000000000..c1949b899
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,452 @@
|
||||
@@ -0,0 +1,529 @@
|
||||
+#include "ggml.h"
|
||||
+#include "ggml-impl.h"
|
||||
+
|
||||
+#ifdef _WIN32
|
||||
+// AMD Device Library eXtra (ADLX)
|
||||
@@ -570,7 +571,6 @@ index 000000000..5a7f5d465
|
||||
+// Unused function parameters are commented out to avoid unnecessary type
|
||||
+// definitions.
|
||||
+
|
||||
+#include "ggml-impl.h"
|
||||
+#include <filesystem>
|
||||
+#include <mutex>
|
||||
+
|
||||
@@ -990,15 +990,92 @@ index 000000000..5a7f5d465
|
||||
+
|
||||
+#else // #ifdef _WIN32
|
||||
+
|
||||
+#include <fstream>
|
||||
+#include <iostream>
|
||||
+#include <sstream>
|
||||
+#include <string>
|
||||
+#include <vector>
|
||||
+#include <filesystem>
|
||||
+
|
||||
+#include <sys/stat.h>
|
||||
+#include <dirent.h>
|
||||
+#include <unistd.h>
|
||||
+#include <glob.h>
|
||||
+namespace fs = std::filesystem;
|
||||
+
|
||||
+extern "C" {
|
||||
+
|
||||
+// TODO Linux implementation of accurate VRAM reporting
|
||||
+int ggml_hip_mgmt_init() {
|
||||
+ return -1;
|
||||
+ return 0;
|
||||
+}
|
||||
+void ggml_hip_mgmt_release() {}
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
+ return -1;
|
||||
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
+
|
||||
+ glob_t glob_result;
|
||||
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
+
|
||||
+ for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
|
||||
+ const char* device_file = glob_result.gl_pathv[i];
|
||||
+ std::ifstream file(device_file);
|
||||
+ if (!file.is_open()) {
|
||||
+ std::cerr << "Failed to open sysfs node" << std::endl;
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ std::string line;
|
||||
+ while (std::getline(file, line)) {
|
||||
+ // Check for PCI_SLOT_NAME label
|
||||
+ if (line.find(drmUeventPCISlotLabel) == 0) {
|
||||
+ std::istringstream iss(line.substr(drmUeventPCISlotLabel.size()));
|
||||
+ std::string pciSlot;
|
||||
+ iss >> pciSlot;
|
||||
+ if (pciSlot == std::string(id)) {
|
||||
+ std::string dir = fs::path(device_file).parent_path().string();
|
||||
+
|
||||
+ std::string totalFile = dir + "/" + drmTotalMemoryFile;
|
||||
+ std::ifstream totalFileStream(totalFile.c_str());
|
||||
+ if (!totalFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ uint64_t memory;
|
||||
+ totalFileStream >> memory;
|
||||
+ *total = memory;
|
||||
+
|
||||
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
+ std::ifstream usedFileStream(usedFile.c_str());
|
||||
+ if (!usedFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ uint64_t memoryUsed;
|
||||
+ usedFileStream >> memoryUsed;
|
||||
+ *free = memory - memoryUsed;
|
||||
+
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 0;
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ file.close();
|
||||
+ }
|
||||
+ GGML_LOG_DEBUG("%s unable to find matching device\n", __func__);
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+}
|
||||
+
|
||||
+} // extern "C"
|
||||
|
||||
@@ -38,7 +38,7 @@ index 44ae76d66..639d551a2 100644
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index d2c278a35..221e29509 100644
|
||||
index ca02ea079..c12b069e5 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
|
||||
@@ -11,7 +11,7 @@ vidmem optimization.
|
||||
1 file changed, 1 insertion(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 221e29509..18b7cbccf 100644
|
||||
index c12b069e5..76c78c2ea 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
||||
|
||||
@@ -50,7 +50,7 @@ Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
|
||||
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 18b7cbccf..53b57c179 100644
|
||||
index 76c78c2ea..7669ed206 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -488,6 +488,7 @@ struct vk_device_struct {
|
||||
|
||||
@@ -58,7 +58,7 @@ index 639d551a2..e5c446d1d 100644
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 53b57c179..b2855b078 100644
|
||||
index 7669ed206..63a762ec2 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
|
||||
@@ -31,7 +31,7 @@ Add new backend tests.
|
||||
6 files changed, 371 insertions(+), 117 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index b2855b078..aaf4334b5 100644
|
||||
index 63a762ec2..db92a7901 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -458,6 +458,11 @@ static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||
|
||||
@@ -9,7 +9,7 @@ Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
|
||||
2 files changed, 16 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index aaf4334b5..3604ceb04 100644
|
||||
index db92a7901..e959674d1 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
|
||||
|
||||
@@ -20,7 +20,7 @@ Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
|
||||
1 file changed, 13 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 3604ceb04..80185d9f0 100644
|
||||
index e959674d1..903050b0b 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
|
||||
|
||||
25
llama/patches/0036-ggml-cuda-skip-large-batches.patch
Normal file
25
llama/patches/0036-ggml-cuda-skip-large-batches.patch
Normal file
@@ -0,0 +1,25 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Tue, 18 Nov 2025 11:13:04 -0800
|
||||
Subject: [PATCH] ggml-cuda: skip large batches
|
||||
|
||||
cuda panics on batches larger than 1024 so mark it as unsupported to
|
||||
fallback to cpu
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++
|
||||
1 file changed, 3 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index f1a20e7fe..1a71e07c9 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
+ if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
|
||||
+ return false;
|
||||
+ }
|
||||
#ifdef GGML_USE_MUSA
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
||||
28
llama/patches/0036-win-exit-instead-of-abort.patch
Normal file
28
llama/patches/0036-win-exit-instead-of-abort.patch
Normal file
@@ -0,0 +1,28 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Tue, 18 Nov 2025 09:58:23 -0800
|
||||
Subject: [PATCH] win: exit instead of abort
|
||||
|
||||
---
|
||||
ggml/src/ggml.c | 7 ++++++-
|
||||
1 file changed, 6 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
|
||||
index 9be35c1be..923c33d05 100644
|
||||
--- a/ggml/src/ggml.c
|
||||
+++ b/ggml/src/ggml.c
|
||||
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
|
||||
fprintf(stderr, "%s\n", message);
|
||||
ggml_print_backtrace();
|
||||
}
|
||||
-
|
||||
+#if defined(_WIN32)
|
||||
+ fflush(stderr);
|
||||
+ fflush(stdout);
|
||||
+ exit(1);
|
||||
+#else
|
||||
abort();
|
||||
+#endif
|
||||
}
|
||||
|
||||
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
|
||||
@@ -173,6 +173,7 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
@@ -193,6 +194,7 @@ type Tensor interface {
|
||||
Repeat(ctx Context, dim, n int) Tensor
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
|
||||
@@ -207,6 +209,8 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
@@ -230,7 +234,7 @@ type Tensor interface {
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
@@ -372,3 +376,10 @@ const (
|
||||
DTypeI32
|
||||
DTypeMXFP4
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
|
||||
@@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
"altup_proj", "altup_unembd_proj",
|
||||
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
@@ -499,7 +499,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for _, t := range b.meta.Tensors().Items() {
|
||||
t := t
|
||||
g.Go(func() error {
|
||||
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
|
||||
for i := range tts {
|
||||
@@ -1339,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1379,6 +1385,10 @@ func inferShape(t *Tensor, shape []int) {
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if !C.ggml_is_contiguous(t.t) {
|
||||
return t.Contiguous(ctx, shape...)
|
||||
}
|
||||
|
||||
if slices.Contains(shape, -1) {
|
||||
inferShape(t, shape)
|
||||
}
|
||||
@@ -1568,6 +1578,16 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
if len(t2) > 0 {
|
||||
tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
|
||||
} else {
|
||||
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
|
||||
}
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
@@ -1625,7 +1645,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
kqMask = mask.(*Tensor).t
|
||||
@@ -1642,6 +1662,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
|
||||
}
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
|
||||
if vmla != nil {
|
||||
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = vmla.Mulmat(ctx, cur)
|
||||
cur = cur.Permute(ctx, 0, 2, 1, 3)
|
||||
cur = cur.Contiguous(ctx)
|
||||
kqv = cur.(*Tensor).t
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
@@ -1654,6 +1684,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
@@ -1711,6 +1745,23 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
case ml.SamplingModeNearest:
|
||||
mode = C.GGML_SCALE_MODE_NEAREST
|
||||
case ml.SamplingModeBilinear:
|
||||
mode = C.GGML_SCALE_MODE_BILINEAR
|
||||
default:
|
||||
panic("unsupported interpolate mode")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
|
||||
}
|
||||
}
|
||||
|
||||
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||
|
||||
@@ -3513,7 +3513,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
return;
|
||||
}
|
||||
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
|
||||
return false;
|
||||
}
|
||||
#ifdef GGML_USE_MUSA
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
||||
|
||||
@@ -13212,7 +13212,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
return;
|
||||
}
|
||||
|
||||
7
ml/backend/ggml/ggml/src/ggml.c
vendored
7
ml/backend/ggml/ggml/src/ggml.c
vendored
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
|
||||
fprintf(stderr, "%s\n", message);
|
||||
ggml_print_backtrace();
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
fflush(stderr);
|
||||
fflush(stdout);
|
||||
exit(1);
|
||||
#else
|
||||
abort();
|
||||
#endif
|
||||
}
|
||||
|
||||
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
|
||||
|
||||
85
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
85
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
@@ -1,4 +1,5 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
// AMD Device Library eXtra (ADLX)
|
||||
@@ -16,7 +17,6 @@
|
||||
// Unused function parameters are commented out to avoid unnecessary type
|
||||
// definitions.
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
|
||||
@@ -436,15 +436,92 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
|
||||
#else // #ifdef _WIN32
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <unistd.h>
|
||||
#include <glob.h>
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// TODO Linux implementation of accurate VRAM reporting
|
||||
int ggml_hip_mgmt_init() {
|
||||
return -1;
|
||||
return 0;
|
||||
}
|
||||
void ggml_hip_mgmt_release() {}
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
return -1;
|
||||
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
|
||||
glob_t glob_result;
|
||||
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
|
||||
for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
|
||||
const char* device_file = glob_result.gl_pathv[i];
|
||||
std::ifstream file(device_file);
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open sysfs node" << std::endl;
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(file, line)) {
|
||||
// Check for PCI_SLOT_NAME label
|
||||
if (line.find(drmUeventPCISlotLabel) == 0) {
|
||||
std::istringstream iss(line.substr(drmUeventPCISlotLabel.size()));
|
||||
std::string pciSlot;
|
||||
iss >> pciSlot;
|
||||
if (pciSlot == std::string(id)) {
|
||||
std::string dir = fs::path(device_file).parent_path().string();
|
||||
|
||||
std::string totalFile = dir + "/" + drmTotalMemoryFile;
|
||||
std::ifstream totalFileStream(totalFile.c_str());
|
||||
if (!totalFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint64_t memory;
|
||||
totalFileStream >> memory;
|
||||
*total = memory;
|
||||
|
||||
std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
std::ifstream usedFileStream(usedFile.c_str());
|
||||
if (!usedFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint64_t memoryUsed;
|
||||
usedFileStream >> memoryUsed;
|
||||
*free = memory - memoryUsed;
|
||||
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
GGML_LOG_DEBUG("%s unable to find matching device\n", __func__);
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -22,10 +22,14 @@ import (
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithSinks(ctx, query, key, value, nil, scale, cache)
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
@@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale)
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
@@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if addSpecial {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
|
||||
@@ -25,12 +25,15 @@ const (
|
||||
|
||||
// Composite returns an image with the alpha channel removed by drawing over a white background.
|
||||
func Composite(img image.Image) image.Image {
|
||||
dst := image.NewRGBA(img.Bounds())
|
||||
|
||||
white := color.RGBA{255, 255, 255, 255}
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
|
||||
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
|
||||
return CompositeColor(img, white)
|
||||
}
|
||||
|
||||
// CompositeColor returns an image with the alpha channel removed by drawing over a white background.
|
||||
func CompositeColor(img image.Image, color color.Color) image.Image {
|
||||
dst := image.NewRGBA(img.Bounds())
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
|
||||
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
|
||||
return dst
|
||||
}
|
||||
|
||||
@@ -55,6 +58,31 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image {
|
||||
return dst
|
||||
}
|
||||
|
||||
// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color.
|
||||
func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image {
|
||||
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
|
||||
|
||||
var minPoint, maxPoint image.Point
|
||||
if img.Bounds().Dx() > img.Bounds().Dy() {
|
||||
// landscape
|
||||
height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx()
|
||||
minPoint = image.Point{0, (newSize.Y - height) / 2}
|
||||
maxPoint = image.Point{newSize.X, height + minPoint.Y}
|
||||
} else {
|
||||
// portrait
|
||||
width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy()
|
||||
minPoint = image.Point{(newSize.X - width) / 2, 0}
|
||||
maxPoint = image.Point{minPoint.X + width, newSize.Y}
|
||||
}
|
||||
|
||||
kernel.Scale(dst, image.Rectangle{
|
||||
Min: minPoint,
|
||||
Max: maxPoint,
|
||||
}, img, img.Bounds(), draw.Over, nil)
|
||||
return dst
|
||||
}
|
||||
|
||||
// Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value.
|
||||
func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 {
|
||||
var pixelVals []float32
|
||||
|
||||
@@ -156,6 +156,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
)),
|
||||
},
|
||||
},
|
||||
true,
|
||||
)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
|
||||
@@ -3,6 +3,7 @@ package deepseek2
|
||||
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
isMLA bool
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
@@ -32,8 +34,6 @@ type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength,
|
||||
originalContextLength int
|
||||
|
||||
eps,
|
||||
@@ -62,6 +62,9 @@ type Attention struct {
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
KB *nn.Linear `gguf:"attn_k_b"`
|
||||
VB *nn.Linear `gguf:"attn_v_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 { // nil {
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
var attention ml.Tensor
|
||||
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
if !opts.isMLA { // v3
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
@@ -233,6 +250,34 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||
|
||||
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
|
||||
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
|
||||
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
|
||||
|
||||
var pre []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "deepseek-v3":
|
||||
pre = []string{
|
||||
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
case "deepseek-llm":
|
||||
// TODO: these models haven't been vetted so skip for now
|
||||
// pre = []string{
|
||||
// "[\r\n]",
|
||||
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
||||
// "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
||||
// "\\s+$",
|
||||
// "[一-龥ࠀ-一가-]+",
|
||||
// "[0-9]",
|
||||
// }
|
||||
fallthrough
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
@@ -247,18 +292,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
pre...,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
isMLA: isMLA,
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
@@ -266,13 +307,13 @@ func New(c fs.Config) (model.Model, error) {
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")),
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: int(c.Uint("attention.key_length")),
|
||||
vHeadDim: int(c.Uint("attention.value_length")),
|
||||
qkHeadDim: keyLength,
|
||||
vHeadDim: valueLength,
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
83
model/models/deepseekocr/imageprocessor.go
Normal file
83
model/models/deepseekocr/imageprocessor.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ratio struct {
|
||||
x, y int
|
||||
}
|
||||
|
||||
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
|
||||
img, _, err := image.Decode(bytes.NewReader(bts))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
|
||||
var targetRatios []ratio
|
||||
for n := minNum; n <= maxNum; n++ {
|
||||
for i := 1; i <= n; i++ {
|
||||
for j := 1; j <= n; j++ {
|
||||
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
|
||||
targetRatios = append(targetRatios, ratio{i, j})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
|
||||
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
|
||||
blocks := targetRatio.x * targetRatio.y
|
||||
|
||||
mean := imageproc.ImageNetStandardMean
|
||||
std := imageproc.ImageNetStandardSTD
|
||||
|
||||
var patches []float32
|
||||
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
|
||||
for i := range blocks {
|
||||
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
|
||||
draw.Draw(patch, patch.Bounds(), resized, image.Point{
|
||||
X: i % (targetWidth / imageSize) * imageSize,
|
||||
Y: i / (targetWidth / imageSize) * imageSize,
|
||||
}, draw.Over)
|
||||
|
||||
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
|
||||
}
|
||||
|
||||
img = imageproc.CompositeColor(img, color.Gray{})
|
||||
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
|
||||
|
||||
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
|
||||
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
|
||||
[]int{targetRatio.x, targetRatio.y},
|
||||
nil
|
||||
}
|
||||
|
||||
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
|
||||
bestDiff := math.MaxFloat64
|
||||
best := ratio{1, 1}
|
||||
realRatio := float64(width) / float64(height)
|
||||
for _, target := range targetRatios {
|
||||
targetRatio := float64(target.x) / float64(target.y)
|
||||
diff := math.Abs(realRatio - targetRatio)
|
||||
if diff < bestDiff {
|
||||
bestDiff = diff
|
||||
best = target
|
||||
} else if diff == bestDiff {
|
||||
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
|
||||
best = target
|
||||
}
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
192
model/models/deepseekocr/model.go
Normal file
192
model/models/deepseekocr/model.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
Text *textModel
|
||||
|
||||
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
|
||||
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.layers"`
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
|
||||
patches, original, crop, err := ProcessImage(ctx, bts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var outputs []ml.Tensor
|
||||
if true { // TODO: local features if sum(patches) != 0
|
||||
samOutputs := m.Sam.Forward(ctx, patches)
|
||||
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
|
||||
|
||||
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
|
||||
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
|
||||
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
|
||||
localOutputs = m.Projector.Forward(ctx, localOutputs)
|
||||
|
||||
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
|
||||
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
|
||||
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
|
||||
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
|
||||
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
|
||||
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
|
||||
|
||||
outputs = append(outputs, localOutputs)
|
||||
}
|
||||
|
||||
samOutputs := m.Sam.Forward(ctx, original)
|
||||
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
|
||||
|
||||
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
|
||||
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
|
||||
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
|
||||
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
|
||||
|
||||
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
|
||||
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
|
||||
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
|
||||
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
|
||||
|
||||
outputs = append(outputs, globalOutputs, m.ViewSeperator)
|
||||
return []input.Multimodal{
|
||||
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
outputs := make([]*input.Input, 0, len(inputs))
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
outputs = append(outputs, inputs[i])
|
||||
continue
|
||||
}
|
||||
|
||||
t := inputs[i].Multimodal[0].Tensor
|
||||
outputs = append(outputs, &input.Input{
|
||||
Token: 128815,
|
||||
Multimodal: inputs[i].Multimodal,
|
||||
MultimodalHash: inputs[i].MultimodalHash,
|
||||
SameBatch: t.Dim(1) - 1,
|
||||
})
|
||||
|
||||
outputs = slices.Grow(outputs, t.Dim(1)-1)
|
||||
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
for _, mm := range batch.Multimodal {
|
||||
t := mm.Multimodal[0].Tensor
|
||||
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
|
||||
}
|
||||
|
||||
hiddenStates := inputsEmbeds
|
||||
for i, block := range m.Text.Blocks {
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Text.Blocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
|
||||
return m.Text.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
|
||||
textBlocks := make([]textBlock, c.Uint("block_count"))
|
||||
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
|
||||
for i := range textBlocks {
|
||||
if i >= leadingDenseBlockCount {
|
||||
textBlocks[i].FeedForward = &textMoe{}
|
||||
} else {
|
||||
textBlocks[i].FeedForward = &textMLP{}
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Text: &textModel{
|
||||
Blocks: textBlocks,
|
||||
Options: textOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
ropeBase: c.Float("rope.freq_base", 10_000),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
|
||||
},
|
||||
},
|
||||
Vision: &visionModel{
|
||||
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
|
||||
Options: visionOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.head_count")),
|
||||
imageSize: int(c.Uint("vision.image_size", 224)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
||||
},
|
||||
},
|
||||
Sam: &samModel{
|
||||
Blocks: make([]samBlock, c.Uint("sam.block_count")),
|
||||
Options: samOptions{
|
||||
hiddenSize: int(c.Uint("sam.embedding_length")),
|
||||
numHeads: int(c.Uint("sam.head_count")),
|
||||
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
|
||||
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
|
||||
return &m, nil
|
||||
})
|
||||
}
|
||||
225
model/models/deepseekocr/model_sam.go
Normal file
225
model/models/deepseekocr/model_sam.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type samModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd"`
|
||||
|
||||
Blocks []samBlock `gguf:"blk"`
|
||||
|
||||
Neck *samNeck `gguf:"neck"`
|
||||
Net2 *nn.Conv2D `gguf:"net_2"`
|
||||
Net3 *nn.Conv2D `gguf:"net_3"`
|
||||
|
||||
Options samOptions
|
||||
}
|
||||
|
||||
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
source := m.PositionEmbedding.Dim(1)
|
||||
target := hiddenStates.Dim(2)
|
||||
if source != target {
|
||||
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
|
||||
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
|
||||
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
return m.PositionEmbedding
|
||||
}
|
||||
|
||||
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
if m.PositionEmbedding != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
|
||||
}
|
||||
|
||||
for i, block := range m.Blocks {
|
||||
var windowSize int
|
||||
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
|
||||
windowSize = 14
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
|
||||
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
||||
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type samOptions struct {
|
||||
hiddenSize,
|
||||
numHeads int
|
||||
eps float32
|
||||
globalAttentionLayers []int32
|
||||
}
|
||||
|
||||
func (o samOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
type samBlock struct {
|
||||
Norm1 *nn.LayerNorm `gguf:"norm1"`
|
||||
Attention *samAttention `gguf:"attn"`
|
||||
Norm2 *nn.LayerNorm `gguf:"norm2"`
|
||||
FeedForward *samMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
|
||||
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
var pw, ph int
|
||||
if windowSize > 0 {
|
||||
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
|
||||
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
|
||||
if pw > 0 || ph > 0 {
|
||||
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
|
||||
}
|
||||
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
if windowSize > 0 {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
|
||||
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
|
||||
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type samAttention struct {
|
||||
QKV *nn.Linear `gguf:"qkv"`
|
||||
Output *nn.Linear `gguf:"proj"`
|
||||
|
||||
RelativePosition *struct {
|
||||
Height ml.Tensor `gguf:"h"`
|
||||
Width ml.Tensor `gguf:"w"`
|
||||
} `gguf:",pre:rel_pos_"`
|
||||
}
|
||||
|
||||
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
|
||||
s := make([]int32, qn*kn)
|
||||
for i := range qn {
|
||||
for j := range kn {
|
||||
q := i * max(kn/qn, 1)
|
||||
k := j * max(qn/kn, 1)
|
||||
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
|
||||
}
|
||||
}
|
||||
return ctx.Input().FromInts(s, qn*kn)
|
||||
}
|
||||
|
||||
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
|
||||
maxRelativeDistance := 2*max(qn, kn) - 1
|
||||
if positions.Dim(1) != maxRelativeDistance {
|
||||
// linear interpolation kernel not available so approx. with bilinear interpolation
|
||||
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
|
||||
}
|
||||
|
||||
rc := relativeCoordinates(ctx, qn, kn)
|
||||
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
|
||||
}
|
||||
|
||||
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
|
||||
qh, qw := qn[0], qn[1]
|
||||
kh, kw := kn[0], kn[1]
|
||||
|
||||
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
|
||||
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
|
||||
|
||||
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
|
||||
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
|
||||
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
|
||||
return rh, rw
|
||||
}
|
||||
|
||||
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
|
||||
|
||||
qkv := m.QKV.Forward(ctx, hiddenStates)
|
||||
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
ctx.Forward(query, key, value)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
|
||||
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
|
||||
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
|
||||
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = attention.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type samMLP struct {
|
||||
Lin1 *nn.Linear `gguf:"lin1"`
|
||||
Lin2 *nn.Linear `gguf:"lin2"`
|
||||
}
|
||||
|
||||
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type LayerNorm2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
u := x.Mean(ctx)
|
||||
d := x.Sub(ctx, u)
|
||||
s := d.Sqr(ctx).Mean(ctx)
|
||||
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
|
||||
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
|
||||
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
type samNeck struct {
|
||||
C1 *nn.Conv2D `gguf:"0"`
|
||||
LN1 *LayerNorm2D `gguf:"1"`
|
||||
C2 *nn.Conv2D `gguf:"2"`
|
||||
LN2 *LayerNorm2D `gguf:"3"`
|
||||
}
|
||||
|
||||
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
|
||||
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
|
||||
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
|
||||
return hiddenStates
|
||||
}
|
||||
140
model/models/deepseekocr/model_text.go
Normal file
140
model/models/deepseekocr/model_text.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type textModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Blocks []textBlock `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output"`
|
||||
|
||||
Options textOptions
|
||||
}
|
||||
|
||||
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
type textOptions struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
numExperts,
|
||||
numExpertsUsed int
|
||||
ropeBase,
|
||||
ropeScale,
|
||||
eps float32
|
||||
}
|
||||
|
||||
func (o textOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
type textBlock struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *textAttention
|
||||
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
FeedForward textFeedForward
|
||||
}
|
||||
|
||||
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type textAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
|
||||
query := m.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
|
||||
|
||||
key := m.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||
|
||||
value := m.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||
|
||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type textFeedForward interface {
|
||||
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type textMoe struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
SharedExperts *textMLP `gguf:",suf:_shexp"`
|
||||
}
|
||||
|
||||
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
|
||||
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
|
||||
indices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
|
||||
|
||||
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
|
||||
experts = m.Down.Forward(ctx, experts, indices)
|
||||
experts = experts.Mul(ctx, weights)
|
||||
|
||||
expert := func(i int) ml.Tensor {
|
||||
return experts.View(
|
||||
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
|
||||
)
|
||||
}
|
||||
|
||||
routedStates := expert(0)
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
routedStates = routedStates.Add(ctx, expert(i))
|
||||
}
|
||||
|
||||
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
|
||||
return routedStates.Add(ctx, sharedStates)
|
||||
}
|
||||
|
||||
type textMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
|
||||
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
|
||||
return m.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
117
model/models/deepseekocr/model_vision.go
Normal file
117
model/models/deepseekocr/model_vision.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type visionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
ClassEmbedding ml.Tensor `gguf:"class_embd"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
|
||||
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
|
||||
Blocks []visionBlock `gguf:"blk"`
|
||||
|
||||
Options visionOptions
|
||||
}
|
||||
|
||||
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
|
||||
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
|
||||
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
|
||||
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
|
||||
|
||||
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
|
||||
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
|
||||
if source != target {
|
||||
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
|
||||
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
|
||||
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
|
||||
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
|
||||
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
|
||||
|
||||
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
|
||||
}
|
||||
|
||||
return positionEmbeds
|
||||
}
|
||||
|
||||
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
|
||||
if patchEmbeds == nil {
|
||||
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
|
||||
}
|
||||
|
||||
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
|
||||
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
|
||||
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
|
||||
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
|
||||
|
||||
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
|
||||
for _, block := range m.Blocks {
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type visionOptions struct {
|
||||
hiddenSize,
|
||||
numHeads int
|
||||
eps float32
|
||||
|
||||
imageSize, patchSize int
|
||||
}
|
||||
|
||||
func (o visionOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
type visionBlock struct {
|
||||
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
Attention *visionAttention `gguf:"self_attn"`
|
||||
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
FeedForward *visionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type visionAttention struct {
|
||||
QKV *nn.Linear `gguf:"qkv_proj"`
|
||||
Output *nn.Linear `gguf:"out_proj"`
|
||||
}
|
||||
|
||||
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
|
||||
qkv := m.QKV.Forward(ctx, t)
|
||||
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type visionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/deepseek2"
|
||||
_ "github.com/ollama/ollama/model/models/deepseekocr"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
|
||||
170
model/models/nomicbert/model.go
Normal file
170
model/models/nomicbert/model.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package nomicbert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||
|
||||
Layers []EncoderLayer `gguf:"blk"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
eps float32
|
||||
poolingType pooling.Type
|
||||
normalize bool
|
||||
ropeFreqBase float32
|
||||
}
|
||||
|
||||
// Single Encoder Layer
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QKV *nn.Linear `gguf:"attn_qkv"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
typeEmbed := m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Add(ctx, typeEmbed)
|
||||
|
||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, positions, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
qkv := a.QKV.Forward(ctx, hiddenStates)
|
||||
|
||||
qkv = qkv.Reshape(ctx, opts.headDim, opts.numHeads*3, batchSize)
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hidden := m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
|
||||
|
||||
return m.Down.Forward(ctx, hidden)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
hiddenSize := int(c.Uint("embedding_length"))
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
headDim := hiddenSize / numHeads
|
||||
|
||||
processor := model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
false,
|
||||
)
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", false),
|
||||
ropeFreqBase: c.Float("rope.freq_base", 1000.0),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("nomic-bert", New)
|
||||
model.Register("nomic-bert_embed", New)
|
||||
}
|
||||
319
model/parsers/cogito.go
Normal file
319
model/parsers/cogito.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type CogitoParserState int
|
||||
|
||||
const (
|
||||
CogitoCollectingThinking CogitoParserState = iota
|
||||
CogitoCollectingContent
|
||||
CogitoCollectingToolCalls
|
||||
CogitoCollectingToolOutput
|
||||
)
|
||||
|
||||
const (
|
||||
cogitoThinkingCloseTag = "</think>"
|
||||
cogitoToolCallsBeginTag = "<|tool▁calls▁begin|>"
|
||||
cogitoToolCallsEndTag = "<|tool▁calls▁end|>"
|
||||
cogitoToolCallBeginTag = "<|tool▁call▁begin|>"
|
||||
cogitoToolCallEndTag = "<|tool▁call▁end|>"
|
||||
cogitoToolSepTag = "<|tool▁sep|>"
|
||||
cogitoToolOutputBeginTag = "<|tool▁output▁begin|>"
|
||||
cogitoToolOutputEndTag = "<|tool▁output▁end|>"
|
||||
cogitoToolOutputsBeginTag = "<|tool▁outputs▁begin|>"
|
||||
cogitoToolOutputsEndTag = "<|tool▁outputs▁end|>"
|
||||
)
|
||||
|
||||
type CogitoParser struct {
|
||||
state CogitoParserState
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func (p *CogitoParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *CogitoParser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *CogitoParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// Check both model capability AND request preference
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
// thinkingEnabled should be set to false for tools
|
||||
|
||||
if !thinkingEnabled {
|
||||
p.state = CogitoCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = CogitoCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
// Note: for cogito, if there are tools, then we don't want to be thinking
|
||||
if len(tools) > 0 {
|
||||
p.state = CogitoCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = CogitoCollectingThinking
|
||||
}
|
||||
|
||||
func (p *CogitoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage, tools, thinkValue)
|
||||
return tools
|
||||
}
|
||||
|
||||
type cogitoEvent interface {
|
||||
isCogitoEvent()
|
||||
}
|
||||
|
||||
type cogitoEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type cogitoEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type cogitoEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (cogitoEventThinkingContent) isCogitoEvent() {}
|
||||
func (cogitoEventContent) isCogitoEvent() {}
|
||||
func (cogitoEventToolCall) isCogitoEvent() {}
|
||||
|
||||
func (p *CogitoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case cogitoEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case cogitoEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case cogitoEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *CogitoParser) parseEvents() []cogitoEvent {
|
||||
var all []cogitoEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []cogitoEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *CogitoParser) eat() ([]cogitoEvent, bool) {
|
||||
var events []cogitoEvent
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case CogitoCollectingThinking:
|
||||
if strings.Contains(bufStr, cogitoThinkingCloseTag) { // thinking[</think>] -> content
|
||||
split := strings.SplitN(bufStr, cogitoThinkingCloseTag, 2)
|
||||
thinking := split[0]
|
||||
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
|
||||
|
||||
remaining := split[1]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = CogitoCollectingContent
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, cogitoEventThinkingContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, cogitoThinkingCloseTag); overlapLen > 0 { // partial </think>
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, cogitoEventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else { // otherwise its thinking content
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, cogitoEventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case CogitoCollectingContent:
|
||||
switch {
|
||||
case strings.Contains(bufStr, cogitoToolCallsBeginTag): // content[<|tool▁calls▁begin|>] -> tool calls
|
||||
split := strings.SplitN(bufStr, cogitoToolCallsBeginTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = CogitoCollectingToolCalls
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, cogitoEventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
case strings.Contains(bufStr, cogitoToolOutputsBeginTag): // content[<|tool▁outputs▁begin|>] -> tool outputs
|
||||
split := strings.SplitN(bufStr, cogitoToolOutputsBeginTag, 2)
|
||||
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = CogitoCollectingToolOutput
|
||||
|
||||
if len(contentBefore) > 0 {
|
||||
events = append(events, cogitoEventContent{content: contentBefore})
|
||||
}
|
||||
return events, true
|
||||
default: // otherwise its content
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, cogitoEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case CogitoCollectingToolCalls:
|
||||
if idx := strings.Index(bufStr, cogitoToolCallBeginTag); idx != -1 {
|
||||
startIdx := idx + len(cogitoToolCallBeginTag)
|
||||
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolCallEndTag); endIdx != -1 {
|
||||
toolCallContent := bufStr[startIdx : startIdx+endIdx]
|
||||
|
||||
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
|
||||
remaining := bufStr[startIdx+endIdx+len(cogitoToolCallEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
events = append(events, cogitoEventToolCall{toolCall: toolCall})
|
||||
return events, true
|
||||
} else {
|
||||
slog.Warn("cogito tool call parsing failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idx := strings.Index(bufStr, cogitoToolCallsEndTag); idx != -1 {
|
||||
remaining := bufStr[idx+len(cogitoToolCallsEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = CogitoCollectingContent
|
||||
|
||||
return events, true
|
||||
}
|
||||
|
||||
return events, false
|
||||
|
||||
case CogitoCollectingToolOutput:
|
||||
if idx := strings.Index(bufStr, cogitoToolOutputBeginTag); idx != -1 {
|
||||
startIdx := idx + len(cogitoToolOutputBeginTag)
|
||||
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolOutputEndTag); endIdx != -1 {
|
||||
remaining := bufStr[startIdx+endIdx+len(cogitoToolOutputEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
return events, true
|
||||
}
|
||||
}
|
||||
|
||||
if idx := strings.Index(bufStr, cogitoToolOutputsEndTag); idx != -1 {
|
||||
remaining := bufStr[idx+len(cogitoToolOutputsEndTag):]
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = CogitoCollectingContent
|
||||
|
||||
return events, true
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
func (p *CogitoParser) parseToolCallContent(content string) (api.ToolCall, error) {
|
||||
// Expected format: function<|tool▁sep|>tool_name\n```json\n{args}\n```
|
||||
parts := strings.SplitN(content, cogitoToolSepTag, 2)
|
||||
if len(parts) < 2 {
|
||||
return api.ToolCall{}, errors.New("invalid format")
|
||||
}
|
||||
nameAndArgs := parts[1]
|
||||
|
||||
jsonStart := strings.Index(nameAndArgs, "\n```json\n")
|
||||
if jsonStart == -1 {
|
||||
return api.ToolCall{}, errors.New("invalid format")
|
||||
}
|
||||
toolName := strings.TrimSpace(nameAndArgs[:jsonStart])
|
||||
jsonContent := nameAndArgs[jsonStart+len("\n```json\n"):]
|
||||
|
||||
jsonEnd := strings.Index(jsonContent, "\n```")
|
||||
if jsonEnd == -1 {
|
||||
return api.ToolCall{}, errors.New("invalid format")
|
||||
}
|
||||
argsJSON := jsonContent[:jsonEnd]
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
565
model/parsers/cogito_test.go
Normal file
565
model/parsers/cogito_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestCogitoParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedToolCalls []api.ToolCall
|
||||
tools []api.Tool
|
||||
lastMessage *api.Message
|
||||
}{
|
||||
{
|
||||
name: "simple_content",
|
||||
input: "This is a simple response.",
|
||||
expectedContent: "This is a simple response.",
|
||||
expectedThinking: "",
|
||||
},
|
||||
{
|
||||
name: "thinking_only",
|
||||
input: "This is thinking content.</think>This is response content.",
|
||||
expectedContent: "This is response content.",
|
||||
expectedThinking: "This is thinking content.",
|
||||
},
|
||||
{
|
||||
name: "tool_call_simple",
|
||||
input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_with_tool_call",
|
||||
input: `I need to check the weather.</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`,
|
||||
expectedContent: "I need to check the weather.</think>",
|
||||
expectedThinking: "", // No thinking when tools are present (Cogito-specific behavior)
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_tool_calls",
|
||||
input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"London"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex_tool_arguments",
|
||||
input: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process_data
|
||||
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true,"threshold":0.95},"count":42}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|>`,
|
||||
expectedToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true, "threshold": 0.95},
|
||||
"count": 42.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool_output_parsing",
|
||||
input: `<|tool▁outputs▁begin|><|tool▁output▁begin|>{"temperature": 22, "condition": "sunny"}<|tool▁output▁end|><|tool▁outputs▁end|>`,
|
||||
expectedContent: "",
|
||||
expectedThinking: "",
|
||||
},
|
||||
{
|
||||
name: "thinking_with_multiline_content",
|
||||
input: `This is line 1
|
||||
This is line 2
|
||||
This is line 3</think>Final response here.`,
|
||||
expectedContent: "Final response here.",
|
||||
expectedThinking: "This is line 1\nThis is line 2\nThis is line 3",
|
||||
},
|
||||
{
|
||||
name: "no_thinking_simple",
|
||||
input: "This is content.",
|
||||
expectedContent: "This is content.",
|
||||
expectedThinking: "",
|
||||
},
|
||||
{
|
||||
name: "prefill_content_only",
|
||||
input: "Continuing from previous content.",
|
||||
expectedContent: "Continuing from previous content.",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Previous content",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prefill_with_thinking",
|
||||
input: "Continuing thinking</think>Continuing content.",
|
||||
expectedContent: "Continuing content.",
|
||||
expectedThinking: "Continuing thinking",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "nested_think_tags_in_thinking",
|
||||
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
|
||||
expectedContent: "more thinking</think>Final content.",
|
||||
expectedThinking: "I'm thinking <think>nested",
|
||||
},
|
||||
{
|
||||
name: "multiple_think_close_tags",
|
||||
input: "First thinking</think>Content</think>More content.",
|
||||
expectedContent: "Content</think>More content.",
|
||||
expectedThinking: "First thinking",
|
||||
},
|
||||
{
|
||||
name: "empty_thinking_content",
|
||||
input: "</think>Just content here.",
|
||||
expectedContent: "</think>Just content here.",
|
||||
expectedThinking: "",
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled_with_think_tags",
|
||||
input: "Content with </think> tags should be treated as content.",
|
||||
expectedContent: "Content with </think> tags should be treated as content.",
|
||||
expectedThinking: "",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
Content: "existing", // Forces non-thinking mode
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use thinking-enabled parser for tests that expect thinking
|
||||
hasThinking := tt.expectedThinking != ""
|
||||
parser := &CogitoParser{} // it has thinking support
|
||||
parser.Init(tt.tools, tt.lastMessage, &api.ThinkValue{Value: hasThinking}) // but we should set it with the request that the user wants
|
||||
|
||||
content, thinking, toolCalls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCogitoParser_Streaming(t *testing.T) {
|
||||
parser := &CogitoParser{}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
chunks := []string{
|
||||
"This is ",
|
||||
"thinking content",
|
||||
".</think>This is ",
|
||||
"content.<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test_tool\n```json\n{\"arg\":\"value\"}\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
}
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
var finalToolCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, thinking, toolCalls, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
finalToolCalls = append(finalToolCalls, toolCalls...)
|
||||
}
|
||||
|
||||
expectedContent := "This is content."
|
||||
expectedThinking := "This is thinking content."
|
||||
expectedToolCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"arg": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if finalContent.String() != expectedContent {
|
||||
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedToolCalls []api.ToolCall
|
||||
hasThinkingSupport bool
|
||||
}{
|
||||
{
|
||||
name: "split_thinking_tag",
|
||||
chunks: []string{
|
||||
"This is thinking content</thi",
|
||||
"nk>This is content.",
|
||||
},
|
||||
expectedContent: "This is content.",
|
||||
expectedThinking: "This is thinking content",
|
||||
hasThinkingSupport: true,
|
||||
},
|
||||
{
|
||||
name: "split_tool_calls_begin_tag_conservative_parsing",
|
||||
chunks: []string{
|
||||
"Content before<|tool▁calls▁beg",
|
||||
"in|><|tool▁call▁begin|>function<|tool▁sep|>test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
},
|
||||
// Parser is conservative - treats incomplete tags as content
|
||||
expectedContent: "Content before<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test\n```json\n{}\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
expectedToolCalls: nil,
|
||||
hasThinkingSupport: false,
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled_with_split_tags",
|
||||
chunks: []string{
|
||||
"Content with </thi",
|
||||
"nk> should be treated as content.",
|
||||
},
|
||||
expectedContent: "Content with </think> should be treated as content.",
|
||||
expectedThinking: "",
|
||||
hasThinkingSupport: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &CogitoParser{}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: tt.hasThinkingSupport})
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
var finalToolCalls []api.ToolCall
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, toolCalls, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
finalToolCalls = append(finalToolCalls, toolCalls...)
|
||||
}
|
||||
|
||||
if finalContent.String() != tt.expectedContent {
|
||||
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != tt.expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCogitoParser_HasToolSupport(t *testing.T) {
|
||||
parser := &CogitoParser{}
|
||||
if !parser.HasToolSupport() {
|
||||
t.Error("CogitoParser should support tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCogitoParser_Init(t *testing.T) {
|
||||
parser := &CogitoParser{}
|
||||
|
||||
tools := []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "test_tool"}},
|
||||
}
|
||||
|
||||
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
|
||||
|
||||
returnedTools := parser.Init(tools, lastMessage, nil)
|
||||
|
||||
if len(returnedTools) != len(tools) {
|
||||
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected api.ToolCall
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_tool_call_standard_format",
|
||||
content: `function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid_tool_call_complex_args",
|
||||
content: `function<|tool▁sep|>process_data
|
||||
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true},"count":42}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true},
|
||||
"count": 42.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid_tool_call_empty_args",
|
||||
content: `function<|tool▁sep|>no_args_tool
|
||||
` + "```json\n" + `{}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "no_args_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing_separator",
|
||||
content: `functionget_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid_function_type",
|
||||
content: `not_function<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing_json_block_start",
|
||||
content: `function<|tool▁sep|>get_weather{"location":"Paris"}` + "```",
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing_json_block_end",
|
||||
content: `function<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}`,
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid_json",
|
||||
content: `function<|tool▁sep|>get_weather` + "```json\n" + `{location:Paris}` + "\n```",
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty_function_type",
|
||||
content: `<|tool▁sep|>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
|
||||
expected: api.ToolCall{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "tool_with_spaces_in_name",
|
||||
content: `function<|tool▁sep|> get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "tool_with_multiline_json",
|
||||
content: `function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{
|
||||
"location": "Paris",
|
||||
"units": "metric"
|
||||
}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
"units": "metric",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "tool_with_nested_objects",
|
||||
content: `function<|tool▁sep|>complex_tool
|
||||
` + "```json\n" + `{"nested":{"deep":{"value":123}}}
|
||||
` + "```",
|
||||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "complex_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"nested": map[string]any{
|
||||
"deep": map[string]any{
|
||||
"value": 123.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &CogitoParser{}
|
||||
|
||||
result, err := parser.parseToolCallContent(tt.content)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
)
|
||||
|
||||
type Parser interface {
|
||||
// Init initializes the parser with tools and optional last message for chat prefill
|
||||
// Init initializes the parser with tools, optional last message for chat prefill, and think value
|
||||
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
|
||||
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
|
||||
Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool
|
||||
// Add processes streamed content and returns parsed content, thinking, and tool calls
|
||||
// The done flag indicates if this is the last chunk (used for draining accumulators)
|
||||
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||
@@ -52,6 +52,8 @@ func ParserForName(name string) Parser {
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
return harmony.NewHarmonyMessageHandler()
|
||||
case "cogito":
|
||||
return &CogitoParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -59,7 +61,7 @@ func ParserForName(name string) Parser {
|
||||
|
||||
type PassthroughParser struct{}
|
||||
|
||||
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
return tools // passthrough doesn't modify tools
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ type mockParser struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
return tools
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func transformToXML(raw string) string {
|
||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||
tag := groups[1]
|
||||
var escapedValue strings.Builder
|
||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||
_ = xml.EscapeText(&escapedValue, []byte(groups[2])) // error is always nil for strings.Builder
|
||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||
})
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
|
||||
p.state = CollectingThinkingContent
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestQwen3VLNonThinkingParserStreaming(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: false}
|
||||
parser.Init([]api.Tool{}, nil)
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
@@ -515,7 +515,7 @@ func TestQwenOldParserStreaming(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: false}
|
||||
parser.Init([]api.Tool{}, nil)
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
@@ -822,7 +822,7 @@ func TestQwen3VLNonThinkingToolCallWhitespaceHandling(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: false}
|
||||
parser.Init([]api.Tool{}, nil)
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
|
||||
@@ -205,7 +205,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, nil)
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
// parser.state = CollectingThinkingContent
|
||||
|
||||
for i, step := range tc.steps {
|
||||
@@ -386,7 +386,7 @@ func TestQwen3VLParserState(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
|
||||
parser.Init(nil, tc.last)
|
||||
parser.Init(nil, tc.last, nil)
|
||||
if parser.state != tc.wantState {
|
||||
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
|
||||
}
|
||||
@@ -437,7 +437,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
@@ -500,7 +500,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
@@ -523,7 +523,7 @@ func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
|
||||
// last message is assistant with content ⇒ start in CollectingContent
|
||||
last := &api.Message{Role: "assistant", Content: "has content"}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
parser.Init([]api.Tool{}, last, nil)
|
||||
|
||||
type step struct {
|
||||
input string
|
||||
@@ -750,7 +750,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, nil)
|
||||
parser.Init([]api.Tool{}, nil, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
@@ -859,7 +859,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) {
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, tc.prefillMsg)
|
||||
parser.Init([]api.Tool{}, tc.prefillMsg, nil)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
|
||||
129
model/renderers/cogito.go
Normal file
129
model/renderers/cogito.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type CogitoRenderer struct {
|
||||
isThinking bool
|
||||
}
|
||||
|
||||
func (r *CogitoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
defaultPrompt := "You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco."
|
||||
|
||||
// thinking is enabled: model must support it AND user must request it (true)
|
||||
enableThinking := r.isThinking && (thinkValue != nil && thinkValue.Bool())
|
||||
|
||||
var systemPrompt string
|
||||
var conversationMessages []api.Message
|
||||
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
systemPrompt = messages[0].Content
|
||||
conversationMessages = messages[1:]
|
||||
} else {
|
||||
conversationMessages = messages
|
||||
}
|
||||
|
||||
var finalSystemPrompt string
|
||||
if enableThinking {
|
||||
finalSystemPrompt = "Enable deep thinking subroutine.\n\n" + defaultPrompt
|
||||
if systemPrompt != "" {
|
||||
finalSystemPrompt += "\n\n" + systemPrompt + "\n\n"
|
||||
}
|
||||
} else {
|
||||
finalSystemPrompt = defaultPrompt
|
||||
if systemPrompt != "" {
|
||||
finalSystemPrompt += "\n\n" + systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
if finalSystemPrompt != "" {
|
||||
finalSystemPrompt += "\nYou have the following functions available:\n"
|
||||
} else {
|
||||
finalSystemPrompt = "You have the following functions available:\n"
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
toolJSON, _ := json.MarshalIndent(tool, "", " ") // TODO(gguo): double check json format
|
||||
finalSystemPrompt += "```json\n" + string(toolJSON) + "\n```\n"
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|begin▁of▁sentence|>" + finalSystemPrompt)
|
||||
|
||||
outputsOpen := false
|
||||
isLastUser := false
|
||||
|
||||
for i, message := range conversationMessages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
isLastUser = true
|
||||
sb.WriteString("<|User|>" + message.Content + "<|Assistant|>")
|
||||
|
||||
case "assistant":
|
||||
isLastUser = false
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
sb.WriteString("<|tool▁calls▁begin|>")
|
||||
|
||||
for j, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<|tool▁call▁begin|>function<|tool▁sep|>" + toolCall.Function.Name)
|
||||
|
||||
argsJSON, _ := json.Marshal(toolCall.Function.Arguments)
|
||||
sb.WriteString("\n```json\n" + string(argsJSON) + "\n```")
|
||||
sb.WriteString("<|tool▁call▁end|>")
|
||||
|
||||
if j < len(message.ToolCalls)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|tool▁calls▁end|><|end▁of▁sentence|>")
|
||||
} else {
|
||||
sb.WriteString(message.Content + "<|end▁of▁sentence|>")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
isLastUser = false
|
||||
|
||||
if !outputsOpen {
|
||||
sb.WriteString("<|tool▁outputs▁begin|>")
|
||||
outputsOpen = true
|
||||
}
|
||||
|
||||
sb.WriteString("<|tool▁output▁begin|>" + message.Content + "<|tool▁output▁end|>")
|
||||
|
||||
hasNextTool := i+1 < len(conversationMessages) && conversationMessages[i+1].Role == "tool"
|
||||
if hasNextTool {
|
||||
sb.WriteString("\n")
|
||||
} else {
|
||||
sb.WriteString("<|tool▁outputs▁end|>")
|
||||
outputsOpen = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if outputsOpen {
|
||||
sb.WriteString("<|tool▁outputs▁end|>")
|
||||
}
|
||||
|
||||
if !isLastUser {
|
||||
sb.WriteString("<|Assistant|>")
|
||||
}
|
||||
|
||||
if enableThinking {
|
||||
sb.WriteString("<think>\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
491
model/renderers/cogito_test.go
Normal file
491
model/renderers/cogito_test.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestCogitoRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic user message",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello, how are you?<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a helpful assistant.<|User|>Hello, how are you?<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "conversation with assistant response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>What is the capital of France?<|Assistant|>The capital of France is Paris.<|end▁of▁sentence|><|User|>Fantastic!<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "thinking enabled without system",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello, how are you?<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "thinking enabled with system",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
<|User|>Hello, how are you?<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "thinking disabled",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello, how are you?<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather like?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
You have the following functions available:
|
||||
` + "```json\n" + `{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"location"
|
||||
],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
` + "```\n" + `<|User|>What's the weather like?<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check the weather in Paris for you.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>What's the weather in Paris?<|Assistant|>I'll check the weather in Paris for you.<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "tool response",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Temperature: 22°C, Sunny"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>What's the weather in Paris?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "multiple tool responses",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris and London"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "London",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Paris: 22°C, Sunny"},
|
||||
{Role: "tool", Content: "London: 18°C, Cloudy"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Get weather for Paris and London<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"London"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>Paris: 22°C, Sunny<|tool▁output▁end|>
|
||||
<|tool▁output▁begin|>London: 18°C, Cloudy<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "thinking with tools",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather like?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
You have the following functions available:
|
||||
` + "```json\n" + `{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"location"
|
||||
],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
` + "```\n" + `<|User|>What's the weather like?<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
// test cases based on cogito
|
||||
{
|
||||
name: "single_turn_thinking_false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "single_turn_thinking_true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "multi_turn_thinking_false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello<|Assistant|>Hi there!<|end▁of▁sentence|><|User|>How are you?<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "multi_turn_thinking_true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello<|Assistant|>Hi there!<|end▁of▁sentence|><|User|>How are you?<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "multi_with_system_thinking_false",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: "Start"},
|
||||
{Role: "assistant", Content: "Okay"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a helpful assistant<|User|>Start<|Assistant|>Okay<|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "multi_with_system_thinking_true",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: "Start"},
|
||||
{Role: "assistant", Content: "Okay"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a helpful assistant
|
||||
|
||||
<|User|>Start<|Assistant|>Okay<|end▁of▁sentence|><|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "multi_with_system2_thinking_false",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
|
||||
{Role: "user", Content: "Give me a short introduction to LLMs."},
|
||||
{Role: "assistant", Content: "Arrr! I'm a pirate"},
|
||||
{Role: "user", Content: "Tell me more about LLMs."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a pirate chatbot who always responds in pirate speak!<|User|>Give me a short introduction to LLMs.<|Assistant|>Arrr! I'm a pirate<|end▁of▁sentence|><|User|>Tell me more about LLMs.<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "multi_with_system2_thinking_true",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
|
||||
{Role: "user", Content: "Give me a short introduction to LLMs."},
|
||||
{Role: "assistant", Content: "Arrr! I'm a pirate"},
|
||||
{Role: "user", Content: "Tell me more about LLMs."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
You are a pirate chatbot who always responds in pirate speak!
|
||||
|
||||
<|User|>Give me a short introduction to LLMs.<|Assistant|>Arrr! I'm a pirate<|end▁of▁sentence|><|User|>Tell me more about LLMs.<|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
// tools
|
||||
{
|
||||
name: "tool_calls_only_no_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Get weather for Paris"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Paris",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Get weather for Paris<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
` + "```json\n" + `{"location":"Paris"}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "complex_tool_arguments",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Process complex data"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2", "item3"},
|
||||
"config": map[string]any{
|
||||
"enabled": true,
|
||||
"threshold": 0.95,
|
||||
"tags": []string{"important", "urgent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Process complex data<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process_data
|
||||
` + "```json\n" + `{"config":{"enabled":true,"tags":["important","urgent"],"threshold":0.95},"items":["item1","item2","item3"]}
|
||||
` + "```" + `<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "empty_messages",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: ""},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: ""},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hello<|Assistant|><|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "thinking_with_empty_assistant_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Think about this"},
|
||||
{Role: "assistant", Content: ""},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `<|begin▁of▁sentence|>Enable deep thinking subroutine.
|
||||
|
||||
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Think about this<|Assistant|><|end▁of▁sentence|><|Assistant|><think>
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "multiple_system_messages",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "First instruction"},
|
||||
{Role: "system", Content: "Second instruction"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
|
||||
|
||||
First instruction<|User|>Hello<|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "special_characters_in_content",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What about <|special|> tokens and \"quotes\"?"},
|
||||
{Role: "assistant", Content: "They're handled normally in content."},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>What about <|special|> tokens and "quotes"?<|Assistant|>They're handled normally in content.<|end▁of▁sentence|><|Assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "long_conversation_multiple_rounds",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
{Role: "assistant", Content: "Good, thanks!"},
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `<|begin▁of▁sentence|>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<|User|>Hi<|Assistant|>Hello!<|end▁of▁sentence|><|User|>How are you?<|Assistant|>Good, thanks!<|end▁of▁sentence|><|User|>What's the weather?<|Assistant|>`,
|
||||
},
|
||||
}
|
||||
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Render() error = %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
|
||||
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
|
||||
case "qwen3-vl-thinking":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if addSpecial {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||
|
||||
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
if v.AddBOS && len(v.BOS) > 0 {
|
||||
if slices.Contains(v.BOS, ids[0]) {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
}
|
||||
|
||||
if v.AddEOS && len(v.EOS) > 0 {
|
||||
if slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestSpecialVocabulary(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||
@@ -14,3 +18,90 @@ func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddSpecialVocabulary(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
vocab *Vocabulary
|
||||
input []int32
|
||||
want []int32
|
||||
}{
|
||||
{
|
||||
name: "add bos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add bos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{0, 2, 3, 4},
|
||||
want: []int32{0, 0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "add eos",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
// TODO(mxyng): this is to match previous behaviour
|
||||
name: "add eos when already present",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4, 1},
|
||||
want: []int32{2, 3, 4, 1, 1},
|
||||
},
|
||||
{
|
||||
name: "add both",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
},
|
||||
input: []int32{2, 3, 4},
|
||||
want: []int32{0, 2, 3, 4, 1},
|
||||
},
|
||||
{
|
||||
name: "add bos to empty inputs",
|
||||
vocab: &Vocabulary{
|
||||
BOS: []int32{0},
|
||||
EOS: []int32{1},
|
||||
AddBOS: true,
|
||||
AddEOS: false,
|
||||
},
|
||||
input: []int32{},
|
||||
want: []int32{0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.vocab.addSpecials(tt.input)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("no match (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
vocab *Vocabulary
|
||||
lowercase bool
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
@@ -114,8 +115,10 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
// TODO: some models might not want [ToLower]
|
||||
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||
if wpm.lowercase {
|
||||
subword = strings.ToLower(subword)
|
||||
}
|
||||
piece = wpm.vocab.Encode(subword)
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
@@ -140,7 +143,7 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if addSpecial {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
@@ -160,8 +163,9 @@ func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
vocab: vocab,
|
||||
lowercase: lowercase,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,9 @@ func TestWordPiece(t *testing.T) {
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
})
|
||||
},
|
||||
true, // lowercase
|
||||
)
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
|
||||
@@ -175,7 +175,6 @@ func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType,
|
||||
origTensors := orig.Tensors().Items()
|
||||
outputTensors := make([]*fsggml.Tensor, len(origTensors))
|
||||
for i, tensor := range origTensors {
|
||||
tensor := tensor
|
||||
newType := newType(tensor, kv, qs, newFileType)
|
||||
newTensor := &fsggml.Tensor{
|
||||
Name: tensor.Name,
|
||||
|
||||
@@ -340,7 +340,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||
if builtinParser != nil {
|
||||
// no tools or last message for generate endpoint
|
||||
builtinParser.Init(nil, nil)
|
||||
builtinParser.Init(nil, nil, req.Think)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2051,7 +2051,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
// Initialize parser and get processed tools
|
||||
processedTools = builtinParser.Init(req.Tools, lastMessage)
|
||||
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -219,7 +219,6 @@ func TestParse(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range validCases {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user