mirror of
https://github.com/ollama/ollama.git
synced 2025-12-30 19:19:41 -05:00
Compare commits
90 Commits
jmorganca/
...
parth/agen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96d69ee2b2 | ||
|
|
89f74a8b05 | ||
|
|
ca43de117f | ||
|
|
7ff2b373f4 | ||
|
|
805177c054 | ||
|
|
6f9fc4e1bf | ||
|
|
fc62078ba4 | ||
|
|
d08c33faa0 | ||
|
|
253b035b4a | ||
|
|
d4f9bd5fe5 | ||
|
|
18fdcc94e5 | ||
|
|
7ad036992f | ||
|
|
172b5924af | ||
|
|
8852220f59 | ||
|
|
7325791599 | ||
|
|
522c11a763 | ||
|
|
0fadeffaee | ||
|
|
49a9c9ba6a | ||
|
|
1c094038bc | ||
|
|
a013693f80 | ||
|
|
f6a016f49d | ||
|
|
45c4739374 | ||
|
|
2dd029de12 | ||
|
|
903b1fc97f | ||
|
|
89eb795293 | ||
|
|
7e3ea813c1 | ||
|
|
7b95087b9d | ||
|
|
971d62595a | ||
|
|
ffbe8e076d | ||
|
|
2c639431b1 | ||
|
|
aacd1cb394 | ||
|
|
e3731fb160 | ||
|
|
8dbc9e7b68 | ||
|
|
abe67acf8a | ||
|
|
4ff8a691bc | ||
|
|
1b308e1d2a | ||
|
|
bd6c1d6b49 | ||
|
|
3af5d3b738 | ||
|
|
7730895158 | ||
|
|
de9ecfd01c | ||
|
|
95fdd8d619 | ||
|
|
9f7822851c | ||
|
|
9b2035d194 | ||
|
|
93d45d7a04 | ||
|
|
709f842457 | ||
|
|
2dfb74410d | ||
|
|
1eb5e75972 | ||
|
|
3475d915cb | ||
|
|
48e78e9be1 | ||
|
|
a838421ea3 | ||
|
|
1c4e85b4df | ||
|
|
dac4f17fea | ||
|
|
56b8fb024c | ||
|
|
b95693056c | ||
|
|
c34fc64688 | ||
|
|
7cf6f18c1f | ||
|
|
bbbb6b2a01 | ||
|
|
76f88caf43 | ||
|
|
2bccf8c624 | ||
|
|
0c5e5f6630 | ||
|
|
d475d1f081 | ||
|
|
d2f334c1f7 | ||
|
|
603ceefaa6 | ||
|
|
e082d60a24 | ||
|
|
5dae738067 | ||
|
|
0c78723174 | ||
|
|
5a41d69b2a | ||
|
|
c146a138e3 | ||
|
|
31b8c6a214 | ||
|
|
9191dfaf05 | ||
|
|
1108d8b34e | ||
|
|
7837a5bc7e | ||
|
|
0a844f8e96 | ||
|
|
a03223b86f | ||
|
|
0cf7794b16 | ||
|
|
854d40edc5 | ||
|
|
84a2cedf18 | ||
|
|
3f30836734 | ||
|
|
cc9555aff0 | ||
|
|
20aee96706 | ||
|
|
18b5958d46 | ||
|
|
5317202c38 | ||
|
|
d771043e88 | ||
|
|
f8f1071818 | ||
|
|
d3e0a0dee4 | ||
|
|
554172759c | ||
|
|
5b6a8e6001 | ||
|
|
467bbc0dd5 | ||
|
|
6d9f9323c5 | ||
|
|
0c2489605d |
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -19,6 +19,8 @@ ml/backend/**/*.comp linguist-vendored
|
||||
ml/backend/**/*.glsl linguist-vendored
|
||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||
|
||||
app/webview linguist-vendored
|
||||
|
||||
llama/build-info.cpp linguist-generated
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||
|
||||
|
||||
17
.github/workflows/release.yaml
vendored
17
.github/workflows/release.yaml
vendored
@@ -16,13 +16,15 @@ jobs:
|
||||
outputs:
|
||||
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
||||
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
||||
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set environment
|
||||
id: goflags
|
||||
run: |
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
|
||||
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-14-xlarge
|
||||
@@ -53,6 +55,9 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- run: |
|
||||
./scripts/build_darwin.sh
|
||||
- name: Log build results
|
||||
@@ -185,7 +190,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
@@ -249,6 +254,9 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- name: Verify gcc is actually clang
|
||||
run: |
|
||||
$ErrorActionPreference='Continue'
|
||||
@@ -302,6 +310,9 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows*
|
||||
|
||||
9
.github/workflows/test.yaml
vendored
9
.github/workflows/test.yaml
vendored
@@ -22,6 +22,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.changes.outputs.changed }}
|
||||
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -37,6 +38,7 @@ jobs:
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -83,7 +85,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: /github/home/.cache/ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
- run: |
|
||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||
@@ -178,7 +180,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
@@ -206,6 +208,9 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
@@ -1,77 +1,51 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- asasalint
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- copyloopvar
|
||||
- errcheck
|
||||
- errorlint
|
||||
- exptostd
|
||||
- gocheckcompilerdirectives
|
||||
- gocritic
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
- makezero
|
||||
- misspell
|
||||
- modernize
|
||||
- nilerr
|
||||
- nilnil
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- perfsprint
|
||||
- prealloc
|
||||
- sloglint
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unused
|
||||
- usestdlibvars
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
disable:
|
||||
- errcheck
|
||||
- usestdlibvars
|
||||
settings:
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
perfsprint:
|
||||
strconcat: false
|
||||
concat-loop: false
|
||||
govet:
|
||||
disable:
|
||||
- unusedresult
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
# Using a deprecated function, variable, constant or field.
|
||||
# https://staticcheck.dev/docs/checks/#SA1019
|
||||
- -QF* # disable quick fix suggestions
|
||||
- -SA1019
|
||||
# Incorrect or missing package comment.
|
||||
# https://staticcheck.dev/docs/checks/#ST1000
|
||||
- -ST1000
|
||||
# Poorly chosen identifier.
|
||||
# https://staticcheck.dev/docs/checks/#ST1003
|
||||
- -ST1003
|
||||
# The documentation of an exported function should start with the function's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1020
|
||||
- -ST1020
|
||||
# The documentation of an exported type should start with type's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1021
|
||||
- -ST1021
|
||||
# The documentation of an exported variable or constant should start with variable's name.
|
||||
# https://staticcheck.dev/docs/checks/#ST1022
|
||||
- -ST1022
|
||||
usestdlibvars:
|
||||
http-method: false
|
||||
http-status-code: false
|
||||
|
||||
- -ST1000 # package comment format
|
||||
- -ST1003 # underscores in package names
|
||||
- -ST1005 # error strings should not be capitalized
|
||||
- -ST1012 # error var naming (ErrFoo)
|
||||
- -ST1016 # receiver name consistency
|
||||
- -ST1020 # comment on exported function format
|
||||
- -ST1021 # comment on exported type format
|
||||
- -ST1022 # comment on exported var format
|
||||
- -ST1023 # omit type from declaration
|
||||
severity:
|
||||
default: error
|
||||
rules:
|
||||
- linters:
|
||||
- gofmt
|
||||
- goimports
|
||||
- intrange
|
||||
severity: info
|
||||
formatters:
|
||||
enable:
|
||||
- gci
|
||||
- gofmt
|
||||
- gofumpt
|
||||
settings:
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
|
||||
@@ -54,6 +54,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
# Define GGML version variables for shared library SOVERSION
|
||||
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 0)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
|
||||
FETCH_HEAD=ec98e2002
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
|
||||
$(WORKDIR):
|
||||
git clone $(UPSTREAM) $(WORKDIR)
|
||||
|
||||
.PHONE: format-patches
|
||||
.PHONY: format-patches
|
||||
format-patches: llama/patches
|
||||
git -C $(WORKDIR) format-patch \
|
||||
--no-signature \
|
||||
@@ -66,7 +66,11 @@ format-patches: llama/patches
|
||||
-o $(realpath $<) \
|
||||
$(FETCH_HEAD)
|
||||
|
||||
.PHONE: clean
|
||||
.PHONY: clean
|
||||
clean: checkout
|
||||
@git -C $(WORKDIR) am --abort || true
|
||||
$(RM) llama/patches/.*.patched
|
||||
|
||||
.PHONY: print-base
|
||||
print-base:
|
||||
@echo $(FETCH_HEAD)
|
||||
@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
||||
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
|
||||
@@ -226,7 +226,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: string(bts),
|
||||
}
|
||||
}
|
||||
return errors.New(string(bts))
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
@@ -340,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
|
||||
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||
// behaves similarly to other methods (see [Client.Pull]).
|
||||
//
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
|
||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
|
||||
@@ -55,6 +55,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
type testError struct {
|
||||
message string
|
||||
statusCode int
|
||||
raw bool // if true, write message as-is instead of JSON encoding
|
||||
}
|
||||
|
||||
func (e testError) Error() string {
|
||||
@@ -111,6 +112,20 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
responses: []any{
|
||||
"internal server error",
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
responses: []any{
|
||||
"<html><body>404 Not Found</body></html>",
|
||||
},
|
||||
wantErr: "404 Not Found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -135,6 +150,12 @@ func TestClientStream(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := resp.(string); ok {
|
||||
fmt.Fprintln(w, str)
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
@@ -173,9 +194,10 @@ func TestClientStream(t *testing.T) {
|
||||
|
||||
func TestClientDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
@@ -183,7 +205,8 @@ func TestClientDo(t *testing.T) {
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantErr: "test error message",
|
||||
wantErr: "test error message",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "server error response",
|
||||
@@ -191,7 +214,8 @@ func TestClientDo(t *testing.T) {
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
wantErr: "internal error",
|
||||
wantErr: "internal error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "successful response",
|
||||
@@ -203,6 +227,26 @@ func TestClientDo(t *testing.T) {
|
||||
Success: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
response: testError{
|
||||
message: "internal server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
response: testError{
|
||||
message: "<html><body>404 Not Found</body></html>",
|
||||
statusCode: http.StatusNotFound,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "<html><body>404 Not Found</body></html>",
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -210,11 +254,16 @@ func TestClientDo(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if errResp, ok := tc.response.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
if !errResp.raw {
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
} else {
|
||||
// Write raw message (simulates non-JSON error responses)
|
||||
fmt.Fprint(w, errResp.message)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -241,6 +290,15 @@ func TestClientDo(t *testing.T) {
|
||||
if err.Error() != tc.wantErr {
|
||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||
}
|
||||
if tc.wantStatusCode != 0 {
|
||||
if statusErr, ok := err.(StatusError); ok {
|
||||
if statusErr.StatusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("expected StatusError, got %T", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,19 +15,19 @@ func main() {
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Provide very brief, concise responses",
|
||||
},
|
||||
api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Name some unusual animals",
|
||||
},
|
||||
api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Monotreme, platypus, echidna",
|
||||
},
|
||||
api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "which of these is the most dangerous?",
|
||||
},
|
||||
|
||||
37
api/types.go
37
api/types.go
@@ -17,6 +17,12 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillRef is an alias for model.SkillRef representing a skill reference.
|
||||
type SkillRef = model.SkillRef
|
||||
|
||||
// MCPRef is an alias for model.MCPRef representing an MCP server reference.
|
||||
type MCPRef = model.MCPRef
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
@@ -283,11 +289,12 @@ func (pt PropertyType) String() string {
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
@@ -553,6 +560,21 @@ type CreateRequest struct {
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Skills is a list of skill references for the agent (local paths or registry refs)
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
|
||||
// MCPs is a list of MCP server references for the agent
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
|
||||
// AgentType defines the type of agent (e.g., "conversational", "task-based")
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
|
||||
// Entrypoint specifies an external command to run instead of the built-in chat loop
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -603,6 +625,11 @@ type ShowResponse struct {
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
Skills []SkillRef `json:"skills,omitempty"`
|
||||
MCPs []MCPRef `json:"mcps,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
Entrypoint string `json:"entrypoint,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
@@ -504,6 +504,107 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected ToolProperty
|
||||
}{
|
||||
{
|
||||
name: "nested object properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Location details",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "Street address"
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location details",
|
||||
Properties: map[string]ToolProperty{
|
||||
"address": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "Street address",
|
||||
},
|
||||
"city": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Event",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "object",
|
||||
"description": "Location",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "object",
|
||||
"description": "GPS coordinates",
|
||||
"properties": {
|
||||
"lat": {"type": "number", "description": "Latitude"},
|
||||
"lng": {"type": "number", "description": "Longitude"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Event",
|
||||
Properties: map[string]ToolProperty{
|
||||
"location": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location",
|
||||
Properties: map[string]ToolProperty{
|
||||
"coordinates": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "GPS coordinates",
|
||||
Properties: map[string]ToolProperty{
|
||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var prop ToolProperty
|
||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop)
|
||||
|
||||
// Round-trip test: marshal and unmarshal again
|
||||
data, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
|
||||
var prop2 ToolProperty
|
||||
err = json.Unmarshal(data, &prop2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -273,10 +273,6 @@ func main() {
|
||||
Handler: uiServer.Handler(),
|
||||
}
|
||||
|
||||
if _, err := uiServer.UserData(ctx); err != nil {
|
||||
slog.Warn("failed to load user data", "error", err)
|
||||
}
|
||||
|
||||
// Start the UI server
|
||||
slog.Info("starting ui server", "port", port)
|
||||
go func() {
|
||||
@@ -320,6 +316,17 @@ func main() {
|
||||
slog.Debug("no URL scheme request to handle")
|
||||
}
|
||||
|
||||
go func() {
|
||||
slog.Debug("waiting for ollama server to be ready")
|
||||
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
|
||||
slog.Warn("ollama server not ready, continuing anyway", "error", err)
|
||||
}
|
||||
|
||||
if _, err := uiServer.UserData(ctx); err != nil {
|
||||
slog.Warn("failed to load user data", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
osRun(cancel, hasCompletedFirstRun, startHidden)
|
||||
|
||||
slog.Info("shutting down desktop server")
|
||||
@@ -361,7 +368,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
|
||||
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
|
||||
if err != nil {
|
||||
slog.Debug("failed to call local auth endpoint", "error", err)
|
||||
return false
|
||||
|
||||
@@ -191,13 +191,6 @@ func LaunchNewApp() {
|
||||
C.launchApp(appName)
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
p := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(p))
|
||||
C.uiRequest(p)
|
||||
}
|
||||
|
||||
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
||||
// Remove any stale Login Item registrations
|
||||
C.unregisterSelfFromLoginItem()
|
||||
|
||||
@@ -263,11 +263,6 @@ func createLoginShortcut() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
wintray.SendUIRequestMessage(path)
|
||||
}
|
||||
|
||||
func LaunchNewApp() {
|
||||
}
|
||||
|
||||
|
||||
@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
|
||||
}
|
||||
|
||||
NSArray* urls = [panel URLs];
|
||||
if(self->params->allowMultiple && [urls count] >= 1) {
|
||||
if([urls count] == 0) {
|
||||
return DLG_CANCEL;
|
||||
}
|
||||
|
||||
if(self->params->allowMultiple) {
|
||||
// For multiple files, we need to return all paths separated by null bytes
|
||||
char* bufPtr = self->params->buf;
|
||||
int remainingBuf = self->params->nbuf;
|
||||
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
} else {
|
||||
// Single file/directory selection - write path to buffer
|
||||
NSURL* url = [urls firstObject];
|
||||
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
}
|
||||
|
||||
return DLG_OK;
|
||||
|
||||
@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
|
||||
type WinDlgError int
|
||||
|
||||
func (e WinDlgError) Error() string {
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
|
||||
}
|
||||
|
||||
func err() error {
|
||||
|
||||
@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
if _, err := os.Stat(settings.Models); err == nil {
|
||||
env["OLLAMA_MODELS"] = settings.Models
|
||||
} else {
|
||||
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
|
||||
settings.Models = ""
|
||||
s.store.SetSettings(settings)
|
||||
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
|
||||
}
|
||||
}
|
||||
if settings.ContextLength > 0 {
|
||||
|
||||
@@ -469,26 +469,24 @@ export class HealthResponse {
|
||||
}
|
||||
export class User {
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
avatarURL: string;
|
||||
plan: string;
|
||||
bio: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
overThreshold: boolean;
|
||||
name: string;
|
||||
bio?: string;
|
||||
avatarurl?: string;
|
||||
firstname?: string;
|
||||
lastname?: string;
|
||||
plan?: string;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.id = source["id"];
|
||||
this.name = source["name"];
|
||||
this.email = source["email"];
|
||||
this.avatarURL = source["avatarURL"];
|
||||
this.plan = source["plan"];
|
||||
this.name = source["name"];
|
||||
this.bio = source["bio"];
|
||||
this.firstName = source["firstName"];
|
||||
this.lastName = source["lastName"];
|
||||
this.overThreshold = source["overThreshold"];
|
||||
this.avatarurl = source["avatarurl"];
|
||||
this.firstname = source["firstname"];
|
||||
this.lastname = source["lastname"];
|
||||
this.plan = source["plan"];
|
||||
}
|
||||
}
|
||||
export class Attachment {
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
||||
import { ollamaClient as ollama } from "./lib/ollama-client";
|
||||
import type { ModelResponse } from "ollama/browser";
|
||||
import { API_BASE } from "./lib/config";
|
||||
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
|
||||
|
||||
// Extend Model class with utility methods
|
||||
declare module "@/gotypes" {
|
||||
@@ -27,7 +27,6 @@ declare module "@/gotypes" {
|
||||
Model.prototype.isCloud = function (): boolean {
|
||||
return this.model.endsWith("cloud");
|
||||
};
|
||||
|
||||
// Helper function to convert Uint8Array to base64
|
||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||
@@ -42,44 +41,50 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
}
|
||||
|
||||
export async function fetchUser(): Promise<User | null> {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/v1/me`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const userData: User = await response.json();
|
||||
return userData;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.error("Error fetching user:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchConnectUrl(): Promise<string> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/connect`, {
|
||||
method: "GET",
|
||||
const response = await fetch(`${API_BASE}/api/me`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch connect URL");
|
||||
if (response.ok) {
|
||||
const userData: User = await response.json();
|
||||
|
||||
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
|
||||
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
|
||||
}
|
||||
|
||||
return userData;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data.connect_url;
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
return null;
|
||||
}
|
||||
|
||||
throw new Error(`Failed to fetch user: ${response.status}`);
|
||||
}
|
||||
|
||||
export async function fetchConnectUrl(): Promise<string> {
|
||||
const response = await fetch(`${API_BASE}/api/me`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
const data = await response.json();
|
||||
if (data.signin_url) {
|
||||
return data.signin_url;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error("Failed to fetch connect URL");
|
||||
}
|
||||
|
||||
export async function disconnectUser(): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
|
||||
const response = await fetch(`${API_BASE}/api/signout`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
@@ -204,12 +209,10 @@ export async function* sendMessage(
|
||||
data: uint8ArrayToBase64(att.data),
|
||||
}));
|
||||
|
||||
// Only send think parameter when actually requesting thinking
|
||||
// Don't send false as it causes issues with some providers
|
||||
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
|
||||
const shouldSendThink =
|
||||
think !== undefined &&
|
||||
((typeof think === "boolean" && think) ||
|
||||
(typeof think === "string" && think !== ""));
|
||||
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
|
||||
|
||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||
method: "POST",
|
||||
@@ -391,7 +394,8 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
|
||||
export async function fetchHealth(): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/v1/health`, {
|
||||
// Use the /api/version endpoint as a health check
|
||||
const response = await fetch(`${API_BASE}/api/version`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
@@ -400,7 +404,8 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
return data.healthy || false;
|
||||
// If we get a version back, the server is healthy
|
||||
return !!data.version;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -299,9 +299,9 @@ export default function Settings() {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{user?.avatarURL && (
|
||||
{user?.avatarurl && (
|
||||
<img
|
||||
src={user.avatarURL}
|
||||
src={user.avatarurl}
|
||||
alt={user?.name}
|
||||
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
||||
onError={(e) => {
|
||||
|
||||
@@ -50,21 +50,33 @@ export default function Thinking({
|
||||
// Position content to show bottom when collapsed
|
||||
useEffect(() => {
|
||||
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
||||
const contentHeight = contentRef.current.scrollHeight;
|
||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||
if (contentHeight > wrapperHeight) {
|
||||
const translateY = -(contentHeight - wrapperHeight);
|
||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||
setHasOverflow(true);
|
||||
} else {
|
||||
setHasOverflow(false);
|
||||
}
|
||||
requestAnimationFrame(() => {
|
||||
if (!contentRef.current || !wrapperRef.current) return;
|
||||
|
||||
const contentHeight = contentRef.current.scrollHeight;
|
||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||
if (contentHeight > wrapperHeight) {
|
||||
const translateY = -(contentHeight - wrapperHeight);
|
||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||
setHasOverflow(true);
|
||||
} else {
|
||||
contentRef.current.style.transform = "translateY(0)";
|
||||
setHasOverflow(false);
|
||||
}
|
||||
});
|
||||
} else if (contentRef.current) {
|
||||
contentRef.current.style.transform = "translateY(0)";
|
||||
setHasOverflow(false);
|
||||
}
|
||||
}, [thinking, isCollapsed]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activelyThinking && wrapperRef.current && !isCollapsed) {
|
||||
// When expanded and actively thinking, scroll to bottom
|
||||
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
|
||||
}
|
||||
}, [thinking, activelyThinking, isCollapsed]);
|
||||
|
||||
const handleToggle = () => {
|
||||
setIsCollapsed(!isCollapsed);
|
||||
setHasUserInteracted(true);
|
||||
|
||||
@@ -7,6 +7,7 @@ import { createQueryBatcher } from "./useQueryBatcher";
|
||||
import { useRefetchModels } from "./useModels";
|
||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { getModelCapabilities } from "@/api";
|
||||
|
||||
export const useChats = () => {
|
||||
return useQuery({
|
||||
@@ -606,6 +607,24 @@ export const useSendMessage = (chatId: string) => {
|
||||
queryClient.setQueryData(["staleModels"], newStaleMap);
|
||||
|
||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
|
||||
// Fetch fresh capabilities for the downloaded model
|
||||
getModelCapabilities(selectedModel.model)
|
||||
.then((capabilities) => {
|
||||
queryClient.setQueryData(
|
||||
["modelCapabilities", selectedModel.model],
|
||||
capabilities,
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
"Failed to fetch capabilities after download:",
|
||||
error,
|
||||
);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["modelCapabilities", selectedModel.model],
|
||||
});
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useState } from "react";
|
||||
import { pullModel } from "@/api";
|
||||
import { useSelectedModel } from "./useSelectedModel";
|
||||
import { useSettings } from "./useSettings";
|
||||
|
||||
interface DownloadProgress {
|
||||
status: string;
|
||||
digest?: string;
|
||||
total?: number;
|
||||
completed?: number;
|
||||
done?: boolean;
|
||||
}
|
||||
|
||||
export function useDownloadModel(chatId?: string) {
|
||||
const queryClient = useQueryClient();
|
||||
const { selectedModel } = useSelectedModel(chatId);
|
||||
const { setSettings } = useSettings();
|
||||
const [downloadProgress, setDownloadProgress] =
|
||||
useState<DownloadProgress | null>(null);
|
||||
const [abortController, setAbortController] =
|
||||
useState<AbortController | null>(null);
|
||||
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
|
||||
new Set(),
|
||||
);
|
||||
|
||||
const mutation = useMutation({
|
||||
mutationFn: async (modelName: string) => {
|
||||
const controller = new AbortController();
|
||||
setAbortController(controller);
|
||||
setDownloadProgress({ status: "Starting download..." });
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
|
||||
}
|
||||
|
||||
try {
|
||||
for await (const progress of pullModel(modelName, controller.signal)) {
|
||||
setDownloadProgress(progress);
|
||||
|
||||
if (progress.status === "success") {
|
||||
// Update selected model to indicate it's now available locally
|
||||
if (selectedModel && selectedModel.model === modelName) {
|
||||
setSettings({ SelectedModel: modelName });
|
||||
}
|
||||
// Invalidate models query to refresh the list
|
||||
await queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
break;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
setAbortController(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
onSuccess: () => {
|
||||
setDownloadProgress(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
const status =
|
||||
error.name === "AbortError" ? "Download cancelled" : "Download failed";
|
||||
setDownloadProgress({ status, done: true });
|
||||
|
||||
// Clear error message after delay
|
||||
const delay = error.name === "AbortError" ? 1500 : 3000;
|
||||
setTimeout(() => {
|
||||
setDownloadProgress(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}, delay);
|
||||
},
|
||||
});
|
||||
|
||||
const cancelDownload = () => {
|
||||
if (abortController) {
|
||||
abortController.abort();
|
||||
setAbortController(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
downloadModel: mutation.mutate,
|
||||
isDownloading:
|
||||
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
|
||||
downloadProgress:
|
||||
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
|
||||
error: mutation.error,
|
||||
cancelDownload,
|
||||
};
|
||||
}
|
||||
@@ -1,29 +1,20 @@
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
||||
|
||||
export function useUser() {
|
||||
const queryClient = useQueryClient();
|
||||
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
|
||||
|
||||
// Wait for initial data to be loaded
|
||||
useEffect(() => {
|
||||
const initialPromise = window.__initialUserDataPromise;
|
||||
if (initialPromise) {
|
||||
initialPromise.finally(() => {
|
||||
setInitialDataLoaded(true);
|
||||
});
|
||||
} else {
|
||||
setInitialDataLoaded(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const userQuery = useQuery({
|
||||
queryKey: ["user"],
|
||||
queryFn: () => fetchUser(),
|
||||
queryFn: async () => {
|
||||
const result = await fetchUser();
|
||||
return result;
|
||||
},
|
||||
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
||||
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
||||
initialData: null, // Start with null to prevent flashing
|
||||
retry: 10,
|
||||
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
|
||||
refetchOnMount: true, // Always fetch when component mounts
|
||||
});
|
||||
|
||||
// Mutation to refresh user data
|
||||
@@ -49,14 +40,15 @@ export function useUser() {
|
||||
},
|
||||
});
|
||||
|
||||
const isLoading = userQuery.isLoading || userQuery.isFetching;
|
||||
const isAuthenticated = Boolean(userQuery.data?.name);
|
||||
|
||||
return {
|
||||
user: userQuery.data,
|
||||
isLoading:
|
||||
!initialDataLoaded ||
|
||||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
|
||||
isLoading,
|
||||
isError: userQuery.isError,
|
||||
error: userQuery.error,
|
||||
isAuthenticated: Boolean(userQuery.data?.name),
|
||||
isAuthenticated,
|
||||
refreshUser: refreshUser.mutate,
|
||||
isRefreshing: refreshUser.isPending,
|
||||
refetchUser: userQuery.refetch,
|
||||
|
||||
@@ -8,3 +8,6 @@ export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
|
||||
export const OLLAMA_HOST = import.meta.env.DEV
|
||||
? DEV_API_URL
|
||||
: window.location.origin;
|
||||
|
||||
export const OLLAMA_DOT_COM =
|
||||
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";
|
||||
|
||||
@@ -5,13 +5,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { routeTree } from "./routeTree.gen";
|
||||
import { fetchUser } from "./api";
|
||||
import { StreamingProvider } from "./contexts/StreamingContext";
|
||||
import { User } from "@/gotypes";
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
__initialUserDataPromise?: Promise<User | null>;
|
||||
}
|
||||
}
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
@@ -24,27 +17,11 @@ const queryClient = new QueryClient({
|
||||
},
|
||||
});
|
||||
|
||||
// Track initial user data fetch
|
||||
let initialUserDataPromise: Promise<User | null> | null = null;
|
||||
|
||||
// Initialize user data on app startup
|
||||
const initializeUserData = async () => {
|
||||
try {
|
||||
const userData = await fetchUser();
|
||||
fetchUser().then((userData) => {
|
||||
if (userData) {
|
||||
queryClient.setQueryData(["user"], userData);
|
||||
return userData;
|
||||
} catch (error) {
|
||||
console.error("Error initializing user data:", error);
|
||||
queryClient.setQueryData(["user"], null);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
// Start initialization immediately and track the promise
|
||||
initialUserDataPromise = initializeUserData();
|
||||
|
||||
// Export the promise so hooks can await it
|
||||
window.__initialUserDataPromise = initialUserDataPromise;
|
||||
});
|
||||
|
||||
const router = createRouter({
|
||||
routeTree,
|
||||
|
||||
@@ -101,15 +101,14 @@ type HealthResponse struct {
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
Plan string `json:"plan"`
|
||||
Bio string `json:"bio"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
OverThreshold bool `json:"overThreshold"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
AvatarURL string `json:"avatarurl,omitempty"`
|
||||
FirstName string `json:"firstname,omitempty"`
|
||||
LastName string `json:"lastname,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
type Attachment struct {
|
||||
|
||||
241
app/ui/ui.go
241
app/ui/ui.go
@@ -12,18 +12,17 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/auth"
|
||||
"github.com/ollama/ollama/app/server"
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
@@ -118,40 +117,66 @@ func (s *Server) log() *slog.Logger {
|
||||
|
||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||
func (s *Server) ollamaProxy() http.Handler {
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
if ollamaHost == "" {
|
||||
ollamaHost = "http://127.0.0.1:11434"
|
||||
}
|
||||
var (
|
||||
proxy http.Handler
|
||||
proxyMu sync.Mutex
|
||||
)
|
||||
|
||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
||||
ollamaHost = "http://" + ollamaHost
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyMu.Lock()
|
||||
p := proxy
|
||||
proxyMu.Unlock()
|
||||
|
||||
target, err := url.Parse(ollamaHost)
|
||||
if err != nil {
|
||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
if p == nil {
|
||||
proxyMu.Lock()
|
||||
if proxy == nil {
|
||||
var err error
|
||||
for i := range 2 {
|
||||
if i > 0 {
|
||||
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
err = WaitForServer(context.Background(), 10*time.Second)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
if err != nil {
|
||||
proxyMu.Unlock()
|
||||
s.log().Error("ollama server not ready after retries", "error", err)
|
||||
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
target := envconfig.Host()
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
return proxy
|
||||
originalDirector := newProxy.Director
|
||||
newProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
|
||||
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
proxy = newProxy
|
||||
p = newProxy
|
||||
} else {
|
||||
p = proxy
|
||||
}
|
||||
proxyMu.Unlock()
|
||||
}
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||
@@ -264,11 +289,10 @@ func (s *Server) Handler() http.Handler {
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
mux.Handle("GET /api/tags", ollamaProxy)
|
||||
mux.Handle("POST /api/show", ollamaProxy)
|
||||
|
||||
mux.Handle("GET /api/v1/me", handle(s.me))
|
||||
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
|
||||
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
|
||||
mux.Handle("GET /api/v1/health", handle(s.health))
|
||||
mux.Handle("GET /api/version", ollamaProxy)
|
||||
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||
mux.Handle("POST /api/me", ollamaProxy)
|
||||
mux.Handle("POST /api/signout", ollamaProxy)
|
||||
|
||||
// React app - catch all non-API routes and serve the React app
|
||||
mux.Handle("GET /", s.appHandler())
|
||||
@@ -338,7 +362,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
|
||||
}
|
||||
|
||||
// UserData fetches user data from ollama.com API for the current ollama key
|
||||
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
||||
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
||||
@@ -349,7 +373,7 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var user responses.User
|
||||
var user api.UserResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
||||
}
|
||||
@@ -368,29 +392,27 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func waitForServer(ctx context.Context) error {
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
// TODO: this avoids an error on first load of the app
|
||||
// however we should either show a loading state or
|
||||
// wait for the Ollama server to be ready before redirecting
|
||||
for {
|
||||
// WaitForServer waits for the Ollama server to be ready
|
||||
func WaitForServer(ctx context.Context, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.Version(ctx); err == nil {
|
||||
break
|
||||
}
|
||||
if time.Now().After(timeout) {
|
||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
||||
slog.Debug("ollama server is ready")
|
||||
return nil
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
return errors.New("timeout waiting for Ollama server to be ready")
|
||||
}
|
||||
|
||||
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
||||
waitForServer(r.Context())
|
||||
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
@@ -1438,129 +1460,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.UserData(r.Context())
|
||||
if err != nil {
|
||||
// If fetching from API fails, try to return cached user data if available
|
||||
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
|
||||
s.log().Info("API request failed, returning cached user data", "error", err)
|
||||
responseUser := &responses.User{
|
||||
Name: cachedUser.Name,
|
||||
Email: cachedUser.Email,
|
||||
Plan: cachedUser.Plan,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(responseUser)
|
||||
}
|
||||
|
||||
s.log().Error("failed to get user data", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to get user data",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.Store.ClearUser(); err != nil {
|
||||
s.log().Warn("failed to clear cached user data", "error", err)
|
||||
}
|
||||
|
||||
// Get the SSH public key to encode for the delete request
|
||||
pubKey, err := ollamaAuth.GetPublicKey()
|
||||
if err != nil {
|
||||
s.log().Error("failed to get public key", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to get public key",
|
||||
})
|
||||
}
|
||||
|
||||
// Encode the key using base64 URL encoding
|
||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
|
||||
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
|
||||
if err != nil {
|
||||
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to disconnect from ollama.com",
|
||||
})
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
s.log().Error("disconnect request failed", "status", resp.StatusCode)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to disconnect from ollama.com",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
|
||||
}
|
||||
|
||||
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
|
||||
if err != nil {
|
||||
s.log().Error("failed to build connect URL", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to build connect URL",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(map[string]string{
|
||||
"connect_url": connectURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
healthy := false
|
||||
c, err := api.ClientFromEnvironment()
|
||||
if err == nil {
|
||||
if _, err := c.Version(r.Context()); err == nil {
|
||||
healthy = true
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(responses.HealthResponse{
|
||||
Healthy: healthy,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
||||
case uint32(UI_REQUEST_MSG_ID):
|
||||
// Requests for the UI must always come from the main event thread
|
||||
l := int(wParam)
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
|
||||
t.app.UIRun(path)
|
||||
case WM_COPYDATA:
|
||||
// Handle URL scheme requests from other instances
|
||||
if lParam != 0 {
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
// Convert the data back to string
|
||||
data := make([]byte, cds.CbData)
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
|
||||
urlScheme := string(data)
|
||||
handleURLSchemeRequest(urlScheme)
|
||||
lResult = 1 // Return non-zero to indicate success
|
||||
|
||||
@@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa
|
||||
|
||||
```
|
||||
go build -o ollama-bench bench.go
|
||||
./bench -model gpt-oss:20b -epochs 6 -format csv
|
||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
||||
```
|
||||
|
||||
Using Go Run (without building)
|
||||
@@ -29,31 +29,32 @@ go run bench.go -model gpt-oss:20b -epochs 3
|
||||
### Basic Example
|
||||
|
||||
```
|
||||
./bench -model gemma3 -epochs 6
|
||||
./ollama-bench -model gemma3 -epochs 6
|
||||
```
|
||||
|
||||
### Benchmark Multiple Models
|
||||
|
||||
```
|
||||
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||
benchstat -col /name gemma.bench
|
||||
```
|
||||
|
||||
### With Image Prompt
|
||||
|
||||
```
|
||||
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||
```
|
||||
|
||||
### Advanced Example
|
||||
|
||||
```
|
||||
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
|
||||
| Option | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| -model | Comma-separated list of models to benchmark | (required) |
|
||||
| -epochs | Number of iterations per model | 1 |
|
||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||
|
||||
@@ -48,8 +48,8 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool)
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
printHeader := func() {
|
||||
fmt.Printf("sysname: %s\n", runtime.GOOS)
|
||||
fmt.Printf("machine: %s\n", runtime.GOARCH)
|
||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
||||
}
|
||||
once.Do(printHeader)
|
||||
}
|
||||
@@ -147,6 +147,17 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var out io.Writer = os.Stdout
|
||||
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
|
||||
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
out = f
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
for range *fOpt.epochs {
|
||||
options := make(map[string]interface{})
|
||||
@@ -241,13 +252,14 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||
},
|
||||
}
|
||||
|
||||
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
|
||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||
|
||||
if *fOpt.keepAlive > 0 {
|
||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
443
cmd/cmd.go
443
cmd/cmd.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -494,6 +495,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
// Check if this is an agent
|
||||
isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != ""
|
||||
if isAgent {
|
||||
opts.IsAgent = true
|
||||
opts.AgentType = info.AgentType
|
||||
opts.Skills = info.Skills
|
||||
opts.MCPs = info.MCPs
|
||||
opts.Entrypoint = info.Entrypoint
|
||||
}
|
||||
|
||||
// Check if this is an embedding model
|
||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||
|
||||
@@ -517,6 +528,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// If agent has entrypoint, run it instead of chat loop
|
||||
if opts.Entrypoint != "" {
|
||||
return runEntrypoint(cmd, opts)
|
||||
}
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -545,9 +561,62 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
// For agents, use chat API even in non-interactive mode to support tools
|
||||
if opts.IsAgent {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
|
||||
_, err := chat(cmd, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop.
|
||||
func runEntrypoint(cmd *cobra.Command, opts runOptions) error {
|
||||
entrypoint := opts.Entrypoint
|
||||
|
||||
// Check if entrypoint contains $PROMPT placeholder
|
||||
hasPlaceholder := strings.Contains(entrypoint, "$PROMPT")
|
||||
|
||||
if hasPlaceholder && opts.Prompt != "" {
|
||||
// Replace $PROMPT with the actual prompt
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt)
|
||||
} else if hasPlaceholder {
|
||||
// No prompt provided but placeholder exists - remove placeholder
|
||||
entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "")
|
||||
}
|
||||
|
||||
// Parse entrypoint into command and args
|
||||
parts := strings.Fields(entrypoint)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("empty entrypoint")
|
||||
}
|
||||
|
||||
command := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
// If user provided a prompt and no placeholder was used, append it as argument
|
||||
if opts.Prompt != "" && !hasPlaceholder {
|
||||
args = append(args, opts.Prompt)
|
||||
}
|
||||
|
||||
// Look up command in PATH
|
||||
execPath, err := exec.LookPath(command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("entrypoint command not found: %s", command)
|
||||
}
|
||||
|
||||
// Create subprocess
|
||||
proc := exec.Command(execPath, args...)
|
||||
proc.Stdin = os.Stdin
|
||||
proc.Stdout = os.Stdout
|
||||
proc.Stderr = os.Stderr
|
||||
|
||||
// Run and wait
|
||||
return proc.Run()
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -907,44 +976,96 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
// Only show Model section if there's actual model info (not for entrypoint-only agents)
|
||||
hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != ""
|
||||
if hasModelInfo {
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// Display agent information if this is an agent
|
||||
if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" {
|
||||
tableRender("Agent", func() (rows [][]string) {
|
||||
if resp.AgentType != "" {
|
||||
rows = append(rows, []string{"", "type", resp.AgentType})
|
||||
}
|
||||
if resp.Entrypoint != "" {
|
||||
rows = append(rows, []string{"", "entrypoint", resp.Entrypoint})
|
||||
}
|
||||
if len(resp.Skills) > 0 {
|
||||
for i, skill := range resp.Skills {
|
||||
label := "skill"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show skill name or digest
|
||||
skillDisplay := skill.Name
|
||||
if skillDisplay == "" && skill.Digest != "" {
|
||||
skillDisplay = skill.Digest[:12] + "..."
|
||||
}
|
||||
rows = append(rows, []string{"", label, skillDisplay})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
return
|
||||
})
|
||||
if len(resp.MCPs) > 0 {
|
||||
for i, mcp := range resp.MCPs {
|
||||
label := "mcp"
|
||||
if i > 0 {
|
||||
label = ""
|
||||
}
|
||||
// Show MCP name and command
|
||||
mcpDisplay := mcp.Name
|
||||
if mcp.Command != "" {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
mcpDisplay += " (" + cmdLine + ")"
|
||||
}
|
||||
rows = append(rows, []string{"", label, mcpDisplay})
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
@@ -1186,6 +1307,11 @@ type runOptions struct {
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
IsAgent bool
|
||||
AgentType string
|
||||
Skills []api.SkillRef
|
||||
MCPs []api.MCPRef
|
||||
Entrypoint string
|
||||
}
|
||||
|
||||
func (r runOptions) Copy() runOptions {
|
||||
@@ -1215,6 +1341,12 @@ func (r runOptions) Copy() runOptions {
|
||||
think = &cThink
|
||||
}
|
||||
|
||||
var skills []api.SkillRef
|
||||
if r.Skills != nil {
|
||||
skills = make([]api.SkillRef, len(r.Skills))
|
||||
copy(skills, r.Skills)
|
||||
}
|
||||
|
||||
return runOptions{
|
||||
Model: r.Model,
|
||||
ParentModel: r.ParentModel,
|
||||
@@ -1230,6 +1362,9 @@ func (r runOptions) Copy() runOptions {
|
||||
Think: think,
|
||||
HideThinking: r.HideThinking,
|
||||
ShowConnect: r.ShowConnect,
|
||||
IsAgent: r.IsAgent,
|
||||
AgentType: r.AgentType,
|
||||
Skills: skills,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1313,6 +1448,65 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load skills for agents
|
||||
var skillsCatalog *skillCatalog
|
||||
if opts.IsAgent && len(opts.Skills) > 0 {
|
||||
skillsCatalog, err = loadSkillsFromRefs(opts.Skills)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
if skillsCatalog != nil && len(skillsCatalog.Skills) > 0 {
|
||||
var skillNames []string
|
||||
for _, s := range skillsCatalog.Skills {
|
||||
skillNames = append(skillNames, s.Name)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Load MCP servers for agents (from opts and global config)
|
||||
var mcpMgr *mcpManager
|
||||
allMCPs := opts.MCPs
|
||||
|
||||
// Load global MCPs from ~/.ollama/mcp.json
|
||||
if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Skip disabled MCPs
|
||||
if srv.Disabled {
|
||||
continue
|
||||
}
|
||||
// Check if already in opts.MCPs (model takes precedence)
|
||||
found := false
|
||||
for _, m := range opts.MCPs {
|
||||
if m.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) > 0 {
|
||||
mcpMgr = newMCPManager()
|
||||
if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MCP servers: %w", err)
|
||||
}
|
||||
if mcpMgr.ToolCount() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n",
|
||||
strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount())
|
||||
}
|
||||
defer mcpMgr.Shutdown()
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
@@ -1336,6 +1530,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
|
||||
role := "assistant"
|
||||
|
||||
@@ -1376,7 +1571,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
if skillsCatalog != nil || mcpMgr != nil {
|
||||
// Store tool calls for execution after response is complete
|
||||
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||
} else {
|
||||
// No skills catalog or MCP, just display tool calls
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1389,31 +1590,159 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
opts.Format = `"` + opts.Format + `"`
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
// Prepare messages with agent-specific system prompt
|
||||
messages := opts.Messages
|
||||
if skillsCatalog != nil {
|
||||
// Add skills system prompt as the first system message
|
||||
skillsPrompt := skillsCatalog.SystemPrompt()
|
||||
if skillsPrompt != "" {
|
||||
// Insert skills prompt at the beginning, or append to existing system message
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
// Append to existing system message
|
||||
messages[0].Content = messages[0].Content + "\n\n" + skillsPrompt
|
||||
} else {
|
||||
// Insert new system message at the beginning
|
||||
systemMsg := api.Message{Role: "system", Content: skillsPrompt}
|
||||
messages = append([]api.Message{systemMsg}, messages...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
// Add tools for agents (combine skills and MCP tools)
|
||||
var allTools api.Tools
|
||||
if skillsCatalog != nil {
|
||||
allTools = append(allTools, skillsCatalog.Tools()...)
|
||||
}
|
||||
return nil, err
|
||||
if mcpMgr != nil {
|
||||
allTools = append(allTools, mcpMgr.Tools()...)
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
req.Tools = allTools
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute tool calls and continue the conversation
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Add assistant's tool call message to history
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: fullResponse.String(),
|
||||
ToolCalls: pendingToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Execute each tool call and collect results
|
||||
var toolResults []api.Message
|
||||
for _, call := range pendingToolCalls {
|
||||
// Show what's being executed
|
||||
switch call.Function.Name {
|
||||
case "run_skill_script":
|
||||
skill, _ := call.Function.Arguments["skill"].(string)
|
||||
command, _ := call.Function.Arguments["command"].(string)
|
||||
fmt.Fprintf(os.Stderr, "Running script in %s: %s\n", skill, command)
|
||||
case "read_skill_file":
|
||||
skill, _ := call.Function.Arguments["skill"].(string)
|
||||
path, _ := call.Function.Arguments["path"].(string)
|
||||
fmt.Fprintf(os.Stderr, "Reading file from %s: %s\n", skill, path)
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name)
|
||||
}
|
||||
|
||||
var result api.Message
|
||||
var handled bool
|
||||
var err error
|
||||
|
||||
// Try skill catalog first
|
||||
if skillsCatalog != nil {
|
||||
result, handled, err = skillsCatalog.RunToolCall(call)
|
||||
}
|
||||
|
||||
// If not handled by skills, try MCP
|
||||
if !handled && mcpMgr != nil {
|
||||
result, handled, err = mcpMgr.RunToolCall(call)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
// Add error result
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if !handled {
|
||||
fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Display tool output
|
||||
if result.Content != "" {
|
||||
fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content)
|
||||
}
|
||||
|
||||
// Add tool result to messages
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: result.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool results to message history
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Reset state for next iteration
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
|
||||
// Start new progress spinner for next API call
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
@@ -1430,7 +1759,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
@@ -1905,6 +2234,8 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
NewSkillCommand(),
|
||||
NewMCPCommand(),
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
||||
@@ -291,6 +291,31 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("min version", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := ` Model
|
||||
architecture test
|
||||
parameters 7B
|
||||
quantization FP16
|
||||
requires 0.14.0
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# eval
|
||||
|
||||
Evaluation tool for testing Ollama models.
|
||||
|
||||
## Usage
|
||||
|
||||
Run all tests:
|
||||
|
||||
```bash
|
||||
go run . -model llama3.2:latest
|
||||
```
|
||||
|
||||
Run specific suite:
|
||||
|
||||
```bash
|
||||
go run . -model llama3.2:latest -suite tool-calling-basic -v
|
||||
```
|
||||
|
||||
List available suites:
|
||||
|
||||
```bash
|
||||
go run . -list
|
||||
```
|
||||
|
||||
## Adding Tests
|
||||
|
||||
Edit `suites.go` to add new test suites. Each test needs:
|
||||
|
||||
- `Name`: test identifier
|
||||
- `Prompt`: what to send to the model
|
||||
- `Check`: function to validate the response
|
||||
|
||||
Example:
|
||||
|
||||
```go
|
||||
{
|
||||
Name: "my-test",
|
||||
Prompt: "What is 2+2?",
|
||||
Check: Contains("4"),
|
||||
}
|
||||
```
|
||||
|
||||
Available check functions:
|
||||
|
||||
- `HasResponse()` - response is non-empty
|
||||
- `Contains(s)` - response contains substring
|
||||
- `CallsTool(name)` - model called specific tool
|
||||
- `NoTools()` - model called no tools
|
||||
- `MinTools(n)` - model called at least n tools
|
||||
- `All(checks...)` - all checks pass
|
||||
151
cmd/eval/eval.go
151
cmd/eval/eval.go
@@ -1,151 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Test is a single evaluation test
|
||||
type Test struct {
|
||||
Name string
|
||||
Prompt string
|
||||
System string
|
||||
Tools []api.Tool
|
||||
Think bool
|
||||
Options map[string]any
|
||||
Check func(response string, tools []api.ToolCall) bool
|
||||
}
|
||||
|
||||
// Suite is a collection of tests
|
||||
type Suite struct {
|
||||
Name string
|
||||
Tests []Test
|
||||
}
|
||||
|
||||
// Result holds test execution results
|
||||
type Result struct {
|
||||
Name string
|
||||
Passed bool
|
||||
Error error
|
||||
Duration time.Duration
|
||||
Response string
|
||||
Tools []string
|
||||
ToolCalls []api.ToolCall
|
||||
Thinking bool
|
||||
}
|
||||
|
||||
// Run executes a test against a model
|
||||
func Run(ctx context.Context, client *api.Client, model string, test Test) Result {
|
||||
result := Result{Name: test.Name}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: test.Prompt},
|
||||
},
|
||||
Options: test.Options,
|
||||
}
|
||||
|
||||
if test.System != "" {
|
||||
req.Messages = append([]api.Message{
|
||||
{Role: "system", Content: test.System},
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
if len(test.Tools) > 0 {
|
||||
req.Tools = test.Tools
|
||||
}
|
||||
|
||||
if test.Think {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
var resp strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
start := time.Now()
|
||||
err := client.Chat(ctx, req, func(r api.ChatResponse) error {
|
||||
resp.WriteString(r.Message.Content)
|
||||
if r.Message.Thinking != "" {
|
||||
result.Thinking = true
|
||||
}
|
||||
toolCalls = append(toolCalls, r.Message.ToolCalls...)
|
||||
return nil
|
||||
})
|
||||
result.Duration = time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
result.Error = err
|
||||
return result
|
||||
}
|
||||
|
||||
result.Response = resp.String()
|
||||
result.Tools = uniqueToolNames(toolCalls)
|
||||
result.ToolCalls = toolCalls
|
||||
result.Passed = test.Check(result.Response, toolCalls)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func uniqueToolNames(calls []api.ToolCall) []string {
|
||||
seen := make(map[string]bool)
|
||||
var names []string
|
||||
for _, c := range calls {
|
||||
if !seen[c.Function.Name] {
|
||||
seen[c.Function.Name] = true
|
||||
names = append(names, c.Function.Name)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Check functions for common test patterns
|
||||
|
||||
func HasResponse() func(string, []api.ToolCall) bool {
|
||||
return func(resp string, _ []api.ToolCall) bool {
|
||||
return strings.TrimSpace(resp) != ""
|
||||
}
|
||||
}
|
||||
|
||||
func Contains(s string) func(string, []api.ToolCall) bool {
|
||||
return func(resp string, _ []api.ToolCall) bool {
|
||||
return strings.Contains(strings.ToLower(resp), strings.ToLower(s))
|
||||
}
|
||||
}
|
||||
|
||||
func CallsTool(name string) func(string, []api.ToolCall) bool {
|
||||
return func(_ string, tools []api.ToolCall) bool {
|
||||
for _, t := range tools {
|
||||
if t.Function.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func NoTools() func(string, []api.ToolCall) bool {
|
||||
return func(_ string, tools []api.ToolCall) bool {
|
||||
return len(tools) == 0
|
||||
}
|
||||
}
|
||||
|
||||
func MinTools(n int) func(string, []api.ToolCall) bool {
|
||||
return func(_ string, tools []api.ToolCall) bool {
|
||||
return len(tools) >= n
|
||||
}
|
||||
}
|
||||
|
||||
func All(checks ...func(string, []api.ToolCall) bool) func(string, []api.ToolCall) bool {
|
||||
return func(resp string, tools []api.ToolCall) bool {
|
||||
for _, check := range checks {
|
||||
if !check(resp, tools) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
217
cmd/eval/main.go
217
cmd/eval/main.go
@@ -1,217 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
model := flag.String("model", "", "model to evaluate")
|
||||
suite := flag.String("suite", "", "comma-separated list of suites to run (empty runs all)")
|
||||
list := flag.Bool("list", false, "list available suites")
|
||||
verbose := flag.Bool("v", false, "verbose output")
|
||||
timeout := flag.Int("timeout", 60, "timeout per test in seconds")
|
||||
export := flag.String("export", "eval-results.json", "export results to file")
|
||||
flag.Parse()
|
||||
|
||||
if *list {
|
||||
for _, s := range suites {
|
||||
fmt.Printf("%s (%d tests)\n", s.Name, len(s.Tests))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if *model == "" {
|
||||
fmt.Fprintf(os.Stderr, "error: -model parameter is required\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := client.Heartbeat(ctx); err != nil {
|
||||
cancel()
|
||||
fmt.Fprintf(os.Stderr, "error: cannot connect to ollama\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
|
||||
selected := suites
|
||||
if *suite != "" {
|
||||
suiteNames := strings.Split(*suite, ",")
|
||||
selected = []Suite{}
|
||||
var notFound []string
|
||||
|
||||
for _, name := range suiteNames {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, s := range suites {
|
||||
if s.Name == name {
|
||||
selected = append(selected, s)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
notFound = append(notFound, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notFound) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "error: suite(s) not found: %s\n", strings.Join(notFound, ", "))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var results []Result
|
||||
for _, s := range selected {
|
||||
if *verbose {
|
||||
fmt.Printf("\n%s (%d tests)\n", s.Name, len(s.Tests))
|
||||
}
|
||||
for i, test := range s.Tests {
|
||||
if test.Options == nil {
|
||||
test.Options = map[string]any{"temperature": 0.1}
|
||||
}
|
||||
if test.Check == nil {
|
||||
test.Check = HasResponse()
|
||||
}
|
||||
|
||||
if *verbose {
|
||||
fmt.Printf(" [%d/%d] %s... ", i+1, len(s.Tests), test.Name)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second)
|
||||
result := Run(ctx, client, *model, test)
|
||||
cancel()
|
||||
|
||||
results = append(results, result)
|
||||
|
||||
if *verbose {
|
||||
if result.Error != nil {
|
||||
fmt.Printf("ERROR: %v\n", result.Error)
|
||||
} else if result.Passed {
|
||||
fmt.Printf("PASS (%.2fs)", result.Duration.Seconds())
|
||||
if len(result.Tools) > 0 || result.Thinking {
|
||||
fmt.Printf(" [")
|
||||
if len(result.Tools) > 0 {
|
||||
fmt.Printf("tools: %s", strings.Join(result.Tools, ","))
|
||||
}
|
||||
if result.Thinking {
|
||||
if len(result.Tools) > 0 {
|
||||
fmt.Printf(", ")
|
||||
}
|
||||
fmt.Printf("thinking")
|
||||
}
|
||||
fmt.Printf("]")
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Print tool calls with details
|
||||
if len(result.ToolCalls) > 0 {
|
||||
fmt.Printf(" Tool Calls:\n")
|
||||
for _, tc := range result.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Function.Arguments)
|
||||
fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON))
|
||||
}
|
||||
}
|
||||
|
||||
// Print response if there is one
|
||||
if result.Response != "" {
|
||||
fmt.Printf(" Response: %s\n", result.Response)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("FAIL (%.2fs)\n", result.Duration.Seconds())
|
||||
|
||||
// Print tool calls with details even on failure
|
||||
if len(result.ToolCalls) > 0 {
|
||||
fmt.Printf(" Tool Calls:\n")
|
||||
for _, tc := range result.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Function.Arguments)
|
||||
fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON))
|
||||
}
|
||||
}
|
||||
|
||||
// Print response even on failure
|
||||
if result.Response != "" {
|
||||
fmt.Printf(" Response: %s\n", result.Response)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printSummary(results)
|
||||
|
||||
if *export != "" {
|
||||
if err := writeJSON(*export, results); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: export failed: %v\n", err)
|
||||
} else if *verbose {
|
||||
fmt.Printf("\nResults: %s\n", *export)
|
||||
}
|
||||
}
|
||||
|
||||
if anyFailed(results) {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func printSummary(results []Result) {
|
||||
var passed, failed, errors int
|
||||
for _, r := range results {
|
||||
if r.Error != nil {
|
||||
errors++
|
||||
} else if r.Passed {
|
||||
passed++
|
||||
} else {
|
||||
failed++
|
||||
}
|
||||
}
|
||||
|
||||
total := len(results)
|
||||
rate := 0.0
|
||||
if total > 0 {
|
||||
rate = float64(passed) / float64(total) * 100
|
||||
}
|
||||
|
||||
fmt.Printf("\n%d/%d passed (%.1f%%)", passed, total, rate)
|
||||
if errors > 0 {
|
||||
fmt.Printf(", %d errors", errors)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func anyFailed(results []Result) bool {
|
||||
for _, r := range results {
|
||||
if !r.Passed || r.Error != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeJSON(path string, results []Result) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(results)
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
var suites = []Suite{
|
||||
{
|
||||
Name: "basic-qa",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "simple-math",
|
||||
Prompt: "What is 2+2? Reply with just the number.",
|
||||
Check: Contains("4"),
|
||||
},
|
||||
{
|
||||
Name: "capital-city",
|
||||
Prompt: "What is the capital of France? Reply with just the city name.",
|
||||
Check: Contains("Paris"),
|
||||
},
|
||||
{
|
||||
Name: "greeting",
|
||||
Prompt: "Say hello",
|
||||
Check: HasResponse(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "reasoning",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "logic-puzzle",
|
||||
Prompt: "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Answer yes or no.",
|
||||
Check: Contains("no"),
|
||||
},
|
||||
{
|
||||
Name: "counting",
|
||||
Prompt: "How many letters are in the word 'HELLO'?",
|
||||
Check: Contains("5"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "instruction-following",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "json-output",
|
||||
Prompt: "Reply with a JSON object containing a 'status' field set to 'ok'.",
|
||||
Check: All(Contains("status"), Contains("ok")),
|
||||
},
|
||||
{
|
||||
Name: "system-prompt",
|
||||
Prompt: "What is your name?",
|
||||
System: "You are a helpful assistant named TestBot. When asked your name, always respond with 'TestBot'.",
|
||||
Check: Contains("TestBot"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "tool-calling-basic",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "single-tool",
|
||||
Prompt: "What's the weather like in San Francisco?",
|
||||
Tools: []api.Tool{weatherTool},
|
||||
Check: CallsTool("get_weather"),
|
||||
},
|
||||
{
|
||||
Name: "tool-selection",
|
||||
Prompt: "What time is it in Tokyo?",
|
||||
Tools: []api.Tool{weatherTool, timeTool},
|
||||
Check: CallsTool("get_time"),
|
||||
},
|
||||
{
|
||||
Name: "no-tool-needed",
|
||||
Prompt: "What is 2+2?",
|
||||
Tools: []api.Tool{weatherTool, timeTool},
|
||||
Check: NoTools(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "tool-calling-advanced",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "parallel-calls",
|
||||
Prompt: "Get the weather in both New York and Los Angeles.",
|
||||
Tools: []api.Tool{weatherTool},
|
||||
Check: All(CallsTool("get_weather"), MinTools(2)),
|
||||
},
|
||||
{
|
||||
Name: "multi-param",
|
||||
Prompt: "Search for Italian restaurants with prices between $20 and $40.",
|
||||
Tools: []api.Tool{restaurantTool},
|
||||
Check: CallsTool("search_restaurants"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "tool-calling-thinking",
|
||||
Tests: []Test{
|
||||
{
|
||||
Name: "thinking-before-tool",
|
||||
Prompt: "I need to know the weather in Paris before I decide what to pack.",
|
||||
Tools: []api.Tool{weatherTool},
|
||||
Think: true,
|
||||
Check: CallsTool("get_weather"),
|
||||
},
|
||||
{
|
||||
Name: "thinking-multi-tool",
|
||||
Prompt: "I'm planning a trip to London. I need to know what time it is there and what the weather is like.",
|
||||
Tools: []api.Tool{weatherTool, timeTool},
|
||||
Think: true,
|
||||
Check: MinTools(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var weatherTool = api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var timeTool = api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_time",
|
||||
Description: "Get the current time in a timezone",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"timezone"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"timezone": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The timezone name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var restaurantTool = api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search_restaurants",
|
||||
Description: "Search for restaurants",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"cuisine"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"cuisine": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Type of cuisine",
|
||||
},
|
||||
"min_price": {
|
||||
Type: api.PropertyType{"number"},
|
||||
Description: "Minimum price",
|
||||
},
|
||||
"max_price": {
|
||||
Type: api.PropertyType{"number"},
|
||||
Description: "Maximum price",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -34,6 +34,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /skills Show available skills")
|
||||
fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
@@ -443,6 +446,411 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
} else {
|
||||
usageShow()
|
||||
}
|
||||
case strings.HasPrefix(line, "/skill "):
|
||||
args := strings.Fields(line)
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current skills")
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill add <path>")
|
||||
continue
|
||||
}
|
||||
skillPath := args[2]
|
||||
|
||||
// Expand ~ to home directory
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
fmt.Printf("Error expanding path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
// Make absolute
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error resolving path: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify SKILL.md exists
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
fmt.Printf("Error: %s does not contain SKILL.md\n", skillPath)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract skill name from SKILL.md
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading SKILL.md: %v\n", err)
|
||||
continue
|
||||
}
|
||||
skillName, _ := extractSkillMetadata(string(content))
|
||||
if skillName == "" {
|
||||
skillName = filepath.Base(absPath)
|
||||
}
|
||||
|
||||
// Check if already added
|
||||
for _, s := range opts.Skills {
|
||||
if s.Name == skillName {
|
||||
fmt.Printf("Skill '%s' is already loaded\n", skillName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Add to skills (using path as Name, no digest for local skills)
|
||||
opts.Skills = append(opts.Skills, api.SkillRef{Name: absPath})
|
||||
opts.IsAgent = true // Enable agent mode if not already
|
||||
fmt.Printf("Added skill '%s' from %s\n", skillName, skillPath)
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /skill remove <name>")
|
||||
continue
|
||||
}
|
||||
skillName := args[2]
|
||||
|
||||
found := false
|
||||
newSkills := make([]api.SkillRef, 0, len(opts.Skills))
|
||||
for _, s := range opts.Skills {
|
||||
// Match by name or by path basename
|
||||
name := s.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name)
|
||||
}
|
||||
if name == skillName || s.Name == skillName {
|
||||
found = true
|
||||
fmt.Printf("Removed skill '%s'\n", skillName)
|
||||
} else {
|
||||
newSkills = append(newSkills, s)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
fmt.Printf("Skill '%s' not found\n", skillName)
|
||||
} else {
|
||||
opts.Skills = newSkills
|
||||
}
|
||||
|
||||
case "list", "ls":
|
||||
if len(opts.Skills) == 0 {
|
||||
fmt.Println("No skills loaded in this session.")
|
||||
} else {
|
||||
fmt.Println("Skills loaded in this session:")
|
||||
for _, skill := range opts.Skills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
// For local paths, show basename
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (local: " + skill.Name + ")"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown skill command '%s'. Use /skill add, /skill remove, or /skill list\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/skills"):
|
||||
// Show skills from model (bundled) + session skills
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model skills with session skills
|
||||
allSkills := make([]api.SkillRef, 0)
|
||||
allSkills = append(allSkills, resp.Skills...)
|
||||
|
||||
// Add session skills that aren't already in model skills
|
||||
for _, sessionSkill := range opts.Skills {
|
||||
found := false
|
||||
for _, modelSkill := range resp.Skills {
|
||||
if modelSkill.Name == sessionSkill.Name || modelSkill.Digest == sessionSkill.Digest {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allSkills = append(allSkills, sessionSkill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allSkills) == 0 {
|
||||
fmt.Println("No skills available.")
|
||||
} else {
|
||||
fmt.Println("Available Skills:")
|
||||
for _, skill := range allSkills {
|
||||
if skill.Digest != "" {
|
||||
fmt.Printf(" %s (%s)\n", skill.Name, skill.Digest[:19])
|
||||
} else {
|
||||
name := skill.Name
|
||||
if strings.Contains(name, string(os.PathSeparator)) {
|
||||
name = filepath.Base(name) + " (session)"
|
||||
}
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/mcp"):
|
||||
args := strings.Fields(line)
|
||||
|
||||
// If just "/mcp" with no args, show all MCP servers
|
||||
if len(args) == 1 {
|
||||
// Show MCPs from model (bundled) + global config
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get model info")
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine model MCPs with global config MCPs
|
||||
allMCPs := make([]api.MCPRef, 0)
|
||||
allMCPs = append(allMCPs, resp.MCPs...)
|
||||
|
||||
// Load global config
|
||||
globalConfig, _ := loadMCPConfig()
|
||||
globalMCPNames := make(map[string]bool)
|
||||
|
||||
if globalConfig != nil {
|
||||
for name, srv := range globalConfig.MCPServers {
|
||||
// Check if already in model MCPs
|
||||
found := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
allMCPs = append(allMCPs, api.MCPRef{
|
||||
Name: name,
|
||||
Command: srv.Command,
|
||||
Args: srv.Args,
|
||||
Env: srv.Env,
|
||||
Type: srv.Type,
|
||||
})
|
||||
}
|
||||
globalMCPNames[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allMCPs) == 0 {
|
||||
fmt.Println("No MCP servers available.")
|
||||
fmt.Println("Use '/mcp add <name> <command> [args...]' to add one.")
|
||||
} else {
|
||||
fmt.Println("Available MCP Servers:")
|
||||
for _, mcp := range allMCPs {
|
||||
cmdLine := mcp.Command
|
||||
if len(mcp.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(mcp.Args, " ")
|
||||
}
|
||||
source := ""
|
||||
disabled := ""
|
||||
// Check if it's from model or global config
|
||||
isFromModel := false
|
||||
for _, modelMCP := range resp.MCPs {
|
||||
if modelMCP.Name == mcp.Name {
|
||||
isFromModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFromModel {
|
||||
source = " (model)"
|
||||
} else if globalMCPNames[mcp.Name] {
|
||||
source = " (global)"
|
||||
// Check if disabled
|
||||
if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled {
|
||||
disabled = " [disabled]"
|
||||
}
|
||||
}
|
||||
fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
continue
|
||||
}
|
||||
|
||||
switch args[1] {
|
||||
case "add":
|
||||
if len(args) < 4 {
|
||||
fmt.Println("Usage: /mcp add <name> <command> [args...]")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
mcpCommand := args[3]
|
||||
mcpArgs := args[4:]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := config.MCPServers[mcpName]; exists {
|
||||
fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName)
|
||||
}
|
||||
|
||||
// Add to global config
|
||||
config.MCPServers[mcpName] = MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: mcpCommand,
|
||||
Args: mcpArgs,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmdLine := mcpCommand
|
||||
if len(mcpArgs) > 0 {
|
||||
cmdLine += " " + strings.Join(mcpArgs, " ")
|
||||
}
|
||||
fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath())
|
||||
fmt.Println("Note: MCP server will be started on next message.")
|
||||
|
||||
case "remove", "rm":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp remove <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
// Load global config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := config.MCPServers[mcpName]; !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
delete(config.MCPServers, mcpName)
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath())
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "disable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp disable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already disabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = true
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Disabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
case "enable":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /mcp enable <name>")
|
||||
continue
|
||||
}
|
||||
mcpName := args[2]
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
srv, exists := config.MCPServers[mcpName]
|
||||
if !exists {
|
||||
fmt.Printf("MCP server '%s' not found in global config\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
if !srv.Disabled {
|
||||
fmt.Printf("MCP server '%s' is already enabled\n", mcpName)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = false
|
||||
config.MCPServers[mcpName] = srv
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
fmt.Printf("Error saving MCP config: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Enabled MCP server '%s'\n", mcpName)
|
||||
fmt.Println("Note: Changes will take effect on next message.")
|
||||
|
||||
default:
|
||||
fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1])
|
||||
}
|
||||
continue
|
||||
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
@@ -451,6 +859,20 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usageSet()
|
||||
case "show", "/show":
|
||||
usageShow()
|
||||
case "skill", "/skill":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /skill add <path> Add a skill from local path")
|
||||
fmt.Fprintln(os.Stderr, " /skill remove <name> Remove a skill by name")
|
||||
fmt.Fprintln(os.Stderr, " /skill list List current session skills")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "mcp", "/mcp":
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers")
|
||||
fmt.Fprintln(os.Stderr, " /mcp add <name> <command> [args...] Add an MCP server to global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp remove <name> Remove an MCP server from global config")
|
||||
fmt.Fprintln(os.Stderr, " /mcp disable <name> Disable an MCP server (keep in config)")
|
||||
fmt.Fprintln(os.Stderr, " /mcp enable <name> Re-enable a disabled MCP server")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
case "shortcut", "shortcuts":
|
||||
usageShortcuts()
|
||||
}
|
||||
|
||||
545
cmd/mcp.go
Normal file
545
cmd/mcp.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
mcpInitTimeout = 30 * time.Second
|
||||
mcpCallTimeout = 60 * time.Second
|
||||
mcpShutdownTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// JSON-RPC types
|
||||
type jsonrpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type jsonrpcResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *jsonrpcError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type jsonrpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// MCP protocol types
|
||||
type mcpInitializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
ClientInfo mcpClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
type mcpClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpInitializeResult struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities mcpCapabilities `json:"capabilities"`
|
||||
ServerInfo mcpServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
type mcpCapabilities struct {
|
||||
Tools *mcpToolsCapability `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolsCapability struct {
|
||||
ListChanged bool `json:"listChanged,omitempty"`
|
||||
}
|
||||
|
||||
type mcpServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema mcpToolInputSchema `json:"inputSchema"`
|
||||
}
|
||||
|
||||
type mcpToolInputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolsListResult struct {
|
||||
Tools []mcpTool `json:"tools"`
|
||||
}
|
||||
|
||||
type mcpToolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type mcpToolCallResult struct {
|
||||
Content []mcpContent `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type mcpContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// mcpServer represents a running MCP server process
|
||||
type mcpServer struct {
|
||||
ref api.MCPRef
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
tools []mcpTool
|
||||
mu sync.Mutex
|
||||
nextID int
|
||||
started bool
|
||||
}
|
||||
|
||||
// mcpManager manages multiple MCP servers for an agent session
|
||||
type mcpManager struct {
|
||||
servers map[string]*mcpServer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// newMCPManager creates a new MCP manager
|
||||
func newMCPManager() *mcpManager {
|
||||
return &mcpManager{
|
||||
servers: make(map[string]*mcpServer),
|
||||
}
|
||||
}
|
||||
|
||||
// loadMCPsFromRefs initializes MCP servers from refs
|
||||
func (m *mcpManager) loadMCPsFromRefs(refs []api.MCPRef) error {
|
||||
if len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ref := range refs {
|
||||
if err := m.addServer(ref); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to initialize MCP server %q: %v\n", ref.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addServer adds and starts an MCP server
|
||||
func (m *mcpManager) addServer(ref api.MCPRef) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.servers[ref.Name]; exists {
|
||||
return fmt.Errorf("MCP server %q already exists", ref.Name)
|
||||
}
|
||||
|
||||
srv := &mcpServer{
|
||||
ref: ref,
|
||||
nextID: 1,
|
||||
}
|
||||
|
||||
if err := srv.start(); err != nil {
|
||||
return fmt.Errorf("starting MCP server: %w", err)
|
||||
}
|
||||
|
||||
m.servers[ref.Name] = srv
|
||||
return nil
|
||||
}
|
||||
|
||||
// start starts the MCP server process
|
||||
func (s *mcpServer) start() error {
|
||||
s.mu.Lock()
|
||||
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
s.cmd = exec.Command(s.ref.Command, s.ref.Args...)
|
||||
|
||||
// Set environment
|
||||
s.cmd.Env = os.Environ()
|
||||
for k, v := range s.ref.Env {
|
||||
s.cmd.Env = append(s.cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
|
||||
var err error
|
||||
s.stdin, err = s.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := s.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stdout pipe: %w", err)
|
||||
}
|
||||
s.stdout = bufio.NewReader(stdout)
|
||||
|
||||
s.stderr, err = s.cmd.StderrPipe()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("creating stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Start stderr reader goroutine (discard stderr for now)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(s.stderr)
|
||||
for scanner.Scan() {
|
||||
_ = scanner.Text()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := s.cmd.Start(); err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("starting process: %w", err)
|
||||
}
|
||||
|
||||
s.started = true
|
||||
s.mu.Unlock() // Release lock before calling initialize/listTools which use the mutex
|
||||
|
||||
// Initialize the server
|
||||
if err := s.initialize(); err != nil {
|
||||
s.stop()
|
||||
return fmt.Errorf("initializing MCP server: %w", err)
|
||||
}
|
||||
|
||||
// Get available tools
|
||||
if err := s.listTools(); err != nil {
|
||||
s.stop()
|
||||
return fmt.Errorf("listing tools: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initialize sends the MCP initialize request
|
||||
func (s *mcpServer) initialize() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
||||
defer cancel()
|
||||
|
||||
params := mcpInitializeParams{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]any{},
|
||||
ClientInfo: mcpClientInfo{
|
||||
Name: "ollama",
|
||||
Version: "0.1.0",
|
||||
},
|
||||
}
|
||||
|
||||
var result mcpInitializeResult
|
||||
if err := s.call(ctx, "initialize", params, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send initialized notification
|
||||
return s.notify("notifications/initialized", nil)
|
||||
}
|
||||
|
||||
// listTools fetches the available tools from the MCP server
|
||||
func (s *mcpServer) listTools() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
||||
defer cancel()
|
||||
|
||||
var result mcpToolsListResult
|
||||
if err := s.call(ctx, "tools/list", nil, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.tools = result.Tools
|
||||
return nil
|
||||
}
|
||||
|
||||
// call sends a JSON-RPC request and waits for the response
|
||||
func (s *mcpServer) call(ctx context.Context, method string, params any, result any) error {
|
||||
s.mu.Lock()
|
||||
id := s.nextID
|
||||
s.nextID++
|
||||
s.mu.Unlock()
|
||||
|
||||
req := jsonrpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Send request
|
||||
s.mu.Lock()
|
||||
_, err = s.stdin.Write(append(reqBytes, '\n'))
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing request: %w", err)
|
||||
}
|
||||
|
||||
// Read response with timeout
|
||||
respCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
s.mu.Lock()
|
||||
line, err := s.stdout.ReadBytes('\n')
|
||||
s.mu.Unlock()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
respCh <- line
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
return fmt.Errorf("reading response: %w", err)
|
||||
case line := <-respCh:
|
||||
var resp jsonrpcResponse
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return fmt.Errorf("unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("MCP error %d: %s", resp.Error.Code, resp.Error.Message)
|
||||
}
|
||||
|
||||
if result != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, result); err != nil {
|
||||
return fmt.Errorf("unmarshaling result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends a JSON-RPC notification (no response expected)
|
||||
func (s *mcpServer) notify(method string, params any) error {
|
||||
req := jsonrpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling notification: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, err := s.stdin.Write(append(reqBytes, '\n')); err != nil {
|
||||
return fmt.Errorf("writing notification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callTool executes a tool call on the MCP server
|
||||
func (s *mcpServer) callTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
|
||||
params := mcpToolCallParams{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
var result mcpToolCallResult
|
||||
if err := s.call(ctx, "tools/call", params, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Concatenate text content
|
||||
var sb strings.Builder
|
||||
for _, content := range result.Content {
|
||||
if content.Type == "text" {
|
||||
sb.WriteString(content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return sb.String(), errors.New(sb.String())
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// stop shuts down the MCP server
|
||||
func (s *mcpServer) stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stdin to signal shutdown
|
||||
if s.stdin != nil {
|
||||
s.stdin.Close()
|
||||
}
|
||||
|
||||
// Wait for process with timeout
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- s.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(mcpShutdownTimeout):
|
||||
s.cmd.Process.Kill()
|
||||
case <-done:
|
||||
}
|
||||
|
||||
s.started = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tools returns all tools from all MCP servers as api.Tools
|
||||
func (m *mcpManager) Tools() api.Tools {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var tools api.Tools
|
||||
|
||||
for serverName, srv := range m.servers {
|
||||
for _, t := range srv.tools {
|
||||
// Namespace tool names: mcp_{servername}_{toolname}
|
||||
namespacedName := fmt.Sprintf("mcp_%s_%s", serverName, t.Name)
|
||||
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: namespacedName,
|
||||
Description: t.Description,
|
||||
Parameters: convertMCPSchema(t.InputSchema),
|
||||
},
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// convertMCPSchema converts MCP input schema to api.ToolFunctionParameters
|
||||
func convertMCPSchema(schema mcpToolInputSchema) api.ToolFunctionParameters {
|
||||
params := api.ToolFunctionParameters{
|
||||
Type: schema.Type,
|
||||
Required: schema.Required,
|
||||
Properties: make(map[string]api.ToolProperty),
|
||||
}
|
||||
|
||||
for name, prop := range schema.Properties {
|
||||
if propMap, ok := prop.(map[string]any); ok {
|
||||
tp := api.ToolProperty{}
|
||||
if t, ok := propMap["type"].(string); ok {
|
||||
tp.Type = api.PropertyType{t}
|
||||
}
|
||||
if d, ok := propMap["description"].(string); ok {
|
||||
tp.Description = d
|
||||
}
|
||||
params.Properties[name] = tp
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// RunToolCall routes a tool call to the appropriate MCP server
|
||||
func (m *mcpManager) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
||||
name := call.Function.Name
|
||||
|
||||
// Check if this is an MCP tool (mcp_servername_toolname)
|
||||
if !strings.HasPrefix(name, "mcp_") {
|
||||
return api.Message{}, false, nil
|
||||
}
|
||||
|
||||
// Parse server name and tool name
|
||||
rest := strings.TrimPrefix(name, "mcp_")
|
||||
idx := strings.Index(rest, "_")
|
||||
if idx == -1 {
|
||||
return toolMessage(call, fmt.Sprintf("invalid MCP tool name: %s", name)), true, nil
|
||||
}
|
||||
|
||||
serverName := rest[:idx]
|
||||
toolName := rest[idx+1:]
|
||||
|
||||
m.mu.RLock()
|
||||
srv, ok := m.servers[serverName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("MCP server %q not found", serverName)), true, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := srv.callTool(ctx, toolName, call.Function.Arguments)
|
||||
if err != nil {
|
||||
return toolMessage(call, fmt.Sprintf("error: %v", err)), true, nil
|
||||
}
|
||||
|
||||
return toolMessage(call, result), true, nil
|
||||
}
|
||||
|
||||
// Shutdown stops all MCP servers
|
||||
func (m *mcpManager) Shutdown() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, srv := range m.servers {
|
||||
srv.stop()
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*mcpServer)
|
||||
}
|
||||
|
||||
// ServerNames returns the names of all running MCP servers
|
||||
func (m *mcpManager) ServerNames() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(m.servers))
|
||||
for name := range m.servers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ToolCount returns the total number of tools across all servers
|
||||
func (m *mcpManager) ToolCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, srv := range m.servers {
|
||||
count += len(srv.tools)
|
||||
}
|
||||
return count
|
||||
}
|
||||
898
cmd/mcp_cmd.go
Normal file
898
cmd/mcp_cmd.go
Normal file
@@ -0,0 +1,898 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// MCPConfigFile represents the global MCP configuration file structure.
|
||||
type MCPConfigFile struct {
|
||||
MCPServers map[string]MCPServerConfig `json:"mcpServers"`
|
||||
}
|
||||
|
||||
// MCPServerConfig represents a single MCP server configuration.
|
||||
type MCPServerConfig struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
// getMCPConfigPath returns the path to the global MCP config file.
|
||||
func getMCPConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "mcp.json")
|
||||
}
|
||||
|
||||
// loadMCPConfig loads the global MCP configuration file.
|
||||
func loadMCPConfig() (*MCPConfigFile, error) {
|
||||
configPath := getMCPConfigPath()
|
||||
if configPath == "" {
|
||||
return nil, fmt.Errorf("could not determine home directory")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Return empty config if file doesn't exist
|
||||
return &MCPConfigFile{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("reading config: %w", err)
|
||||
}
|
||||
|
||||
var config MCPConfigFile
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if config.MCPServers == nil {
|
||||
config.MCPServers = make(map[string]MCPServerConfig)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// saveMCPConfig saves the global MCP configuration file.
|
||||
func saveMCPConfig(config *MCPConfigFile) error {
|
||||
configPath := getMCPConfigPath()
|
||||
if configPath == "" {
|
||||
return fmt.Errorf("could not determine home directory")
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating config directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
return fmt.Errorf("writing config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPAddHandler handles the mcp add command.
|
||||
func MCPAddHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: ollama mcp add NAME COMMAND [ARGS...]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
command := args[1]
|
||||
cmdArgs := args[2:]
|
||||
|
||||
// Load existing config
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := config.MCPServers[name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: overwriting existing MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
// Add the new server
|
||||
config.MCPServers[name] = MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: command,
|
||||
Args: cmdArgs,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
configPath := getMCPConfigPath()
|
||||
fmt.Fprintf(os.Stderr, "Added MCP server '%s' to %s\n", name, configPath)
|
||||
fmt.Fprintf(os.Stderr, " Command: %s %s\n", command, strings.Join(cmdArgs, " "))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPRemoveGlobalHandler handles removing an MCP from global config.
|
||||
func MCPRemoveGlobalHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp remove-global NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
if _, exists := config.MCPServers[name]; !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
delete(config.MCPServers, name)
|
||||
fmt.Fprintf(os.Stderr, "Removed MCP server '%s' from global config\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPListGlobalHandler handles listing global MCP servers.
|
||||
func MCPListGlobalHandler(cmd *cobra.Command, args []string) error {
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
if len(config.MCPServers) == 0 {
|
||||
fmt.Println("No global MCP servers configured")
|
||||
fmt.Printf("Add one with: ollama mcp add NAME COMMAND [ARGS...]\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Global MCP servers (%s):\n\n", getMCPConfigPath())
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tCOMMAND\tSTATUS")
|
||||
|
||||
for name, srv := range config.MCPServers {
|
||||
cmdLine := srv.Command
|
||||
if len(srv.Args) > 0 {
|
||||
cmdLine += " " + strings.Join(srv.Args, " ")
|
||||
}
|
||||
status := "enabled"
|
||||
if srv.Disabled {
|
||||
status = "disabled"
|
||||
}
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\n", name, cmdLine, status)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// MCPDisableHandler handles disabling an MCP server in global config.
|
||||
func MCPDisableHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp disable NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
srv, exists := config.MCPServers[name]
|
||||
if !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
if srv.Disabled {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' is already disabled\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = true
|
||||
config.MCPServers[name] = srv
|
||||
fmt.Fprintf(os.Stderr, "Disabled MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPEnableHandler handles enabling an MCP server in global config.
|
||||
func MCPEnableHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp enable NAME [NAME...]")
|
||||
}
|
||||
|
||||
config, err := loadMCPConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
srv, exists := config.MCPServers[name]
|
||||
if !exists {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
if !srv.Disabled {
|
||||
fmt.Fprintf(os.Stderr, "MCP server '%s' is already enabled\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
srv.Disabled = false
|
||||
config.MCPServers[name] = srv
|
||||
fmt.Fprintf(os.Stderr, "Enabled MCP server '%s'\n", name)
|
||||
}
|
||||
|
||||
if err := saveMCPConfig(config); err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPPushHandler handles the mcp push command.
|
||||
func MCPPushHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("usage: ollama mcp push NAME[:TAG] PATH")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
path := args[1]
|
||||
|
||||
// Expand path
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving path: %w", err)
|
||||
}
|
||||
|
||||
// Validate MCP directory - check for mcp.json, package.json, or any config file
|
||||
validFiles := []string{"mcp.json", "package.json", "server.py", "server.js", "main.py", "index.js"}
|
||||
found := false
|
||||
for _, vf := range validFiles {
|
||||
if _, err := os.Stat(filepath.Join(absPath, vf)); err == nil {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("MCP directory should contain one of: %s", strings.Join(validFiles, ", "))
|
||||
}
|
||||
|
||||
// Parse MCP name (will set Kind="mcp")
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Create MCP layer
|
||||
displayName := n.DisplayShortest()
|
||||
status := fmt.Sprintf("Creating MCP layer for %s", displayName)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
layer, err := server.CreateMCPLayer(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating MCP layer: %w", err)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
|
||||
// Create MCP manifest
|
||||
manifest, configLayer, err := createMCPManifest(absPath, layer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating MCP manifest: %w", err)
|
||||
}
|
||||
|
||||
// Write manifest locally
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "MCP %s created locally\n", displayName)
|
||||
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
|
||||
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
|
||||
|
||||
// Push to registry
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.PushRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Push(context.Background(), req, fn); err != nil {
|
||||
// If push fails, still show success for local creation
|
||||
fmt.Fprintf(os.Stderr, "\nNote: Local MCP created but push failed: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama mcp push %s\n", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPPullHandler handles the mcp pull command.
|
||||
func MCPPullHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama mcp pull NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
req := &api.PullRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Pull(context.Background(), req, fn); err != nil {
|
||||
return fmt.Errorf("pulling MCP: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPListHandler handles the mcp list command.
|
||||
func MCPListHandler(cmd *cobra.Command, args []string) error {
|
||||
mcps, err := listLocalMCPs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing MCPs: %w", err)
|
||||
}
|
||||
|
||||
if len(mcps) == 0 {
|
||||
fmt.Println("No MCPs installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
|
||||
|
||||
for _, mcp := range mcps {
|
||||
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
|
||||
mcp.Namespace,
|
||||
mcp.Name,
|
||||
mcp.Tag,
|
||||
format.HumanBytes(mcp.Size),
|
||||
format.HumanTime(mcp.ModifiedAt, "Never"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// MCPRemoveHandler handles the mcp rm command.
|
||||
func MCPRemoveHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama mcp rm NAME[:TAG] [NAME[:TAG]...]")
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
fmt.Fprintf(os.Stderr, "Invalid MCP name: %s\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "MCP not found: %s\n", displayName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(manifestPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up empty parent directories
|
||||
dir := filepath.Dir(manifestPath)
|
||||
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
if len(entries) == 0 {
|
||||
os.Remove(dir)
|
||||
dir = filepath.Dir(dir)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPShowHandler handles the mcp show command.
|
||||
func MCPShowHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama mcp show NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseMCPName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid MCP name: %s", name)
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetMCPManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("MCP not found: %s", displayName)
|
||||
}
|
||||
return fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("MCP: %s\n\n", displayName)
|
||||
|
||||
fmt.Println("Layers:")
|
||||
for _, layer := range manifest.Layers {
|
||||
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
|
||||
}
|
||||
|
||||
// Try to read and display mcp.json or package.json content
|
||||
if len(manifest.Layers) > 0 {
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == server.MediaTypeMCP {
|
||||
mcpPath, err := server.GetMCPsPath(layer.Digest)
|
||||
if err == nil {
|
||||
// Try mcp.json first
|
||||
mcpJSONPath := filepath.Join(mcpPath, "mcp.json")
|
||||
if content, err := os.ReadFile(mcpJSONPath); err == nil {
|
||||
fmt.Println("\nConfig (mcp.json):")
|
||||
fmt.Println(string(content))
|
||||
} else {
|
||||
// Try package.json
|
||||
pkgJSONPath := filepath.Join(mcpPath, "package.json")
|
||||
if content, err := os.ReadFile(pkgJSONPath); err == nil {
|
||||
fmt.Println("\nConfig (package.json):")
|
||||
fmt.Println(string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// List files in the MCP
|
||||
fmt.Println("\nFiles:")
|
||||
filepath.Walk(mcpPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
relPath, _ := filepath.Rel(mcpPath, path)
|
||||
if relPath == "." {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
fmt.Printf(" %s/\n", relPath)
|
||||
} else {
|
||||
fmt.Printf(" %s (%s)\n", relPath, format.HumanBytes(info.Size()))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPInfo represents information about an installed MCP.
|
||||
type MCPInfo struct {
|
||||
Namespace string
|
||||
Name string
|
||||
Tag string
|
||||
Size int64
|
||||
ModifiedAt time.Time
|
||||
}
|
||||
|
||||
// listLocalMCPs returns a list of locally installed MCPs.
|
||||
// MCPs are stored with 5-part paths: host/namespace/kind/model/tag
|
||||
// where kind is "mcp".
|
||||
func listLocalMCPs() ([]MCPInfo, error) {
|
||||
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
|
||||
|
||||
var mcps []MCPInfo
|
||||
|
||||
// Walk through all registries
|
||||
registries, err := os.ReadDir(manifestsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcps, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, registry := range registries {
|
||||
if !registry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk namespaces
|
||||
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
if !namespace.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk kinds looking for "mcp"
|
||||
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kind := range kinds {
|
||||
if !kind.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process mcp kind
|
||||
if kind.Name() != server.MCPNamespace {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk MCP names (model names)
|
||||
mcpNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, mcpName := range mcpNames {
|
||||
if !mcpName.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk tags
|
||||
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name(), tag.Name())
|
||||
fi, err := os.Stat(manifestPath)
|
||||
if err != nil || fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read manifest to get size
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Layers {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
|
||||
// Build display name using model.Name
|
||||
n := model.Name{
|
||||
Host: registry.Name(),
|
||||
Namespace: namespace.Name(),
|
||||
Kind: kind.Name(),
|
||||
Model: mcpName.Name(),
|
||||
Tag: tag.Name(),
|
||||
}
|
||||
|
||||
mcps = append(mcps, MCPInfo{
|
||||
Namespace: n.Namespace + "/" + n.Kind,
|
||||
Name: n.Model,
|
||||
Tag: n.Tag,
|
||||
Size: totalSize,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mcps, nil
|
||||
}
|
||||
|
||||
// createMCPManifest creates a manifest for a standalone MCP.
|
||||
func createMCPManifest(mcpDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
|
||||
// Try to read mcp.json or package.json to extract metadata
|
||||
name, description := extractMCPMetadata(mcpDir)
|
||||
if name == "" {
|
||||
// Use directory name as fallback
|
||||
name = filepath.Base(mcpDir)
|
||||
}
|
||||
|
||||
// Create config
|
||||
config := map[string]any{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"architecture": "amd64",
|
||||
"os": "linux",
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer
|
||||
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating config layer: %w", err)
|
||||
}
|
||||
|
||||
manifest := &server.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: configLayer,
|
||||
Layers: []server.Layer{layer},
|
||||
}
|
||||
|
||||
return manifest, &configLayer, nil
|
||||
}
|
||||
|
||||
// extractMCPMetadata extracts name and description from mcp.json or package.json.
|
||||
func extractMCPMetadata(mcpDir string) (name, description string) {
|
||||
// Try mcp.json first
|
||||
mcpJSONPath := filepath.Join(mcpDir, "mcp.json")
|
||||
if data, err := os.ReadFile(mcpJSONPath); err == nil {
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err == nil {
|
||||
if n, ok := config["name"].(string); ok {
|
||||
name = n
|
||||
}
|
||||
if d, ok := config["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
return name, description
|
||||
}
|
||||
}
|
||||
|
||||
// Try package.json
|
||||
pkgJSONPath := filepath.Join(mcpDir, "package.json")
|
||||
if data, err := os.ReadFile(pkgJSONPath); err == nil {
|
||||
var config map[string]any
|
||||
if err := json.Unmarshal(data, &config); err == nil {
|
||||
if n, ok := config["name"].(string); ok {
|
||||
name = n
|
||||
}
|
||||
if d, ok := config["description"].(string); ok {
|
||||
description = d
|
||||
}
|
||||
return name, description
|
||||
}
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// NewMCPCommand creates the mcp parent command with subcommands.
|
||||
func NewMCPCommand() *cobra.Command {
|
||||
mcpCmd := &cobra.Command{
|
||||
Use: "mcp",
|
||||
Short: "Manage MCP servers",
|
||||
Long: "Commands for managing MCP (Model Context Protocol) servers (add, push, pull, list, rm, show)",
|
||||
}
|
||||
|
||||
// Global config commands
|
||||
addCmd := &cobra.Command{
|
||||
Use: "add NAME COMMAND [ARGS...]",
|
||||
Short: "Add an MCP server to global config",
|
||||
Long: `Add an MCP server to the global config (~/.ollama/mcp.json).
|
||||
Global MCP servers are available to all agents.
|
||||
|
||||
Examples:
|
||||
ollama mcp add web-search uv run ./mcp-server.py
|
||||
ollama mcp add calculator python3 /path/to/calc.py`,
|
||||
Args: cobra.MinimumNArgs(2),
|
||||
RunE: MCPAddHandler,
|
||||
DisableFlagParsing: true, // Allow args with dashes
|
||||
}
|
||||
|
||||
removeGlobalCmd := &cobra.Command{
|
||||
Use: "remove-global NAME [NAME...]",
|
||||
Aliases: []string{"rm-global"},
|
||||
Short: "Remove an MCP server from global config",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPRemoveGlobalHandler,
|
||||
}
|
||||
|
||||
listGlobalCmd := &cobra.Command{
|
||||
Use: "list-global",
|
||||
Short: "List global MCP servers",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: MCPListGlobalHandler,
|
||||
}
|
||||
|
||||
// Registry commands
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push NAME[:TAG] PATH",
|
||||
Short: "Push an MCP server to a registry",
|
||||
Long: "Package a local MCP server directory and push it to a registry",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: MCPPushHandler,
|
||||
}
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull NAME[:TAG]",
|
||||
Short: "Pull an MCP server from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: MCPPullHandler,
|
||||
}
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List installed MCP servers (from registry)",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: MCPListHandler,
|
||||
}
|
||||
|
||||
rmCmd := &cobra.Command{
|
||||
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
|
||||
Aliases: []string{"remove", "delete"},
|
||||
Short: "Remove an MCP server (from registry)",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPRemoveHandler,
|
||||
}
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show NAME[:TAG]",
|
||||
Short: "Show MCP server details",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: MCPShowHandler,
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable NAME [NAME...]",
|
||||
Short: "Disable an MCP server (keep in config)",
|
||||
Long: `Disable an MCP server without removing it from config.
|
||||
Disabled servers will not be started when running agents.
|
||||
Use 'ollama mcp enable' to re-enable.`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPDisableHandler,
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable NAME [NAME...]",
|
||||
Short: "Enable a disabled MCP server",
|
||||
Long: `Re-enable a previously disabled MCP server.`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: MCPEnableHandler,
|
||||
}
|
||||
|
||||
mcpCmd.AddCommand(addCmd, removeGlobalCmd, listGlobalCmd, disableCmd, enableCmd, pushCmd, pullCmd, listCmd, rmCmd, showCmd)
|
||||
|
||||
return mcpCmd
|
||||
}
|
||||
570
cmd/skill_cmd.go
Normal file
570
cmd/skill_cmd.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// SkillPushHandler handles the skill push command.
|
||||
func SkillPushHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("usage: ollama skill push NAME[:TAG] PATH")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
path := args[1]
|
||||
|
||||
// Expand path
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving path: %w", err)
|
||||
}
|
||||
|
||||
// Validate skill directory
|
||||
skillMdPath := filepath.Join(absPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillMdPath); err != nil {
|
||||
return fmt.Errorf("skill directory must contain SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Parse skill name (will set Kind="skill")
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Create skill layer
|
||||
displayName := n.DisplayShortest()
|
||||
status := fmt.Sprintf("Creating skill layer for %s", displayName)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
layer, err := server.CreateSkillLayer(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill layer: %w", err)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
|
||||
// Create skill manifest
|
||||
manifest, configLayer, err := createSkillManifest(absPath, layer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating skill manifest: %w", err)
|
||||
}
|
||||
|
||||
// Write manifest locally
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Skill %s created locally\n", displayName)
|
||||
fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size))
|
||||
fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size))
|
||||
|
||||
// Push to registry
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
// For now, we'll use the existing push mechanism
|
||||
fmt.Fprintf(os.Stderr, "\nPushing to registry...\n")
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &api.PushRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Push(context.Background(), req, fn); err != nil {
|
||||
// If push fails, still show success for local creation
|
||||
fmt.Fprintf(os.Stderr, "\nNote: Local skill created but push failed: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama skill push %s\n", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillPullHandler handles the skill pull command.
|
||||
func SkillPullHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill pull NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
|
||||
insecure, _ := cmd.Flags().GetBool("insecure")
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar := progress.NewBar(resp.Status, resp.Total, resp.Completed)
|
||||
p.Add(resp.Digest, bar)
|
||||
} else if resp.Status != "" {
|
||||
spinner := progress.NewSpinner(resp.Status)
|
||||
p.Add(resp.Status, spinner)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
req := &api.PullRequest{
|
||||
Model: displayName,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
if err := client.Pull(context.Background(), req, fn); err != nil {
|
||||
return fmt.Errorf("pulling skill: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillListHandler handles the skill list command.
|
||||
func SkillListHandler(cmd *cobra.Command, args []string) error {
|
||||
skills, err := listLocalSkills()
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing skills: %w", err)
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
fmt.Println("No skills installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED")
|
||||
|
||||
for _, skill := range skills {
|
||||
fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n",
|
||||
skill.Namespace,
|
||||
skill.Name,
|
||||
skill.Tag,
|
||||
format.HumanBytes(skill.Size),
|
||||
format.HumanTime(skill.ModifiedAt, "Never"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// SkillRemoveHandler handles the skill rm command.
|
||||
func SkillRemoveHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("usage: ollama skill rm NAME[:TAG] [NAME[:TAG]...]")
|
||||
}
|
||||
|
||||
for _, name := range args {
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
fmt.Fprintf(os.Stderr, "Invalid skill name: %s\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Skill not found: %s\n", displayName)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(manifestPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up empty parent directories
|
||||
dir := filepath.Dir(manifestPath)
|
||||
for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") {
|
||||
entries, _ := os.ReadDir(dir)
|
||||
if len(entries) == 0 {
|
||||
os.Remove(dir)
|
||||
dir = filepath.Dir(dir)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillShowHandler handles the skill show command.
|
||||
func SkillShowHandler(cmd *cobra.Command, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("usage: ollama skill show NAME[:TAG]")
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
n := server.ParseSkillName(name)
|
||||
if n.Model == "" {
|
||||
return fmt.Errorf("invalid skill name: %s", name)
|
||||
}
|
||||
|
||||
displayName := n.DisplayShortest()
|
||||
manifestPath, err := server.GetSkillManifestPath(n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting manifest path: %w", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("skill not found: %s", displayName)
|
||||
}
|
||||
return fmt.Errorf("reading manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return fmt.Errorf("parsing manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Skill: %s\n\n", displayName)
|
||||
|
||||
fmt.Println("Layers:")
|
||||
for _, layer := range manifest.Layers {
|
||||
fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size))
|
||||
}
|
||||
|
||||
// Try to read and display SKILL.md content
|
||||
if len(manifest.Layers) > 0 {
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == server.MediaTypeSkill {
|
||||
skillPath, err := server.GetSkillsPath(layer.Digest)
|
||||
if err == nil {
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
if content, err := os.ReadFile(skillMdPath); err == nil {
|
||||
fmt.Println("\nContent:")
|
||||
fmt.Println(string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SkillInfo represents information about an installed skill.
|
||||
type SkillInfo struct {
|
||||
Namespace string
|
||||
Name string
|
||||
Tag string
|
||||
Size int64
|
||||
ModifiedAt time.Time
|
||||
}
|
||||
|
||||
// listLocalSkills returns a list of locally installed skills.
|
||||
// Skills are stored with 5-part paths: host/namespace/kind/model/tag
|
||||
// where kind is "skill".
|
||||
func listLocalSkills() ([]SkillInfo, error) {
|
||||
manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests")
|
||||
|
||||
var skills []SkillInfo
|
||||
|
||||
// Walk through all registries
|
||||
registries, err := os.ReadDir(manifestsPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return skills, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, registry := range registries {
|
||||
if !registry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk namespaces
|
||||
namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
if !namespace.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk kinds looking for "skill"
|
||||
kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kind := range kinds {
|
||||
if !kind.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only process skill kind
|
||||
if kind.Name() != server.SkillNamespace {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk skill names (model names)
|
||||
skillNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, skillName := range skillNames {
|
||||
if !skillName.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk tags
|
||||
tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), skillName.Name(), tag.Name())
|
||||
fi, err := os.Stat(manifestPath)
|
||||
if err != nil || fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read manifest to get size
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest server.Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Layers {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
|
||||
// Build display name using model.Name
|
||||
n := model.Name{
|
||||
Host: registry.Name(),
|
||||
Namespace: namespace.Name(),
|
||||
Kind: kind.Name(),
|
||||
Model: skillName.Name(),
|
||||
Tag: tag.Name(),
|
||||
}
|
||||
|
||||
skills = append(skills, SkillInfo{
|
||||
Namespace: n.Namespace + "/" + n.Kind,
|
||||
Name: n.Model,
|
||||
Tag: n.Tag,
|
||||
Size: totalSize,
|
||||
ModifiedAt: fi.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
// createSkillManifest creates a manifest for a standalone skill.
|
||||
func createSkillManifest(skillDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) {
|
||||
// Read SKILL.md to extract metadata
|
||||
skillMdPath := filepath.Join(skillDir, "SKILL.md")
|
||||
content, err := os.ReadFile(skillMdPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading SKILL.md: %w", err)
|
||||
}
|
||||
|
||||
// Extract name and description from frontmatter
|
||||
name, description := extractSkillMetadata(string(content))
|
||||
if name == "" {
|
||||
return nil, nil, errors.New("skill name not found in SKILL.md frontmatter")
|
||||
}
|
||||
|
||||
// Create config
|
||||
config := map[string]any{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"architecture": "amd64",
|
||||
"os": "linux",
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshaling config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer
|
||||
configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating config layer: %w", err)
|
||||
}
|
||||
|
||||
manifest := &server.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: configLayer,
|
||||
Layers: []server.Layer{layer},
|
||||
}
|
||||
|
||||
return manifest, &configLayer, nil
|
||||
}
|
||||
|
||||
// extractSkillMetadata extracts name and description from SKILL.md frontmatter.
|
||||
func extractSkillMetadata(content string) (name, description string) {
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
inFrontmatter := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
if trimmed == "---" {
|
||||
if !inFrontmatter {
|
||||
inFrontmatter = true
|
||||
continue
|
||||
} else {
|
||||
break // End of frontmatter
|
||||
}
|
||||
}
|
||||
|
||||
if inFrontmatter {
|
||||
if strings.HasPrefix(trimmed, "name:") {
|
||||
name = strings.TrimSpace(strings.TrimPrefix(trimmed, "name:"))
|
||||
} else if strings.HasPrefix(trimmed, "description:") {
|
||||
description = strings.TrimSpace(strings.TrimPrefix(trimmed, "description:"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return name, description
|
||||
}
|
||||
|
||||
// NewSkillCommand creates the skill parent command with subcommands.
|
||||
func NewSkillCommand() *cobra.Command {
|
||||
skillCmd := &cobra.Command{
|
||||
Use: "skill",
|
||||
Short: "Manage skills",
|
||||
Long: "Commands for managing agent skills (push, pull, list, rm, show)",
|
||||
}
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push NAME[:TAG] PATH",
|
||||
Short: "Push a skill to a registry",
|
||||
Long: "Package a local skill directory and push it to a registry",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPushHandler,
|
||||
}
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull NAME[:TAG]",
|
||||
Short: "Pull a skill from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SkillPullHandler,
|
||||
}
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List installed skills",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: SkillListHandler,
|
||||
}
|
||||
|
||||
rmCmd := &cobra.Command{
|
||||
Use: "rm NAME[:TAG] [NAME[:TAG]...]",
|
||||
Aliases: []string{"remove", "delete"},
|
||||
Short: "Remove a skill",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: SkillRemoveHandler,
|
||||
}
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show NAME[:TAG]",
|
||||
Short: "Show skill details",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: SkillShowHandler,
|
||||
}
|
||||
|
||||
skillCmd.AddCommand(pushCmd, pullCmd, listCmd, rmCmd, showCmd)
|
||||
|
||||
return skillCmd
|
||||
}
|
||||
589
cmd/skills.go
Normal file
589
cmd/skills.go
Normal file
@@ -0,0 +1,589 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/server"
|
||||
)
|
||||
|
||||
const (
|
||||
skillFileName = "SKILL.md"
|
||||
maxSkillDescription = 1024
|
||||
maxSkillNameLength = 64
|
||||
)
|
||||
|
||||
var skillNamePattern = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
|
||||
type skillMetadata struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type skillDefinition struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string // Full SKILL.md content (without frontmatter)
|
||||
Dir string
|
||||
SkillPath string
|
||||
}
|
||||
|
||||
type skillCatalog struct {
|
||||
Skills []skillDefinition
|
||||
byName map[string]skillDefinition
|
||||
}
|
||||
|
||||
func loadSkills(paths []string) (*skillCatalog, error) {
|
||||
if len(paths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
for _, root := range paths {
|
||||
info, err := os.Stat(root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skills directory %q: %w", root, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("skills path %q is not a directory", root)
|
||||
}
|
||||
|
||||
err = filepath.WalkDir(root, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
// loadSkillsFromRefs loads skills from a list of SkillRef objects.
|
||||
// Skills can be referenced by:
|
||||
// - Digest: loaded from the extracted skill cache (for bundled/pulled skills)
|
||||
// - Name (local path): loaded from the filesystem (for development)
|
||||
func loadSkillsFromRefs(refs []api.SkillRef) (*skillCatalog, error) {
|
||||
if len(refs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var skills []skillDefinition
|
||||
byName := make(map[string]skillDefinition)
|
||||
|
||||
for _, ref := range refs {
|
||||
var skillDir string
|
||||
|
||||
if ref.Digest != "" {
|
||||
// Load from extracted skill cache
|
||||
path, err := server.GetSkillsPath(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting skill path for %s: %w", ref.Digest, err)
|
||||
}
|
||||
|
||||
// Check if skill is already extracted
|
||||
skillMdPath := filepath.Join(path, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); os.IsNotExist(err) {
|
||||
// Try to extract the skill blob
|
||||
path, err = server.ExtractSkillBlob(ref.Digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extracting skill %s: %w", ref.Digest, err)
|
||||
}
|
||||
}
|
||||
|
||||
skillDir = path
|
||||
} else if ref.Name != "" {
|
||||
// Check if this is a local path or a registry reference
|
||||
if !server.IsLocalSkillPath(ref.Name) {
|
||||
// Registry reference without a digest - skill needs to be pulled first
|
||||
// This happens when an agent references a skill that hasn't been bundled
|
||||
return nil, fmt.Errorf("skill %q is a registry reference but has no digest - the agent may need to be recreated or the skill pulled separately", ref.Name)
|
||||
}
|
||||
|
||||
// Local path - resolve it
|
||||
skillPath := ref.Name
|
||||
if strings.HasPrefix(skillPath, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expanding home directory: %w", err)
|
||||
}
|
||||
skillPath = filepath.Join(home, skillPath[1:])
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(skillPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
// Check if this is a directory containing skills or a single skill
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skill path %q: %w", ref.Name, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Check if it's a skill directory (has SKILL.md) or a parent of skill directories
|
||||
skillMdPath := filepath.Join(absPath, skillFileName)
|
||||
if _, err := os.Stat(skillMdPath); err == nil {
|
||||
// Direct skill directory
|
||||
skillDir = absPath
|
||||
} else {
|
||||
// Parent directory - walk to find skill subdirectories
|
||||
err := filepath.WalkDir(absPath, func(path string, entry fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if entry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if entry.Name() != skillFileName {
|
||||
return nil
|
||||
}
|
||||
|
||||
skillSubDir := filepath.Dir(path)
|
||||
skill, err := parseSkillFile(path, skillSubDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: skipping skill at %s: %v\n", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q at %s\n", skill.Name, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("skill path %q is not a directory", ref.Name)
|
||||
}
|
||||
} else {
|
||||
// Both empty - skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the skill from skillDir if set
|
||||
if skillDir != "" {
|
||||
skillMdPath := filepath.Join(skillDir, skillFileName)
|
||||
skill, err := parseSkillFile(skillMdPath, skillDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing skill at %s: %w", skillDir, err)
|
||||
}
|
||||
|
||||
if _, exists := byName[skill.Name]; exists {
|
||||
fmt.Fprintf(os.Stderr, "Warning: duplicate skill name %q\n", skill.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
byName[skill.Name] = skill
|
||||
skills = append(skills, skill)
|
||||
}
|
||||
}
|
||||
|
||||
if len(skills) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(skills, func(i, j int) bool {
|
||||
return skills[i].Name < skills[j].Name
|
||||
})
|
||||
|
||||
return &skillCatalog{Skills: skills, byName: byName}, nil
|
||||
}
|
||||
|
||||
func parseSkillFile(path, skillDir string) (skillDefinition, error) {
|
||||
rawContent, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
frontmatter, bodyContent, err := extractFrontmatterAndContent(string(rawContent))
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
var meta skillMetadata
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &meta); err != nil {
|
||||
return skillDefinition{}, fmt.Errorf("invalid frontmatter: %w", err)
|
||||
}
|
||||
|
||||
if err := validateSkillMetadata(meta, skillDir); err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
absDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return skillDefinition{}, err
|
||||
}
|
||||
|
||||
return skillDefinition{
|
||||
Name: meta.Name,
|
||||
Description: meta.Description,
|
||||
Content: bodyContent,
|
||||
Dir: absDir,
|
||||
SkillPath: absPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractFrontmatterAndContent(content string) (frontmatter string, body string, err error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
if !scanner.Scan() {
|
||||
return "", "", errors.New("empty SKILL.md")
|
||||
}
|
||||
if strings.TrimSpace(scanner.Text()) != "---" {
|
||||
return "", "", errors.New("missing YAML frontmatter")
|
||||
}
|
||||
|
||||
var fmLines []string
|
||||
foundEnd := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "---" {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
fmLines = append(fmLines, line)
|
||||
}
|
||||
if !foundEnd {
|
||||
return "", "", errors.New("frontmatter not terminated")
|
||||
}
|
||||
|
||||
// Collect remaining content as body
|
||||
var bodyLines []string
|
||||
for scanner.Scan() {
|
||||
bodyLines = append(bodyLines, scanner.Text())
|
||||
}
|
||||
|
||||
return strings.Join(fmLines, "\n"), strings.TrimSpace(strings.Join(bodyLines, "\n")), nil
|
||||
}
|
||||
|
||||
func validateSkillMetadata(meta skillMetadata, skillDir string) error {
|
||||
name := strings.TrimSpace(meta.Name)
|
||||
description := strings.TrimSpace(meta.Description)
|
||||
|
||||
switch {
|
||||
case name == "":
|
||||
return errors.New("missing skill name")
|
||||
case len(name) > maxSkillNameLength:
|
||||
return fmt.Errorf("skill name exceeds %d characters", maxSkillNameLength)
|
||||
case !skillNamePattern.MatchString(name):
|
||||
return fmt.Errorf("invalid skill name %q", name)
|
||||
}
|
||||
|
||||
if description == "" {
|
||||
return errors.New("missing skill description")
|
||||
}
|
||||
if len(description) > maxSkillDescription {
|
||||
return fmt.Errorf("skill description exceeds %d characters", maxSkillDescription)
|
||||
}
|
||||
|
||||
// Skip directory name check for digest-based paths (extracted from blobs)
|
||||
dirName := filepath.Base(skillDir)
|
||||
if !strings.HasPrefix(dirName, "sha256-") && dirName != name {
|
||||
return fmt.Errorf("skill directory %q does not match name %q", dirName, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *skillCatalog) SystemPrompt() string {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("# Skills\n\n")
|
||||
b.WriteString("You have the following skills loaded. Each skill provides instructions and may include executable scripts.\n\n")
|
||||
b.WriteString("## Available Tools\n\n")
|
||||
b.WriteString("- `run_skill_script`: Execute a script bundled with a skill. Use this when the skill instructions tell you to run a script.\n")
|
||||
b.WriteString("- `read_skill_file`: Read additional files from a skill directory.\n\n")
|
||||
|
||||
for _, skill := range c.Skills {
|
||||
fmt.Fprintf(&b, "## Skill: %s\n\n", skill.Name)
|
||||
fmt.Fprintf(&b, "%s\n\n", skill.Content)
|
||||
b.WriteString("---\n\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (c *skillCatalog) Tools() api.Tools {
|
||||
if c == nil || len(c.Skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "run_skill_script",
|
||||
Description: "Execute a script or command within a skill's directory. Use this to run Python scripts, shell scripts, or other executables bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "command"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"skill": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the script",
|
||||
},
|
||||
"command": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The command to execute (e.g., 'python scripts/calculate.py 25 4' or './scripts/run.sh')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "read_skill_file",
|
||||
Description: "Read a file from a skill's directory. Use this to read additional documentation, reference files, or data files bundled with a skill.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"skill", "path"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"skill": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The name of the skill containing the file",
|
||||
},
|
||||
"path": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The relative path to the file within the skill directory",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *skillCatalog) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
||||
switch call.Function.Name {
|
||||
case "read_skill_file":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
relPath, err := requireStringArg(call.Function.Arguments, "path")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
content, err := readSkillFile(skill.Dir, relPath)
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
return toolMessage(call, content), true, nil
|
||||
|
||||
case "run_skill_script":
|
||||
skillName, err := requireStringArg(call.Function.Arguments, "skill")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
command, err := requireStringArg(call.Function.Arguments, "command")
|
||||
if err != nil {
|
||||
return toolMessage(call, err.Error()), true, nil
|
||||
}
|
||||
skill, ok := c.byName[skillName]
|
||||
if !ok {
|
||||
return toolMessage(call, fmt.Sprintf("unknown skill %q", skillName)), true, nil
|
||||
}
|
||||
output, err := runSkillScript(skill.Dir, command)
|
||||
if err != nil {
|
||||
return toolMessage(call, fmt.Sprintf("error: %v\noutput: %s", err, output)), true, nil
|
||||
}
|
||||
return toolMessage(call, output), true, nil
|
||||
|
||||
default:
|
||||
return api.Message{}, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// runSkillScript executes a shell command within a skill's directory.
|
||||
//
|
||||
// SECURITY LIMITATIONS (TODO):
|
||||
// - No sandboxing: commands run with full user permissions
|
||||
// - No path validation: model can run any command, not just scripts in skill dir
|
||||
// - Shell injection risk: sh -c is used, malicious input could be crafted
|
||||
// - No executable allowlist: any program can be called (curl, rm, etc.)
|
||||
// - No environment isolation: scripts inherit full environment variables
|
||||
//
|
||||
// POTENTIAL IMPROVEMENTS:
|
||||
// - Restrict commands to only reference files within skill directory
|
||||
// - Allowlist specific executables (python3, node, bash)
|
||||
// - Use sandboxing (Docker, nsjail, seccomp)
|
||||
// - Require explicit script registration in SKILL.md frontmatter
|
||||
// - Add per-skill configurable timeouts
|
||||
func runSkillScript(skillDir, command string) (string, error) {
|
||||
// Validate the skill directory exists
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := os.Stat(absSkillDir); err != nil {
|
||||
return "", fmt.Errorf("skill directory not found: %w", err)
|
||||
}
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", command)
|
||||
cmd.Dir = absSkillDir
|
||||
|
||||
// Inject the current working directory (where ollama run was called from)
|
||||
// as an environment variable so scripts can reference files in that directory
|
||||
workingDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
cmd.Env = append(os.Environ(), "OLLAMA_WORKING_DIR="+workingDir)
|
||||
|
||||
// Capture both stdout and stderr
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
// Combine output
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
if output != "" {
|
||||
output += "\n"
|
||||
}
|
||||
output += stderr.String()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return output, fmt.Errorf("command timed out after 30 seconds")
|
||||
}
|
||||
return output, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func readSkillFile(skillDir, relPath string) (string, error) {
|
||||
relPath = filepath.Clean(strings.TrimSpace(relPath))
|
||||
if relPath == "" {
|
||||
return "", errors.New("path is required")
|
||||
}
|
||||
if filepath.IsAbs(relPath) {
|
||||
return "", errors.New("path must be relative to the skill directory")
|
||||
}
|
||||
|
||||
target := filepath.Join(skillDir, relPath)
|
||||
absTarget, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
absSkillDir, err := filepath.Abs(skillDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel, err := filepath.Rel(absSkillDir, absTarget)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") {
|
||||
return "", errors.New("path escapes the skill directory")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absTarget)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read %q: %w", relPath, err)
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
func requireStringArg(args api.ToolCallFunctionArguments, name string) (string, error) {
|
||||
value, ok := args[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required argument %q", name)
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("argument %q must be a string", name)
|
||||
}
|
||||
if strings.TrimSpace(str) == "" {
|
||||
return "", fmt.Errorf("argument %q cannot be empty", name)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func toolMessage(call api.ToolCall, content string) api.Message {
|
||||
msg := api.Message{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolName: call.Function.Name,
|
||||
}
|
||||
if call.ID != "" {
|
||||
msg.ToolCallID = call.ID
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "Ministral3ForCausalLM":
|
||||
conv = &mistral3CausalModel{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
@@ -200,14 +202,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "Olmo3ForCausalLM":
|
||||
conv = &olmoModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "NomicBertModel", "NomicBertMoEModel":
|
||||
conv = &nomicbertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
case "DeepseekOCRForCausalLM":
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
173
convert/convert_deepseek2.go
Normal file
173
convert/convert_deepseek2.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type deepseek2Model struct {
|
||||
ModelParameters // architectures, vocab_size
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
ScoringFunc string `json:"scoring_func"`
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
Type string `json:"type"`
|
||||
MScaleAllDim float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_scaling"`
|
||||
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
kv["deepseek2.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["deepseek2.attention.head_count"] = numHeads
|
||||
kv["deepseek2.attention.head_count_kv"] = numKVHeads
|
||||
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["deepseek2.embedding_length"] = p.HiddenSize
|
||||
kv["deepseek2.expert_count"] = p.ExpertCount
|
||||
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
var scoringFunc uint32
|
||||
switch p.ScoringFunc {
|
||||
case "softmax":
|
||||
// not currently supported in the model, but needed for Deepseek-OCR
|
||||
scoringFunc = 1
|
||||
case "sigmoid":
|
||||
scoringFunc = 2
|
||||
}
|
||||
kv["deepseek2.expert_gating_func"] = scoringFunc
|
||||
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
|
||||
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "deepseek-v3"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"language_model.", "",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -26,16 +27,26 @@ type gemma3Model struct {
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
RopeScaling *struct {
|
||||
Type string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
} `json:"rope_scaling"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||
|
||||
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||
// key (an int) or as the layer_types key (a list of strings).
|
||||
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for i := range numBlocks {
|
||||
var isLocal bool
|
||||
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||
}
|
||||
if !yield(isLocal) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if p.FinalLogitSoftcap > 0 {
|
||||
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||
}
|
||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||
}
|
||||
|
||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||
default:
|
||||
|
||||
@@ -29,6 +29,17 @@ type mistral3Model struct {
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
@@ -41,6 +52,9 @@ type mistral3Model struct {
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
@@ -61,8 +75,25 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
|
||||
|
||||
if p.TextModel.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
|
||||
}
|
||||
if p.TextModel.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
|
||||
}
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
}
|
||||
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
@@ -74,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
|
||||
181
convert/convert_mistral_causal.go
Normal file
181
convert/convert_mistral_causal.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3CausalModel struct {
|
||||
ModelParameters
|
||||
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||
|
||||
if p.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
|
||||
}
|
||||
|
||||
if p.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
|
||||
}
|
||||
|
||||
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
if p.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.norm", "output_norm",
|
||||
"model.", "",
|
||||
"layers", "blk",
|
||||
"transformer.layers", "blk",
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"embed_tokens", "token_embd",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"multi_modal_projector", "mm",
|
||||
"ffn_norm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||
heads = p.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
213
convert/convert_nomicbert.go
Normal file
213
convert/convert_nomicbert.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type nomicbertModel struct {
|
||||
ModelParameters
|
||||
NLayers uint32 `json:"n_layers"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
RopeFreqBase float32 `json:"rope_theta"`
|
||||
normalizeEmbeddings bool
|
||||
PoolingType uint32
|
||||
|
||||
// MoE parameters (only present in v2 models)
|
||||
NumExperts uint32 `json:"num_local_experts"`
|
||||
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
|
||||
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
|
||||
}
|
||||
|
||||
var (
|
||||
_ ModelConverter = (*nomicbertModel)(nil)
|
||||
_ moreParser = (*nomicbertModel)(nil)
|
||||
)
|
||||
|
||||
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "modules.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modules []struct {
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &modules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
if pooling != "" {
|
||||
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pc struct {
|
||||
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
||||
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &pc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pc.PoolingModeMeanTokens {
|
||||
p.PoolingType = 1
|
||||
} else if pc.PoolingModeCLSToken {
|
||||
p.PoolingType = 2
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
arch := "nomic-bert"
|
||||
if p.MoEEveryNLayers > 0 {
|
||||
arch += "-moe"
|
||||
}
|
||||
|
||||
kv["general.architecture"] = arch
|
||||
kv["attention.causal"] = false
|
||||
kv["pooling_type"] = p.PoolingType
|
||||
kv["normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
|
||||
|
||||
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
|
||||
kv["context_length"] = contextLength
|
||||
}
|
||||
|
||||
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
|
||||
kv["embedding_length"] = p.HiddenSize
|
||||
}
|
||||
|
||||
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
|
||||
kv["feed_forward_length"] = p.IntermediateSize
|
||||
}
|
||||
|
||||
if headCount := p.NumAttentionHeads; headCount > 0 {
|
||||
kv["attention.head_count"] = p.NumAttentionHeads
|
||||
}
|
||||
|
||||
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
|
||||
kv["attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
}
|
||||
|
||||
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
|
||||
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||
}
|
||||
|
||||
if p.RopeFreqBase > 0 {
|
||||
kv["rope.freq_base"] = p.RopeFreqBase
|
||||
}
|
||||
|
||||
// MoE specific parameters (only if MoE is enabled)
|
||||
if p.NumExperts > 0 {
|
||||
kv["expert_count"] = p.NumExperts
|
||||
}
|
||||
|
||||
if p.NumExpertsUsed > 0 {
|
||||
kv["expert_used_count"] = p.NumExpertsUsed
|
||||
}
|
||||
|
||||
if p.MoEEveryNLayers > 0 {
|
||||
kv["moe_every_n_layers"] = p.MoEEveryNLayers
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.model"] = "bert"
|
||||
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
||||
|
||||
// convert to phantom space tokens
|
||||
for i, e := range t.Tokens {
|
||||
switch {
|
||||
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
|
||||
// noop - keep special tokens as-is
|
||||
case strings.HasPrefix(e, "##"):
|
||||
t.Tokens[i] = e[2:]
|
||||
default:
|
||||
t.Tokens[i] = "\u2581" + e
|
||||
}
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.tokens"] = t.Tokens
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
"pooler.dense.weight",
|
||||
"pooler.dense.bias",
|
||||
}, t.Name()) {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (nomicbertModel) Replacements() []string {
|
||||
return []string{
|
||||
"encoder.layer", "blk",
|
||||
"encoder.layers", "blk",
|
||||
"embeddings.word_embeddings", "token_embd",
|
||||
"embeddings.token_type_embeddings", "token_types",
|
||||
"embeddings.LayerNorm", "token_embd_norm",
|
||||
|
||||
"attention.self.qkv", "attn_qkv",
|
||||
|
||||
"attention.output.dense", "attn_output",
|
||||
"attention.output.LayerNorm", "attn_output_norm",
|
||||
|
||||
"mlp.up", "ffn_up",
|
||||
"mlp.down", "ffn_down",
|
||||
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.up", "ffn_up_exps",
|
||||
"mlp.experts.down", "ffn_down_exps",
|
||||
|
||||
"intermediate.dense", "ffn_up",
|
||||
"output.dense", "ffn_down",
|
||||
"output.LayerNorm", "layer_output_norm",
|
||||
}
|
||||
}
|
||||
117
convert/convert_olmo.go
Normal file
117
convert/convert_olmo.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ropeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||
AttentionFactor float32 `json:"attention_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
RopeType string `json:"rope_type"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
}
|
||||
|
||||
type olmoModel struct {
|
||||
ModelParameters
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||
}
|
||||
|
||||
if p.RopeScaling != nil {
|
||||
if p.RopeScaling.Factor > 0 {
|
||||
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
}
|
||||
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||
}
|
||||
if p.RopeScaling.AttentionFactor > 0 {
|
||||
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||
}
|
||||
if p.RopeScaling.RopeType != "" {
|
||||
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||
}
|
||||
}
|
||||
|
||||
if p.RMSNormEPS > 0 {
|
||||
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
}
|
||||
|
||||
if p.SlidingWindow > 0 {
|
||||
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||
for i, layerType := range p.LayerTypes {
|
||||
slidingPattern[i] = (layerType == "sliding_attention")
|
||||
}
|
||||
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *olmoModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"model.norm", "output_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
detectIncompatibleLibraries()
|
||||
|
||||
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||
overrideWarnings()
|
||||
@@ -98,6 +99,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
|
||||
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
|
||||
continue
|
||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||
continue
|
||||
@@ -143,7 +147,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
|
||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
|
||||
devices[i].AddInitValidation(extraEnvs)
|
||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
slog.Debug("filtering device which didn't fully initialize",
|
||||
@@ -329,7 +333,8 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
defer cancel()
|
||||
|
||||
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
||||
devFilter := ml.GetVisibleDevicesEnv(devices)
|
||||
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
|
||||
devFilter := ml.GetVisibleDevicesEnv(devices, false)
|
||||
|
||||
for dir := range libDirs {
|
||||
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
||||
@@ -484,3 +489,16 @@ func overrideWarnings() {
|
||||
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||
}
|
||||
}
|
||||
|
||||
func detectIncompatibleLibraries() {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
basePath, err := exec.LookPath("ggml-base.dll")
|
||||
if err != nil || basePath == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
||||
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
||||
}
|
||||
}
|
||||
|
||||
211
docs/ENTRYPOINT_FEATURE.md
Normal file
211
docs/ENTRYPOINT_FEATURE.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# ENTRYPOINT Feature for Ollama Agents
|
||||
|
||||
## Overview
|
||||
|
||||
The ENTRYPOINT command allows agents to specify an external program to run instead of the built-in Ollama chat loop. This makes Ollama a packaging/distribution mechanism for agents with custom runtimes.
|
||||
|
||||
## Status: Implemented ✓
|
||||
|
||||
## What Was Done
|
||||
|
||||
### 1. Types & API
|
||||
|
||||
**`types/model/config.go`**
|
||||
- Added `Entrypoint string` field to `ConfigV2` struct
|
||||
|
||||
**`api/types.go`**
|
||||
- Added `Entrypoint string` to `CreateRequest` (line ~576)
|
||||
- Added `Entrypoint string` to `ShowResponse` (line ~632)
|
||||
|
||||
### 2. Parser
|
||||
|
||||
**`parser/parser.go`**
|
||||
- Added "entrypoint" to `isValidCommand()` switch
|
||||
- Added case in `CreateRequest()` to set `req.Entrypoint = c.Args`
|
||||
- Updated `ParseFile()` to allow ENTRYPOINT without FROM (entrypoint-only agents)
|
||||
- Added entrypoint serialization in `Command.String()`
|
||||
|
||||
### 3. Server
|
||||
|
||||
**`server/create.go`**
|
||||
- Added `config.Entrypoint = r.Entrypoint` to store entrypoint in config
|
||||
- Made FROM optional when ENTRYPOINT is specified:
|
||||
```go
|
||||
} else if r.Entrypoint != "" {
|
||||
// Entrypoint-only agent: no base model needed
|
||||
slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint)
|
||||
}
|
||||
```
|
||||
|
||||
**`server/routes.go`**
|
||||
- Added `Entrypoint: m.Config.Entrypoint` to ShowResponse in `GetModelInfo()`
|
||||
|
||||
**`server/images.go`**
|
||||
- Added entrypoint serialization in `Model.String()`:
|
||||
```go
|
||||
if m.Config.Entrypoint != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "entrypoint",
|
||||
Args: m.Config.Entrypoint,
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### 4. CLI
|
||||
|
||||
**`cmd/cmd.go`**
|
||||
- Added `Entrypoint string` to `runOptions` struct
|
||||
- Updated agent detection to include Entrypoint check
|
||||
- Added entrypoint check before interactive mode:
|
||||
```go
|
||||
if opts.Entrypoint != "" {
|
||||
return runEntrypoint(cmd, opts)
|
||||
}
|
||||
```
|
||||
- Implemented `runEntrypoint()` function:
|
||||
- Parses entrypoint into command and args
|
||||
- Appends user prompt as additional argument if provided
|
||||
- Looks up command in PATH
|
||||
- Creates subprocess with stdin/stdout/stderr connected
|
||||
- Runs and waits for completion
|
||||
- Updated `showInfo()` to display entrypoint in Agent section
|
||||
- Updated `showInfo()` to hide Model section for entrypoint-only agents (no blank fields)
|
||||
- Added `$PROMPT` placeholder support in `runEntrypoint()`:
|
||||
- If entrypoint contains `$PROMPT`, it's replaced with the user's prompt
|
||||
- If no placeholder, prompt is appended as positional argument (backwards compatible)
|
||||
- If no prompt provided, `$PROMPT` is removed from the command
|
||||
|
||||
## Usage
|
||||
|
||||
### Agentfile
|
||||
```dockerfile
|
||||
# Minimal entrypoint agent (no model required)
|
||||
ENTRYPOINT ducky
|
||||
|
||||
# Or with full path
|
||||
ENTRYPOINT /usr/local/bin/ducky
|
||||
|
||||
# Or with arguments
|
||||
ENTRYPOINT ducky --verbose
|
||||
|
||||
# Use $PROMPT placeholder to control where prompt is inserted
|
||||
ENTRYPOINT ducky -p $PROMPT
|
||||
|
||||
# Without placeholder, prompt is appended as positional argument
|
||||
ENTRYPOINT echo "Hello" # becomes: echo "Hello" <prompt>
|
||||
|
||||
# Can still bundle skills/MCPs with entrypoint agents
|
||||
SKILL ./my-skill
|
||||
MCP calculator python3 ./calc.py
|
||||
ENTRYPOINT my-custom-runtime
|
||||
```
|
||||
|
||||
### CLI
|
||||
```bash
|
||||
# Create the agent
|
||||
ollama create ducky -f ducky.Agentfile
|
||||
|
||||
# Run it - starts the entrypoint (e.g., REPL)
|
||||
ollama run ducky
|
||||
|
||||
# With prompt (passed as argument to entrypoint)
|
||||
ollama run ducky "hello"
|
||||
|
||||
# Show agent info
|
||||
ollama show ducky
|
||||
# Agent
|
||||
# entrypoint ducky
|
||||
```
|
||||
|
||||
## Testing Done
|
||||
|
||||
1. **Basic entrypoint execution**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Hello from entrypoint"
|
||||
ollama run test-entry # Output: "Hello from entrypoint"
|
||||
```
|
||||
|
||||
2. **Prompt passing (positional)**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Args:"
|
||||
ollama run echo-test "hello world" # Output: "Args:" hello world
|
||||
```
|
||||
|
||||
3. **Prompt passing ($PROMPT placeholder)**: ✓
|
||||
```bash
|
||||
# Agentfile: ENTRYPOINT echo "Prompt was:" $PROMPT "end"
|
||||
ollama run echo-placeholder "hello world" # Output: "Prompt was:" hello world "end"
|
||||
ollama run echo-placeholder # Output: "Prompt was:" "end"
|
||||
```
|
||||
|
||||
4. **Show command**: ✓
|
||||
```bash
|
||||
ollama show ducky
|
||||
# Agent
|
||||
# entrypoint ducky
|
||||
# (Model section hidden for entrypoint-only agents)
|
||||
```
|
||||
|
||||
5. **List command**: ✓
|
||||
- Entrypoint-only agents show with small sizes (~200 bytes)
|
||||
|
||||
## Left Over / Future Enhancements
|
||||
|
||||
### 1. Context Passing via Environment Variables
|
||||
Pass agent context to entrypoint via env vars:
|
||||
- `OLLAMA_AGENT_NAME` - Name of the agent
|
||||
- `OLLAMA_SKILLS_PATH` - Path to bundled skills
|
||||
- `OLLAMA_MCPS` - JSON of MCP configurations
|
||||
|
||||
### ~~2. Arguments Placeholder~~ ✓ DONE
|
||||
~~Support placeholder syntax for more control:~~
|
||||
```dockerfile
|
||||
# Now supported!
|
||||
ENTRYPOINT ducky -p $PROMPT
|
||||
```
|
||||
|
||||
### 3. Working Directory
|
||||
Set working directory for entrypoint:
|
||||
```dockerfile
|
||||
WORKDIR /app
|
||||
ENTRYPOINT ./run.sh
|
||||
```
|
||||
|
||||
### 4. Interactive Mode Detection
|
||||
Different behavior for REPL vs single-shot:
|
||||
- Detect if stdin is a TTY
|
||||
- Pass different flags based on mode
|
||||
|
||||
### 5. Signal Handling
|
||||
Improved signal forwarding to subprocess:
|
||||
- Forward SIGINT, SIGTERM gracefully
|
||||
- Handle cleanup on parent exit
|
||||
|
||||
### 6. Entrypoint with Model
|
||||
Allow both model and entrypoint:
|
||||
```dockerfile
|
||||
FROM llama3.2
|
||||
ENTRYPOINT my-custom-ui
|
||||
```
|
||||
The entrypoint could then use the model via Ollama API.
|
||||
|
||||
### 7. Pull/Push for Entrypoint Agents
|
||||
- Currently entrypoint agents can be created locally
|
||||
- Need to test/verify push/pull to registry works correctly
|
||||
- May need to handle entrypoint binaries (or just reference system commands)
|
||||
|
||||
### 8. Error Handling
|
||||
- Better error messages when entrypoint command not found
|
||||
- Validation of entrypoint during create (optional, warn if not found)
|
||||
|
||||
## Design Decisions
|
||||
|
||||
1. **Subprocess mode (not exec)**: Ollama stays as parent process to handle signals and cleanup
|
||||
|
||||
2. **No context passing initially**: Keep it simple, entrypoint handles its own config
|
||||
|
||||
3. **Skills/MCPs allowed**: Enables packaging assets with the agent even if entrypoint manages execution
|
||||
|
||||
4. **FROM optional**: Entrypoint agents don't need a model, just the runtime
|
||||
|
||||
5. **Prompt as argument**: User prompt is appended as argument to entrypoint command (simplest approach)
|
||||
332
docs/agent-skills-changes.md
Normal file
332
docs/agent-skills-changes.md
Normal file
@@ -0,0 +1,332 @@
|
||||
# Agent Skills Feature - Implementation Summary
|
||||
|
||||
This document summarizes all changes made to implement agent skills in Ollama, enabling `ollama run <agent>` with skill-based capabilities.
|
||||
|
||||
## Overview
|
||||
|
||||
Agents are models with attached skills. Skills are directories containing a `SKILL.md` file with instructions and optional executable scripts. When an agent runs, skills are loaded and injected into the system prompt, and the model can execute scripts via tool calls.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### 1. `cmd/skills.go` (NEW FILE)
|
||||
|
||||
Core skills implementation:
|
||||
|
||||
```go
|
||||
// Key types
|
||||
type skillMetadata struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type skillDefinition struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string // SKILL.md body content
|
||||
Dir string // Absolute path to skill directory
|
||||
SkillPath string // Absolute path to SKILL.md
|
||||
}
|
||||
|
||||
type skillCatalog struct {
|
||||
Skills []skillDefinition
|
||||
byName map[string]skillDefinition
|
||||
}
|
||||
```
|
||||
|
||||
**Key functions:**
|
||||
- `loadSkills(paths []string)` - Walks skill directories, parses SKILL.md files
|
||||
- `parseSkillFile(path, skillDir)` - Extracts YAML frontmatter and body content
|
||||
- `SystemPrompt()` - Generates system prompt with skill instructions
|
||||
- `Tools()` - Returns `run_skill_script` and `read_skill_file` tools
|
||||
- `RunToolCall(call)` - Executes tool calls from the model
|
||||
- `runSkillScript(skillDir, command)` - Executes shell commands in skill directory
|
||||
|
||||
**Tools provided to model:**
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `run_skill_script` | Execute a script in a skill's directory |
|
||||
| `read_skill_file` | Read a file from a skill's directory |
|
||||
|
||||
**Security note:** `runSkillScript` has documented limitations (no sandboxing, no path validation). See the function's doc comment for details.
|
||||
|
||||
---
|
||||
|
||||
### 2. `cmd/cmd.go`
|
||||
|
||||
**Changes to `runOptions` struct:**
|
||||
```go
|
||||
type runOptions struct {
|
||||
// ... existing fields ...
|
||||
IsAgent bool
|
||||
AgentType string
|
||||
Skills []string
|
||||
}
|
||||
```
|
||||
|
||||
**Agent detection in `RunHandler`** (~line 497-503):
|
||||
```go
|
||||
// Check if this is an agent
|
||||
isAgent := info.AgentType != "" || len(info.Skills) > 0
|
||||
if isAgent {
|
||||
opts.IsAgent = true
|
||||
opts.AgentType = info.AgentType
|
||||
opts.Skills = info.Skills
|
||||
}
|
||||
```
|
||||
|
||||
**Route agents to chat API** (~line 557-562):
|
||||
```go
|
||||
// For agents, use chat API even in non-interactive mode to support tools
|
||||
if opts.IsAgent {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: opts.Prompt})
|
||||
_, err := chat(cmd, opts)
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
**Skills loading in `chat` function** (~line 1347-1361):
|
||||
```go
|
||||
var skillsCatalog *skillCatalog
|
||||
if opts.IsAgent && len(opts.Skills) > 0 {
|
||||
skillsCatalog, err = loadSkills(opts.Skills)
|
||||
// ... error handling ...
|
||||
// Print loaded skills
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills: %s\n", strings.Join(skillNames, ", "))
|
||||
}
|
||||
```
|
||||
|
||||
**System prompt injection** (~line 1448-1455):
|
||||
- Skills system prompt is prepended to messages
|
||||
|
||||
**Tool execution** (~line 1497-1533):
|
||||
- Executes pending tool calls via `skillsCatalog.RunToolCall()`
|
||||
- Displays script execution and output to terminal
|
||||
|
||||
---
|
||||
|
||||
### 3. `parser/parser.go`
|
||||
|
||||
**New valid commands** in `isValidCommand()`:
|
||||
```go
|
||||
case "from", "license", "template", "system", "adapter", "renderer",
|
||||
"parser", "parameter", "message", "requires", "skill", "agent_type":
|
||||
```
|
||||
|
||||
**Command handling in `CreateRequest()`**:
|
||||
```go
|
||||
case "skill":
|
||||
skills = append(skills, c.Args)
|
||||
case "agent_type":
|
||||
req.AgentType = c.Args
|
||||
```
|
||||
|
||||
**Underscore support in command names** (~line 545):
|
||||
```go
|
||||
case isAlpha(r), r == '_':
|
||||
return stateName, r, nil
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. `api/types.go`
|
||||
|
||||
**CreateRequest additions** (~line 560-564):
|
||||
```go
|
||||
// Skills is a list of skill directories for the agent
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
|
||||
// AgentType defines the type of agent (e.g., "conversational", "task-based")
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
```
|
||||
|
||||
**ShowResponse additions** (~line 633-637):
|
||||
```go
|
||||
// Skills loaded for this agent
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
|
||||
// AgentType for this agent
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 5. `types/model/config.go`
|
||||
|
||||
**ConfigV2 additions**:
|
||||
```go
|
||||
type ConfigV2 struct {
|
||||
// ... existing fields ...
|
||||
|
||||
// Agent-specific fields
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
AgentType string `json:"agent_type,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 6. `server/create.go`
|
||||
|
||||
**Store agent fields** (~line 65-66):
|
||||
```go
|
||||
config.Skills = r.Skills
|
||||
config.AgentType = r.AgentType
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 7. `server/routes.go`
|
||||
|
||||
**Return agent fields in ShowResponse** (~line 1107):
|
||||
```go
|
||||
resp := &api.ShowResponse{
|
||||
// ... existing fields ...
|
||||
Skills: m.Config.Skills,
|
||||
AgentType: m.Config.AgentType,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 8. `envconfig/config.go`
|
||||
|
||||
**Environment variable support**:
|
||||
```go
|
||||
func Skills() []string {
|
||||
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Agentfile Format
|
||||
|
||||
Agentfiles use the same syntax as Modelfiles with additional commands:
|
||||
|
||||
```dockerfile
|
||||
FROM gpt-oss:20b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
SKILL /path/to/skills/directory
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
|
||||
PARAMETER temperature 0.3
|
||||
PARAMETER top_p 0.9
|
||||
```
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `SKILL` | Path to a directory containing skill subdirectories |
|
||||
| `AGENT_TYPE` | Type of agent (e.g., "conversational") |
|
||||
|
||||
---
|
||||
|
||||
## SKILL.md Format
|
||||
|
||||
Each skill is a directory with a `SKILL.md` file:
|
||||
|
||||
```
|
||||
calculator-skill/
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
└── calculate.py
|
||||
```
|
||||
|
||||
**SKILL.md structure:**
|
||||
```markdown
|
||||
---
|
||||
name: calculator-skill
|
||||
description: A skill for performing calculations.
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Use `run_skill_script` to execute calculations
|
||||
2. Call: `python3 scripts/calculate.py '<expression>'`
|
||||
|
||||
## Examples
|
||||
|
||||
For "What is 25 * 4?":
|
||||
- Call: run_skill_script with skill="calculator-skill" and command="python3 scripts/calculate.py '25 * 4'"
|
||||
```
|
||||
|
||||
**Requirements:**
|
||||
- `name` must match directory name
|
||||
- `name` must be lowercase alphanumeric with hyphens only
|
||||
- `name` max 64 characters
|
||||
- `description` required, max 1024 characters
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Create an agent
|
||||
ollama create math-agent -f math-agent.Agentfile
|
||||
|
||||
# Run the agent
|
||||
ollama run math-agent "What is 25 * 4?"
|
||||
|
||||
# Output:
|
||||
# Loaded skills: calculator-skill
|
||||
# Running script in calculator-skill: python3 scripts/calculate.py '25 * 4'
|
||||
# Output:
|
||||
# 25 * 4 = 100
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Flow Diagram
|
||||
|
||||
```
|
||||
1. ollama run math-agent "query"
|
||||
│
|
||||
▼
|
||||
2. RunHandler detects agent (AgentType or Skills present)
|
||||
│
|
||||
▼
|
||||
3. Routes to chat() instead of generate()
|
||||
│
|
||||
▼
|
||||
4. loadSkills() parses SKILL.md files
|
||||
│
|
||||
▼
|
||||
5. SystemPrompt() injects skill instructions
|
||||
│
|
||||
▼
|
||||
6. Tools() provides run_skill_script, read_skill_file
|
||||
│
|
||||
▼
|
||||
7. Model generates response (may include tool calls)
|
||||
│
|
||||
▼
|
||||
8. RunToolCall() executes scripts, returns output
|
||||
│
|
||||
▼
|
||||
9. Display results to user
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
The `runSkillScript` function has known limitations documented in the code:
|
||||
|
||||
- No sandboxing (commands run with user permissions)
|
||||
- No path validation (model can run any command)
|
||||
- Shell injection risk (`sh -c` is used)
|
||||
- No executable allowlist
|
||||
- No environment isolation
|
||||
|
||||
**Potential improvements** (documented as TODOs):
|
||||
- Restrict to skill directory paths only
|
||||
- Allowlist executables (python3, node, bash)
|
||||
- Use sandboxing (Docker, nsjail, seccomp)
|
||||
- Require explicit script registration in SKILL.md
|
||||
10
docs/api.md
10
docs/api.md
@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
@@ -507,7 +507,7 @@ The `message` object has the following fields:
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
||||
- `template`: (optional) the prompt template for the model
|
||||
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
||||
- `system`: (optional) a string containing the system prompt for the model
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
|
||||
- `messages`: (optional) a list of message objects used to create a conversation
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
||||
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
|
||||
Advanced parameters:
|
||||
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Examples
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
|
||||
```shell
|
||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen3",
|
||||
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
|
||||
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
|
||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen3",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the temperature in New York?"},
|
||||
{"role": "user", "content": "What is the temperature in New York?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
|
||||
}
|
||||
return temperatures.get(city, "Unknown")
|
||||
|
||||
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
|
||||
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
|
||||
|
||||
# pass functions directly as tools in the tools list or as a JSON schema
|
||||
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
||||
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
|
||||
},
|
||||
]
|
||||
|
||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
||||
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'qwen3',
|
||||
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
|
||||
return temperatures.get(city, 'Unknown')
|
||||
|
||||
|
||||
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
|
||||
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
|
||||
|
||||
while True:
|
||||
stream = chat(
|
||||
@@ -684,7 +684,7 @@ const getTemperatureTool = {
|
||||
}
|
||||
|
||||
async function agentLoop() {
|
||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
||||
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||
|
||||
while (true) {
|
||||
const stream = await ollama.chat({
|
||||
|
||||
@@ -49,6 +49,8 @@ Install prerequisites:
|
||||
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||
- (Optional) NVIDIA GPU support
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
@@ -57,6 +59,17 @@ cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
> Building for Vulkan requires VULKAN_SDK environment variable:
|
||||
>
|
||||
> PowerShell
|
||||
> ```powershell
|
||||
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
|
||||
> ```
|
||||
> CMD
|
||||
> ```cmd
|
||||
> set VULKAN_SDK=C:\VulkanSDK\<version>
|
||||
> ```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Building for ROCm requires additional flags:
|
||||
> ```
|
||||
@@ -65,6 +78,7 @@ cmake --build build --config Release
|
||||
> ```
|
||||
|
||||
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```shell
|
||||
@@ -84,7 +98,9 @@ Install prerequisites:
|
||||
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
||||
- (Optional) NVIDIA GPU support
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
|
||||
15
docs/faq.mdx
15
docs/faq.mdx
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu.md).
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
@@ -57,8 +57,13 @@ ollama ps
|
||||
```
|
||||
|
||||
<Info>
|
||||
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
|
||||
100% GPU 4 minutes from now ```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
NAME ID SIZE PROCESSOR UNTIL
|
||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -385,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
|
||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||
|
||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
||||
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||
|
||||
### GPU Selection
|
||||
|
||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
265
docs/mcp-integration.md
Normal file
265
docs/mcp-integration.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# MCP (Model Context Protocol) Integration
|
||||
|
||||
This document describes the MCP integration for Ollama agents, enabling agents to use external tools via the Model Context Protocol.
|
||||
|
||||
## Overview
|
||||
|
||||
MCP allows Ollama agents to communicate with external tool servers over JSON-RPC 2.0 via stdio. This enables agents to access capabilities like web search, file operations, databases, and more through standardized tool interfaces.
|
||||
|
||||
## Status
|
||||
|
||||
| Phase | Description | Status |
|
||||
|-------|-------------|--------|
|
||||
| Phase 1 | Types & Parser | ✅ Complete |
|
||||
| Phase 2 | Layer Handling | ✅ Complete |
|
||||
| Phase 3 | Runtime Manager | ✅ Complete |
|
||||
| Phase 4 | CLI Commands | ✅ Complete |
|
||||
|
||||
## Agentfile Syntax
|
||||
|
||||
### Simple Command Format
|
||||
```dockerfile
|
||||
MCP <name> <command> [args...]
|
||||
```
|
||||
|
||||
Example:
|
||||
```dockerfile
|
||||
FROM llama3.2
|
||||
AGENT TYPE conversational
|
||||
SYSTEM You are a helpful assistant with MCP tools.
|
||||
MCP calculator python3 ./mcp-server.py
|
||||
MCP websearch node ./search-server.js
|
||||
```
|
||||
|
||||
### JSON Format
|
||||
```dockerfile
|
||||
MCP {"name": "custom", "command": "uv", "args": ["run", "server.py"], "env": {"API_KEY": "xxx"}}
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Type Definitions
|
||||
|
||||
**MCPRef** (`types/model/config.go`):
|
||||
```go
|
||||
type MCPRef struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Type string `json:"type,omitempty"` // "stdio"
|
||||
}
|
||||
```
|
||||
|
||||
### Tool Namespacing
|
||||
|
||||
MCP tools are namespaced to avoid conflicts:
|
||||
- Format: `mcp_{serverName}_{toolName}`
|
||||
- Example: Server "calculator" with tool "add" → `mcp_calculator_add`
|
||||
|
||||
### Runtime Flow
|
||||
|
||||
1. Agent starts → MCP servers spawn as subprocesses
|
||||
2. Initialize via JSON-RPC: `initialize` → `notifications/initialized`
|
||||
3. Discover tools: `tools/list`
|
||||
4. During chat, model calls tools → routed via `tools/call`
|
||||
5. On shutdown, MCP servers are gracefully terminated
|
||||
|
||||
## Files
|
||||
|
||||
### Created
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `cmd/mcp.go` | Runtime MCP manager with JSON-RPC protocol |
|
||||
| `cmd/mcp_cmd.go` | CLI commands for managing MCPs (push, pull, list, etc.) |
|
||||
| `server/mcp.go` | MCP layer utilities (extraction, creation) |
|
||||
|
||||
### Modified
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `types/model/config.go` | Added `MCPRef` type, `MCPs` field to `ConfigV2` |
|
||||
| `types/model/name.go` | Added `"mcp"` to `ValidKinds` for 5-part name parsing |
|
||||
| `api/types.go` | Added `MCPRef` alias, `MCPs` to `CreateRequest`/`ShowResponse` |
|
||||
| `parser/parser.go` | Added `MCP` command parsing with JSON and simple formats |
|
||||
| `server/create.go` | Added `setMCPLayers()` for MCP config handling |
|
||||
| `server/routes.go` | Added `MCPs` to show response |
|
||||
| `cmd/cmd.go` | MCP integration in `chat()` function |
|
||||
| `cmd/interactive.go` | Added `/mcp` and `/mcps` REPL commands |
|
||||
|
||||
## Usage Example
|
||||
|
||||
### 1. Create an MCP Server
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
# mcp-server.py
|
||||
import json
|
||||
import sys
|
||||
|
||||
def handle_request(req):
|
||||
method = req.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "example", "version": "1.0"}
|
||||
}
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"tools": [{
|
||||
"name": "add",
|
||||
"description": "Adds two numbers",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
args = req["params"]["arguments"]
|
||||
return {"content": [{"type": "text", "text": f"{args['a'] + args['b']}"}]}
|
||||
return {}
|
||||
|
||||
for line in sys.stdin:
|
||||
req = json.loads(line)
|
||||
if "id" in req:
|
||||
result = handle_request(req)
|
||||
print(json.dumps({"jsonrpc": "2.0", "id": req["id"], "result": result}), flush=True)
|
||||
```
|
||||
|
||||
### 2. Create an Agent
|
||||
|
||||
```dockerfile
|
||||
# my-agent.Agentfile
|
||||
FROM gpt-oss:20b
|
||||
AGENT TYPE conversational
|
||||
SYSTEM You have access to a calculator. Use the add tool when asked to add numbers.
|
||||
MCP calculator python3 ./mcp-server.py
|
||||
```
|
||||
|
||||
### 3. Build and Run
|
||||
|
||||
```bash
|
||||
ollama create my-agent -f my-agent.Agentfile
|
||||
ollama run my-agent "What is 15 + 27?"
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
Loaded MCP servers: calculator (1 tools)
|
||||
Executing: mcp_calculator_add
|
||||
Output: 42
|
||||
The result is 42.
|
||||
```
|
||||
|
||||
## CLI Commands
|
||||
|
||||
The `ollama mcp` command provides utilities for managing MCP servers:
|
||||
|
||||
### Global Config Commands
|
||||
|
||||
Add an MCP server to the global config (`~/.ollama/mcp.json`):
|
||||
```bash
|
||||
# Add MCP to global config (available to all agents)
|
||||
ollama mcp add web-search uv run ./mcp-server.py
|
||||
ollama mcp add calculator python3 /path/to/calc.py
|
||||
|
||||
# List global MCP servers (shows enabled/disabled status)
|
||||
ollama mcp list-global
|
||||
|
||||
# Disable an MCP server (keeps in config but won't be loaded)
|
||||
ollama mcp disable web-search
|
||||
|
||||
# Re-enable a disabled MCP server
|
||||
ollama mcp enable web-search
|
||||
|
||||
# Remove from global config
|
||||
ollama mcp remove-global web-search
|
||||
```
|
||||
|
||||
### Registry Commands
|
||||
|
||||
Package and push MCPs to a registry:
|
||||
```bash
|
||||
# Push MCP to registry (creates locally first)
|
||||
ollama mcp push mcp/websearch:1.0 ./my-mcp-server/
|
||||
|
||||
# Pull MCP from registry
|
||||
ollama mcp pull mcp/websearch:1.0
|
||||
|
||||
# List installed MCPs (from registry)
|
||||
ollama mcp list
|
||||
|
||||
# Show MCP details
|
||||
ollama mcp show mcp/websearch:1.0
|
||||
|
||||
# Remove MCP
|
||||
ollama mcp rm mcp/websearch:1.0
|
||||
```
|
||||
|
||||
## REPL Commands
|
||||
|
||||
Inside `ollama run`, you can manage MCP servers dynamically:
|
||||
|
||||
```
|
||||
>>> /mcp # Show all MCP servers (model + global)
|
||||
>>> /mcp add calc python3 ./calc-server.py # Add MCP server to global config
|
||||
>>> /mcp remove calc # Remove MCP server from global config
|
||||
>>> /mcp disable calc # Disable an MCP server (keep in config)
|
||||
>>> /mcp enable calc # Re-enable a disabled MCP server
|
||||
>>> /? mcp # Get help for MCP commands
|
||||
```
|
||||
|
||||
The `/mcp` command shows all available MCP servers (both bundled with the model and from global config). Disabled servers are shown with a `[disabled]` marker. Use `/mcp add` and `/mcp remove` to manage MCPs in `~/.ollama/mcp.json`. Changes take effect on the next message.
|
||||
|
||||
## Global Config
|
||||
|
||||
MCPs can be configured globally in `~/.ollama/mcp.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"web-search": {
|
||||
"type": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "./mcp-server.py"]
|
||||
},
|
||||
"calculator": {
|
||||
"type": "stdio",
|
||||
"command": "python3",
|
||||
"args": ["/path/to/calc.py"],
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `disabled` field is optional. When set to `true`, the MCP server will not be loaded when running agents.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Remote Registry Push/Pull**: Full support for pushing/pulling MCPs to/from remote registries
|
||||
2. **Use go-sdk**: Consider using `github.com/modelcontextprotocol/go-sdk` for protocol handling
|
||||
3. **Resource Support**: Add MCP resources (not just tools)
|
||||
4. **Prompt Support**: Add MCP prompts
|
||||
|
||||
## Protocol Reference
|
||||
|
||||
MCP uses JSON-RPC 2.0 over stdio with these key methods:
|
||||
|
||||
| Method | Direction | Purpose |
|
||||
|--------|-----------|---------|
|
||||
| `initialize` | Client→Server | Handshake with capabilities |
|
||||
| `notifications/initialized` | Client→Server | Confirm initialization |
|
||||
| `tools/list` | Client→Server | Discover available tools |
|
||||
| `tools/call` | Client→Server | Execute a tool |
|
||||
|
||||
See [MCP Specification](https://modelcontextprotocol.io/docs) for full details.
|
||||
@@ -41,6 +41,7 @@ INSTRUCTION arguments
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
### REQUIRES
|
||||
|
||||
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||
|
||||
```
|
||||
REQUIRES <version>
|
||||
```
|
||||
|
||||
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
362
docs/skill-registry-design.md
Normal file
362
docs/skill-registry-design.md
Normal file
@@ -0,0 +1,362 @@
|
||||
# Skill Registry Design
|
||||
|
||||
## Overview
|
||||
|
||||
Skills are distributable capability packages for Ollama agents. They can be:
|
||||
- Bundled with agents at creation time (local paths)
|
||||
- Pulled from the registry (skill references)
|
||||
- Pushed to the registry for sharing
|
||||
|
||||
## User Experience
|
||||
|
||||
### Push a Skill
|
||||
|
||||
```bash
|
||||
# Push a local skill directory to the registry
|
||||
ollama skill push myname/calculator:1.0.0 ./skills/calculator-skill
|
||||
|
||||
# Output:
|
||||
# Creating skill layer for skill/myname/calculator:1.0.0
|
||||
# pushing sha256:abc123... 1.2KB
|
||||
# pushing sha256:def456... 220B
|
||||
# pushing manifest
|
||||
# Successfully pushed skill/myname/calculator:1.0.0
|
||||
```
|
||||
|
||||
### Pull a Skill
|
||||
|
||||
```bash
|
||||
# Pull a skill from the registry
|
||||
ollama skill pull calculator:1.0.0
|
||||
|
||||
# Output:
|
||||
# pulling manifest
|
||||
# pulling sha256:abc123... 1.2KB
|
||||
# extracting skill...
|
||||
# Successfully pulled skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
### List Installed Skills
|
||||
|
||||
```bash
|
||||
ollama skill list
|
||||
|
||||
# Output:
|
||||
# NAME TAG SIZE MODIFIED
|
||||
# skill/calculator 1.0.0 1.2 KB 2 hours ago
|
||||
# skill/myname/hello latest 0.8 KB 1 day ago
|
||||
```
|
||||
|
||||
### Remove a Skill
|
||||
|
||||
```bash
|
||||
ollama skill rm calculator:1.0.0
|
||||
# Deleted 'skill/calculator:1.0.0'
|
||||
```
|
||||
|
||||
### Use Skills in Agentfile
|
||||
|
||||
```dockerfile
|
||||
FROM llama3.2:3b
|
||||
|
||||
AGENT_TYPE conversational
|
||||
SKILL skill/calculator:1.0.0 # Registry reference
|
||||
SKILL ./local-skill # Local path (for development)
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
```
|
||||
|
||||
## Technical Implementation
|
||||
|
||||
### Skill Manifest Format
|
||||
|
||||
```json
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {
|
||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
||||
"digest": "sha256:config...",
|
||||
"size": 220
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.skill",
|
||||
"digest": "sha256:skill...",
|
||||
"size": 1234
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Skill Config Format
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "calculator",
|
||||
"description": "A skill for performing calculations",
|
||||
"architecture": "amd64",
|
||||
"os": "linux"
|
||||
}
|
||||
```
|
||||
|
||||
### Storage Layout
|
||||
|
||||
Skills use a 5-part manifest structure: `host/namespace/kind/model/tag`
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
├── blobs/
|
||||
│ └── sha256-<skill-digest> # Skill tar.gz blob
|
||||
├── manifests/
|
||||
│ └── registry.ollama.ai/
|
||||
│ └── library/
|
||||
│ └── skill/ # Kind = skill
|
||||
│ └── calculator/
|
||||
│ └── 1.0.0
|
||||
│ └── myname/
|
||||
│ └── skill/ # User skills
|
||||
│ └── my-skill/
|
||||
│ └── latest
|
||||
└── skills/
|
||||
└── sha256-<digest>/ # Extracted skill cache
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
```
|
||||
|
||||
### Name Structure
|
||||
|
||||
Skills use a 5-part name structure with `kind` to distinguish from models:
|
||||
|
||||
| Skill Reference | Namespace | Kind | Model | Tag |
|
||||
|-----------------|-----------|------|-------|-----|
|
||||
| `skill/calculator:1.0.0` | library | skill | calculator | 1.0.0 |
|
||||
| `myname/skill/calc:latest` | myname | skill | calc | latest |
|
||||
|
||||
### Media Type
|
||||
|
||||
```go
|
||||
const MediaTypeSkill = "application/vnd.ollama.image.skill"
|
||||
```
|
||||
|
||||
### Key Types
|
||||
|
||||
```go
|
||||
// SkillRef represents a skill reference in agent config
|
||||
type SkillRef struct {
|
||||
Name string `json:"name,omitempty"` // "calculator-skill" or "myname/skill/calc:1.0.0"
|
||||
Digest string `json:"digest,omitempty"` // "sha256:abc..." (set when bundled)
|
||||
}
|
||||
|
||||
// model.Name represents a parsed 5-part name
|
||||
type Name struct {
|
||||
Host string // "registry.ollama.ai"
|
||||
Namespace string // "library" or "myname"
|
||||
Kind string // "skill" or "agent" or "" for models
|
||||
Model string // "calculator"
|
||||
Tag string // "1.0.0"
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Files
|
||||
|
||||
### Client (ollama)
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `server/skill.go` | Skill blob handling, path parsing, extraction |
|
||||
| `cmd/skill_cmd.go` | CLI commands (push, pull, list, rm, show) |
|
||||
| `cmd/skills.go` | Skill loading and catalog management |
|
||||
| `server/create.go` | Skill layer creation during agent create |
|
||||
| `server/images.go` | Skill extraction during pull |
|
||||
| `types/model/config.go` | SkillRef type definition |
|
||||
|
||||
### Registry (ollama.com)
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `ollamadotcom/registry/store.go` | MediaTypeSkill constant |
|
||||
| `ollamadotcom/store/store.go` | RecordPush handles skill layers |
|
||||
|
||||
## Registry Integration
|
||||
|
||||
### What Works
|
||||
|
||||
- Blob uploads (content-addressable, no auth required)
|
||||
- Layer indexing (skill layers stored with mediatype)
|
||||
- Manifest structure (4-part path compatible)
|
||||
|
||||
### What's Needed
|
||||
|
||||
1. **Namespace Configuration**: The `skill` namespace needs to be configured with:
|
||||
- Public read access
|
||||
- Authenticated write access
|
||||
|
||||
2. **Permission Model**: Decide who can push to `skill/` namespace:
|
||||
- Only Ollama team (curated library)
|
||||
- Verified publishers
|
||||
- Anyone (open registry)
|
||||
|
||||
## Pull Flow
|
||||
|
||||
### Agent with Bundled Skills
|
||||
|
||||
```
|
||||
ollama pull my-agent
|
||||
→ GET manifest (includes skill layers)
|
||||
→ Download all blobs (model + skills)
|
||||
→ Extract skill blobs to ~/.ollama/models/skills/
|
||||
→ Ready to run
|
||||
```
|
||||
|
||||
### Standalone Skill
|
||||
|
||||
```
|
||||
ollama skill pull calculator:1.0.0
|
||||
→ Parse as skill/calculator:1.0.0
|
||||
→ Convert to model.Name{Namespace: "skill", Model: "calculator", Tag: "1.0.0"}
|
||||
→ GET manifest from registry
|
||||
→ Download skill blob
|
||||
→ Extract to ~/.ollama/models/skills/sha256-<digest>/
|
||||
→ Available for agents to reference
|
||||
```
|
||||
|
||||
## Push Flow
|
||||
|
||||
```
|
||||
ollama skill push myname/calculator:1.0.0 ./my-skill
|
||||
→ Validate SKILL.md exists
|
||||
→ Create tar.gz of skill directory
|
||||
→ Compute SHA256 digest
|
||||
→ Store blob locally
|
||||
→ Create skill manifest with config layer
|
||||
→ Store manifest locally
|
||||
→ Push blobs to registry
|
||||
→ Push manifest to registry
|
||||
```
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
- Old agents with `Skills: []string` (paths) continue to work
|
||||
- New agents use `Skills: []SkillRef` with name and digest
|
||||
- Parser detects format and handles both
|
||||
|
||||
## Local Registry Testing
|
||||
|
||||
To test push/pull locally, you need MinIO and the Docker registry running:
|
||||
|
||||
```bash
|
||||
# 1. Start MinIO (for blob storage)
|
||||
minio server ~/.minio-data --console-address ':9001' &
|
||||
|
||||
# 2. Create the ollama-dev bucket (first time only)
|
||||
mc config host add local http://localhost:9000 minioadmin minioadmin
|
||||
mc mb local/ollama-dev
|
||||
|
||||
# 3. Start the registry (from ollama.com repo)
|
||||
cd /path/to/ollama.com/registry
|
||||
go run cmd/registry/main.go serve config-dev.yml &
|
||||
|
||||
# 4. Verify registry is running
|
||||
curl http://localhost:6000/v2/
|
||||
```
|
||||
|
||||
**Important:** The `config-dev.yml` must have matching ports:
|
||||
```yaml
|
||||
http:
|
||||
addr: :6000
|
||||
host: http://localhost:6000 # Must match addr!
|
||||
```
|
||||
|
||||
### Test Commands
|
||||
|
||||
```bash
|
||||
# Push skill from local folder
|
||||
ollama skill push localhost:6000/testuser/skill/calculator:1.0.0 ./skills/calculator-skill --insecure
|
||||
|
||||
# Pull skill from registry
|
||||
ollama skill pull localhost:6000/testuser/skill/calculator:1.0.0 --insecure
|
||||
|
||||
# List skills
|
||||
ollama skill list
|
||||
|
||||
# Show skill
|
||||
ollama skill show localhost:6000/testuser/skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Skill Naming Structure"
|
||||
A["skill/calculator:1.0.0"] --> B["host: registry.ollama.ai"]
|
||||
A --> C["namespace: library"]
|
||||
A --> D["kind: skill"]
|
||||
A --> E["model: calculator"]
|
||||
A --> F["tag: 1.0.0"]
|
||||
end
|
||||
|
||||
subgraph "Storage Layout"
|
||||
G["~/.ollama/models/"]
|
||||
G --> H["blobs/"]
|
||||
H --> I["sha256-<skill-digest>"]
|
||||
G --> J["manifests/"]
|
||||
J --> K["registry.ollama.ai/"]
|
||||
K --> L["library/skill/calculator/1.0.0"]
|
||||
K --> M["myname/skill/my-skill/latest"]
|
||||
G --> N["skills/"]
|
||||
N --> O["sha256-<digest>/"]
|
||||
O --> P["SKILL.md"]
|
||||
O --> Q["scripts/"]
|
||||
end
|
||||
|
||||
subgraph "Push Flow"
|
||||
R["User Command: ollama skill push"]
|
||||
R --> S["Validate SKILL.md"]
|
||||
S --> T["Create tar.gz of skill dir"]
|
||||
T --> U["Compute SHA256 digest"]
|
||||
U --> V["Store blob locally"]
|
||||
V --> W["Create skill manifest"]
|
||||
W --> X["Store manifest locally"]
|
||||
X --> Y["Push blobs to registry"]
|
||||
Y --> Z["Push manifest to registry"]
|
||||
end
|
||||
|
||||
subgraph "Pull Flow - Standalone Skill"
|
||||
AA["User Command: ollama skill pull"]
|
||||
AA --> AB["Parse name structure"]
|
||||
AB --> AC["GET manifest from registry"]
|
||||
AC --> AD["Download skill blob"]
|
||||
AD --> AE["Extract to skills/ directory"]
|
||||
AE --> AF["Available for agents"]
|
||||
end
|
||||
|
||||
subgraph "Pull Flow - Agent with Skills"
|
||||
AG["Pull Agent: ollama pull my-agent"]
|
||||
AG --> AH["GET manifest (includes skill layers)"]
|
||||
AH --> AI["Download all blobs (model + skills)"]
|
||||
AI --> AJ["Extract skill blobs"]
|
||||
AJ --> AK["Ready to run"]
|
||||
end
|
||||
|
||||
subgraph "Agentfile Integration"
|
||||
AL["Agentfile"]
|
||||
AL --> AM["FROM llama3.2:3b"]
|
||||
AL --> AN["SKILL skill/calculator:1.0.0"]
|
||||
AL --> AO["SKILL ./local-skill"]
|
||||
AO --> AP["Local path (development)"]
|
||||
AN --> AQ["Registry reference"]
|
||||
end
|
||||
|
||||
subgraph "Registry Components"
|
||||
AR["Registry Server"]
|
||||
AR --> AS["Blob Storage (MinIO)"]
|
||||
AR --> AT["Layer Indexing"]
|
||||
AR --> AU["Manifest Storage"]
|
||||
AR --> AV["Namespace Config"]
|
||||
end
|
||||
|
||||
Z --> AR
|
||||
AC --> AR
|
||||
AH --> AR
|
||||
```
|
||||
548
docs/skills.md
Normal file
548
docs/skills.md
Normal file
@@ -0,0 +1,548 @@
|
||||
# Ollama Skills
|
||||
|
||||
Skills are reusable capability packages that extend what agents can do. They bundle instructions, scripts, and data that teach an agent how to perform specific tasks.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Creating a Skill
|
||||
|
||||
Create a directory with a `SKILL.md` file:
|
||||
|
||||
```
|
||||
my-skill/
|
||||
├── SKILL.md # Required: Instructions for the agent
|
||||
└── scripts/ # Optional: Executable scripts
|
||||
└── run.py
|
||||
```
|
||||
|
||||
The `SKILL.md` file must have YAML frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: A brief description of what this skill does
|
||||
---
|
||||
|
||||
# My Skill
|
||||
|
||||
## Purpose
|
||||
Explain what this skill does and when to use it.
|
||||
|
||||
## Instructions
|
||||
Step-by-step instructions for the agent on how to use this skill.
|
||||
|
||||
## Examples
|
||||
Show example inputs and expected outputs.
|
||||
```
|
||||
|
||||
### Using Skills in an Agent
|
||||
|
||||
Reference skills in your Agentfile:
|
||||
|
||||
```dockerfile
|
||||
FROM llama3.2:3b
|
||||
AGENT_TYPE conversational
|
||||
|
||||
# Local skill (bundled with agent)
|
||||
SKILL ./path/to/my-skill
|
||||
|
||||
# Registry skill (pulled from ollama.com)
|
||||
SKILL library/skill/calculator:1.0.0
|
||||
|
||||
# User skill from registry
|
||||
SKILL myname/skill/calculator:1.0.0
|
||||
|
||||
SYSTEM You are a helpful assistant.
|
||||
```
|
||||
|
||||
### Managing Skills
|
||||
|
||||
```bash
|
||||
# Push a skill to the registry (uses your namespace)
|
||||
ollama skill push myname/skill/calculator:1.0.0 ./my-skill
|
||||
|
||||
# Pull a skill from the official library
|
||||
ollama skill pull skill/calculator:1.0.0
|
||||
|
||||
# Pull a skill from a user's namespace
|
||||
ollama skill pull myname/skill/calculator:1.0.0
|
||||
|
||||
# List installed skills
|
||||
ollama skill list
|
||||
|
||||
# Show skill details
|
||||
ollama skill show skill/calculator:1.0.0
|
||||
|
||||
# Remove a skill
|
||||
ollama skill rm skill/calculator:1.0.0
|
||||
```
|
||||
|
||||
### Dynamic Skills in Chat
|
||||
|
||||
You can add and remove skills dynamically during an interactive chat session:
|
||||
|
||||
```
|
||||
>>> /skills
|
||||
Available Skills:
|
||||
calculator (sha256:abc123def456...)
|
||||
|
||||
>>> /skill add ./my-local-skill
|
||||
Added skill 'my-skill' from ./my-local-skill
|
||||
|
||||
>>> /skill list
|
||||
Skills loaded in this session:
|
||||
my-skill (local: /path/to/my-local-skill)
|
||||
|
||||
>>> /skill remove my-skill
|
||||
Removed skill 'my-skill'
|
||||
```
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/skills` | Show all available skills (model + session) |
|
||||
| `/skill add <path>` | Add a skill from a local path |
|
||||
| `/skill remove <name>` | Remove a skill by name |
|
||||
| `/skill list` | List skills loaded in this session |
|
||||
|
||||
Dynamic skills take effect on the next message. This is useful for:
|
||||
- Testing skills during development
|
||||
- Temporarily adding capabilities to a model
|
||||
- Experimenting with skill combinations
|
||||
|
||||
## Skill Reference Formats
|
||||
|
||||
Skills use a 5-part name structure: `host/namespace/kind/model:tag`
|
||||
|
||||
| Format | Example | Description |
|
||||
|--------|---------|-------------|
|
||||
| Local path | `./skills/calc` | Bundled with agent at create time |
|
||||
| Library skill | `skill/calculator:1.0.0` | From the official skill library (library/skill/calculator) |
|
||||
| User skill | `alice/skill/calc:1.0.0` | From a user's namespace |
|
||||
| Full path | `registry.ollama.ai/alice/skill/calc:1.0.0` | Fully qualified with host |
|
||||
|
||||
The `kind` field distinguishes skills from models:
|
||||
- `skill` - Skill packages
|
||||
- `agent` - Agent packages (future)
|
||||
- (empty) - Regular models
|
||||
|
||||
## SKILL.md Structure
|
||||
|
||||
### Required Frontmatter
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: skill-name # Must match directory name
|
||||
description: Brief description of the skill
|
||||
---
|
||||
```
|
||||
|
||||
### Recommended Sections
|
||||
|
||||
1. **Purpose**: What the skill does and when to use it
|
||||
2. **When to use**: Trigger conditions for the agent
|
||||
3. **Instructions**: Step-by-step usage guide
|
||||
4. **Examples**: Input/output examples
|
||||
5. **Scripts**: Documentation for any bundled scripts
|
||||
|
||||
### Example: Calculator Skill
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: calculator
|
||||
description: Performs mathematical calculations using Python
|
||||
---
|
||||
|
||||
# Calculator Skill
|
||||
|
||||
## Purpose
|
||||
This skill performs mathematical calculations using a bundled Python script.
|
||||
|
||||
## When to use
|
||||
- User asks to calculate something
|
||||
- User wants to do math operations
|
||||
- Any arithmetic is needed
|
||||
|
||||
## Instructions
|
||||
1. When calculation is needed, use the `run_skill_script` tool
|
||||
2. Call: `python3 scripts/calculate.py "<expression>"`
|
||||
3. Return the result to the user
|
||||
|
||||
## Examples
|
||||
|
||||
**Input**: "What is 25 * 4?"
|
||||
**Action**: `run_skill_script` with command `python3 scripts/calculate.py '25 * 4'`
|
||||
**Output**: "25 * 4 = 100"
|
||||
```
|
||||
|
||||
## Storage Layout
|
||||
|
||||
```
|
||||
~/.ollama/models/
|
||||
├── blobs/
|
||||
│ └── sha256-<digest> # Skill tar.gz blob
|
||||
├── manifests/
|
||||
│ └── registry.ollama.ai/
|
||||
│ └── skill/ # Library skills
|
||||
│ └── calculator/
|
||||
│ └── 1.0.0
|
||||
│ └── skill-username/ # User skills
|
||||
│ └── my-skill/
|
||||
│ └── latest
|
||||
└── skills/
|
||||
└── sha256-<digest>/ # Extracted skill cache
|
||||
├── SKILL.md
|
||||
└── scripts/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Security Considerations
|
||||
|
||||
## Current State (Development)
|
||||
|
||||
The current implementation has several security considerations that need to be addressed before production use.
|
||||
|
||||
### 1. Script Execution
|
||||
|
||||
**Risk**: Skills can bundle arbitrary scripts that execute on the host system.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts run with the same permissions as the Ollama process
|
||||
- No sandboxing or isolation
|
||||
- Full filesystem access
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Sandbox script execution (containers, seccomp, etc.)
|
||||
- [ ] Resource limits (CPU, memory, time)
|
||||
- [ ] Filesystem isolation (read-only mounts, restricted paths)
|
||||
- [ ] Network policy controls
|
||||
- [ ] Capability dropping
|
||||
|
||||
### 2. Skill Provenance
|
||||
|
||||
**Risk**: Malicious skills could be pushed to the registry.
|
||||
|
||||
**Current behavior**:
|
||||
- No code signing or verification
|
||||
- No malware scanning
|
||||
- Trust based on namespace ownership
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Skill signing with author keys
|
||||
- [ ] Registry-side malware scanning
|
||||
- [ ] Content policy enforcement
|
||||
- [ ] Reputation system for skill authors
|
||||
|
||||
### 3. Namespace Squatting
|
||||
|
||||
**Risk**: Malicious actors could register skill names that impersonate official tools.
|
||||
|
||||
**Current behavior**:
|
||||
- First-come-first-served namespace registration
|
||||
- No verification of skill names
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Reserved namespace list (official tools, common names)
|
||||
- [ ] Trademark/name verification for popular skills
|
||||
- [ ] Clear namespacing conventions
|
||||
|
||||
### 4. Supply Chain Attacks
|
||||
|
||||
**Risk**: Compromised skills could inject malicious code into agents.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills pulled without integrity verification beyond digest
|
||||
- No dependency tracking
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] SBOM (Software Bill of Materials) for skills
|
||||
- [ ] Dependency vulnerability scanning
|
||||
- [ ] Pinned versions in Agentfiles
|
||||
- [ ] Audit logging of skill usage
|
||||
|
||||
### 5. Data Exfiltration
|
||||
|
||||
**Risk**: Skills could exfiltrate sensitive data from conversations or the host.
|
||||
|
||||
**Current behavior**:
|
||||
- Skills have access to conversation context
|
||||
- Scripts can make network requests
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Network egress controls
|
||||
- [ ] Sensitive data detection/masking
|
||||
- [ ] Audit logging of script network activity
|
||||
- [ ] User consent for data access
|
||||
|
||||
### 6. Privilege Escalation
|
||||
|
||||
**Risk**: Skills could escalate privileges through script execution.
|
||||
|
||||
**Current behavior**:
|
||||
- Scripts inherit Ollama process privileges
|
||||
- No capability restrictions
|
||||
|
||||
**Mitigations needed**:
|
||||
- [ ] Run scripts as unprivileged user
|
||||
- [ ] Drop all capabilities
|
||||
- [ ] Mandatory access controls (SELinux/AppArmor)
|
||||
|
||||
## Recommended Security Model
|
||||
|
||||
### Skill Trust Levels
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Level 0: Untrusted (default) │
|
||||
│ - No script execution │
|
||||
│ - Instructions only │
|
||||
│ - Safe for any skill │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 1: Sandboxed │
|
||||
│ - Scripts run in isolated container │
|
||||
│ - No network access │
|
||||
│ - Read-only filesystem │
|
||||
│ - Resource limits enforced │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 2: Trusted │
|
||||
│ - Scripts run with network access │
|
||||
│ - Can write to designated directories │
|
||||
│ - Requires explicit user approval │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Level 3: Privileged (admin only) │
|
||||
│ - Full host access │
|
||||
│ - System administration skills │
|
||||
│ - Requires admin approval │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Skill Manifest Security Fields (Future)
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: A skill description
|
||||
security:
|
||||
trust_level: sandboxed
|
||||
permissions:
|
||||
- network:read # Can make HTTP GET requests
|
||||
- filesystem:read:/data # Can read from /data
|
||||
resource_limits:
|
||||
max_memory: 256MB
|
||||
max_cpu_time: 30s
|
||||
max_disk: 100MB
|
||||
signature: sha256:abc... # Author signature
|
||||
---
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Future Considerations
|
||||
|
||||
## Feature Roadmap
|
||||
|
||||
### Phase 1: Foundation (Current)
|
||||
- [x] Skill bundling with agents
|
||||
- [x] Local skill development
|
||||
- [x] Basic CLI commands (push, pull, list, rm, show)
|
||||
- [x] Registry blob storage
|
||||
- [ ] Registry namespace configuration
|
||||
|
||||
### Phase 2: Security
|
||||
- [ ] Script sandboxing
|
||||
- [ ] Permission model
|
||||
- [ ] Skill signing
|
||||
- [ ] Audit logging
|
||||
|
||||
### Phase 3: Discovery
|
||||
- [ ] Skill search on ollama.com
|
||||
- [ ] Skill ratings and reviews
|
||||
- [ ] Usage analytics
|
||||
- [ ] Featured/trending skills
|
||||
|
||||
### Phase 4: Advanced Features
|
||||
- [ ] Skill dependencies
|
||||
- [ ] Skill versioning constraints
|
||||
- [ ] Skill composition (skills using skills)
|
||||
- [ ] Skill testing framework
|
||||
|
||||
## Open Questions
|
||||
|
||||
### 1. Skill Execution Model
|
||||
|
||||
**Question**: How should skills execute scripts?
|
||||
|
||||
Options:
|
||||
- **A) In-process**: Fast but unsafe
|
||||
- **B) Subprocess**: Current approach, moderate isolation
|
||||
- **C) Container**: Good isolation, requires container runtime
|
||||
- **D) WASM**: Portable and safe, limited capabilities
|
||||
- **E) Remote execution**: Offload to secure service
|
||||
|
||||
### 2. Skill Versioning
|
||||
|
||||
**Question**: How strict should version pinning be?
|
||||
|
||||
Options:
|
||||
- **A) Always latest**: Simple but risky
|
||||
- **B) Semantic versioning**: `^1.0.0` allows minor updates
|
||||
- **C) Exact pinning**: `=1.0.0` requires explicit updates
|
||||
- **D) Digest pinning**: `@sha256:abc` immutable reference
|
||||
|
||||
### 3. Skill Permissions
|
||||
|
||||
**Question**: How should users grant permissions to skills?
|
||||
|
||||
Options:
|
||||
- **A) All or nothing**: Accept all permissions or don't use
|
||||
- **B) Granular consent**: Approve each permission individually
|
||||
- **C) Trust levels**: Pre-defined permission bundles
|
||||
- **D) Runtime prompts**: Ask when permission is first used
|
||||
|
||||
### 4. Skill Discovery
|
||||
|
||||
**Question**: How should users find skills?
|
||||
|
||||
Options:
|
||||
- **A) Central registry only**: ollama.com/skills
|
||||
- **B) Federated registries**: Multiple skill sources
|
||||
- **C) Git repositories**: Pull from GitHub, etc.
|
||||
- **D) All of the above**: Multiple discovery mechanisms
|
||||
|
||||
### 5. Skill Monetization
|
||||
|
||||
**Question**: Should skill authors be able to monetize?
|
||||
|
||||
Options:
|
||||
- **A) Free only**: All skills are free and open
|
||||
- **B) Paid skills**: Authors can charge for skills
|
||||
- **C) Freemium**: Free tier with paid features
|
||||
- **D) Donations**: Voluntary support for authors
|
||||
|
||||
### 6. Skill Updates
|
||||
|
||||
**Question**: How should skill updates be handled?
|
||||
|
||||
Options:
|
||||
- **A) Manual**: User explicitly updates
|
||||
- **B) Auto-update**: Always use latest
|
||||
- **C) Notify**: Alert user to available updates
|
||||
- **D) Policy-based**: Organization controls update policy
|
||||
|
||||
## API Considerations
|
||||
|
||||
### Skill Metadata API
|
||||
|
||||
```
|
||||
GET /api/skills
|
||||
GET /api/skills/:namespace/:name
|
||||
GET /api/skills/:namespace/:name/versions
|
||||
GET /api/skills/:namespace/:name/readme
|
||||
```
|
||||
|
||||
### Skill Execution API
|
||||
|
||||
```
|
||||
POST /api/skills/:namespace/:name/execute
|
||||
{
|
||||
"command": "python3 scripts/run.py",
|
||||
"args": ["--input", "data"],
|
||||
"timeout": 30
|
||||
}
|
||||
```
|
||||
|
||||
### Skill Permissions API
|
||||
|
||||
```
|
||||
GET /api/skills/:namespace/:name/permissions
|
||||
POST /api/skills/:namespace/:name/permissions/grant
|
||||
DELETE /api/skills/:namespace/:name/permissions/revoke
|
||||
```
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
### Skill Testing Framework
|
||||
|
||||
```bash
|
||||
# Run skill tests
|
||||
ollama skill test ./my-skill
|
||||
|
||||
# Test with specific model
|
||||
ollama skill test ./my-skill --model llama3.2:3b
|
||||
|
||||
# Generate test report
|
||||
ollama skill test ./my-skill --report
|
||||
```
|
||||
|
||||
### Test File Format
|
||||
|
||||
```yaml
|
||||
# my-skill/tests/test.yaml
|
||||
tests:
|
||||
- name: "basic calculation"
|
||||
input: "What is 2 + 2?"
|
||||
expect:
|
||||
contains: "4"
|
||||
tool_called: "run_skill_script"
|
||||
|
||||
- name: "complex expression"
|
||||
input: "Calculate 15% of 200"
|
||||
expect:
|
||||
contains: "30"
|
||||
```
|
||||
|
||||
## Compatibility Considerations
|
||||
|
||||
### Minimum Ollama Version
|
||||
|
||||
Skills should declare minimum Ollama version:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
requires:
|
||||
ollama: ">=0.4.0"
|
||||
---
|
||||
```
|
||||
|
||||
### Model Compatibility
|
||||
|
||||
Skills may require specific model capabilities:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: vision-skill
|
||||
requires:
|
||||
capabilities:
|
||||
- vision
|
||||
- tools
|
||||
---
|
||||
```
|
||||
|
||||
## Migration Path
|
||||
|
||||
### From Local to Registry
|
||||
|
||||
```bash
|
||||
# Develop locally
|
||||
SKILL ./my-skill
|
||||
|
||||
# Push when ready
|
||||
ollama skill push myname/my-skill:1.0.0 ./my-skill
|
||||
|
||||
# Update Agentfile
|
||||
SKILL skill/myname/my-skill:1.0.0
|
||||
```
|
||||
|
||||
### Version Upgrades
|
||||
|
||||
```bash
|
||||
# Check for updates
|
||||
ollama skill outdated
|
||||
|
||||
# Update specific skill
|
||||
ollama skill update calculator:1.0.0
|
||||
|
||||
# Update all skills
|
||||
ollama skill update --all
|
||||
```
|
||||
46
docs/tools/extract-examples/README.md
Normal file
46
docs/tools/extract-examples/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# extract-examples
|
||||
|
||||
Extracts code examples from MDX files to a temp directory so you can run them.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go <mdx-file>
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
- 01_basic.py
|
||||
- 01_basic.js
|
||||
- 01_basic.sh
|
||||
- 02_responses.py
|
||||
- 02_responses.js
|
||||
- 02_responses.sh
|
||||
- 03_vision.py
|
||||
- 03_vision.js
|
||||
- 03_vision.sh
|
||||
|
||||
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
To run examples:
|
||||
|
||||
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
npm install # for JS examples
|
||||
|
||||
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
|
||||
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
|
||||
- Writes all extracted files to a temp directory
|
||||
137
docs/tools/extract-examples/main.go
Normal file
137
docs/tools/extract-examples/main.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
mdxFile := os.Args[1]
|
||||
|
||||
f, err := os.Open(mdxFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create temp directory
|
||||
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
|
||||
|
||||
// Patterns
|
||||
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
|
||||
codeGroupStart := regexp.MustCompile("^<CodeGroup")
|
||||
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
inCodeBlock := false
|
||||
inCodeGroup := false
|
||||
var currentFile string
|
||||
var content strings.Builder
|
||||
count := 0
|
||||
codeGroupNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Track CodeGroup boundaries
|
||||
if codeGroupStart.MatchString(line) {
|
||||
inCodeGroup = true
|
||||
codeGroupNum++
|
||||
continue
|
||||
}
|
||||
if codeGroupEnd.MatchString(line) {
|
||||
inCodeGroup = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inCodeBlock {
|
||||
if line == "```" {
|
||||
// End of code block - write file
|
||||
if currentFile != "" {
|
||||
outPath := filepath.Join(tempDir, currentFile)
|
||||
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
|
||||
} else {
|
||||
fmt.Printf(" - %s\n", currentFile)
|
||||
count++
|
||||
}
|
||||
}
|
||||
inCodeBlock = false
|
||||
currentFile = ""
|
||||
content.Reset()
|
||||
} else {
|
||||
content.WriteString(line)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
|
||||
inCodeBlock = true
|
||||
filename := matches[2]
|
||||
// Prefix with CodeGroup number if inside a CodeGroup
|
||||
if inCodeGroup {
|
||||
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
|
||||
} else {
|
||||
currentFile = filename
|
||||
}
|
||||
content.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Write package.json for JavaScript dependencies
|
||||
packageJSON := `{
|
||||
"name": "mdx-examples",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"openai": "^4",
|
||||
"ollama": "^0.5"
|
||||
}
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
|
||||
}
|
||||
|
||||
// Write pyproject.toml for Python dependencies
|
||||
pyprojectTOML := `[project]
|
||||
name = "mdx-examples"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"openai",
|
||||
"ollama",
|
||||
]
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("To run examples:\n")
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
|
||||
}
|
||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
||||
|
||||
### Linux NVIDIA Troubleshooting
|
||||
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||
|
||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||
|
||||
|
||||
3
ducky.Agentfile
Normal file
3
ducky.Agentfile
Normal file
@@ -0,0 +1,3 @@
|
||||
SKILL ./skills/calculator-skill
|
||||
ENTRYPOINT ducky
|
||||
|
||||
@@ -148,6 +148,16 @@ func Remotes() []string {
|
||||
return r
|
||||
}
|
||||
|
||||
// Skills returns the list of skill directories. Skills directories can be configured via the OLLAMA_SKILLS environment variable.
|
||||
// Returns empty slice if not configured.
|
||||
func Skills() []string {
|
||||
raw := strings.TrimSpace(Var("OLLAMA_SKILLS"))
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
|
||||
func BoolWithDefault(k string) func(defaultValue bool) bool {
|
||||
return func(defaultValue bool) bool {
|
||||
if s := Var(k); s != "" {
|
||||
@@ -317,6 +327,9 @@ func AsMap() map[string]EnvVar {
|
||||
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
|
||||
}
|
||||
|
||||
// Skills configuration would go here when added
|
||||
ret["OLLAMA_SKILLS"] = EnvVar{"OLLAMA_SKILLS", Skills(), "Comma-separated list of skill directories"}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
@@ -240,18 +241,20 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"deepseek2",
|
||||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -550,7 +553,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
@@ -791,7 +794,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
@@ -809,6 +812,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
@@ -829,8 +840,11 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
|
||||
@@ -597,6 +597,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
|
||||
var err error
|
||||
switch v := v.(type) {
|
||||
case int32:
|
||||
err = writeGGUF(ws, ggufTypeInt32, v)
|
||||
case int64:
|
||||
err = writeGGUF(ws, ggufTypeInt64, v)
|
||||
case uint32, FileType:
|
||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||
case uint64:
|
||||
@@ -611,6 +615,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||
case *array[int32]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||
case []int64:
|
||||
err = writeGGUFArray(ws, ggufTypeInt64, v)
|
||||
case *array[int64]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
|
||||
case []uint32:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||
case *array[uint32]:
|
||||
|
||||
@@ -42,6 +42,10 @@ func TestWriteGGUF(t *testing.T) {
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"test.key": "value",
|
||||
"test.int32_key": int32(-42),
|
||||
"test.int64_key": int64(-9223372036854775808),
|
||||
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
|
||||
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
|
||||
"attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
@@ -55,7 +59,7 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
ff, err := Decode(r, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -65,15 +69,19 @@ func TestWriteGGUF(t *testing.T) {
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
"test.key": "value",
|
||||
"test.int32_key": int32(-42),
|
||||
"test.int64_key": int64(-9223372036854775808),
|
||||
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
|
||||
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
|
||||
"test.attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ff.KV()); diff != "" {
|
||||
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 800,
|
||||
Offset: 992,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
|
||||
17
go.mod
17
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.36.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -29,7 +29,8 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -76,11 +77,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
30
go.sum
30
go.sum
@@ -224,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -255,6 +255,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -267,8 +269,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -278,8 +280,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -295,17 +297,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -319,8 +321,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -4,7 +4,9 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 6 {
|
||||
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 8 {
|
||||
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 12 {
|
||||
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 16 {
|
||||
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*api.EmbedResponse, error)
|
||||
check func(*testing.T, *api.EmbedResponse, error)
|
||||
}{
|
||||
{
|
||||
name: "target truncation",
|
||||
@@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "the input length exceeds the context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
name: "input after truncate error with context length of 1",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
@@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
@@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||
Options: map[string]any{"num_ctx": 16},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
resp, err := embedTestHelper(ctx, client, t, req.request)
|
||||
req.check(t, resp, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -409,3 +414,230 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
||||
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
func TestEmbedTruncation(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
// Give each model its own budget to account for first-time pulls/loads
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
t.Run("truncation batch", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 30},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 90 {
|
||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"test", "test", "test"},
|
||||
}
|
||||
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCount := baseRes.PromptEvalCount * 3
|
||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||
func TestEmbedLargeInput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Test with progressively larger inputs
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputWords int
|
||||
}{
|
||||
{"medium_input_256_words", 256},
|
||||
{"large_input_512_words", 512},
|
||||
{"very_large_input_800_words", 800},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
words := make([]string, tc.inputWords)
|
||||
for i := range words {
|
||||
words[i] = "word"
|
||||
}
|
||||
input := strings.Join(words, " ")
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: input,
|
||||
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) == 0 {
|
||||
t.Fatal("expected non-empty embedding")
|
||||
}
|
||||
|
||||
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
// where api.StatusError errors should maintain their original status code.
|
||||
func TestEmbedStatusCode(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Pull the model if needed
|
||||
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: longInput,
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||
// not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message is meaningful
|
||||
if !strings.Contains(err.Error(), "context length") {
|
||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{
|
||||
"short input",
|
||||
strings.Repeat("very long input ", 100),
|
||||
"another short input",
|
||||
},
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,9 @@ func TestVisionModels(t *testing.T) {
|
||||
// Qwen 3 VL mixture of experts
|
||||
model: "qwen3-vl:30b",
|
||||
},
|
||||
{
|
||||
model: "ministral-3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range testCases {
|
||||
|
||||
@@ -30,6 +30,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
"mistral": 6,
|
||||
"qwen2.5": 6,
|
||||
"qwen2": 6,
|
||||
"ministral-3": 20,
|
||||
"mistral-nemo": 9,
|
||||
"mistral-small": 16,
|
||||
"mixtral:8x22b": 80,
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
@@ -167,6 +168,7 @@ var (
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"ministral-3",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
@@ -270,6 +272,7 @@ var (
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"ministral-3",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
|
||||
@@ -140,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
@@ -364,15 +360,12 @@ func roundUp(length, pad int) int {
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
mask := make([]float32, c.curBatchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
enabled := !slices.Contains(c.opts.Except, i)
|
||||
@@ -386,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
|
||||
char const *LLAMA_COMMIT = "ec98e2002";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
@@ -17,11 +17,17 @@ include /tools/mtmd/clip.cpp
|
||||
include /tools/mtmd/mtmd.cpp
|
||||
include /tools/mtmd/mtmd-audio.cpp
|
||||
include /tools/mtmd/mtmd-helper.cpp
|
||||
include /tools/mtmd/models/
|
||||
include /tools/mtmd/models/*.h
|
||||
include /tools/mtmd/models/*.cpp
|
||||
include /src/
|
||||
include /src/llama.*
|
||||
include /src/llama-*.*
|
||||
include /src/unicode-data.*
|
||||
include /src/unicode.*
|
||||
include /src/models/
|
||||
include /src/models/*.h
|
||||
include /src/models/*.cpp
|
||||
include /vendor/
|
||||
include /vendor/miniaudio/
|
||||
include /vendor/miniaudio/*.h
|
||||
|
||||
359
llama/llama.cpp/common/common.cpp
vendored
359
llama/llama.cpp/common/common.cpp
vendored
@@ -8,6 +8,7 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@@ -26,7 +27,6 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@@ -60,6 +60,14 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
common_time_meas::~common_time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
@@ -355,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
@@ -690,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
|
||||
// Validate if a filename is safe to use
|
||||
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||
bool fs_validate_filename(const std::string & filename) {
|
||||
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
if (!filename.length()) {
|
||||
// Empty filename invalid
|
||||
return false;
|
||||
@@ -750,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||
return false;
|
||||
}
|
||||
if (!allow_subdirs && (c == '/' || c == '\\')) {
|
||||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
@@ -778,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
#include <iostream>
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
@@ -855,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
bool fs_is_directory(const std::string & path) {
|
||||
std::filesystem::path dir(path);
|
||||
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
|
||||
}
|
||||
|
||||
std::string fs_get_cache_directory() {
|
||||
std::string cache_directory = "";
|
||||
auto ensure_trailing_slash = [](std::string p) {
|
||||
@@ -889,6 +920,8 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
GGML_ABORT("not implemented on this platform");
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
@@ -908,34 +941,258 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
|
||||
std::vector<common_file_info> files;
|
||||
if (path.empty()) return files;
|
||||
|
||||
std::filesystem::path dir(path);
|
||||
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||
return files;
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||
try {
|
||||
// Only include regular files (skip directories)
|
||||
const auto & p = entry.path();
|
||||
if (std::filesystem::is_regular_file(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.is_dir = false;
|
||||
try {
|
||||
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
info.size = 0;
|
||||
}
|
||||
files.push_back(std::move(info));
|
||||
} else if (include_directories && std::filesystem::is_directory(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.size = 0; // Directories have no size
|
||||
info.is_dir = true;
|
||||
files.push_back(std::move(info));
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
// skip entries we cannot inspect
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
bool tty_can_use_colors() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
// TODO: move to common/sampling
|
||||
static void common_init_sampler_from_model(
|
||||
const llama_model * model,
|
||||
common_params_sampling & sparams) {
|
||||
|
||||
const uint64_t config = sparams.user_sampling_config;
|
||||
|
||||
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[64] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
int32_t v = strtol(buf, &end, 10);
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[128] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
float v = strtof(buf, &end);
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Sampling sequence
|
||||
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||
char buf[512] = {0};
|
||||
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||
if (!sampler_names.empty()) {
|
||||
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||
}
|
||||
|
||||
struct common_init_result::impl {
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
pimpl(new impl{}) {
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
return;
|
||||
}
|
||||
|
||||
pimpl->model.reset(model);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
// updates params.sampling
|
||||
// TODO: fix naming
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
//if (params.sampling.penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
pimpl->context.reset(lctx);
|
||||
}
|
||||
|
||||
llama_model * common_init_result::model() {
|
||||
return pimpl->model.get();
|
||||
}
|
||||
|
||||
llama_context * common_init_result::context() {
|
||||
return pimpl->context.get();
|
||||
}
|
||||
|
||||
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
@@ -947,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||
if (cvec.n_embd == -1) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
@@ -961,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
params.control_vector_layer_start,
|
||||
params.control_vector_layer_end);
|
||||
if (err) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -988,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1001,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
@@ -1012,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
if (!params.lora_init_without_apply) {
|
||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
@@ -1087,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
llama_set_warmup(lctx, false);
|
||||
}
|
||||
|
||||
iparams.model.reset(model);
|
||||
iparams.context.reset(lctx);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
common_init_result::~common_init_result() = default;
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
@@ -1101,7 +1316,9 @@ std::string get_model_endpoint() {
|
||||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
if (model_endpoint.back() != '/') {
|
||||
model_endpoint += '/';
|
||||
}
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
|
||||
125
llama/llama.cpp/common/common.h
vendored
125
llama/llama.cpp/common/common.h
vendored
@@ -2,17 +2,19 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
@@ -28,7 +30,14 @@
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||
} while(0)
|
||||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
@@ -73,7 +82,8 @@ int32_t cpu_get_num_math();
|
||||
enum llama_example {
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_MAIN,
|
||||
LLAMA_EXAMPLE_COMPLETION,
|
||||
LLAMA_EXAMPLE_CLI,
|
||||
LLAMA_EXAMPLE_EMBEDDING,
|
||||
LLAMA_EXAMPLE_PERPLEXITY,
|
||||
LLAMA_EXAMPLE_RETRIEVAL,
|
||||
@@ -89,6 +99,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
LLAMA_EXAMPLE_FINETUNE,
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -133,6 +144,22 @@ struct common_grammar_trigger {
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
};
|
||||
|
||||
enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
@@ -165,8 +192,9 @@ struct common_params_sampling {
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
@@ -188,6 +216,10 @@ struct common_params_sampling {
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
@@ -198,6 +230,7 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
@@ -274,8 +307,8 @@ struct lr_opt {
|
||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
@@ -296,9 +329,12 @@ struct common_params {
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
@@ -344,7 +380,7 @@ struct common_params {
|
||||
|
||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||
|
||||
int32_t verbosity = 0;
|
||||
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
@@ -378,6 +414,7 @@ struct common_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool show_timings = true; // show timing information on CLI
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
@@ -406,6 +443,8 @@ struct common_params {
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
|
||||
// finetune
|
||||
struct lr_opt lr;
|
||||
@@ -432,7 +471,7 @@ struct common_params {
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool use_jinja = true; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
@@ -451,14 +490,22 @@ struct common_params {
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
|
||||
bool log_json = false;
|
||||
|
||||
std::string slot_save_path;
|
||||
std::string media_path; // path to directory for loading media files
|
||||
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
// batched-bench params
|
||||
bool is_pp_shared = false;
|
||||
bool is_pp_shared = false;
|
||||
bool is_tg_separate = false;
|
||||
|
||||
std::vector<int32_t> n_pp;
|
||||
std::vector<int32_t> n_tg;
|
||||
@@ -505,6 +552,10 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -599,25 +650,55 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
bool fs_validate_filename(const std::string & filename);
|
||||
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||
bool fs_create_directory_with_parents(const std::string & path);
|
||||
bool fs_is_directory(const std::string & path);
|
||||
|
||||
std::string fs_get_cache_directory();
|
||||
std::string fs_get_cache_file(const std::string & filename);
|
||||
|
||||
struct common_file_info {
|
||||
std::string path;
|
||||
std::string name;
|
||||
size_t size = 0; // in bytes
|
||||
bool is_dir = false;
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
// Auto-detect if colors can be enabled based on terminal and environment
|
||||
bool tty_can_use_colors();
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
// note: defines object's lifetime
|
||||
struct common_init_result {
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
struct common_sampler;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
~common_init_result();
|
||||
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params);
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
||||
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
|
||||
}
|
||||
|
||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||
};
|
||||
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
@@ -303,8 +303,11 @@ static std::string format_literal(const std::string & literal) {
|
||||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
class SchemaConverter {
|
||||
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||
|
||||
class common_schema_converter {
|
||||
private:
|
||||
friend class common_schema_info;
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
@@ -601,7 +604,10 @@ private:
|
||||
}
|
||||
|
||||
std::string _resolve_ref(const std::string & ref) {
|
||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
||||
auto it = ref.find('#');
|
||||
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||
_refs_being_resolved.insert(ref);
|
||||
json resolved = _refs[ref];
|
||||
@@ -724,7 +730,7 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
SchemaConverter(
|
||||
common_schema_converter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
@@ -774,11 +780,24 @@ public:
|
||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
if (target.is_null() || !target.contains(sel)) {
|
||||
if (target.is_object() && target.contains(sel)) {
|
||||
target = target[sel];
|
||||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
if (sel_index >= target.size()) {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel_index];
|
||||
} else {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
_refs[ref] = target;
|
||||
}
|
||||
@@ -956,7 +975,7 @@ public:
|
||||
|
||||
void check_errors() {
|
||||
if (!_errors.empty()) {
|
||||
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||
}
|
||||
if (!_warnings.empty()) {
|
||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||
@@ -972,6 +991,134 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// common_schema_info implementation (pimpl)
|
||||
|
||||
common_schema_info::common_schema_info()
|
||||
: impl_(std::make_unique<common_schema_converter>(
|
||||
[](const std::string &) { return json(); },
|
||||
false)) {}
|
||||
|
||||
common_schema_info::~common_schema_info() = default;
|
||||
|
||||
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
|
||||
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
|
||||
|
||||
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
|
||||
impl_->resolve_refs(schema, "");
|
||||
}
|
||||
|
||||
// Determines if a JSON schema can resolve to a string type through any path.
|
||||
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
|
||||
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
|
||||
// true, allowing callers to handle the value as a raw string for simplicity.
|
||||
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
|
||||
std::unordered_set<std::string> visited_refs;
|
||||
|
||||
std::function<bool(const json &)> check = [&](const json & s) -> bool {
|
||||
if (!s.is_object()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle $ref
|
||||
if (s.contains("$ref")) {
|
||||
const std::string & ref = s["$ref"];
|
||||
if (visited_refs.find(ref) != visited_refs.end()) {
|
||||
// Circular reference, assume not a string to be safe
|
||||
return false;
|
||||
}
|
||||
visited_refs.insert(ref);
|
||||
auto it = impl_->_refs.find(ref);
|
||||
if (it != impl_->_refs.end()) {
|
||||
return check(it->second);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check type field
|
||||
if (s.contains("type")) {
|
||||
const json & schema_type = s["type"];
|
||||
if (schema_type.is_string()) {
|
||||
if (schema_type == "string") {
|
||||
return true;
|
||||
}
|
||||
} else if (schema_type.is_array()) {
|
||||
// Type can be an array like ["string", "null"]
|
||||
for (const auto & t : schema_type) {
|
||||
if (t == "string") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check oneOf/anyOf - if any alternative can be a string
|
||||
if (s.contains("oneOf")) {
|
||||
for (const auto & alt : s["oneOf"]) {
|
||||
if (check(alt)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (s.contains("anyOf")) {
|
||||
for (const auto & alt : s["anyOf"]) {
|
||||
if (check(alt)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check allOf - all components must be compatible with string type
|
||||
if (s.contains("allOf")) {
|
||||
bool all_string = true;
|
||||
for (const auto & component : s["allOf"]) {
|
||||
if (!check(component)) {
|
||||
all_string = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_string) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check const - if the constant value is a string
|
||||
if (s.contains("const")) {
|
||||
if (s["const"].is_string()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check enum - if any enum value is a string
|
||||
if (s.contains("enum")) {
|
||||
for (const auto & val : s["enum"]) {
|
||||
if (val.is_string()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String-specific keywords imply string type
|
||||
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check format - many formats imply string
|
||||
if (s.contains("format")) {
|
||||
const std::string & fmt = s["format"];
|
||||
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
|
||||
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
|
||||
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
|
||||
fmt.find("uuid") == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return check(schema);
|
||||
}
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
if (!force_gbnf) {
|
||||
@@ -988,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
|
||||
22
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
22
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
@@ -3,11 +3,31 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
class common_schema_converter;
|
||||
|
||||
// Probes a JSON schema to extract information about its structure and type constraints.
|
||||
class common_schema_info {
|
||||
std::unique_ptr<common_schema_converter> impl_;
|
||||
|
||||
public:
|
||||
common_schema_info();
|
||||
~common_schema_info();
|
||||
|
||||
common_schema_info(const common_schema_info &) = delete;
|
||||
common_schema_info & operator=(const common_schema_info &) = delete;
|
||||
common_schema_info(common_schema_info &&) noexcept;
|
||||
common_schema_info & operator=(common_schema_info &&) noexcept;
|
||||
|
||||
void resolve_refs(nlohmann::ordered_json & schema);
|
||||
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
||||
};
|
||||
|
||||
struct common_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||
@@ -18,4 +38,6 @@ struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
};
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal);
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
||||
54
llama/llama.cpp/common/log.cpp
vendored
54
llama/llama.cpp/common/log.cpp
vendored
@@ -1,3 +1,4 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
@@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
// Auto-detect if colors should be enabled based on terminal and environment
|
||||
static bool common_log_should_use_colors_auto() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
@@ -391,7 +368,7 @@ struct common_log * common_log_main() {
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(common_log_should_use_colors_auto());
|
||||
log.set_colors(tty_can_use_colors());
|
||||
});
|
||||
|
||||
return &log;
|
||||
@@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
|
||||
|
||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||
if (colors == LOG_COLORS_AUTO) {
|
||||
log->set_colors(common_log_should_use_colors_auto());
|
||||
log->set_colors(tty_can_use_colors());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -442,3 +419,28 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_flush(struct common_log * log) {
|
||||
log->pause();
|
||||
log->resume();
|
||||
}
|
||||
|
||||
static int common_get_verbosity(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
||||
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
||||
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
||||
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
||||
case GGML_LOG_LEVEL_NONE:
|
||||
default:
|
||||
return LOG_LEVEL_OUTPUT;
|
||||
}
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
auto verbosity = common_get_verbosity(level);
|
||||
if (verbosity <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
34
llama/llama.cpp/common/log.h
vendored
34
llama/llama.cpp/common/log.h
vendored
@@ -21,8 +21,14 @@
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#define LOG_DEFAULT_DEBUG 1
|
||||
#define LOG_DEFAULT_LLAMA 0
|
||||
#define LOG_LEVEL_DEBUG 4
|
||||
#define LOG_LEVEL_INFO 3
|
||||
#define LOG_LEVEL_WARN 2
|
||||
#define LOG_LEVEL_ERROR 1
|
||||
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
||||
|
||||
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
||||
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
||||
|
||||
enum log_colors {
|
||||
LOG_COLORS_AUTO = -1,
|
||||
@@ -36,6 +42,8 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
@@ -65,16 +73,18 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
|
||||
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
||||
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
||||
//
|
||||
// I - info (stdout, V = 0)
|
||||
// W - warning (stderr, V = 0)
|
||||
// E - error (stderr, V = 0)
|
||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
||||
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
||||
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
||||
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
||||
//
|
||||
|
||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||
void common_log_flush (struct common_log * log); // flush all pending log messages
|
||||
|
||||
// helper macros for logging
|
||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||
@@ -93,14 +103,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
|
||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
||||
|
||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||
|
||||
249
llama/llama.cpp/common/sampling.cpp
vendored
249
llama/llama.cpp/common/sampling.cpp
vendored
@@ -3,9 +3,10 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@@ -103,15 +104,22 @@ struct ring_buffer {
|
||||
struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
bool grammar;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
llama_token_data_array cur_p;
|
||||
|
||||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
@@ -128,6 +136,12 @@ struct common_sampler {
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
|
||||
common_time_meas tm() {
|
||||
return common_time_meas(t_total_us, params.no_perf);
|
||||
}
|
||||
|
||||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
@@ -153,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
bool grammar = false;
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
||||
grammar = true;
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
@@ -203,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
if (!grmr) {
|
||||
return nullptr;
|
||||
if (!params.grammar.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
samplers.push_back(
|
||||
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size()));
|
||||
} else {
|
||||
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
||||
}
|
||||
|
||||
grammar = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
if (params.has_logit_bias()) {
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
@@ -239,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
|
||||
for (auto * smpl : samplers) {
|
||||
llama_sampler_chain_add(chain, smpl);
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .chain = */ chain,
|
||||
/* .grammar = */ grammar,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
@@ -298,91 +322,117 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
llama_sampler_accept(gsmpl->chain, token);
|
||||
if (gsmpl->grammar) {
|
||||
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
||||
|
||||
for (int i = 0; i < n_smpl; i++) {
|
||||
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
|
||||
// the grammar sampler is always the first one
|
||||
if (i == 0) {
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(gsmpl->chain, token);
|
||||
}
|
||||
|
||||
gsmpl->prev.push_back(token);
|
||||
}
|
||||
|
||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
llama_sampler_reset(gsmpl->grmr);
|
||||
|
||||
llama_sampler_reset(gsmpl->chain);
|
||||
gsmpl->reset();
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .grammar = */ gsmpl->grammar,
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
};
|
||||
}
|
||||
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||
// TODO: measure grammar performance
|
||||
|
||||
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||
|
||||
llama_perf_sampler_data data_smpl;
|
||||
llama_perf_context_data data_ctx;
|
||||
|
||||
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||
|
||||
if (gsmpl) {
|
||||
llama_perf_sampler_print(gsmpl->chain);
|
||||
auto & data = data_smpl;
|
||||
|
||||
data = llama_perf_sampler(gsmpl->chain);
|
||||
|
||||
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_perf_context_print(ctx);
|
||||
auto & data = data_ctx;
|
||||
|
||||
data = llama_perf_context(ctx);
|
||||
|
||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||
|
||||
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||
|
||||
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||
|
||||
llama_memory_breakdown_print(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||
return gsmpl->chain;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// check if it the sampled token fits the grammar
|
||||
{
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||
|
||||
llama_sampler_apply(grmr, &single_token_data_array);
|
||||
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// resampling:
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
return id;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
@@ -390,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < draft.size(); i++) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
@@ -402,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
}
|
||||
|
||||
if (i == draft.size()) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
@@ -412,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
}
|
||||
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||
}
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
@@ -428,6 +478,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
@@ -461,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||
result += std::string("-> ");
|
||||
result += std::string(llama_sampler_name(smpl)) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
17
llama/llama.cpp/common/sampling.h
vendored
17
llama/llama.cpp/common/sampling.h
vendored
@@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
// arguments can be nullptr to skip printing
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
//
|
||||
// - set logits
|
||||
@@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
||||
// - check if the token fits the grammar (if any)
|
||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||
//
|
||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||
//
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||
|
||||
// generalized version of common_sampler_sample
|
||||
//
|
||||
@@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
//
|
||||
// returns at least 1 token, up to idxs.size()
|
||||
//
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
||||
|
||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
@@ -107,3 +106,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||
const char * grammar_kind, const char * grammar_data);
|
||||
|
||||
struct common_sampler_deleter {
|
||||
void operator()(common_sampler * s) { common_sampler_free(s); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||
|
||||
47
llama/llama.cpp/include/llama.h
vendored
47
llama/llama.cpp/include/llama.h
vendored
@@ -83,6 +83,7 @@ extern "C" {
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
||||
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
|
||||
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
||||
};
|
||||
|
||||
@@ -245,6 +246,21 @@ extern "C" {
|
||||
LLAMA_KV_OVERRIDE_TYPE_STR,
|
||||
};
|
||||
|
||||
enum llama_model_meta_key {
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
|
||||
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
|
||||
};
|
||||
|
||||
struct llama_model_kv_override {
|
||||
enum llama_model_kv_override_type tag;
|
||||
|
||||
@@ -297,6 +313,7 @@ extern "C" {
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
@@ -450,17 +467,35 @@ extern "C" {
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||
// returns true if the parameters could be successfully modified to fit device memory
|
||||
// this function is NOT thread safe because it modifies the global llama logger state
|
||||
LLAMA_API bool llama_params_fit(
|
||||
const char * path_model,
|
||||
struct llama_model_params * mparams,
|
||||
struct llama_context_params * cparams,
|
||||
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
||||
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
||||
size_t margin, // margin of memory to leave per device in bytes
|
||||
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
||||
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
||||
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||
LLAMA_API size_t llama_max_tensor_buft_overrides(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
LLAMA_API bool llama_supports_rpc (void);
|
||||
|
||||
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
@@ -481,6 +516,7 @@ extern "C" {
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
@@ -512,6 +548,9 @@ extern "C" {
|
||||
// Get the number of metadata key/value pairs
|
||||
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
||||
|
||||
// Get sampling metadata key name. Returns nullptr if the key is invalid
|
||||
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
|
||||
|
||||
// Get metadata key name by index
|
||||
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
@@ -584,7 +623,7 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
|
||||
// Get the invocation tokens if the current lora is an alora
|
||||
@@ -1110,8 +1149,6 @@ extern "C" {
|
||||
// // sample from the logits of the last token in the batch
|
||||
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||
//
|
||||
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
||||
// llama_sampler_accept(smpl, id);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
@@ -1332,7 +1369,9 @@ extern "C" {
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
// The logger state is global so these functions are NOT thread safe.
|
||||
LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data);
|
||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
//
|
||||
// Performance utils
|
||||
|
||||
4039
llama/llama.cpp/src/llama-arch.cpp
vendored
4039
llama/llama.cpp/src/llama-arch.cpp
vendored
File diff suppressed because it is too large
Load Diff
42
llama/llama.cpp/src/llama-arch.h
vendored
42
llama/llama.cpp/src/llama-arch.h
vendored
@@ -3,6 +3,7 @@
|
||||
#include "ggml.h" // ggml_op
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
//
|
||||
// gguf constants (sync with gguf.py)
|
||||
@@ -36,6 +37,9 @@ enum llm_arch {
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_QWEN3,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_QWEN3NEXT,
|
||||
LLM_ARCH_QWEN3VL,
|
||||
LLM_ARCH_QWEN3VLMOE,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
@@ -76,6 +80,7 @@ enum llm_arch {
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
LLM_ARCH_NEMOTRON_H_MOE,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_EXAONE4,
|
||||
LLM_ARCH_RWKV6,
|
||||
@@ -93,6 +98,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_AFMOE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
@@ -108,6 +114,11 @@ enum llm_arch {
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_GROVEMOE,
|
||||
LLM_ARCH_APERTUS,
|
||||
LLM_ARCH_MINIMAX_M2,
|
||||
LLM_ARCH_COGVLM,
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -117,6 +128,18 @@ enum llm_kv {
|
||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||
LLM_KV_GENERAL_ALIGNMENT,
|
||||
LLM_KV_GENERAL_FILE_TYPE,
|
||||
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
|
||||
LLM_KV_GENERAL_SAMPLING_TOP_K,
|
||||
LLM_KV_GENERAL_SAMPLING_TOP_P,
|
||||
LLM_KV_GENERAL_SAMPLING_MIN_P,
|
||||
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
|
||||
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
|
||||
LLM_KV_GENERAL_SAMPLING_TEMP,
|
||||
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
|
||||
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
|
||||
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
|
||||
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
|
||||
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
|
||||
LLM_KV_GENERAL_NAME,
|
||||
LLM_KV_GENERAL_AUTHOR,
|
||||
LLM_KV_GENERAL_VERSION,
|
||||
@@ -150,6 +173,7 @@ enum llm_kv {
|
||||
LLM_KV_EXPERTS_PER_GROUP,
|
||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||
LLM_KV_NUM_DEEPSTACK_LAYERS,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
@@ -188,6 +212,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
@@ -294,6 +319,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_DENSE_3_OUT,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||
@@ -308,6 +334,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
@@ -357,11 +384,13 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_DT,
|
||||
LLM_TENSOR_SSM_DT_NORM,
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN
|
||||
LLM_TENSOR_SSM_B_NORM,
|
||||
LLM_TENSOR_SSM_C_NORM,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
@@ -458,6 +487,11 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SHORTCONV_CONV,
|
||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||
LLM_TENSOR_VISEXP_ATTN_QKV,
|
||||
LLM_TENSOR_VISEXP_ATTN_OUT,
|
||||
LLM_TENSOR_VISEXP_FFN_GATE,
|
||||
LLM_TENSOR_VISEXP_FFN_DOWN,
|
||||
LLM_TENSOR_VISEXP_FFN_UP,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
@@ -497,6 +531,10 @@ struct LLM_TN_IMPL {
|
||||
const int bid;
|
||||
const int xid;
|
||||
|
||||
const std::set<llm_tensor> model_tensors;
|
||||
|
||||
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
|
||||
|
||||
std::string str() const;
|
||||
|
||||
operator std::string() const {
|
||||
@@ -518,11 +556,11 @@ struct LLM_TN {
|
||||
llm_arch arch;
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, suffix, bid, xid };
|
||||
return LLM_TN_IMPL(arch, tensor, suffix, bid, xid);
|
||||
}
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, nullptr, bid, xid };
|
||||
return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
110
llama/llama.cpp/src/llama-batch.cpp
vendored
110
llama/llama.cpp/src/llama-batch.cpp
vendored
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
|
||||
/*.n_seq_tokens =*/ (uint32_t) 1,
|
||||
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
||||
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
/*.token =*/ batch.token,
|
||||
/*.embd =*/ batch.embd,
|
||||
/*.pos =*/ batch.pos,
|
||||
@@ -251,46 +252,72 @@ bool llama_batch_allocr::init(
|
||||
// consistency checks
|
||||
//
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
if (n_pos_per_embd > 1) {
|
||||
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (p0 >= 0) {
|
||||
bool ok = true;
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (batch.token) {
|
||||
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" for M-RoPE, it is required that the position satisfies: X < Y\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// embedding inputs can have overlapping positions
|
||||
if (p0 >= 0 && p0 > seq_pos_min(s)) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||
|
||||
if (p0 >= 0) {
|
||||
bool ok = true;
|
||||
|
||||
if (seq_pos_min(s) != p0 + 1) {
|
||||
ok = false;
|
||||
}
|
||||
} else {
|
||||
assert(batch.embd);
|
||||
|
||||
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
||||
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
||||
ok = false;
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (memory) {
|
||||
@@ -389,6 +416,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.n_seqs_unq =*/ n_seqs,
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
|
||||
/*.token =*/ udata->token.data(),
|
||||
/*.embd =*/ nullptr,
|
||||
@@ -655,10 +683,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
|
||||
auto udata = std::make_shared<llama_ubatch::data_t>();
|
||||
|
||||
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
||||
|
||||
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
||||
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
||||
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
|
||||
|
||||
udata->token .resize(n_tokens);
|
||||
udata->embd .resize(n_embd_all);
|
||||
@@ -669,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
||||
udata->output .resize(n_tokens);
|
||||
|
||||
udata->seq_id_data.reserve(n_tokens);
|
||||
|
||||
seq_set_t seq_set_unq;
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
@@ -680,16 +708,23 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
for (int j = 0; j < n_pos_cur; ++j) {
|
||||
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
||||
for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
|
||||
// if we are using M-RoPE
|
||||
// if the current batch is text, we need to broadcast the same position across all RoPE sections
|
||||
// otherwise, the input batch is image embeddings, we copy the positions as-is
|
||||
// if we are not using M-RoPE, there is only one position per token (this loop runs only once)
|
||||
size_t src_off = batch.token ? 0 : j*batch.n_tokens;
|
||||
udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
|
||||
}
|
||||
|
||||
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
||||
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
||||
udata->output[i] = batch.logits[idxs[i]];
|
||||
|
||||
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
||||
seq_set_unq.set(udata->seq_id[i][s]);
|
||||
const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
|
||||
|
||||
udata->seq_id_data.push_back(seq_id);
|
||||
seq_set_unq.set(seq_id);
|
||||
}
|
||||
|
||||
if (udata->output[i]) {
|
||||
@@ -697,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
}
|
||||
}
|
||||
|
||||
llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
udata->seq_id[i] = seq_id_ptr;
|
||||
seq_id_ptr += udata->n_seq_id[i];
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
udata->seq_idx[s] = udata->seq_id_unq.size();
|
||||
@@ -710,6 +751,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
||||
/*.n_pos =*/ n_pos_per_embd,
|
||||
|
||||
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
||||
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
||||
|
||||
19
llama/llama.cpp/src/llama-batch.h
vendored
19
llama/llama.cpp/src/llama-batch.h
vendored
@@ -17,6 +17,16 @@ struct llama_ubatch {
|
||||
return b_equal_seqs != 0;
|
||||
}
|
||||
|
||||
// typical for M-RoPE cases:
|
||||
// 0 - sequantial position of the tokens/embeddings in the sequence
|
||||
// 1 - y position in the image
|
||||
// 2 - x position in the image
|
||||
// 3 - other
|
||||
bool is_pos_2d() const {
|
||||
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
|
||||
return n_pos >= 3;
|
||||
}
|
||||
|
||||
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
||||
// otherwise address sanitizer complains
|
||||
// TODO: whole_seqs for embeddings?
|
||||
@@ -25,6 +35,7 @@ struct llama_ubatch {
|
||||
uint32_t n_seq_tokens; // tokens per sequence set
|
||||
uint32_t n_seqs; // sequence sets in the ubatch
|
||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||
uint32_t n_pos; // number of position inputs for each token/embedding
|
||||
|
||||
// seq_id_unq: unique sequence ids in the ubatch
|
||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||
@@ -33,7 +44,7 @@ struct llama_ubatch {
|
||||
// // size | idx | val
|
||||
llama_token * token; // [n_tokens] | i | id, token
|
||||
float * embd; // [n_embd, n_tokens] | i | embd
|
||||
llama_pos * pos; // [n_tokens] | i | pos
|
||||
llama_pos * pos; // [n_tokens*n_pos] | i | pos
|
||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||
@@ -45,13 +56,15 @@ struct llama_ubatch {
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
|
||||
std::vector<llama_seq_id> seq_id_data;
|
||||
};
|
||||
|
||||
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
||||
// the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
|
||||
std::shared_ptr<data_t> data;
|
||||
};
|
||||
|
||||
|
||||
32
llama/llama.cpp/src/llama-chat.cpp
vendored
32
llama/llama.cpp/src/llama-chat.cpp
vendored
@@ -73,6 +73,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -213,6 +214,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
|
||||
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -813,6 +816,35 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
}else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) {
|
||||
// [unused9]系统:xxx[unused10]
|
||||
// [unused9]用户:xxx[unused10]
|
||||
// [unused9]助手:xxx[unused10]
|
||||
// ...
|
||||
for (size_t i = 0; i < chat.size(); ++i) {
|
||||
const auto & msg = chat[i];
|
||||
const std::string & role = msg->role;
|
||||
const std::string & content = msg->content;
|
||||
|
||||
if (i == 0 && role != "system") {
|
||||
ss << "[unused9]系统:[unused10]";
|
||||
}
|
||||
|
||||
if (role == "system") {
|
||||
ss << "[unused9]系统:" << content << "[unused10]";
|
||||
} else if (role == "user") {
|
||||
ss << "[unused9]用户:" << content << "[unused10]";
|
||||
} else if (role == "assistant") {
|
||||
ss << "[unused9]助手:" << content << "[unused10]";
|
||||
} else if (role == "tool") {
|
||||
ss << "[unused9]工具:" << content << "[unused10]";
|
||||
} else if (role == "function") {
|
||||
ss << "[unused9]方法:" << content << "[unused10]";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "[unused9]助手:";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
1
llama/llama.cpp/src/llama-chat.h
vendored
1
llama/llama.cpp/src/llama-chat.h
vendored
@@ -53,6 +53,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||
LLM_CHAT_TEMPLATE_GROK_2,
|
||||
LLM_CHAT_TEMPLATE_PANGU_EMBED,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
194
llama/llama.cpp/src/llama-context.cpp
vendored
194
llama/llama.cpp/src/llama-context.cpp
vendored
@@ -1,5 +1,6 @@
|
||||
#include "llama-context.h"
|
||||
|
||||
#include "llama-arch.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-io.h"
|
||||
@@ -8,6 +9,7 @@
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
@@ -21,6 +23,8 @@ llama_context::llama_context(
|
||||
llama_context_params params) :
|
||||
model(model),
|
||||
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
||||
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
||||
// may need to be backend-dependent
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||
|
||||
t_start_us = model.t_start_us;
|
||||
@@ -69,6 +73,43 @@ llama_context::llama_context(
|
||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
if (cparams.yarn_ext_factor != 0) {
|
||||
static auto get_mscale = [](float scale, float mscale) {
|
||||
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||
};
|
||||
|
||||
const float factor = 1.0f / cparams.rope_freq_scale;
|
||||
|
||||
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
||||
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||
// note: here we assume `mscale == 1.0f`
|
||||
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
||||
float mscale = 1.0f;
|
||||
const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||
|
||||
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||
// special-case DEEPSEEK v2:
|
||||
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
||||
mscale = mscale_all_dims;
|
||||
}
|
||||
|
||||
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||
|
||||
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
||||
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
|
||||
} else {
|
||||
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
|
||||
}
|
||||
|
||||
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
||||
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
||||
// https://github.com/ggml-org/llama.cpp/pull/17945
|
||||
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
|
||||
}
|
||||
|
||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
@@ -90,14 +131,6 @@ llama_context::llama_context(
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
@@ -112,11 +145,28 @@ llama_context::llama_context(
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
||||
|
||||
if (cparams.kv_unified) {
|
||||
cparams.n_ctx_seq = cparams.n_ctx;
|
||||
} else {
|
||||
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
|
||||
|
||||
if (cparams.n_ctx_seq == 0) {
|
||||
throw std::runtime_error("n_ctx_seq == 0");
|
||||
}
|
||||
|
||||
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
||||
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
||||
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
@@ -125,14 +175,14 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
@@ -208,6 +258,7 @@ llama_context::llama_context(
|
||||
|
||||
backend_buft.clear();
|
||||
backend_ptrs.clear();
|
||||
backend_buf_exp_size.clear();
|
||||
|
||||
for (auto & backend : backends) {
|
||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||
@@ -224,11 +275,15 @@ llama_context::llama_context(
|
||||
|
||||
backend_buft.push_back(buft);
|
||||
backend_ptrs.push_back(backend.get());
|
||||
backend_buf_exp_size.push_back(0);
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
||||
|
||||
const size_t max_nodes = this->graph_max_nodes();
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
||||
|
||||
@@ -268,9 +323,7 @@ llama_context::llama_context(
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||
}
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
llama_memory_context_ptr mctx;
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||
@@ -282,9 +335,6 @@ llama_context::llama_context(
|
||||
|
||||
cross.v_embd.clear();
|
||||
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
n_outputs = n_seqs;
|
||||
|
||||
@@ -341,9 +391,17 @@ llama_context::llama_context(
|
||||
|
||||
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
||||
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
||||
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
||||
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
}
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
}
|
||||
|
||||
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||
@@ -352,7 +410,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg (token generation) graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -367,7 +425,7 @@ llama_context::llama_context(
|
||||
//
|
||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||
//
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -376,11 +434,13 @@ llama_context::llama_context(
|
||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
if (size > 1) {
|
||||
if (!model.hparams.no_alloc) {
|
||||
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
}
|
||||
if (backend_buf_exp_size[i] > 1) {
|
||||
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buft_name(buft),
|
||||
size / 1024.0 / 1024.0);
|
||||
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,6 +459,23 @@ llama_context::llama_context(
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
// FIXME this currently results in a use-after-free bug if the model is freed before the context
|
||||
// if (!model.hparams.no_alloc) {
|
||||
// for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
// ggml_backend_t backend = backend_ptrs[i];
|
||||
// ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
|
||||
// const size_t size_exp = backend_buf_exp_size[i];
|
||||
// const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
// if (size_exp == size_act) {
|
||||
// LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
||||
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
||||
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
ggml_opt_free(opt_ctx);
|
||||
}
|
||||
|
||||
@@ -448,8 +525,8 @@ uint32_t llama_context::n_ctx() const {
|
||||
return cparams.n_ctx;
|
||||
}
|
||||
|
||||
uint32_t llama_context::n_ctx_per_seq() const {
|
||||
return cparams.n_ctx / cparams.n_seq_max;
|
||||
uint32_t llama_context::n_ctx_seq() const {
|
||||
return cparams.n_ctx_seq;
|
||||
}
|
||||
|
||||
uint32_t llama_context::n_batch() const {
|
||||
@@ -518,7 +595,7 @@ bool llama_context::memory_update(bool optimize) {
|
||||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
@@ -803,7 +880,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
@@ -972,7 +1049,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
|
||||
const bool output_all = false;
|
||||
|
||||
@@ -1223,7 +1300,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// note: this is mostly relevant for recurrent models atm
|
||||
if (!sorted_output) {
|
||||
if (!sorted_output && n_outputs > 1) {
|
||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||
|
||||
// TODO: is there something more efficient which also minimizes swaps?
|
||||
@@ -1300,6 +1377,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
||||
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
#endif
|
||||
synchronize();
|
||||
buf_output = nullptr;
|
||||
logits = nullptr;
|
||||
embd = nullptr;
|
||||
@@ -1360,7 +1438,10 @@ void llama_context::output_reorder() {
|
||||
// graph
|
||||
//
|
||||
|
||||
uint32_t llama_context::graph_max_nodes() const {
|
||||
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
||||
}
|
||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
}
|
||||
|
||||
@@ -1368,7 +1449,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
|
||||
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
||||
ggml_cgraph * llama_context::graph_reserve(
|
||||
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
GGML_ASSERT(n_outputs >= 1);
|
||||
|
||||
@@ -1405,8 +1487,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
|
||||
// initialize scheduler with the specified graph
|
||||
if (split_only) {
|
||||
ggml_backend_sched_split_graph(sched.get(), gf);
|
||||
if (sizes) {
|
||||
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
|
||||
} else {
|
||||
ggml_backend_sched_split_graph(sched.get(), gf);
|
||||
}
|
||||
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
GGML_ASSERT(!sizes);
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -2028,15 +2115,26 @@ void llama_context::perf_reset() {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
||||
for (const auto & buft_size : model.memory_breakdown()) {
|
||||
ret[buft_size.first].model += buft_size.second;
|
||||
for (const auto & [buft, size] : model.memory_breakdown()) {
|
||||
ret[buft].model += size;
|
||||
}
|
||||
for (const auto & buft_size : memory->memory_breakdown()) {
|
||||
ret[buft_size.first].context += buft_size.second;
|
||||
if (memory) {
|
||||
for (const auto & [buft, size] : memory->memory_breakdown()) {
|
||||
ret[buft].context += size;
|
||||
}
|
||||
}
|
||||
for (const auto & backend_ptr : backends) {
|
||||
ggml_backend_t backend = backend_ptr.get();
|
||||
ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
if (model.hparams.no_alloc) {
|
||||
for (size_t i = 0; i < backends.size(); ++i) {
|
||||
ggml_backend_t backend = backends[i].get();
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||
ret[buft].compute += backend_buf_exp_size[i];
|
||||
}
|
||||
} else {
|
||||
for (const auto & backend_ptr : backends) {
|
||||
ggml_backend_t backend = backend_ptr.get();
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -2129,7 +2227,7 @@ void llama_context::opt_epoch_iter(
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
@@ -2377,6 +2475,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
||||
return ctx->n_ctx();
|
||||
}
|
||||
|
||||
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
||||
return ctx->n_ctx_seq();
|
||||
}
|
||||
|
||||
uint32_t llama_n_batch(const llama_context * ctx) {
|
||||
return ctx->n_batch();
|
||||
}
|
||||
|
||||
22
llama/llama.cpp/src/llama-context.h
vendored
22
llama/llama.cpp/src/llama-context.h
vendored
@@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
|
||||
size_t model = 0; // memory allocated for the model
|
||||
size_t context = 0; // memory allocated for the context
|
||||
size_t compute = 0; // memory allocated for temporary compute buffers
|
||||
|
||||
size_t total() const {
|
||||
return model + context + compute;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_context {
|
||||
@@ -43,11 +47,11 @@ struct llama_context {
|
||||
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
uint32_t n_batch() const;
|
||||
uint32_t n_ubatch() const;
|
||||
uint32_t n_seq_max() const;
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_seq() const;
|
||||
uint32_t n_batch() const;
|
||||
uint32_t n_ubatch() const;
|
||||
uint32_t n_seq_max() const;
|
||||
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
@@ -197,7 +201,7 @@ private:
|
||||
//
|
||||
|
||||
public:
|
||||
uint32_t graph_max_nodes() const;
|
||||
uint32_t graph_max_nodes(uint32_t n_tokens) const;
|
||||
|
||||
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
|
||||
llm_graph_result * get_gf_res_reserve() const;
|
||||
@@ -206,7 +210,8 @@ public:
|
||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||
|
||||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
||||
ggml_cgraph * graph_reserve(
|
||||
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
@@ -281,9 +286,10 @@ private:
|
||||
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
// buffer types used for the compute buffer of each backend
|
||||
// pointers and buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
|
||||
|
||||
llm_graph_result_ptr gf_res_prev;
|
||||
llm_graph_result_ptr gf_res_reserve;
|
||||
|
||||
1
llama/llama.cpp/src/llama-cparams.h
vendored
1
llama/llama.cpp/src/llama-cparams.h
vendored
@@ -8,6 +8,7 @@
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_ctx_seq; // context for a single sequence
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
|
||||
291
llama/llama.cpp/src/llama-grammar.cpp
vendored
291
llama/llama.cpp/src/llama-grammar.cpp
vendored
@@ -6,8 +6,10 @@
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_REPETITION_THRESHOLD 2000
|
||||
//
|
||||
// helpers
|
||||
//
|
||||
@@ -179,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
|
||||
const char * pos = src;
|
||||
if (*pos != '<') {
|
||||
throw std::runtime_error(std::string("expecting '<' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// Parse <[id]>
|
||||
if (*pos == '[') {
|
||||
pos++;
|
||||
const char * int_end = parse_int(pos);
|
||||
uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = int_end;
|
||||
if (*pos != ']') {
|
||||
throw std::runtime_error(std::string("expecting ']' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
if (*pos != '>') {
|
||||
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
return std::make_pair(token_id, pos);
|
||||
}
|
||||
|
||||
if (vocab == nullptr) {
|
||||
throw std::runtime_error(std::string("no vocab to parse token at ") + src);
|
||||
}
|
||||
|
||||
// Parse <token> and tokenize to obtain the token id
|
||||
while (*pos != 0 && *pos != '>') {
|
||||
pos++;
|
||||
}
|
||||
if (*pos != '>') {
|
||||
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||
}
|
||||
pos++;
|
||||
|
||||
llama_token tokens[2];
|
||||
int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
|
||||
if (n_tokens != 1) {
|
||||
// must tokenize to exactly 1 token
|
||||
throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
|
||||
}
|
||||
return std::make_pair(tokens[0], pos);
|
||||
}
|
||||
|
||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
@@ -210,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||
case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case LLAMA_GRETYPE_END:
|
||||
@@ -226,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
fprintf(file, "!");
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
@@ -282,6 +343,17 @@ static void print_rule(
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
fprintf(file, ".");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
fprintf(file, "!");
|
||||
fprintf(file, "<[");
|
||||
fprintf(file, "%u", elem.value);
|
||||
fprintf(file, "]> ");
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
@@ -345,8 +417,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
size_t last_sym_start = rule.size();
|
||||
const char * pos = src;
|
||||
|
||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||
|
||||
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||
// (though it's technically the same as -1 now)
|
||||
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
||||
bool no_max = max_times == UINT64_MAX;
|
||||
if (last_sym_start == rule.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||
}
|
||||
@@ -373,20 +447,20 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
rule.resize(last_sym_start);
|
||||
} else {
|
||||
// Repeat the previous elements (min_times - 1) times
|
||||
for (int i = 1; i < min_times; i++) {
|
||||
for (uint64_t i = 1; i < min_times; i++) {
|
||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t last_rec_rule_id = 0;
|
||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||
auto n_opt = no_max ? 1 : max_times - min_times;
|
||||
|
||||
llama_grammar_rule rec_rule(prev_rule);
|
||||
for (int i = 0; i < n_opt; i++) {
|
||||
for (uint64_t i = 0; i < n_opt; i++) {
|
||||
rec_rule.resize(prev_rule.size());
|
||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||
if (i > 0 || max_times < 0) {
|
||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||
if (i > 0 || no_max) {
|
||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
|
||||
}
|
||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
@@ -440,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '<' || *pos == '!') { // token
|
||||
auto type = LLAMA_GRETYPE_TOKEN;
|
||||
if (*pos == '!') { // token inverse
|
||||
type = LLAMA_GRETYPE_TOKEN_NOT;
|
||||
pos++;
|
||||
}
|
||||
auto token_pair = parse_token(vocab, pos);
|
||||
const char * token_end = token_pair.second;
|
||||
last_sym_start = rule.size();
|
||||
rule.push_back({type, token_pair.first});
|
||||
pos = parse_space(token_end, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||
@@ -478,10 +563,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
int max_times = -1;
|
||||
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||
|
||||
if (*pos == '}') {
|
||||
max_times = min_times;
|
||||
@@ -502,6 +587,10 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
} else {
|
||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||
}
|
||||
bool has_max = max_times != UINT64_MAX;
|
||||
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||
}
|
||||
handle_repetitions(min_times, max_times);
|
||||
} else {
|
||||
break;
|
||||
@@ -683,6 +772,21 @@ static bool llama_grammar_match_partial_char(
|
||||
return !is_positive_char;
|
||||
}
|
||||
|
||||
// returns true iff token matches the rule at pos (regular or inverse)
|
||||
// asserts that pos is pointing to a token element
|
||||
static bool llama_grammar_match_token(
|
||||
const llama_grammar_element * pos,
|
||||
const llama_token token) {
|
||||
GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN) {
|
||||
return pos->value == static_cast<uint32_t>(token);
|
||||
}
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
return pos->value != static_cast<uint32_t>(token);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||
// at a character range (terminal element)
|
||||
static void llama_grammar_advance_stack(
|
||||
@@ -730,6 +834,8 @@ static void llama_grammar_advance_stack(
|
||||
case LLAMA_GRETYPE_CHAR:
|
||||
case LLAMA_GRETYPE_CHAR_NOT:
|
||||
case LLAMA_GRETYPE_CHAR_ANY:
|
||||
case LLAMA_GRETYPE_TOKEN:
|
||||
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||
// only add the stack if it's not a duplicate of one we already have
|
||||
new_stacks.emplace_back(stack);
|
||||
@@ -823,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
||||
return grammar->stacks;
|
||||
}
|
||||
|
||||
static void llama_grammar_accept_chr(
|
||||
struct llama_grammar & grammar,
|
||||
const llama_grammar_stack & stack,
|
||||
uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
// ignore if this turns into a token
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto match = llama_grammar_match_char(pos, chr);
|
||||
if (match.first) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||
new_stack.push_back(match.second);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar->stacks.size());
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||
if (match.first) {
|
||||
const llama_grammar_element * pos = match.second;
|
||||
|
||||
// update top of stack to next element, if any
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
new_stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
||||
}
|
||||
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
|
||||
}
|
||||
|
||||
grammar->stacks = std::move(stacks_new);
|
||||
@@ -867,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
||||
const llama_grammar_element * stack_pos = stack.back();
|
||||
|
||||
// if the top of the stack is a token rule, then we only need to check the token id
|
||||
if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
for (const auto & tok : candidates) {
|
||||
if (*tok.code_points == 0) {
|
||||
// reached the end of a token consumed by char rules, reject iff it ended
|
||||
// in a partial response
|
||||
if (tok.partial_utf8.n_remain != 0) {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
} else if (!llama_grammar_match_token(stack_pos, tok.id)) {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
}
|
||||
return rejects;
|
||||
}
|
||||
|
||||
llama_grammar_candidates next_candidates;
|
||||
next_candidates.reserve(candidates.size());
|
||||
|
||||
@@ -879,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
|
||||
} else {
|
||||
rejects.push_back(tok);
|
||||
}
|
||||
@@ -897,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
|
||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||
for (const auto & tok : next_rejects) {
|
||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
|
||||
}
|
||||
|
||||
return rejects;
|
||||
@@ -966,12 +1100,13 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy =*/ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_patterns = */ {},
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_patterns = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -985,7 +1120,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
llama_grammar_parser parser;
|
||||
llama_grammar_parser parser(vocab);
|
||||
|
||||
// if there is a grammar, parse it
|
||||
// rules will be empty (default) if there are parse errors
|
||||
@@ -1073,10 +1208,11 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
ollama_vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
std::move(vec_trigger_tokens),
|
||||
std::move(vec_trigger_patterns),
|
||||
};
|
||||
@@ -1100,6 +1236,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
grammar.lazy,
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_buffer_positions,
|
||||
grammar.trigger_tokens,
|
||||
grammar.trigger_patterns,
|
||||
};
|
||||
@@ -1156,7 +1293,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1176,10 +1313,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
llama_grammar_accept_token(grammar, token, piece);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||
return;
|
||||
} else {
|
||||
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||
grammar.trigger_buffer += piece;
|
||||
|
||||
std::smatch match;
|
||||
@@ -1197,10 +1336,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
|
||||
// replay tokens that overlap with [start, end)
|
||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||
auto [tok_start, tok_end] = tok_pos;
|
||||
if (tok_end <= start) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
|
||||
size_t piece_len = tok_end - piece_start;
|
||||
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
|
||||
llama_grammar_accept_token(grammar, tok, tok_piece);
|
||||
}
|
||||
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
grammar.trigger_buffer_positions.clear();
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||
return;
|
||||
}
|
||||
@@ -1220,7 +1372,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||
}
|
||||
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
llama_grammar_accept_token(grammar, token, piece);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||
@@ -1238,6 +1390,61 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
||||
}
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar.stacks.size());
|
||||
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_grammar_element * pos = stack.back();
|
||||
|
||||
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||
if (llama_grammar_match_token(pos, token)) {
|
||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||
new_stack.push_back(pos + 1);
|
||||
}
|
||||
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||
}
|
||||
} else {
|
||||
llama_grammar_stacks current_stacks = {stack};
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_stacks next_stacks;
|
||||
|
||||
for (const auto & cur_stack : current_stacks) {
|
||||
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
|
||||
}
|
||||
|
||||
current_stacks = std::move(next_stacks);
|
||||
if (current_stacks.empty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & surviving_stack : current_stacks) {
|
||||
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
|
||||
stacks_new.emplace_back(surviving_stack);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
|
||||
if (grammar.stacks.empty()) {
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||
try {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user