mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
3 Commits
implement-
...
hoyyeva/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
def9651b55 | ||
|
|
dd63aac6bd | ||
|
|
78830f6e9d |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -15,12 +15,8 @@ ml/backend/**/*.cu linguist-vendored
|
||||
ml/backend/**/*.cuh linguist-vendored
|
||||
ml/backend/**/*.m linguist-vendored
|
||||
ml/backend/**/*.metal linguist-vendored
|
||||
ml/backend/**/*.comp linguist-vendored
|
||||
ml/backend/**/*.glsl linguist-vendored
|
||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||
|
||||
app/webview linguist-vendored
|
||||
|
||||
llama/build-info.cpp linguist-generated
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||
|
||||
|
||||
18
.github/workflows/release.yaml
vendored
18
.github/workflows/release.yaml
vendored
@@ -16,15 +16,13 @@ jobs:
|
||||
outputs:
|
||||
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
||||
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
||||
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set environment
|
||||
id: goflags
|
||||
run: |
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
|
||||
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-14-xlarge
|
||||
@@ -55,9 +53,6 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- run: |
|
||||
./scripts/build_darwin.sh
|
||||
- name: Log build results
|
||||
@@ -190,7 +185,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
@@ -254,9 +249,6 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- name: Verify gcc is actually clang
|
||||
run: |
|
||||
$ErrorActionPreference='Continue'
|
||||
@@ -310,9 +302,6 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows*
|
||||
@@ -377,7 +366,6 @@ jobs:
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
||||
18
.github/workflows/test.yaml
vendored
18
.github/workflows/test.yaml
vendored
@@ -22,7 +22,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.changes.outputs.changed }}
|
||||
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -38,7 +37,6 @@ jobs:
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -85,7 +83,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: /github/home/.cache/ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||
@@ -180,7 +178,7 @@ jobs:
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
@@ -208,9 +206,6 @@ jobs:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache-dependency-path: |
|
||||
go.sum
|
||||
Makefile.sync
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
@@ -231,9 +226,12 @@ jobs:
|
||||
if: always()
|
||||
run: go test -count=1 -benchtime=1x ./...
|
||||
|
||||
- uses: golangci/golangci-lint-action@v9
|
||||
# TODO(bmizerany): replace this heavy tool with just the
|
||||
# tools/checks/binaries we want and then make them all run in parallel
|
||||
# across jobs, not on a single tiny vm on Github Actions.
|
||||
- uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
only-new-issues: true
|
||||
args: --timeout 10m0s -v
|
||||
|
||||
patches:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -242,4 +240,4 @@ jobs:
|
||||
- name: Verify patches apply cleanly and do not change files
|
||||
run: |
|
||||
make -f Makefile.sync clean checkout apply-patches sync
|
||||
git diff --compact-summary --exit-code
|
||||
git diff --compact-summary --exit-code
|
||||
@@ -1,4 +1,5 @@
|
||||
version: "2"
|
||||
run:
|
||||
timeout: 5m
|
||||
linters:
|
||||
enable:
|
||||
- asasalint
|
||||
@@ -6,46 +7,35 @@ linters:
|
||||
- bodyclose
|
||||
- containedctx
|
||||
- gocheckcompilerdirectives
|
||||
- gofmt
|
||||
- gofumpt
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
- makezero
|
||||
- misspell
|
||||
- nilerr
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
disable:
|
||||
- errcheck
|
||||
- usestdlibvars
|
||||
settings:
|
||||
govet:
|
||||
disable:
|
||||
- unusedresult
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -QF* # disable quick fix suggestions
|
||||
- -SA1019
|
||||
- -ST1000 # package comment format
|
||||
- -ST1003 # underscores in package names
|
||||
- -ST1005 # error strings should not be capitalized
|
||||
- -ST1012 # error var naming (ErrFoo)
|
||||
- -ST1016 # receiver name consistency
|
||||
- -ST1020 # comment on exported function format
|
||||
- -ST1021 # comment on exported type format
|
||||
- -ST1022 # comment on exported var format
|
||||
- -ST1023 # omit type from declaration
|
||||
- errcheck
|
||||
linters-settings:
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -SA1019 # omit Deprecated check
|
||||
severity:
|
||||
default: error
|
||||
default-severity: error
|
||||
rules:
|
||||
- linters:
|
||||
- gofmt
|
||||
- goimports
|
||||
- intrange
|
||||
severity: info
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- gofumpt
|
||||
|
||||
@@ -54,13 +54,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
||||
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
# Define GGML version variables for shared library SOVERSION
|
||||
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 0)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||
|
||||
@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
|
||||
|
||||
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
|
||||
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
|
||||
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
|
||||
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
|
||||
|
||||
### Issues that may not be accepted
|
||||
|
||||
@@ -43,7 +43,7 @@ Tips for proposals:
|
||||
* Explain how the change will be tested.
|
||||
|
||||
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||
see if the changes were accepted.
|
||||
see if the change were accepted.
|
||||
|
||||
## Pull requests
|
||||
|
||||
@@ -66,6 +66,7 @@ Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||
docs: simplify manual installation with shorter curl commands
|
||||
|
||||
Bad Examples:
|
||||
|
||||
|
||||
14
Dockerfile
14
Dockerfile
@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
|
||||
FROM base-${TARGETARCH} AS base
|
||||
ARG CMAKEVERSION
|
||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ENV LDFLAGS=-s
|
||||
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
@@ -57,8 +57,6 @@ ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
@@ -69,8 +67,6 @@ ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
@@ -82,8 +78,6 @@ ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
@@ -93,8 +87,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
@@ -126,8 +118,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS vulkan
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=ec98e2002
|
||||
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
|
||||
$(WORKDIR):
|
||||
git clone $(UPSTREAM) $(WORKDIR)
|
||||
|
||||
.PHONY: format-patches
|
||||
.PHONE: format-patches
|
||||
format-patches: llama/patches
|
||||
git -C $(WORKDIR) format-patch \
|
||||
--no-signature \
|
||||
@@ -66,11 +66,7 @@ format-patches: llama/patches
|
||||
-o $(realpath $<) \
|
||||
$(FETCH_HEAD)
|
||||
|
||||
.PHONY: clean
|
||||
.PHONE: clean
|
||||
clean: checkout
|
||||
@git -C $(WORKDIR) am --abort || true
|
||||
$(RM) llama/patches/.*.patched
|
||||
|
||||
.PHONY: print-base
|
||||
print-base:
|
||||
@echo $(FETCH_HEAD)
|
||||
17
README.md
17
README.md
@@ -299,7 +299,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [AI-UI](https://github.com/bajahaw/ai-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
@@ -366,8 +365,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
||||
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
@@ -399,7 +397,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||
@@ -428,7 +426,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||
@@ -555,7 +552,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
@@ -618,7 +615,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
||||
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||
@@ -626,7 +623,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
@@ -636,12 +633,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
## Security
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
@@ -14,7 +14,7 @@ Please include the following details in your report:
|
||||
|
||||
## Security best practices
|
||||
|
||||
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
|
||||
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
|
||||
|
||||
- Regularly updating to the latest version of Ollama
|
||||
- Securing access to hosted instances of Ollama
|
||||
|
||||
@@ -1,779 +0,0 @@
|
||||
// Package anthropic provides core transformation logic for compatibility with the Anthropic Messages API
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"` // always "error"
|
||||
Error Error `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
etype = "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
etype = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
etype = "permission_error"
|
||||
case http.StatusNotFound:
|
||||
etype = "not_found_error"
|
||||
case http.StatusTooManyRequests:
|
||||
etype = "rate_limit_error"
|
||||
case http.StatusServiceUnavailable, 529:
|
||||
etype = "overloaded_error"
|
||||
default:
|
||||
etype = "api_error"
|
||||
}
|
||||
|
||||
return ErrorResponse{
|
||||
Type: "error",
|
||||
Error: Error{Type: etype, Message: message},
|
||||
RequestID: generateID("req"),
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
|
||||
// MessagesRequest represents an Anthropic Messages API request
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"` // string or []ContentBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
type MessageParam struct {
|
||||
Role string `json:"role"` // "user" or "assistant"
|
||||
Content any `json:"content"` // string or []ContentBlock
|
||||
}
|
||||
|
||||
// ContentBlock represents a content block in a message
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
|
||||
// For text blocks
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"` // "auto", "any", "tool", "none"
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig controls extended thinking
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata for the request
|
||||
type Metadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
|
||||
// MessagesResponse represents an Anthropic Messages API response
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Usage contains token usage information
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// Streaming event types
|
||||
|
||||
// MessageStartEvent is sent at the start of streaming
|
||||
type MessageStartEvent struct {
|
||||
Type string `json:"type"` // "message_start"
|
||||
Message MessagesResponse `json:"message"`
|
||||
}
|
||||
|
||||
// ContentBlockStartEvent signals the start of a content block
|
||||
type ContentBlockStartEvent struct {
|
||||
Type string `json:"type"` // "content_block_start"
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
|
||||
// ContentBlockDeltaEvent contains incremental content updates
|
||||
type ContentBlockDeltaEvent struct {
|
||||
Type string `json:"type"` // "content_block_delta"
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
|
||||
// Delta represents an incremental update
|
||||
type Delta struct {
|
||||
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// ContentBlockStopEvent signals the end of a content block
|
||||
type ContentBlockStopEvent struct {
|
||||
Type string `json:"type"` // "content_block_stop"
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// MessageDeltaEvent contains updates to the message
|
||||
type MessageDeltaEvent struct {
|
||||
Type string `json:"type"` // "message_delta"
|
||||
Delta MessageDelta `json:"delta"`
|
||||
Usage DeltaUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// MessageDelta contains stop information
|
||||
type MessageDelta struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// MessageStopEvent signals the end of the message
|
||||
type MessageStopEvent struct {
|
||||
Type string `json:"type"` // "message_stop"
|
||||
}
|
||||
|
||||
// PingEvent is a keepalive event
|
||||
type PingEvent struct {
|
||||
Type string `json:"type"` // "ping"
|
||||
}
|
||||
|
||||
// StreamErrorEvent is an error during streaming
|
||||
type StreamErrorEvent struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
||||
// Handle system prompt
|
||||
if r.System != nil {
|
||||
switch sys := r.System.(type) {
|
||||
case string:
|
||||
if sys != "" {
|
||||
messages = append(messages, api.Message{Role: "system", Content: sys})
|
||||
}
|
||||
case []any:
|
||||
// System can be an array of content blocks
|
||||
var content strings.Builder
|
||||
for _, block := range sys {
|
||||
if blockMap, ok := block.(map[string]any); ok {
|
||||
if blockMap["type"] == "text" {
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
content.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content.Len() > 0 {
|
||||
messages = append(messages, api.Message{Role: "system", Content: content.String()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for _, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
}
|
||||
|
||||
// Build options
|
||||
options := make(map[string]any)
|
||||
|
||||
options["num_predict"] = r.MaxTokens
|
||||
|
||||
if r.Temperature != nil {
|
||||
options["temperature"] = *r.Temperature
|
||||
}
|
||||
|
||||
if r.TopP != nil {
|
||||
options["top_p"] = *r.TopP
|
||||
}
|
||||
|
||||
if r.TopK != nil {
|
||||
options["top_k"] = *r.TopK
|
||||
}
|
||||
|
||||
if len(r.StopSequences) > 0 {
|
||||
options["stop"] = r.StopSequences
|
||||
}
|
||||
|
||||
// Convert tools
|
||||
var tools api.Tools
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
// Handle thinking
|
||||
var think *api.ThinkValue
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
role := strings.ToLower(msg.Role)
|
||||
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: role, Content: content})
|
||||
|
||||
case []any:
|
||||
// Handle array of content blocks
|
||||
var textContent strings.Builder
|
||||
var images []api.ImageData
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = api.ToolCallFunctionArguments(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
// Extract text from content blocks
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the main message
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
Content: textContent.String(),
|
||||
Images: images,
|
||||
ToolCalls: toolCalls,
|
||||
Thinking: thinking,
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
|
||||
var content []ContentBlock
|
||||
|
||||
// Add thinking block if present
|
||||
if r.Message.Thinking != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: r.Message.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// Add text content if present
|
||||
if r.Message.Content != "" {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "text",
|
||||
Text: r.Message.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
content = append(content, ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
// Map stop reason
|
||||
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
|
||||
|
||||
return MessagesResponse{
|
||||
ID: id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: r.Model,
|
||||
Content: content,
|
||||
StopReason: stopReason,
|
||||
Usage: Usage{
|
||||
InputTokens: r.Metrics.PromptEvalCount,
|
||||
OutputTokens: r.Metrics.EvalCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
|
||||
func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
if hasToolCalls {
|
||||
return "tool_use"
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
default:
|
||||
if reason != "" {
|
||||
return "stop_sequence"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
// NewStreamConverter creates a new StreamConverter
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming event to be sent to the client
|
||||
type StreamEvent struct {
|
||||
Event string
|
||||
Data any
|
||||
}
|
||||
|
||||
// Process converts an Ollama ChatResponse to Anthropic streaming events
|
||||
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
var events []StreamEvent
|
||||
|
||||
// First write: emit message_start
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
Data: MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: MessagesResponse{
|
||||
ID: c.ID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: c.Model,
|
||||
Content: []ContentBlock{},
|
||||
Usage: Usage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle thinking content
|
||||
if r.Message.Thinking != "" && !c.thinkingDone {
|
||||
if !c.thinkingStarted {
|
||||
c.thinkingStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "thinking_delta",
|
||||
Thinking: r.Message.Thinking,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle text content
|
||||
if r.Message.Content != "" {
|
||||
// Close thinking block if it was open
|
||||
if c.thinkingStarted && !c.thinkingDone {
|
||||
c.thinkingDone = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
if !c.textStarted {
|
||||
c.textStarted = true
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "text_delta",
|
||||
Text: r.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
for _, tc := range r.Message.ToolCalls {
|
||||
if c.toolCallsSent[tc.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Close any previous block
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
c.contentIndex++
|
||||
c.textStarted = false
|
||||
}
|
||||
|
||||
// Start tool use block
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_start",
|
||||
Data: ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: c.contentIndex,
|
||||
ContentBlock: ContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Input: map[string]any{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Send input as JSON delta
|
||||
argsJSON, _ := json.Marshal(tc.Function.Arguments)
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: c.contentIndex,
|
||||
Delta: Delta{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: string(argsJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Close tool use block
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
|
||||
c.toolCallsSent[tc.ID] = true
|
||||
c.contentIndex++
|
||||
}
|
||||
|
||||
// Handle done
|
||||
if r.Done {
|
||||
// Close any open block
|
||||
if c.textStarted {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
} else if c.thinkingStarted && !c.thinkingDone {
|
||||
events = append(events, StreamEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: c.contentIndex,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_delta",
|
||||
Data: MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: MessageDelta{
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_stop",
|
||||
Data: MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateID generates a unique ID with the given prefix using crypto/rand
|
||||
func generateID(prefix string) string {
|
||||
b := make([]byte, 12)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to time-based ID if crypto/rand fails
|
||||
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%s_%x", prefix, b)
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a unique message ID
|
||||
func GenerateMessageID() string {
|
||||
return generateID("msg")
|
||||
}
|
||||
@@ -1,667 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
func TestFromMessagesRequest_Basic(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
||||
t.Errorf("unexpected message: %+v", result.Messages[0])
|
||||
}
|
||||
|
||||
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
|
||||
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system message: %+v", result.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are helpful."},
|
||||
map[string]any{"type": "text", "text": " Be concise."},
|
||||
},
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "You are helpful. Be concise." {
|
||||
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOptions(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
topK := 40
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 2048,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: []string{"\n", "END"},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options["temperature"] != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
|
||||
}
|
||||
if result.Options["top_p"] != 0.9 {
|
||||
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
|
||||
}
|
||||
if result.Options["top_k"] != 40 {
|
||||
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
|
||||
}
|
||||
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
|
||||
t.Errorf("stop sequences mismatch: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "What's in this image?"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[0].Content != "What's in this image?" {
|
||||
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
if len(result.Messages[0].Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
|
||||
}
|
||||
|
||||
if string(result.Messages[0].Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{Role: "user", Content: "What's the weather in Paris?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if len(result.Messages[1].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
|
||||
}
|
||||
|
||||
tc := result.Messages[1].ToolCalls[0]
|
||||
if tc.ID != "call_123" {
|
||||
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_123",
|
||||
"content": "The weather in Paris is sunny, 22°C",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_123" {
|
||||
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "The weather in Paris is sunny, 22°C" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
tool := result.Tools[0]
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", tool.Type)
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
|
||||
}
|
||||
if tool.Function.Description != "Get current weather" {
|
||||
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected Think to be set")
|
||||
}
|
||||
if v, ok := result.Think.Value.(bool); !ok || !v {
|
||||
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "get_weather",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use id")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'id' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing tool_use name")
|
||||
}
|
||||
if err.Error() != "tool_use block missing required 'name' field" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Name: "bad_tool",
|
||||
InputSchema: json.RawMessage(`{invalid json`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromMessagesRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid tool schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_Basic(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello there!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 5,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if result.ID != "msg_123" {
|
||||
t.Errorf("expected ID 'msg_123', got %q", result.ID)
|
||||
}
|
||||
if result.Type != "message" {
|
||||
t.Errorf("expected type 'message', got %q", result.Type)
|
||||
}
|
||||
if result.Role != "assistant" {
|
||||
t.Errorf("expected role 'assistant', got %q", result.Role)
|
||||
}
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "text" || result.Content[0].Text != "Hello there!" {
|
||||
t.Errorf("unexpected content: %+v", result.Content[0])
|
||||
}
|
||||
if result.StopReason != "end_turn" {
|
||||
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
||||
}
|
||||
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", result.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "tool_use" {
|
||||
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].ID != "call_123" {
|
||||
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
|
||||
}
|
||||
if result.Content[0].Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
|
||||
}
|
||||
if result.StopReason != "tool_use" {
|
||||
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessagesResponse_WithThinking(t *testing.T) {
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "The answer is 42.",
|
||||
Thinking: "Let me think about this...",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
}
|
||||
|
||||
result := ToMessagesResponse("msg_123", resp)
|
||||
|
||||
if len(result.Content) != 2 {
|
||||
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
||||
}
|
||||
if result.Content[0].Type != "thinking" {
|
||||
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
|
||||
}
|
||||
if result.Content[0].Thinking != "Let me think about this..." {
|
||||
t.Errorf("unexpected thinking content: %q", result.Content[0].Thinking)
|
||||
}
|
||||
if result.Content[1].Type != "text" {
|
||||
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStopReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason string
|
||||
hasToolCalls bool
|
||||
want string
|
||||
}{
|
||||
{"stop", false, "end_turn"},
|
||||
{"length", false, "max_tokens"},
|
||||
{"stop", true, "tool_use"},
|
||||
{"other", false, "stop_sequence"},
|
||||
{"", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := mapStopReason(tt.reason, tt.hasToolCalls)
|
||||
if got != tt.want {
|
||||
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
want string
|
||||
}{
|
||||
{400, "invalid_request_error"},
|
||||
{401, "authentication_error"},
|
||||
{403, "permission_error"},
|
||||
{404, "not_found_error"},
|
||||
{429, "rate_limit_error"},
|
||||
{500, "api_error"},
|
||||
{503, "overloaded_error"},
|
||||
{529, "overloaded_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NewError(tt.code, "test message")
|
||||
if result.Type != "error" {
|
||||
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
|
||||
}
|
||||
if result.Error.Type != tt.want {
|
||||
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
||||
}
|
||||
if result.Error.Message != "test message" {
|
||||
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
|
||||
}
|
||||
if result.RequestID == "" {
|
||||
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageID(t *testing.T) {
|
||||
id1 := GenerateMessageID()
|
||||
id2 := GenerateMessageID()
|
||||
|
||||
if id1 == "" {
|
||||
t.Error("GenerateMessageID returned empty string")
|
||||
}
|
||||
if id1 == id2 {
|
||||
t.Error("GenerateMessageID returned duplicate IDs")
|
||||
}
|
||||
if len(id1) < 10 {
|
||||
t.Errorf("GenerateMessageID returned short ID: %q", id1)
|
||||
}
|
||||
if id1[:4] != "msg_" {
|
||||
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10},
|
||||
}
|
||||
|
||||
events1 := conv.Process(resp1)
|
||||
if len(events1) < 3 {
|
||||
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
|
||||
}
|
||||
|
||||
// Should have message_start, content_block_start, content_block_delta
|
||||
if events1[0].Event != "message_start" {
|
||||
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
|
||||
}
|
||||
if events1[1].Event != "content_block_start" {
|
||||
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
|
||||
}
|
||||
if events1[2].Event != "content_block_delta" {
|
||||
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
|
||||
}
|
||||
|
||||
// Final chunk
|
||||
resp2 := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: " world!",
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
}
|
||||
if !hasStop {
|
||||
t.Error("expected message_stop event in final chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "Paris"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events := conv.Process(resp)
|
||||
|
||||
hasToolStart := false
|
||||
hasToolDelta := false
|
||||
for _, e := range events {
|
||||
if e.Event == "content_block_start" {
|
||||
if start, ok := e.Data.(ContentBlockStartEvent); ok {
|
||||
if start.ContentBlock.Type == "tool_use" {
|
||||
hasToolStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if e.Event == "content_block_delta" {
|
||||
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
|
||||
if delta.Delta.Type == "input_json_delta" {
|
||||
hasToolDelta = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasToolStart {
|
||||
t.Error("expected tool_use content_block_start event")
|
||||
}
|
||||
if !hasToolDelta {
|
||||
t.Error("expected input_json_delta event")
|
||||
}
|
||||
}
|
||||
@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: string(bts),
|
||||
}
|
||||
}
|
||||
return errors.New(string(bts))
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
@@ -347,7 +340,7 @@ type CreateProgressFunc func(ProgressResponse) error
|
||||
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||
// behaves similarly to other methods (see [Client.Pull]).
|
||||
//
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
|
||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
|
||||
@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
type testError struct {
|
||||
message string
|
||||
statusCode int
|
||||
raw bool // if true, write message as-is instead of JSON encoding
|
||||
}
|
||||
|
||||
func (e testError) Error() string {
|
||||
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
responses: []any{
|
||||
"internal server error",
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
responses: []any{
|
||||
"<html><body>404 Not Found</body></html>",
|
||||
},
|
||||
wantErr: "404 Not Found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := resp.(string); ok {
|
||||
fmt.Fprintln(w, str)
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
|
||||
|
||||
func TestClientDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
wantStatusCode int
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantErr: "test error message",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErr: "test error message",
|
||||
},
|
||||
{
|
||||
name: "server error response",
|
||||
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
wantErr: "internal error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantErr: "internal error",
|
||||
},
|
||||
{
|
||||
name: "successful response",
|
||||
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
|
||||
Success: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
response: testError{
|
||||
message: "internal server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
response: testError{
|
||||
message: "<html><body>404 Not Found</body></html>",
|
||||
statusCode: http.StatusNotFound,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "<html><body>404 Not Found</body></html>",
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if errResp, ok := tc.response.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
if !errResp.raw {
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
} else {
|
||||
// Write raw message (simulates non-JSON error responses)
|
||||
fmt.Fprint(w, errResp.message)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
|
||||
if err.Error() != tc.wantErr {
|
||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||
}
|
||||
if tc.wantStatusCode != 0 {
|
||||
if statusErr, ok := err.(StatusError); ok {
|
||||
if statusErr.StatusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("expected StatusError, got %T", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,19 +15,19 @@ func main() {
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
{
|
||||
api.Message{
|
||||
Role: "system",
|
||||
Content: "Provide very brief, concise responses",
|
||||
},
|
||||
{
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "Name some unusual animals",
|
||||
},
|
||||
{
|
||||
api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Monotreme, platypus, echidna",
|
||||
},
|
||||
{
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "which of these is the most dangerous?",
|
||||
},
|
||||
|
||||
18
api/types.go
18
api/types.go
@@ -283,12 +283,11 @@ func (pt PropertyType) String() string {
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
@@ -367,9 +366,6 @@ type TokenLogprob struct {
|
||||
|
||||
// Logprob is the log probability of this token.
|
||||
Logprob float64 `json:"logprob"`
|
||||
|
||||
// Bytes contains the raw byte representation of the token
|
||||
Bytes []int `json:"bytes,omitempty"`
|
||||
}
|
||||
|
||||
// Logprob contains log probability information for a generated token.
|
||||
@@ -554,9 +550,6 @@ type CreateRequest struct {
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Requires is the minimum version of Ollama required by the model.
|
||||
Requires string `json:"requires,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
@@ -607,7 +600,6 @@ type ShowResponse struct {
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Requires string `json:"requires,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
|
||||
@@ -504,107 +504,6 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPropertyNestedProperties(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected ToolProperty
|
||||
}{
|
||||
{
|
||||
name: "nested object properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Location details",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "string",
|
||||
"description": "Street address"
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location details",
|
||||
Properties: map[string]ToolProperty{
|
||||
"address": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "Street address",
|
||||
},
|
||||
"city": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested properties",
|
||||
input: `{
|
||||
"type": "object",
|
||||
"description": "Event",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "object",
|
||||
"description": "Location",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "object",
|
||||
"description": "GPS coordinates",
|
||||
"properties": {
|
||||
"lat": {"type": "number", "description": "Latitude"},
|
||||
"lng": {"type": "number", "description": "Longitude"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Event",
|
||||
Properties: map[string]ToolProperty{
|
||||
"location": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location",
|
||||
Properties: map[string]ToolProperty{
|
||||
"coordinates": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "GPS coordinates",
|
||||
Properties: map[string]ToolProperty{
|
||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var prop ToolProperty
|
||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop)
|
||||
|
||||
// Round-trip test: marshal and unmarshal again
|
||||
data, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
|
||||
var prop2 ToolProperty
|
||||
err = json.Unmarshal(data, &prop2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -273,6 +273,10 @@ func main() {
|
||||
Handler: uiServer.Handler(),
|
||||
}
|
||||
|
||||
if _, err := uiServer.UserData(ctx); err != nil {
|
||||
slog.Warn("failed to load user data", "error", err)
|
||||
}
|
||||
|
||||
// Start the UI server
|
||||
slog.Info("starting ui server", "port", port)
|
||||
go func() {
|
||||
@@ -316,17 +320,6 @@ func main() {
|
||||
slog.Debug("no URL scheme request to handle")
|
||||
}
|
||||
|
||||
go func() {
|
||||
slog.Debug("waiting for ollama server to be ready")
|
||||
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
|
||||
slog.Warn("ollama server not ready, continuing anyway", "error", err)
|
||||
}
|
||||
|
||||
if _, err := uiServer.UserData(ctx); err != nil {
|
||||
slog.Warn("failed to load user data", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
osRun(cancel, hasCompletedFirstRun, startHidden)
|
||||
|
||||
slog.Info("shutting down desktop server")
|
||||
@@ -368,7 +361,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
|
||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
|
||||
if err != nil {
|
||||
slog.Debug("failed to call local auth endpoint", "error", err)
|
||||
return false
|
||||
@@ -404,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
||||
// handleConnectURLScheme fetches the connect URL and opens it in the browser
|
||||
func handleConnectURLScheme() {
|
||||
if checkUserLoggedIn(uiServerPort) {
|
||||
slog.Info("user is already logged in, opening app instead")
|
||||
showWindow(wv.webview.Window())
|
||||
slog.Info("user is already logged in, opening settings instead")
|
||||
sendUIRequestMessage("/")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -441,30 +434,37 @@ func openInBrowser(url string) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseURLScheme parses an ollama:// URL and validates it
|
||||
// Supports: ollama:// (open app) and ollama://connect (OAuth)
|
||||
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
|
||||
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
|
||||
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
|
||||
parsedURL, err := url.Parse(urlSchemeRequest)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid URL: %w", err)
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
// Check if this is a connect URL
|
||||
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
|
||||
return true, nil
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
// Allow bare ollama:// or ollama:/// to open the app
|
||||
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
|
||||
return false, nil
|
||||
// Extract the UI path
|
||||
path := "/"
|
||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
||||
// For URLs like ollama:///settings, use the path directly
|
||||
path = parsedURL.Path
|
||||
} else if parsedURL.Host != "" {
|
||||
// For URLs like ollama://settings (without triple slash),
|
||||
// the "settings" part is parsed as the host, not the path.
|
||||
// We need to convert it to a path by prepending "/"
|
||||
// This also handles ollama://settings/ where Windows adds a trailing slash
|
||||
path = "/" + parsedURL.Host
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
|
||||
return false, path, nil
|
||||
}
|
||||
|
||||
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
|
||||
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
||||
isConnect, err := parseURLScheme(urlSchemeRequest)
|
||||
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
|
||||
if err != nil {
|
||||
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
|
||||
return
|
||||
@@ -473,8 +473,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
||||
if isConnect {
|
||||
handleConnectURLScheme()
|
||||
} else {
|
||||
if wv.webview != nil {
|
||||
showWindow(wv.webview.Window())
|
||||
}
|
||||
sendUIRequestMessage(uiPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,6 +191,13 @@ func LaunchNewApp() {
|
||||
C.launchApp(appName)
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
p := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(p))
|
||||
C.uiRequest(p)
|
||||
}
|
||||
|
||||
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
||||
// Remove any stale Login Item registrations
|
||||
C.unregisterSelfFromLoginItem()
|
||||
|
||||
@@ -24,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
for (NSURL *url in urls) {
|
||||
if ([url.scheme isEqualToString:@"ollama"]) {
|
||||
NSString *path = url.path;
|
||||
|
||||
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
|
||||
if (!path || [path isEqualToString:@""]) {
|
||||
// For URLs like ollama://settings (without triple slash),
|
||||
// the "settings" part is parsed as the host, not the path.
|
||||
// We need to convert it to a path by prepending "/"
|
||||
if (url.host && ![url.host isEqualToString:@""]) {
|
||||
path = [@"/" stringByAppendingString:url.host];
|
||||
} else {
|
||||
path = @"/";
|
||||
}
|
||||
}
|
||||
|
||||
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
|
||||
// Special case: handle connect by opening browser instead of app
|
||||
handleConnectURL();
|
||||
} else {
|
||||
// Set app to be active and visible
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
|
||||
// Open the path with the UI
|
||||
[self uiRequest:path];
|
||||
}
|
||||
|
||||
break;
|
||||
@@ -247,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)openHelp:(id)sender {
|
||||
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
|
||||
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
|
||||
[[NSWorkspace sharedWorkspace] openURL:url];
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
|
||||
|
||||
// handleURLSchemeRequest processes URL scheme requests from other instances
|
||||
func handleURLSchemeRequest(urlScheme string) {
|
||||
isConnect, err := parseURLScheme(urlScheme)
|
||||
isConnect, uiPath, err := parseURLScheme(urlScheme)
|
||||
if err != nil {
|
||||
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
|
||||
return
|
||||
@@ -147,9 +147,7 @@ func handleURLSchemeRequest(urlScheme string) {
|
||||
if isConnect {
|
||||
handleConnectURLScheme()
|
||||
} else {
|
||||
if wv.webview != nil {
|
||||
showWindow(wv.webview.Window())
|
||||
}
|
||||
sendUIRequestMessage(uiPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,6 +261,11 @@ func createLoginShortcut() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send a request to the main app thread to load a UI page
|
||||
func sendUIRequestMessage(path string) {
|
||||
wintray.SendUIRequestMessage(path)
|
||||
}
|
||||
|
||||
func LaunchNewApp() {
|
||||
}
|
||||
|
||||
|
||||
@@ -169,47 +169,37 @@ DlgResult fileDlg(FileDlgParams* params) {
|
||||
}
|
||||
|
||||
NSArray* urls = [panel URLs];
|
||||
if([urls count] == 0) {
|
||||
return DLG_CANCEL;
|
||||
}
|
||||
|
||||
if(self->params->allowMultiple) {
|
||||
if(self->params->allowMultiple && [urls count] >= 1) {
|
||||
// For multiple files, we need to return all paths separated by null bytes
|
||||
char* bufPtr = self->params->buf;
|
||||
int remainingBuf = self->params->nbuf;
|
||||
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
// Calculate total required buffer size first
|
||||
int totalSize = 0;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||
}
|
||||
totalSize += 1; // Final null terminator
|
||||
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
if(totalSize > self->params->nbuf) {
|
||||
// Not enough buffer space
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
} else {
|
||||
// Single file/directory selection - write path to buffer
|
||||
NSURL* url = [urls firstObject];
|
||||
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
|
||||
return DLG_URLFAIL;
|
||||
}
|
||||
// Now actually copy the paths (we know we have space)
|
||||
bufPtr = self->params->buf;
|
||||
for(NSURL* url in urls) {
|
||||
char tempBuf[PATH_MAX];
|
||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||
int pathLen = strlen(tempBuf);
|
||||
strcpy(bufPtr, tempBuf);
|
||||
bufPtr += pathLen + 1;
|
||||
}
|
||||
*bufPtr = '\0'; // Final null terminator
|
||||
}
|
||||
|
||||
return DLG_OK;
|
||||
|
||||
@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
|
||||
type WinDlgError int
|
||||
|
||||
func (e WinDlgError) Error() string {
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
|
||||
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
|
||||
}
|
||||
|
||||
func err() error {
|
||||
|
||||
@@ -224,7 +224,9 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
||||
if _, err := os.Stat(settings.Models); err == nil {
|
||||
env["OLLAMA_MODELS"] = settings.Models
|
||||
} else {
|
||||
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
|
||||
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
|
||||
settings.Models = ""
|
||||
s.store.SetSettings(settings)
|
||||
}
|
||||
}
|
||||
if settings.ContextLength > 0 {
|
||||
|
||||
@@ -469,24 +469,26 @@ export class HealthResponse {
|
||||
}
|
||||
export class User {
|
||||
id: string;
|
||||
email: string;
|
||||
name: string;
|
||||
bio?: string;
|
||||
avatarurl?: string;
|
||||
firstname?: string;
|
||||
lastname?: string;
|
||||
plan?: string;
|
||||
email: string;
|
||||
avatarURL: string;
|
||||
plan: string;
|
||||
bio: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
overThreshold: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.id = source["id"];
|
||||
this.email = source["email"];
|
||||
this.name = source["name"];
|
||||
this.bio = source["bio"];
|
||||
this.avatarurl = source["avatarurl"];
|
||||
this.firstname = source["firstname"];
|
||||
this.lastname = source["lastname"];
|
||||
this.email = source["email"];
|
||||
this.avatarURL = source["avatarURL"];
|
||||
this.plan = source["plan"];
|
||||
this.bio = source["bio"];
|
||||
this.firstName = source["firstName"];
|
||||
this.lastName = source["lastName"];
|
||||
this.overThreshold = source["overThreshold"];
|
||||
}
|
||||
}
|
||||
export class Attachment {
|
||||
|
||||
12
app/ui/app/package-lock.json
generated
12
app/ui/app/package-lock.json
generated
@@ -27,7 +27,8 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"streamdown": "^1.4.0",
|
||||
"unist-builder": "^4.0.0",
|
||||
"unist-util-parents": "^3.0.0"
|
||||
"unist-util-parents": "^3.0.0",
|
||||
"use-stick-to-bottom": "^1.1.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^4.0.1",
|
||||
@@ -12739,6 +12740,15 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/use-stick-to-bottom": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/use-stick-to-bottom/-/use-stick-to-bottom-1.1.1.tgz",
|
||||
"integrity": "sha512-JkDp0b0tSmv7HQOOpL1hT7t7QaoUBXkq045WWWOFDTlLGRzgIIyW7vyzOIJzY7L2XVIG7j1yUxeDj2LHm9Vwng==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/use-sync-external-store": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.5.0.tgz",
|
||||
|
||||
@@ -36,7 +36,8 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"streamdown": "^1.4.0",
|
||||
"unist-builder": "^4.0.0",
|
||||
"unist-util-parents": "^3.0.0"
|
||||
"unist-util-parents": "^3.0.0",
|
||||
"use-stick-to-bottom": "^1.1.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^4.0.1",
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
||||
import { ollamaClient as ollama } from "./lib/ollama-client";
|
||||
import type { ModelResponse } from "ollama/browser";
|
||||
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
|
||||
|
||||
// Extend Model class with utility methods
|
||||
declare module "@/gotypes" {
|
||||
@@ -27,6 +26,9 @@ declare module "@/gotypes" {
|
||||
Model.prototype.isCloud = function (): boolean {
|
||||
return this.model.endsWith("cloud");
|
||||
};
|
||||
|
||||
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
|
||||
|
||||
// Helper function to convert Uint8Array to base64
|
||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||
@@ -41,50 +43,44 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||
}
|
||||
|
||||
export async function fetchUser(): Promise<User | null> {
|
||||
const response = await fetch(`${API_BASE}/api/me`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/v1/me`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const userData: User = await response.json();
|
||||
|
||||
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
|
||||
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
|
||||
if (response.ok) {
|
||||
const userData: User = await response.json();
|
||||
return userData;
|
||||
}
|
||||
|
||||
return userData;
|
||||
}
|
||||
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.error("Error fetching user:", error);
|
||||
return null;
|
||||
}
|
||||
|
||||
throw new Error(`Failed to fetch user: ${response.status}`);
|
||||
}
|
||||
|
||||
export async function fetchConnectUrl(): Promise<string> {
|
||||
const response = await fetch(`${API_BASE}/api/me`, {
|
||||
method: "POST",
|
||||
const response = await fetch(`${API_BASE}/api/v1/connect`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
const data = await response.json();
|
||||
if (data.signin_url) {
|
||||
return data.signin_url;
|
||||
}
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch connect URL");
|
||||
}
|
||||
|
||||
throw new Error("Failed to fetch connect URL");
|
||||
const data = await response.json();
|
||||
return data.connect_url;
|
||||
}
|
||||
|
||||
export async function disconnectUser(): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/signout`, {
|
||||
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
@@ -209,10 +205,12 @@ export async function* sendMessage(
|
||||
data: uint8ArrayToBase64(att.data),
|
||||
}));
|
||||
|
||||
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
|
||||
// Only send think parameter when actually requesting thinking
|
||||
// Don't send false as it causes issues with some providers
|
||||
const shouldSendThink =
|
||||
think !== undefined &&
|
||||
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
|
||||
((typeof think === "boolean" && think) ||
|
||||
(typeof think === "string" && think !== ""));
|
||||
|
||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||
method: "POST",
|
||||
@@ -394,8 +392,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
||||
|
||||
export async function fetchHealth(): Promise<boolean> {
|
||||
try {
|
||||
// Use the /api/version endpoint as a health check
|
||||
const response = await fetch(`${API_BASE}/api/version`, {
|
||||
const response = await fetch(`${API_BASE}/api/v1/health`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
@@ -404,8 +401,7 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
// If we get a version back, the server is healthy
|
||||
return !!data.version;
|
||||
return data.healthy || false;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -4,6 +4,11 @@ import { FileUpload } from "./FileUpload";
|
||||
import { DisplayUpgrade } from "./DisplayUpgrade";
|
||||
import { DisplayStale } from "./DisplayStale";
|
||||
import { DisplayLogin } from "./DisplayLogin";
|
||||
import {
|
||||
ConversationContent,
|
||||
ConversationScrollButton,
|
||||
ConversationWithSpacer,
|
||||
} from "./ai-elements/conversation";
|
||||
import {
|
||||
useChat,
|
||||
useSendMessage,
|
||||
@@ -15,14 +20,7 @@ import {
|
||||
useDismissStaleModel,
|
||||
} from "@/hooks/useChats";
|
||||
import { useHealth } from "@/hooks/useHealth";
|
||||
import { useMessageAutoscroll } from "@/hooks/useMessageAutoscroll";
|
||||
import {
|
||||
useState,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useRef,
|
||||
useCallback,
|
||||
} from "react";
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useSelectedModel } from "@/hooks/useSelectedModel";
|
||||
@@ -47,7 +45,6 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
index: number;
|
||||
originalMessage: Message;
|
||||
} | null>(null);
|
||||
const prevChatIdRef = useRef<string>(chatId);
|
||||
|
||||
const chatFormCallbackRef = useRef<
|
||||
| ((
|
||||
@@ -102,29 +99,8 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
|
||||
const sendMessageMutation = useSendMessage(chatId);
|
||||
|
||||
const { containerRef, handleNewUserMessage, spacerHeight } =
|
||||
useMessageAutoscroll({
|
||||
messages,
|
||||
isStreaming,
|
||||
chatId,
|
||||
});
|
||||
const latestMessageRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Scroll to bottom only when switching to a different existing chat
|
||||
useLayoutEffect(() => {
|
||||
// Only scroll if the chatId actually changed (not just messages updating)
|
||||
if (
|
||||
prevChatIdRef.current !== chatId &&
|
||||
containerRef.current &&
|
||||
messages.length > 0 &&
|
||||
chatId !== "new"
|
||||
) {
|
||||
// Always scroll to the bottom when opening a chat
|
||||
containerRef.current.scrollTop = containerRef.current.scrollHeight;
|
||||
}
|
||||
prevChatIdRef.current = chatId;
|
||||
}, [chatId, messages.length]);
|
||||
|
||||
// Simplified submit handler - ChatForm handles all the attachment logic
|
||||
const handleChatFormSubmit = (
|
||||
message: string,
|
||||
options: {
|
||||
@@ -168,7 +144,6 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
|
||||
// Clear edit mode after submission
|
||||
setEditingMessage(null);
|
||||
handleNewUserMessage();
|
||||
};
|
||||
|
||||
const handleEditMessage = (content: string, index: number) => {
|
||||
@@ -193,8 +168,6 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
);
|
||||
};
|
||||
|
||||
const isWindows = navigator.platform.toLowerCase().includes("win");
|
||||
|
||||
return chatId === "new" || chatQuery ? (
|
||||
<FileUpload
|
||||
onFilesAdded={handleFilesProcessed}
|
||||
@@ -219,25 +192,29 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
</div>
|
||||
) : (
|
||||
<main className="flex h-screen w-full flex-col relative allow-context-menu select-none">
|
||||
<section
|
||||
key={chatId} // This key forces React to recreate the element when chatId changes
|
||||
ref={containerRef}
|
||||
className={`flex-1 overflow-y-auto overscroll-contain relative min-h-0 select-none ${isWindows ? "xl:pt-4" : "xl:pt-8"}`}
|
||||
<ConversationWithSpacer
|
||||
key={chatId}
|
||||
className={`flex-1 overscroll-contain select-none`}
|
||||
isStreaming={isStreaming}
|
||||
messageCount={messages.length}
|
||||
>
|
||||
<MessageList
|
||||
messages={messages}
|
||||
spacerHeight={spacerHeight}
|
||||
isWaitingForLoad={isWaitingForLoad}
|
||||
isStreaming={isStreaming}
|
||||
downloadProgress={downloadProgress}
|
||||
onEditMessage={(content: string, index: number) => {
|
||||
handleEditMessage(content, index);
|
||||
}}
|
||||
editingMessageIndex={editingMessage?.index}
|
||||
error={chatError}
|
||||
browserToolResult={browserToolResult}
|
||||
/>
|
||||
</section>
|
||||
<ConversationContent isStreaming={isStreaming}>
|
||||
<MessageList
|
||||
messages={messages}
|
||||
isWaitingForLoad={isWaitingForLoad}
|
||||
isStreaming={isStreaming}
|
||||
downloadProgress={downloadProgress}
|
||||
onEditMessage={(content: string, index: number) => {
|
||||
handleEditMessage(content, index);
|
||||
}}
|
||||
editingMessageIndex={editingMessage?.index}
|
||||
error={chatError}
|
||||
browserToolResult={browserToolResult}
|
||||
latestMessageRef={latestMessageRef}
|
||||
/>
|
||||
</ConversationContent>
|
||||
<ConversationScrollButton />
|
||||
</ConversationWithSpacer>
|
||||
|
||||
<div className="flex-shrink-0 sticky bottom-0 z-20">
|
||||
{selectedModel && shouldShowStaleDisplay && (
|
||||
@@ -248,14 +225,6 @@ export default function Chat({ chatId }: { chatId: string }) {
|
||||
dismissStaleModel(selectedModel?.model || "")
|
||||
}
|
||||
chatId={chatId}
|
||||
onScrollToBottom={() => {
|
||||
if (containerRef.current) {
|
||||
containerRef.current.scrollTo({
|
||||
top: containerRef.current.scrollHeight,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -3,10 +3,10 @@ import React from "react";
|
||||
import Message from "./Message";
|
||||
import Downloading from "./Downloading";
|
||||
import { ErrorMessage } from "./ErrorMessage";
|
||||
import { ConversationSpacer } from "./ai-elements/conversation";
|
||||
|
||||
export default function MessageList({
|
||||
messages,
|
||||
spacerHeight,
|
||||
isWaitingForLoad,
|
||||
isStreaming,
|
||||
downloadProgress,
|
||||
@@ -14,9 +14,9 @@ export default function MessageList({
|
||||
editingMessageIndex,
|
||||
error,
|
||||
browserToolResult,
|
||||
latestMessageRef,
|
||||
}: {
|
||||
messages: MessageType[];
|
||||
spacerHeight: number;
|
||||
isWaitingForLoad?: boolean;
|
||||
isStreaming: boolean;
|
||||
downloadProgress?: DownloadEvent;
|
||||
@@ -24,6 +24,7 @@ export default function MessageList({
|
||||
editingMessageIndex?: number;
|
||||
error?: ErrorEvent | null;
|
||||
browserToolResult?: any;
|
||||
latestMessageRef?: React.RefObject<HTMLDivElement | null>;
|
||||
}) {
|
||||
const [showDots, setShowDots] = React.useState(false);
|
||||
const isDownloadingModel = downloadProgress && !downloadProgress.done;
|
||||
@@ -84,13 +85,19 @@ export default function MessageList({
|
||||
|
||||
return (
|
||||
<div
|
||||
className="mx-auto flex max-w-[768px] flex-1 flex-col px-6 pb-12 select-text"
|
||||
className="mx-auto flex max-w-[768px] w-full flex-1 flex-col px-6 py-12 select-text"
|
||||
data-role="message-list"
|
||||
>
|
||||
{messages.map((message, idx) => {
|
||||
const lastToolQuery = lastToolQueries[idx];
|
||||
const isLastMessage = idx === messages.length - 1;
|
||||
return (
|
||||
<div key={`${message.created_at}-${idx}`} data-message-index={idx}>
|
||||
<div
|
||||
key={`${message.created_at}-${idx}`}
|
||||
data-message-index={idx}
|
||||
data-message-role={message.role}
|
||||
ref={isLastMessage ? latestMessageRef : null}
|
||||
>
|
||||
<Message
|
||||
message={message}
|
||||
onEditMessage={onEditMessage}
|
||||
@@ -161,8 +168,8 @@ export default function MessageList({
|
||||
</section>
|
||||
)}
|
||||
|
||||
{/* Dynamic spacer to allow scrolling the last message to the top of the container */}
|
||||
<div style={{ height: `${spacerHeight}px` }} aria-hidden="true" />
|
||||
{/* Dynamic spacer */}
|
||||
<ConversationSpacer />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -299,9 +299,9 @@ export default function Settings() {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{user?.avatarurl && (
|
||||
{user?.avatarURL && (
|
||||
<img
|
||||
src={user.avatarurl}
|
||||
src={user.avatarURL}
|
||||
alt={user?.name}
|
||||
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
||||
onError={(e) => {
|
||||
|
||||
@@ -50,33 +50,21 @@ export default function Thinking({
|
||||
// Position content to show bottom when collapsed
|
||||
useEffect(() => {
|
||||
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
||||
requestAnimationFrame(() => {
|
||||
if (!contentRef.current || !wrapperRef.current) return;
|
||||
|
||||
const contentHeight = contentRef.current.scrollHeight;
|
||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||
if (contentHeight > wrapperHeight) {
|
||||
const translateY = -(contentHeight - wrapperHeight);
|
||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||
setHasOverflow(true);
|
||||
} else {
|
||||
contentRef.current.style.transform = "translateY(0)";
|
||||
setHasOverflow(false);
|
||||
}
|
||||
});
|
||||
const contentHeight = contentRef.current.scrollHeight;
|
||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||
if (contentHeight > wrapperHeight) {
|
||||
const translateY = -(contentHeight - wrapperHeight);
|
||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||
setHasOverflow(true);
|
||||
} else {
|
||||
setHasOverflow(false);
|
||||
}
|
||||
} else if (contentRef.current) {
|
||||
contentRef.current.style.transform = "translateY(0)";
|
||||
setHasOverflow(false);
|
||||
}
|
||||
}, [thinking, isCollapsed]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activelyThinking && wrapperRef.current && !isCollapsed) {
|
||||
// When expanded and actively thinking, scroll to bottom
|
||||
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
|
||||
}
|
||||
}, [thinking, activelyThinking, isCollapsed]);
|
||||
|
||||
const handleToggle = () => {
|
||||
setIsCollapsed(!isCollapsed);
|
||||
setHasUserInteracted(true);
|
||||
|
||||
385
app/ui/app/src/components/ai-elements/conversation.tsx
Normal file
385
app/ui/app/src/components/ai-elements/conversation.tsx
Normal file
@@ -0,0 +1,385 @@
|
||||
"use client";
|
||||
|
||||
import clsx from "clsx";
|
||||
import type { ComponentProps } from "react";
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
} from "react";
|
||||
import { StickToBottom, useStickToBottomContext } from "use-stick-to-bottom";
|
||||
|
||||
// Create a context to share the "allow scroll" state and spacer state
|
||||
const ConversationControlContext = createContext<{
|
||||
allowScroll: () => void;
|
||||
spacerHeight: number;
|
||||
setSpacerHeight: (height: number) => void;
|
||||
} | null>(null);
|
||||
|
||||
export type ConversationProps = ComponentProps<typeof StickToBottom> & {
|
||||
isStreaming?: boolean;
|
||||
};
|
||||
|
||||
export const Conversation = ({
|
||||
className,
|
||||
isStreaming = false,
|
||||
children,
|
||||
...props
|
||||
}: ConversationProps) => {
|
||||
const shouldStopScrollRef = useRef(true);
|
||||
const [spacerHeight, setSpacerHeight] = useState(0);
|
||||
|
||||
const allowScroll = useCallback(() => {
|
||||
shouldStopScrollRef.current = false;
|
||||
setTimeout(() => {
|
||||
shouldStopScrollRef.current = true;
|
||||
}, 100);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ConversationControlContext.Provider
|
||||
value={{ allowScroll, spacerHeight, setSpacerHeight }}
|
||||
>
|
||||
<StickToBottom
|
||||
className={clsx("relative h-full w-full overflow-y-auto", className)}
|
||||
initial="instant"
|
||||
resize="instant"
|
||||
role="log"
|
||||
{...props}
|
||||
>
|
||||
<ConversationContentInternal
|
||||
isStreaming={isStreaming}
|
||||
shouldStopScrollRef={shouldStopScrollRef}
|
||||
>
|
||||
<>{children}</>
|
||||
</ConversationContentInternal>
|
||||
</StickToBottom>
|
||||
</ConversationControlContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// New wrapper component that includes spacer management
|
||||
export const ConversationWithSpacer = ({
|
||||
isStreaming,
|
||||
messageCount,
|
||||
children,
|
||||
...props
|
||||
}: ConversationProps & {
|
||||
messageCount: number;
|
||||
}) => {
|
||||
return (
|
||||
<Conversation isStreaming={isStreaming} {...props}>
|
||||
<SpacerController
|
||||
isStreaming={isStreaming ?? false}
|
||||
messageCount={messageCount}
|
||||
/>
|
||||
<>{children}</>
|
||||
</Conversation>
|
||||
);
|
||||
};
|
||||
|
||||
// This component manages the spacer state but doesn't render the spacer itself
|
||||
const SpacerController = ({
|
||||
isStreaming,
|
||||
messageCount,
|
||||
}: {
|
||||
isStreaming: boolean;
|
||||
messageCount: number;
|
||||
}) => {
|
||||
const context = useContext(ConversationControlContext);
|
||||
const { scrollToBottom } = useStickToBottomContext();
|
||||
const previousMessageCountRef = useRef(messageCount);
|
||||
const scrollContainerRef = useRef<HTMLElement | null>(null);
|
||||
const [isActiveInteraction, setIsActiveInteraction] = useState(false);
|
||||
|
||||
// Get reference to scroll container
|
||||
useEffect(() => {
|
||||
const container = document.querySelector('[role="log"]') as HTMLElement;
|
||||
scrollContainerRef.current = container;
|
||||
}, []);
|
||||
|
||||
// Calculate spacer height based on actual DOM elements
|
||||
const calculateSpacerHeight = useCallback(() => {
|
||||
const container = scrollContainerRef.current;
|
||||
if (!container) {
|
||||
console.log("❌ No container");
|
||||
return 0;
|
||||
}
|
||||
|
||||
const containerHeight = container.clientHeight;
|
||||
|
||||
// Find all message elements
|
||||
const messageElements = container.querySelectorAll(
|
||||
"[data-message-index]",
|
||||
) as NodeListOf<HTMLElement>;
|
||||
console.log("📝 Found", messageElements.length, "message elements");
|
||||
|
||||
if (messageElements.length === 0) return 0;
|
||||
|
||||
// Log all messages and their roles
|
||||
messageElements.forEach((el, i) => {
|
||||
const role = el.getAttribute("data-message-role");
|
||||
console.log(` Message ${i}: role="${role}"`);
|
||||
});
|
||||
|
||||
// Find the last user message
|
||||
let lastUserMessageElement: HTMLElement | null = null;
|
||||
let lastUserMessageIndex = -1;
|
||||
|
||||
for (let i = messageElements.length - 1; i >= 0; i--) {
|
||||
const el = messageElements[i];
|
||||
const role = el.getAttribute("data-message-role");
|
||||
if (role === "user") {
|
||||
lastUserMessageElement = el;
|
||||
lastUserMessageIndex = i;
|
||||
console.log("✅ Found user message at index:", i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!lastUserMessageElement) {
|
||||
console.log("❌ No user message found!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Calculate height of content after the last user message
|
||||
let contentHeightAfter = 0;
|
||||
for (let i = lastUserMessageIndex + 1; i < messageElements.length; i++) {
|
||||
contentHeightAfter += messageElements[i].offsetHeight;
|
||||
}
|
||||
|
||||
const userMessageHeight = lastUserMessageElement.offsetHeight;
|
||||
|
||||
// Goal: Position user message at the top with some padding
|
||||
// We want the user message to start at around 10% from the top of viewport
|
||||
const targetTopPosition = containerHeight * 0.05; // 10% from top
|
||||
|
||||
// Calculate spacer: we need enough space so that when scrolled to bottom:
|
||||
// spacerHeight = containerHeight - targetTopPosition - userMessageHeight - contentAfter
|
||||
const calculatedHeight =
|
||||
containerHeight -
|
||||
targetTopPosition -
|
||||
userMessageHeight -
|
||||
contentHeightAfter;
|
||||
|
||||
const baseHeight = Math.max(0, calculatedHeight);
|
||||
|
||||
console.log(
|
||||
"📊 Container:",
|
||||
containerHeight,
|
||||
"User msg:",
|
||||
userMessageHeight,
|
||||
"Content after:",
|
||||
contentHeightAfter,
|
||||
"Target top pos:",
|
||||
targetTopPosition,
|
||||
"→ Final spacer:",
|
||||
baseHeight,
|
||||
);
|
||||
|
||||
return baseHeight;
|
||||
}, []);
|
||||
|
||||
// When a new message is submitted, set initial spacer height and scroll
|
||||
useEffect(() => {
|
||||
if (messageCount > previousMessageCountRef.current) {
|
||||
console.log("🎯 NEW MESSAGE - Setting spacer");
|
||||
// Allow scrolling by temporarily disabling stopScroll
|
||||
context?.allowScroll();
|
||||
setIsActiveInteraction(true);
|
||||
|
||||
const container = scrollContainerRef.current;
|
||||
|
||||
if (container) {
|
||||
// Wait for new message to render in DOM
|
||||
requestAnimationFrame(() => {
|
||||
requestAnimationFrame(() => {
|
||||
const spacerHeight = calculateSpacerHeight();
|
||||
console.log("📏 Calculated spacer:", spacerHeight);
|
||||
context?.setSpacerHeight(spacerHeight);
|
||||
|
||||
// Wait for spacer to be added to DOM, then scroll
|
||||
requestAnimationFrame(() => {
|
||||
requestAnimationFrame(() => {
|
||||
// Use the library's scrollToBottom method
|
||||
console.log("📜 Scrolling to bottom");
|
||||
scrollToBottom("instant");
|
||||
console.log("📜 Final scrollTop:", container.scrollTop);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
previousMessageCountRef.current = messageCount;
|
||||
}, [messageCount, context, calculateSpacerHeight, scrollToBottom]);
|
||||
|
||||
// Update active interaction state
|
||||
useEffect(() => {
|
||||
if (isStreaming) {
|
||||
setIsActiveInteraction(true);
|
||||
}
|
||||
// Don't automatically set to false when streaming stops
|
||||
// Let the ResizeObserver handle clearing the spacer naturally
|
||||
}, [isStreaming]);
|
||||
|
||||
// Use ResizeObserver to recalculate spacer as content changes
|
||||
useEffect(() => {
|
||||
const container = scrollContainerRef.current;
|
||||
if (!container || !isActiveInteraction) return;
|
||||
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
const newHeight = calculateSpacerHeight();
|
||||
context?.setSpacerHeight(newHeight);
|
||||
|
||||
// Clear active interaction when spacer reaches 0
|
||||
if (newHeight === 0) {
|
||||
setIsActiveInteraction(false);
|
||||
}
|
||||
});
|
||||
|
||||
// Observe all message elements
|
||||
const messageElements = container.querySelectorAll("[data-message-index]");
|
||||
messageElements.forEach((element) => {
|
||||
resizeObserver.observe(element);
|
||||
});
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [isActiveInteraction, calculateSpacerHeight, context]);
|
||||
|
||||
// Remove the effect that clears spacer when not streaming
|
||||
// This was causing the spacer to disappear prematurely
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const ConversationContentInternal = ({
|
||||
isStreaming,
|
||||
shouldStopScrollRef,
|
||||
children,
|
||||
}: {
|
||||
isStreaming: boolean;
|
||||
shouldStopScrollRef: React.MutableRefObject<boolean>;
|
||||
children: React.ReactNode;
|
||||
}) => {
|
||||
const { stopScroll } = useStickToBottomContext();
|
||||
const stopScrollRef = useRef(stopScroll);
|
||||
|
||||
useEffect(() => {
|
||||
stopScrollRef.current = stopScroll;
|
||||
}, [stopScroll]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isStreaming) return;
|
||||
|
||||
const interval = setInterval(() => {
|
||||
if (shouldStopScrollRef.current) {
|
||||
stopScrollRef.current();
|
||||
}
|
||||
}, 16);
|
||||
|
||||
if (shouldStopScrollRef.current) {
|
||||
stopScrollRef.current();
|
||||
}
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [isStreaming, shouldStopScrollRef]);
|
||||
|
||||
return <>{children}</>;
|
||||
};
|
||||
|
||||
export type ConversationContentProps = ComponentProps<
|
||||
typeof StickToBottom.Content
|
||||
> & {
|
||||
isStreaming?: boolean;
|
||||
};
|
||||
|
||||
export const ConversationContent = ({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: ConversationContentProps) => {
|
||||
return (
|
||||
<StickToBottom.Content
|
||||
className={clsx("flex flex-col", className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</StickToBottom.Content>
|
||||
);
|
||||
};
|
||||
|
||||
// Spacer component that can be placed anywhere in your content
|
||||
export const ConversationSpacer = () => {
|
||||
const context = useContext(ConversationControlContext);
|
||||
const spacerHeight = context?.spacerHeight ?? 0;
|
||||
|
||||
console.log("🎨 Spacer render - height:", spacerHeight);
|
||||
|
||||
if (spacerHeight === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
height: `${spacerHeight}px`,
|
||||
flexShrink: 0,
|
||||
backgroundColor: "rgba(255,0,0,0.1)", // Temporary for debugging
|
||||
}}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export type ConversationScrollButtonProps = ComponentProps<"button">;
|
||||
|
||||
export const ConversationScrollButton = ({
|
||||
className,
|
||||
...props
|
||||
}: ConversationScrollButtonProps) => {
|
||||
const { isAtBottom, scrollToBottom } = useStickToBottomContext();
|
||||
const context = useContext(ConversationControlContext);
|
||||
|
||||
const handleScrollToBottom = useCallback(() => {
|
||||
context?.allowScroll();
|
||||
scrollToBottom();
|
||||
}, [scrollToBottom, context]);
|
||||
|
||||
// Show button if not at bottom AND spacer is not active (height is 0)
|
||||
const shouldShowButton = !isAtBottom && (context?.spacerHeight ?? 0) === 0;
|
||||
|
||||
return (
|
||||
shouldShowButton && (
|
||||
<button
|
||||
className={clsx(
|
||||
"absolute bottom-4 left-[50%] translate-x-[-50%] rounded-full z-50",
|
||||
"bg-white dark:bg-neutral-800 border border-neutral-200 dark:border-neutral-700",
|
||||
"p-1 shadow-lg hover:shadow-xl transition-all",
|
||||
"text-neutral-700 dark:text-neutral-200 hover:scale-105",
|
||||
"hover:cursor-pointer",
|
||||
className,
|
||||
)}
|
||||
onClick={handleScrollToBottom}
|
||||
type="button"
|
||||
aria-label="Scroll to bottom"
|
||||
{...props}
|
||||
>
|
||||
<svg
|
||||
className="w-5 h-5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<polyline points="6 9 12 15 18 9" />
|
||||
</svg>
|
||||
</button>
|
||||
)
|
||||
);
|
||||
};
|
||||
8
app/ui/app/src/components/ai-elements/index.ts
Normal file
8
app/ui/app/src/components/ai-elements/index.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
export {
|
||||
Conversation,
|
||||
ConversationContent,
|
||||
ConversationScrollButton,
|
||||
type ConversationProps,
|
||||
type ConversationContentProps,
|
||||
type ConversationScrollButtonProps,
|
||||
} from "./conversation";
|
||||
@@ -7,7 +7,6 @@ import { createQueryBatcher } from "./useQueryBatcher";
|
||||
import { useRefetchModels } from "./useModels";
|
||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { getModelCapabilities } from "@/api";
|
||||
|
||||
export const useChats = () => {
|
||||
return useQuery({
|
||||
@@ -607,24 +606,6 @@ export const useSendMessage = (chatId: string) => {
|
||||
queryClient.setQueryData(["staleModels"], newStaleMap);
|
||||
|
||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
|
||||
// Fetch fresh capabilities for the downloaded model
|
||||
getModelCapabilities(selectedModel.model)
|
||||
.then((capabilities) => {
|
||||
queryClient.setQueryData(
|
||||
["modelCapabilities", selectedModel.model],
|
||||
capabilities,
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error(
|
||||
"Failed to fetch capabilities after download:",
|
||||
error,
|
||||
);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["modelCapabilities", selectedModel.model],
|
||||
});
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
114
app/ui/app/src/hooks/useDownloadModel.ts
Normal file
114
app/ui/app/src/hooks/useDownloadModel.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useState } from "react";
|
||||
import { pullModel } from "@/api";
|
||||
import { useSelectedModel } from "./useSelectedModel";
|
||||
import { useSettings } from "./useSettings";
|
||||
|
||||
interface DownloadProgress {
|
||||
status: string;
|
||||
digest?: string;
|
||||
total?: number;
|
||||
completed?: number;
|
||||
done?: boolean;
|
||||
}
|
||||
|
||||
export function useDownloadModel(chatId?: string) {
|
||||
const queryClient = useQueryClient();
|
||||
const { selectedModel } = useSelectedModel(chatId);
|
||||
const { setSettings } = useSettings();
|
||||
const [downloadProgress, setDownloadProgress] =
|
||||
useState<DownloadProgress | null>(null);
|
||||
const [abortController, setAbortController] =
|
||||
useState<AbortController | null>(null);
|
||||
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
|
||||
new Set(),
|
||||
);
|
||||
|
||||
const mutation = useMutation({
|
||||
mutationFn: async (modelName: string) => {
|
||||
const controller = new AbortController();
|
||||
setAbortController(controller);
|
||||
setDownloadProgress({ status: "Starting download..." });
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
|
||||
}
|
||||
|
||||
try {
|
||||
for await (const progress of pullModel(modelName, controller.signal)) {
|
||||
setDownloadProgress(progress);
|
||||
|
||||
if (progress.status === "success") {
|
||||
// Update selected model to indicate it's now available locally
|
||||
if (selectedModel && selectedModel.model === modelName) {
|
||||
setSettings({ SelectedModel: modelName });
|
||||
}
|
||||
// Invalidate models query to refresh the list
|
||||
await queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||
break;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
setAbortController(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
onSuccess: () => {
|
||||
setDownloadProgress(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
const status =
|
||||
error.name === "AbortError" ? "Download cancelled" : "Download failed";
|
||||
setDownloadProgress({ status, done: true });
|
||||
|
||||
// Clear error message after delay
|
||||
const delay = error.name === "AbortError" ? 1500 : 3000;
|
||||
setTimeout(() => {
|
||||
setDownloadProgress(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}, delay);
|
||||
},
|
||||
});
|
||||
|
||||
const cancelDownload = () => {
|
||||
if (abortController) {
|
||||
abortController.abort();
|
||||
setAbortController(null);
|
||||
if (chatId) {
|
||||
setDownloadingChatIds((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(chatId);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
downloadModel: mutation.mutate,
|
||||
isDownloading:
|
||||
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
|
||||
downloadProgress:
|
||||
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
|
||||
error: mutation.error,
|
||||
cancelDownload,
|
||||
};
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
import {
|
||||
useRef,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useState,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import type { Message } from "@/gotypes";
|
||||
|
||||
// warning: this file is all claude code, needs to be looked into more closely
|
||||
|
||||
interface UseMessageAutoscrollOptions {
|
||||
messages: Message[];
|
||||
isStreaming: boolean;
|
||||
chatId: string;
|
||||
}
|
||||
|
||||
interface MessageAutoscrollBehavior {
|
||||
handleNewUserMessage: () => void;
|
||||
containerRef: React.RefObject<HTMLElement | null>;
|
||||
spacerHeight: number;
|
||||
}
|
||||
|
||||
export const useMessageAutoscroll = ({
|
||||
messages,
|
||||
isStreaming,
|
||||
chatId,
|
||||
}: UseMessageAutoscrollOptions): MessageAutoscrollBehavior => {
|
||||
const containerRef = useRef<HTMLElement | null>(null);
|
||||
const pendingScrollToUserMessage = useRef(false);
|
||||
const [spacerHeight, setSpacerHeight] = useState(0);
|
||||
const lastScrollHeightRef = useRef(0);
|
||||
const lastScrollTopRef = useRef(0);
|
||||
const [isActiveInteraction, setIsActiveInteraction] = useState(false);
|
||||
const [hasSubmittedMessage, setHasSubmittedMessage] = useState(false);
|
||||
const prevChatIdRef = useRef<string>(chatId);
|
||||
|
||||
// Find the last user message index from React state
|
||||
const getLastUserMessageIndex = useCallback(() => {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === "user") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}, [messages]);
|
||||
|
||||
const scrollToMessage = useCallback((messageIndex: number) => {
|
||||
if (!containerRef.current || messageIndex < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const container = containerRef.current;
|
||||
// select the exact element by its data-message-index to avoid index mismatches
|
||||
const targetElement = container.querySelector(
|
||||
`[data-message-index="${messageIndex}"]`,
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (!targetElement) return;
|
||||
|
||||
const containerHeight = container.clientHeight;
|
||||
const containerStyle = window.getComputedStyle(container);
|
||||
const paddingTop = parseFloat(containerStyle.paddingTop) || 0;
|
||||
const scrollHeight = container.scrollHeight;
|
||||
const messageHeight = targetElement.offsetHeight;
|
||||
|
||||
// Check if the message is large, which is 70% of the container height
|
||||
const isLarge = messageHeight > containerHeight * 0.7;
|
||||
|
||||
let targetPosition: number = targetElement.offsetTop - paddingTop; // default to scrolling the message to the top of the window
|
||||
|
||||
if (isLarge) {
|
||||
// when the message is large scroll to the bottom of it
|
||||
targetPosition = scrollHeight - containerHeight;
|
||||
}
|
||||
|
||||
// Ensure we don't scroll past content boundaries
|
||||
const maxScroll = scrollHeight - containerHeight;
|
||||
const finalPosition = Math.min(Math.max(0, targetPosition), maxScroll);
|
||||
|
||||
container.scrollTo({
|
||||
top: finalPosition,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Calculate and set the spacer height based on container dimensions
|
||||
const updateSpacerHeight = useCallback(() => {
|
||||
if (!containerRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const containerHeight = containerRef.current.clientHeight;
|
||||
|
||||
// Find the last user message to calculate spacer for
|
||||
const lastUserIndex = getLastUserMessageIndex();
|
||||
|
||||
if (lastUserIndex < 0) {
|
||||
setSpacerHeight(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const messageElements = containerRef.current.querySelectorAll(
|
||||
"[data-message-index]",
|
||||
) as NodeListOf<HTMLElement>;
|
||||
|
||||
if (!messageElements || messageElements.length === 0) {
|
||||
setSpacerHeight(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const targetElement = containerRef.current.querySelector(
|
||||
`[data-message-index="${lastUserIndex}"]`,
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (!targetElement) {
|
||||
setSpacerHeight(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const elementsAfter = Array.from(messageElements).filter((el) => {
|
||||
const idx = Number(el.dataset.messageIndex);
|
||||
return Number.isFinite(idx) && idx > lastUserIndex;
|
||||
});
|
||||
|
||||
const contentHeightAfterTarget = elementsAfter.reduce(
|
||||
(sum, el) => sum + el.offsetHeight,
|
||||
0,
|
||||
);
|
||||
|
||||
// Calculate the spacer height needed to position the user message at the top
|
||||
// Add extra space for assistant response area
|
||||
const targetMessageHeight = targetElement.offsetHeight;
|
||||
|
||||
// Calculate spacer to position the last user message at the top
|
||||
// For new messages, we want them to appear at the top regardless of content after
|
||||
// For large messages, we want to preserve the scroll-to-bottom behavior
|
||||
// which shows part of the message and space for streaming response
|
||||
let baseHeight: number;
|
||||
|
||||
if (contentHeightAfterTarget === 0) {
|
||||
// No content after the user message (new message case)
|
||||
// Position it at the top with some padding
|
||||
baseHeight = Math.max(0, containerHeight - targetMessageHeight);
|
||||
} else {
|
||||
// Content exists after the user message
|
||||
// Calculate spacer to position user message at top
|
||||
baseHeight = Math.max(
|
||||
0,
|
||||
containerHeight - contentHeightAfterTarget - targetMessageHeight,
|
||||
);
|
||||
}
|
||||
|
||||
// Only apply spacer height when actively interacting (streaming or pending new message)
|
||||
// When just viewing a chat, don't add extra space
|
||||
if (!isActiveInteraction) {
|
||||
setSpacerHeight(0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Add extra space for assistant response only when streaming
|
||||
const extraSpaceForAssistant = isStreaming ? containerHeight * 0.4 : 0;
|
||||
const calculatedHeight = baseHeight + extraSpaceForAssistant;
|
||||
|
||||
setSpacerHeight(calculatedHeight);
|
||||
}, [getLastUserMessageIndex, isStreaming, isActiveInteraction]);
|
||||
|
||||
// Handle new user message submission
|
||||
const handleNewUserMessage = useCallback(() => {
|
||||
// Mark that we're expecting a new message and should scroll to it
|
||||
pendingScrollToUserMessage.current = true;
|
||||
setIsActiveInteraction(true);
|
||||
setHasSubmittedMessage(true);
|
||||
}, []);
|
||||
|
||||
// Use layoutEffect to scroll immediately after DOM updates
|
||||
useLayoutEffect(() => {
|
||||
if (pendingScrollToUserMessage.current) {
|
||||
// Find the last user message from current state
|
||||
const targetUserIndex = getLastUserMessageIndex();
|
||||
|
||||
if (targetUserIndex >= 0) {
|
||||
requestAnimationFrame(() => {
|
||||
updateSpacerHeight();
|
||||
requestAnimationFrame(() => {
|
||||
scrollToMessage(targetUserIndex);
|
||||
pendingScrollToUserMessage.current = false;
|
||||
});
|
||||
});
|
||||
} else {
|
||||
pendingScrollToUserMessage.current = false;
|
||||
// Reset active interaction if no target found
|
||||
setIsActiveInteraction(isStreaming);
|
||||
}
|
||||
}
|
||||
}, [
|
||||
messages,
|
||||
getLastUserMessageIndex,
|
||||
scrollToMessage,
|
||||
updateSpacerHeight,
|
||||
isStreaming,
|
||||
]);
|
||||
|
||||
// Update active interaction state based on streaming and message submission
|
||||
useEffect(() => {
|
||||
if (
|
||||
isStreaming ||
|
||||
pendingScrollToUserMessage.current ||
|
||||
hasSubmittedMessage
|
||||
) {
|
||||
setIsActiveInteraction(true);
|
||||
} else {
|
||||
setIsActiveInteraction(false);
|
||||
}
|
||||
}, [isStreaming, hasSubmittedMessage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (prevChatIdRef.current !== chatId) {
|
||||
setIsActiveInteraction(false);
|
||||
setHasSubmittedMessage(false);
|
||||
prevChatIdRef.current = chatId;
|
||||
}
|
||||
}, [chatId]);
|
||||
|
||||
// Recalculate spacer height when messages change
|
||||
useEffect(() => {
|
||||
updateSpacerHeight();
|
||||
}, [messages, updateSpacerHeight]);
|
||||
|
||||
// Use ResizeObserver to handle dynamic content changes
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return;
|
||||
|
||||
let resizeTimeout: ReturnType<typeof setTimeout>;
|
||||
let immediateUpdate = false;
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
// Check if this is a significant height change (like collapsing content)
|
||||
let hasSignificantChange = false;
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
if (
|
||||
element.dataset.messageIndex &&
|
||||
entry.contentRect.height !== element.offsetHeight
|
||||
) {
|
||||
const heightDiff = Math.abs(
|
||||
entry.contentRect.height - element.offsetHeight,
|
||||
);
|
||||
if (heightDiff > 50) {
|
||||
hasSignificantChange = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For significant changes, update immediately
|
||||
if (hasSignificantChange || immediateUpdate) {
|
||||
updateSpacerHeight();
|
||||
immediateUpdate = false;
|
||||
} else {
|
||||
// For small changes (like streaming text), debounce
|
||||
clearTimeout(resizeTimeout);
|
||||
resizeTimeout = setTimeout(() => {
|
||||
updateSpacerHeight();
|
||||
}, 100);
|
||||
}
|
||||
});
|
||||
|
||||
// Also use MutationObserver for immediate attribute changes
|
||||
const mutationObserver = new MutationObserver((mutations) => {
|
||||
// Check if any mutations are related to expanding/collapsing
|
||||
const hasToggle = mutations.some(
|
||||
(mutation) =>
|
||||
mutation.type === "attributes" &&
|
||||
(mutation.attributeName === "class" ||
|
||||
mutation.attributeName === "style" ||
|
||||
mutation.attributeName === "open" ||
|
||||
mutation.attributeName === "data-expanded"),
|
||||
);
|
||||
|
||||
if (hasToggle) {
|
||||
immediateUpdate = true;
|
||||
updateSpacerHeight();
|
||||
}
|
||||
});
|
||||
|
||||
// Observe the container and all messages
|
||||
resizeObserver.observe(containerRef.current);
|
||||
mutationObserver.observe(containerRef.current, {
|
||||
attributes: true,
|
||||
subtree: true,
|
||||
attributeFilter: ["class", "style", "open", "data-expanded"],
|
||||
});
|
||||
|
||||
// Observe all message elements for size changes
|
||||
const messageElements = containerRef.current.querySelectorAll(
|
||||
"[data-message-index]",
|
||||
);
|
||||
messageElements.forEach((element) => {
|
||||
resizeObserver.observe(element);
|
||||
});
|
||||
|
||||
return () => {
|
||||
clearTimeout(resizeTimeout);
|
||||
resizeObserver.disconnect();
|
||||
mutationObserver.disconnect();
|
||||
};
|
||||
}, [messages, updateSpacerHeight]);
|
||||
|
||||
// Track scroll position
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return;
|
||||
|
||||
const container = containerRef.current;
|
||||
const handleScroll = () => {
|
||||
lastScrollTopRef.current = container.scrollTop;
|
||||
lastScrollHeightRef.current = container.scrollHeight;
|
||||
};
|
||||
|
||||
container.addEventListener("scroll", handleScroll);
|
||||
|
||||
// Initialize scroll tracking
|
||||
lastScrollTopRef.current = container.scrollTop;
|
||||
lastScrollHeightRef.current = container.scrollHeight;
|
||||
|
||||
return () => {
|
||||
container.removeEventListener("scroll", handleScroll);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
pendingScrollToUserMessage.current = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
return useMemo(
|
||||
() => ({
|
||||
handleNewUserMessage,
|
||||
containerRef,
|
||||
spacerHeight,
|
||||
}),
|
||||
[handleNewUserMessage, containerRef, spacerHeight],
|
||||
);
|
||||
};
|
||||
@@ -1,20 +1,29 @@
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
||||
|
||||
export function useUser() {
|
||||
const queryClient = useQueryClient();
|
||||
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
|
||||
|
||||
// Wait for initial data to be loaded
|
||||
useEffect(() => {
|
||||
const initialPromise = window.__initialUserDataPromise;
|
||||
if (initialPromise) {
|
||||
initialPromise.finally(() => {
|
||||
setInitialDataLoaded(true);
|
||||
});
|
||||
} else {
|
||||
setInitialDataLoaded(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const userQuery = useQuery({
|
||||
queryKey: ["user"],
|
||||
queryFn: async () => {
|
||||
const result = await fetchUser();
|
||||
return result;
|
||||
},
|
||||
queryFn: () => fetchUser(),
|
||||
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
||||
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
||||
retry: 10,
|
||||
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
|
||||
refetchOnMount: true, // Always fetch when component mounts
|
||||
initialData: null, // Start with null to prevent flashing
|
||||
});
|
||||
|
||||
// Mutation to refresh user data
|
||||
@@ -40,15 +49,14 @@ export function useUser() {
|
||||
},
|
||||
});
|
||||
|
||||
const isLoading = userQuery.isLoading || userQuery.isFetching;
|
||||
const isAuthenticated = Boolean(userQuery.data?.name);
|
||||
|
||||
return {
|
||||
user: userQuery.data,
|
||||
isLoading,
|
||||
isLoading:
|
||||
!initialDataLoaded ||
|
||||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
|
||||
isError: userQuery.isError,
|
||||
error: userQuery.error,
|
||||
isAuthenticated,
|
||||
isAuthenticated: Boolean(userQuery.data?.name),
|
||||
refreshUser: refreshUser.mutate,
|
||||
isRefreshing: refreshUser.isPending,
|
||||
refetchUser: userQuery.refetch,
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
// API configuration
|
||||
const DEV_API_URL = "http://127.0.0.1:3001";
|
||||
|
||||
// Base URL for fetch API calls (can be relative in production)
|
||||
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
|
||||
|
||||
// Full host URL for Ollama client (needs full origin in production)
|
||||
export const OLLAMA_HOST = import.meta.env.DEV
|
||||
? DEV_API_URL
|
||||
: window.location.origin;
|
||||
|
||||
export const OLLAMA_DOT_COM =
|
||||
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Ollama } from "ollama/browser";
|
||||
import { OLLAMA_HOST } from "./config";
|
||||
|
||||
let _ollamaClient: Ollama | null = null;
|
||||
|
||||
@@ -7,7 +6,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
|
||||
get(_target, prop) {
|
||||
if (!_ollamaClient) {
|
||||
_ollamaClient = new Ollama({
|
||||
host: OLLAMA_HOST,
|
||||
host: window.location.origin,
|
||||
});
|
||||
}
|
||||
const value = _ollamaClient[prop as keyof Ollama];
|
||||
|
||||
5
app/ui/app/src/lib/utils.ts
Normal file
5
app/ui/app/src/lib/utils.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
import { clsx, type ClassValue } from "clsx";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return clsx(inputs);
|
||||
}
|
||||
@@ -5,6 +5,13 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { routeTree } from "./routeTree.gen";
|
||||
import { fetchUser } from "./api";
|
||||
import { StreamingProvider } from "./contexts/StreamingContext";
|
||||
import { User } from "@/gotypes";
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
__initialUserDataPromise?: Promise<User | null>;
|
||||
}
|
||||
}
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
@@ -17,11 +24,27 @@ const queryClient = new QueryClient({
|
||||
},
|
||||
});
|
||||
|
||||
fetchUser().then((userData) => {
|
||||
if (userData) {
|
||||
// Track initial user data fetch
|
||||
let initialUserDataPromise: Promise<User | null> | null = null;
|
||||
|
||||
// Initialize user data on app startup
|
||||
const initializeUserData = async () => {
|
||||
try {
|
||||
const userData = await fetchUser();
|
||||
queryClient.setQueryData(["user"], userData);
|
||||
return userData;
|
||||
} catch (error) {
|
||||
console.error("Error initializing user data:", error);
|
||||
queryClient.setQueryData(["user"], null);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Start initialization immediately and track the promise
|
||||
initialUserDataPromise = initializeUserData();
|
||||
|
||||
// Export the promise so hooks can await it
|
||||
window.__initialUserDataPromise = initialUserDataPromise;
|
||||
|
||||
const router = createRouter({
|
||||
routeTree,
|
||||
|
||||
@@ -101,14 +101,15 @@ type HealthResponse struct {
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
AvatarURL string `json:"avatarurl,omitempty"`
|
||||
FirstName string `json:"firstname,omitempty"`
|
||||
LastName string `json:"lastname,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
Plan string `json:"plan"`
|
||||
Bio string `json:"bio"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
OverThreshold bool `json:"overThreshold"`
|
||||
}
|
||||
|
||||
type Attachment struct {
|
||||
|
||||
241
app/ui/ui.go
241
app/ui/ui.go
@@ -12,17 +12,18 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/auth"
|
||||
"github.com/ollama/ollama/app/server"
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
@@ -117,66 +118,40 @@ func (s *Server) log() *slog.Logger {
|
||||
|
||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||
func (s *Server) ollamaProxy() http.Handler {
|
||||
var (
|
||||
proxy http.Handler
|
||||
proxyMu sync.Mutex
|
||||
)
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
if ollamaHost == "" {
|
||||
ollamaHost = "http://127.0.0.1:11434"
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyMu.Lock()
|
||||
p := proxy
|
||||
proxyMu.Unlock()
|
||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
||||
ollamaHost = "http://" + ollamaHost
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
proxyMu.Lock()
|
||||
if proxy == nil {
|
||||
var err error
|
||||
for i := range 2 {
|
||||
if i > 0 {
|
||||
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
target, err := url.Parse(ollamaHost)
|
||||
if err != nil {
|
||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
|
||||
err = WaitForServer(context.Background(), 10*time.Second)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
|
||||
if err != nil {
|
||||
proxyMu.Unlock()
|
||||
s.log().Error("ollama server not ready after retries", "error", err)
|
||||
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
target := envconfig.Host()
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
|
||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
originalDirector := newProxy.Director
|
||||
newProxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = target.Host
|
||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||
}
|
||||
|
||||
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
proxy = newProxy
|
||||
p = newProxy
|
||||
} else {
|
||||
p = proxy
|
||||
}
|
||||
proxyMu.Unlock()
|
||||
}
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
})
|
||||
return proxy
|
||||
}
|
||||
|
||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||
@@ -289,10 +264,11 @@ func (s *Server) Handler() http.Handler {
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
mux.Handle("GET /api/tags", ollamaProxy)
|
||||
mux.Handle("POST /api/show", ollamaProxy)
|
||||
mux.Handle("GET /api/version", ollamaProxy)
|
||||
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||
mux.Handle("POST /api/me", ollamaProxy)
|
||||
mux.Handle("POST /api/signout", ollamaProxy)
|
||||
|
||||
mux.Handle("GET /api/v1/me", handle(s.me))
|
||||
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
|
||||
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
|
||||
mux.Handle("GET /api/v1/health", handle(s.health))
|
||||
|
||||
// React app - catch all non-API routes and serve the React app
|
||||
mux.Handle("GET /", s.appHandler())
|
||||
@@ -362,7 +338,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
|
||||
}
|
||||
|
||||
// UserData fetches user data from ollama.com API for the current ollama key
|
||||
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
||||
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
||||
@@ -373,7 +349,7 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var user api.UserResponse
|
||||
var user responses.User
|
||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
||||
}
|
||||
@@ -392,27 +368,29 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// WaitForServer waits for the Ollama server to be ready
|
||||
func WaitForServer(ctx context.Context, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
func waitForServer(ctx context.Context) error {
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
// TODO: this avoids an error on first load of the app
|
||||
// however we should either show a loading state or
|
||||
// wait for the Ollama server to be ready before redirecting
|
||||
for {
|
||||
c, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.Version(ctx); err == nil {
|
||||
slog.Debug("ollama server is ready")
|
||||
return nil
|
||||
break
|
||||
}
|
||||
if time.Now().After(timeout) {
|
||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
return errors.New("timeout waiting for Ollama server to be ready")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
waitForServer(r.Context())
|
||||
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
@@ -1460,6 +1438,129 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.UserData(r.Context())
|
||||
if err != nil {
|
||||
// If fetching from API fails, try to return cached user data if available
|
||||
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
|
||||
s.log().Info("API request failed, returning cached user data", "error", err)
|
||||
responseUser := &responses.User{
|
||||
Name: cachedUser.Name,
|
||||
Email: cachedUser.Email,
|
||||
Plan: cachedUser.Plan,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(responseUser)
|
||||
}
|
||||
|
||||
s.log().Error("failed to get user data", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to get user data",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.Store.ClearUser(); err != nil {
|
||||
s.log().Warn("failed to clear cached user data", "error", err)
|
||||
}
|
||||
|
||||
// Get the SSH public key to encode for the delete request
|
||||
pubKey, err := ollamaAuth.GetPublicKey()
|
||||
if err != nil {
|
||||
s.log().Error("failed to get public key", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to get public key",
|
||||
})
|
||||
}
|
||||
|
||||
// Encode the key using base64 URL encoding
|
||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
|
||||
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
|
||||
if err != nil {
|
||||
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to disconnect from ollama.com",
|
||||
})
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
s.log().Error("disconnect request failed", "status", resp.StatusCode)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to disconnect from ollama.com",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
|
||||
}
|
||||
|
||||
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
|
||||
if err != nil {
|
||||
s.log().Error("failed to build connect URL", "error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return json.NewEncoder(w).Encode(responses.Error{
|
||||
Error: "failed to build connect URL",
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(map[string]string{
|
||||
"connect_url": connectURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
healthy := false
|
||||
c, err := api.ClientFromEnvironment()
|
||||
if err == nil {
|
||||
if _, err := c.Version(r.Context()); err == nil {
|
||||
healthy = true
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return json.NewEncoder(w).Encode(responses.HealthResponse{
|
||||
Healthy: healthy,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
||||
case uint32(UI_REQUEST_MSG_ID):
|
||||
// Requests for the UI must always come from the main event thread
|
||||
l := int(wParam)
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
|
||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
|
||||
t.app.UIRun(path)
|
||||
case WM_COPYDATA:
|
||||
// Handle URL scheme requests from other instances
|
||||
if lParam != 0 {
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
|
||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||
// Convert the data back to string
|
||||
data := make([]byte, cds.CbData)
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
|
||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
|
||||
urlScheme := string(data)
|
||||
handleURLSchemeRequest(urlScheme)
|
||||
lResult = 1 // Return non-zero to indicate success
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
Ollama Benchmark Tool
|
||||
---------------------
|
||||
|
||||
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
||||
|
||||
## Features
|
||||
|
||||
* Benchmark multiple models in a single run
|
||||
* Support for both text and image prompts
|
||||
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||
* Supports benchstat and CSV output formats
|
||||
* Detailed performance metrics (prefill, generate, load, total durations)
|
||||
|
||||
## Building from Source
|
||||
|
||||
```
|
||||
go build -o ollama-bench bench.go
|
||||
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
||||
```
|
||||
|
||||
Using Go Run (without building)
|
||||
|
||||
```
|
||||
go run bench.go -model gpt-oss:20b -epochs 3
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3 -epochs 6
|
||||
```
|
||||
|
||||
### Benchmark Multiple Models
|
||||
|
||||
```
|
||||
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||
benchstat -col /name gemma.bench
|
||||
```
|
||||
|
||||
### With Image Prompt
|
||||
|
||||
```
|
||||
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||
```
|
||||
|
||||
### Advanced Example
|
||||
|
||||
```
|
||||
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
|
||||
| Option | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| -model | Comma-separated list of models to benchmark | (required) |
|
||||
| -epochs | Number of iterations per model | 1 |
|
||||
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||
| -temperature | Temperature parameter | 0.0 |
|
||||
| -seed | Random seed | 0 (random) |
|
||||
| -timeout | Timeout in seconds | 300 |
|
||||
| -p | Prompt text | "Write a long story." |
|
||||
| -image | Image file to include in prompt | |
|
||||
| -k | Keep-alive duration in seconds | 0 |
|
||||
| -format | Output format (benchstat, csv) | benchstat |
|
||||
| -output | Output file for results | "" (stdout) |
|
||||
| -v | Verbose mode | false |
|
||||
| -debug | Show debug information | false |
|
||||
|
||||
## Output Formats
|
||||
|
||||
### Markdown Format
|
||||
|
||||
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
||||
```
|
||||
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
||||
|-------|------|-------|----------|------------|--------------|
|
||||
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
||||
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
||||
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
||||
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
||||
```
|
||||
|
||||
### Benchstat Format
|
||||
|
||||
Compatible with Go's benchstat tool for statistical analysis:
|
||||
|
||||
```
|
||||
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
||||
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
||||
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
||||
```
|
||||
|
||||
### CSV Format
|
||||
|
||||
Machine-readable comma-separated values:
|
||||
|
||||
```
|
||||
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||
gpt-oss:20b,prefill,128,78125.00,12800.00
|
||||
gpt-oss:20b,generate,512,19531.25,51200.00
|
||||
gpt-oss:20b,load,1,1500000000,0
|
||||
```
|
||||
|
||||
## Metrics Explained
|
||||
|
||||
The tool reports four types of metrics for each model:
|
||||
|
||||
* prefill: Time spent processing the prompt
|
||||
* generate: Time spent generating the response
|
||||
* load: Model loading time (one-time cost)
|
||||
* total: Total request duration
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type flagOptions struct {
|
||||
models *string
|
||||
epochs *int
|
||||
maxTokens *int
|
||||
temperature *float64
|
||||
seed *int
|
||||
timeout *int
|
||||
prompt *string
|
||||
imageFile *string
|
||||
keepAlive *float64
|
||||
format *string
|
||||
outputFile *string
|
||||
debug *bool
|
||||
verbose *bool
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
Model string
|
||||
Step string
|
||||
Count int
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
||||
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||
|
||||
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||
switch format {
|
||||
case "benchstat":
|
||||
if verbose {
|
||||
printHeader := func() {
|
||||
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
||||
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
||||
}
|
||||
once.Do(printHeader)
|
||||
}
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
if m.Count > 0 {
|
||||
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
||||
m.Model, m.Step, m.Count)
|
||||
}
|
||||
} else {
|
||||
var suffix string
|
||||
if m.Step == "load" {
|
||||
suffix = "/step=load"
|
||||
}
|
||||
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
||||
m.Model, suffix, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "csv":
|
||||
printHeader := func() {
|
||||
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
var nsPerToken float64
|
||||
var tokensPerSec float64
|
||||
if m.Count > 0 {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
}
|
||||
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||
}
|
||||
}
|
||||
case "markdown":
|
||||
printHeader := func() {
|
||||
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
||||
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
||||
}
|
||||
once.Do(printHeader)
|
||||
|
||||
for _, m := range metrics {
|
||||
var nsPerToken, tokensPerSec float64
|
||||
var nsPerTokenStr, tokensPerSecStr string
|
||||
|
||||
if m.Step == "generate" || m.Step == "prefill" {
|
||||
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
||||
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
||||
} else {
|
||||
nsPerTokenStr = "-"
|
||||
tokensPerSecStr = "-"
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
||||
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
||||
}
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkChat(fOpt flagOptions) error {
|
||||
models := strings.Split(*fOpt.models, ",")
|
||||
|
||||
// todo - add multi-image support
|
||||
var imgData api.ImageData
|
||||
var err error
|
||||
if *fOpt.imageFile != "" {
|
||||
imgData, err = readImage(*fOpt.imageFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if *fOpt.debug && imgData != nil {
|
||||
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var out io.Writer = os.Stdout
|
||||
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
|
||||
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
out = f
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
for range *fOpt.epochs {
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
|
||||
var keepAliveDuration *api.Duration
|
||||
if *fOpt.keepAlive > 0 {
|
||||
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||
keepAliveDuration = &duration
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: *fOpt.prompt,
|
||||
},
|
||||
},
|
||||
Options: options,
|
||||
KeepAlive: keepAliveDuration,
|
||||
}
|
||||
|
||||
if imgData != nil {
|
||||
req.Messages[0].Images = []api.ImageData{imgData}
|
||||
}
|
||||
|
||||
var responseMetrics *api.Metrics
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
||||
}
|
||||
|
||||
if resp.Done {
|
||||
responseMetrics = &resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if *fOpt.debug {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if responseMetrics == nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||
continue
|
||||
}
|
||||
|
||||
metrics := []Metrics{
|
||||
{
|
||||
Model: model,
|
||||
Step: "prefill",
|
||||
Count: responseMetrics.PromptEvalCount,
|
||||
Duration: responseMetrics.PromptEvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "generate",
|
||||
Count: responseMetrics.EvalCount,
|
||||
Duration: responseMetrics.EvalDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "load",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.LoadDuration,
|
||||
},
|
||||
{
|
||||
Model: model,
|
||||
Step: "total",
|
||||
Count: 1,
|
||||
Duration: responseMetrics.TotalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||
|
||||
if *fOpt.keepAlive > 0 {
|
||||
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readImage(filePath string) (api.ImageData, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return api.ImageData(data), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
fOpt := flagOptions{
|
||||
models: flag.String("model", "", "Model to benchmark"),
|
||||
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||
seed: flag.Int("seed", 0, "Random seed"),
|
||||
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
||||
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||
verbose: flag.Bool("v", false, "Show system information"),
|
||||
debug: flag.Bool("debug", false, "Show debug information"),
|
||||
}
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Description:\n")
|
||||
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(*fOpt.models) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
BenchmarkChat(fOpt)
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func createTestFlagOptions() flagOptions {
|
||||
models := "test-model"
|
||||
format := "benchstat"
|
||||
epochs := 1
|
||||
maxTokens := 100
|
||||
temperature := 0.7
|
||||
seed := 42
|
||||
timeout := 30
|
||||
prompt := "test prompt"
|
||||
imageFile := ""
|
||||
keepAlive := 5.0
|
||||
verbose := false
|
||||
debug := false
|
||||
|
||||
return flagOptions{
|
||||
models: &models,
|
||||
format: &format,
|
||||
epochs: &epochs,
|
||||
maxTokens: &maxTokens,
|
||||
temperature: &temperature,
|
||||
seed: &seed,
|
||||
timeout: &timeout,
|
||||
prompt: &prompt,
|
||||
imageFile: &imageFile,
|
||||
keepAlive: &keepAlive,
|
||||
verbose: &verbose,
|
||||
debug: &debug,
|
||||
}
|
||||
}
|
||||
|
||||
func captureOutput(f func()) string {
|
||||
oldStdout := os.Stdout
|
||||
oldStderr := os.Stderr
|
||||
defer func() {
|
||||
os.Stdout = oldStdout
|
||||
os.Stderr = oldStderr
|
||||
}()
|
||||
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
os.Stderr = w
|
||||
|
||||
f()
|
||||
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, r)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
|
||||
http.Error(w, "Not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST method, got %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, resp := range responses {
|
||||
jsonData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
w.Write(jsonData)
|
||||
w.Write([]byte("\n"))
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond) // Simulate some delay
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_Success(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
mockResponses := []api.ChatResponse{
|
||||
{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response part 1",
|
||||
},
|
||||
Done: false,
|
||||
},
|
||||
{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response part 2",
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
PromptEvalDuration: 100 * time.Millisecond,
|
||||
EvalCount: 50,
|
||||
EvalDuration: 500 * time.Millisecond,
|
||||
TotalDuration: 600 * time.Millisecond,
|
||||
LoadDuration: 50 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server := createMockOllamaServer(t, mockResponses)
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
|
||||
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
|
||||
t.Errorf("Expected output to contain generate metrics, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "ns/token") {
|
||||
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_ServerError(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
|
||||
t.Errorf("Expected error message about chat failure, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_Timeout(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
shortTimeout := 1 // Very short timeout
|
||||
fOpt.timeout = &shortTimeout
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate a long delay that will cause timeout
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response",
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
PromptEvalDuration: 100 * time.Millisecond,
|
||||
EvalCount: 50,
|
||||
EvalDuration: 500 * time.Millisecond,
|
||||
TotalDuration: 600 * time.Millisecond,
|
||||
LoadDuration: 50 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
jsonData, _ := json.Marshal(response)
|
||||
w.Write(jsonData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR: Chat request timed out") {
|
||||
t.Errorf("Expected timeout error message, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_NoMetrics(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
mockResponses := []api.ChatResponse{
|
||||
{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response",
|
||||
},
|
||||
Done: false, // Never sends Done=true
|
||||
},
|
||||
}
|
||||
|
||||
server := createMockOllamaServer(t, mockResponses)
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR: No metrics received") {
|
||||
t.Errorf("Expected no metrics error message, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_MultipleModels(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
models := "model1,model2"
|
||||
epochs := 2
|
||||
fOpt.models = &models
|
||||
fOpt.epochs = &epochs
|
||||
|
||||
callCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var req api.ChatRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
response := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response for " + req.Model,
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
PromptEvalDuration: 100 * time.Millisecond,
|
||||
EvalCount: 50,
|
||||
EvalDuration: 500 * time.Millisecond,
|
||||
TotalDuration: 600 * time.Millisecond,
|
||||
LoadDuration: 50 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
jsonData, _ := json.Marshal(response)
|
||||
w.Write(jsonData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Should be called 4 times (2 models × 2 epochs)
|
||||
if callCount != 4 {
|
||||
t.Errorf("Expected 4 API calls, got %d", callCount)
|
||||
}
|
||||
|
||||
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
|
||||
t.Errorf("Expected output for both models, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_WithImage(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
content := []byte("fake image data")
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
tmpfile.Close()
|
||||
|
||||
tmpfileName := tmpfile.Name()
|
||||
fOpt.imageFile = &tmpfileName
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify the request contains image data
|
||||
var req api.ChatRequest
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
|
||||
t.Error("Expected request to contain images")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
response := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: "test response with image",
|
||||
},
|
||||
Done: true,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
PromptEvalDuration: 100 * time.Millisecond,
|
||||
EvalCount: 50,
|
||||
EvalDuration: 500 * time.Millisecond,
|
||||
TotalDuration: 600 * time.Millisecond,
|
||||
LoadDuration: 50 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
jsonData, _ := json.Marshal(response)
|
||||
w.Write(jsonData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", server.URL)
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
|
||||
t.Errorf("Expected benchmark output, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBenchmarkChat_ImageError(t *testing.T) {
|
||||
randFileName := func() string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const length = 8
|
||||
|
||||
result := make([]byte, length)
|
||||
rand.Read(result) // Fill with random bytes
|
||||
|
||||
for i := range result {
|
||||
result[i] = charset[result[i]%byte(len(charset))]
|
||||
}
|
||||
|
||||
return string(result) + ".txt"
|
||||
}
|
||||
|
||||
fOpt := createTestFlagOptions()
|
||||
imageFile := randFileName()
|
||||
fOpt.imageFile = &imageFile
|
||||
|
||||
output := captureOutput(func() {
|
||||
err := BenchmarkChat(fOpt)
|
||||
if err == nil {
|
||||
t.Error("Expected error from image reading, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
if !strings.Contains(output, "ERROR: Couldn't read image") {
|
||||
t.Errorf("Expected image read error message, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadImage_Success(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
content := []byte("fake image data")
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write to temp file: %v", err)
|
||||
}
|
||||
tmpfile.Close()
|
||||
|
||||
imgData, err := readImage(tmpfile.Name())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if imgData == nil {
|
||||
t.Error("Expected image data, got nil")
|
||||
}
|
||||
|
||||
expected := api.ImageData(content)
|
||||
if string(imgData) != string(expected) {
|
||||
t.Errorf("Expected image data %v, got %v", expected, imgData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadImage_FileNotFound(t *testing.T) {
|
||||
imgData, err := readImage("nonexistentfile.jpg")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent file, got nil")
|
||||
}
|
||||
if imgData != nil {
|
||||
t.Error("Expected nil image data for non-existent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsMapCreation(t *testing.T) {
|
||||
fOpt := createTestFlagOptions()
|
||||
|
||||
options := make(map[string]interface{})
|
||||
if *fOpt.maxTokens > 0 {
|
||||
options["num_predict"] = *fOpt.maxTokens
|
||||
}
|
||||
options["temperature"] = *fOpt.temperature
|
||||
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||
options["seed"] = *fOpt.seed
|
||||
}
|
||||
|
||||
if options["num_predict"] != *fOpt.maxTokens {
|
||||
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
|
||||
}
|
||||
if options["temperature"] != *fOpt.temperature {
|
||||
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
|
||||
}
|
||||
if options["seed"] != *fOpt.seed {
|
||||
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
|
||||
}
|
||||
}
|
||||
@@ -943,9 +943,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
}
|
||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||
if resp.Requires != "" {
|
||||
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
@@ -1433,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
@@ -291,31 +291,6 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("min version", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := ` Model
|
||||
architecture test
|
||||
parameters 7B
|
||||
quantization FP16
|
||||
requires 0.14.0
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
|
||||
@@ -182,8 +182,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "Ministral3ForCausalLM":
|
||||
conv = &mistral3CausalModel{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
@@ -202,20 +200,12 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "Olmo3ForCausalLM":
|
||||
conv = &olmoModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "NomicBertModel", "NomicBertMoEModel":
|
||||
conv = &nomicbertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
case "DeepseekOCRForCausalLM":
|
||||
conv = &deepseekocr{}
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type deepseek2Model struct {
|
||||
ModelParameters // architectures, vocab_size
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||
QLoraRank uint32 `json:"q_lora_rank"`
|
||||
VHeadDim uint32 `json:"v_head_dim"`
|
||||
|
||||
ExpertCount uint32 `json:"n_routed_experts"`
|
||||
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||
|
||||
ScoringFunc string `json:"scoring_func"`
|
||||
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
Type string `json:"type"`
|
||||
MScaleAllDim float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_scaling"`
|
||||
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
kv["deepseek2.block_count"] = p.HiddenLayers
|
||||
|
||||
numHeads := p.NumAttentionHeads
|
||||
numKVHeads := p.NumKeyValueHeads
|
||||
|
||||
kv["deepseek2.attention.head_count"] = numHeads
|
||||
kv["deepseek2.attention.head_count_kv"] = numKVHeads
|
||||
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["deepseek2.embedding_length"] = p.HiddenSize
|
||||
kv["deepseek2.expert_count"] = p.ExpertCount
|
||||
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
|
||||
|
||||
var scoringFunc uint32
|
||||
switch p.ScoringFunc {
|
||||
case "softmax":
|
||||
// not currently supported in the model, but needed for Deepseek-OCR
|
||||
scoringFunc = 1
|
||||
case "sigmoid":
|
||||
scoringFunc = 2
|
||||
}
|
||||
kv["deepseek2.expert_gating_func"] = scoringFunc
|
||||
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
|
||||
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
|
||||
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||
|
||||
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
|
||||
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
|
||||
|
||||
kv["tokenizer.ggml.pre"] = "deepseek-v3"
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"language_model.", "",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||
"self_attn.kv_b_proj", "attn_kv_b",
|
||||
"self_attn.q_a_proj", "attn_q_a",
|
||||
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||
"self_attn.q_b_proj", "attn_q_b",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, p.HiddenLayers*3)
|
||||
for i := range p.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
skipLayer := func(n string, minValue uint32) bool {
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(n)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return uint32(blkNum) >= minValue
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||
slog.Debug("skipping layer", "name", t.Name())
|
||||
continue
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type deepseekocr struct {
|
||||
ModelParameters
|
||||
LanguageConfig struct {
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NumSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||
} `json:"language_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Width struct {
|
||||
Vision struct {
|
||||
Heads uint32 `json:"heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Layers uint32 `json:"layers"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"clip-l-14-224"`
|
||||
Sam struct {
|
||||
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
|
||||
Heads uint32 `json:"heads"`
|
||||
Layers uint32 `json:"layers"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"sam_vit_b"`
|
||||
}
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = m.LanguageConfig.HiddenSize
|
||||
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
|
||||
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
|
||||
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
|
||||
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
|
||||
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
|
||||
|
||||
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
|
||||
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
|
||||
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
|
||||
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
|
||||
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
|
||||
|
||||
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
|
||||
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
|
||||
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
|
||||
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
|
||||
for i := range m.LanguageConfig.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
"model.vision_model", "v",
|
||||
"embeddings.patch_embedding", "patch_embd",
|
||||
"embeddings.class_embedding", "class_embd",
|
||||
"embeddings.position_embedding", "position_embd",
|
||||
"transformer.layers", "blk",
|
||||
|
||||
"model.projector", "mm",
|
||||
"model.image_newline", "mm.image_newline",
|
||||
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||
"model.view_seperator", "mm.view_seperator",
|
||||
|
||||
"model.sam_model.patch_embed.proj", "s.patch_embd",
|
||||
"model.sam_model.pos_embed", "s.position_embd",
|
||||
"model.sam_model.blocks", "s.blk",
|
||||
"model.sam_model.neck", "s.neck",
|
||||
"model.sam_model.net_", "s.net_",
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -27,26 +26,16 @@ type gemma3Model struct {
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
RopeScaling *struct {
|
||||
Type string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
} `json:"rope_scaling"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -92,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||
|
||||
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||
// key (an int) or as the layer_types key (a list of strings).
|
||||
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for i := range numBlocks {
|
||||
var isLocal bool
|
||||
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||
}
|
||||
if !yield(isLocal) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if p.FinalLogitSoftcap > 0 {
|
||||
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||
}
|
||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||
}
|
||||
|
||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||
default:
|
||||
|
||||
@@ -110,12 +110,9 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name = name + ".weight"
|
||||
}
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Name: name + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
@@ -124,12 +121,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1),
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||
}, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "up", 1),
|
||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||
|
||||
@@ -29,17 +29,6 @@ type mistral3Model struct {
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
@@ -52,9 +41,6 @@ type mistral3Model struct {
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
@@ -75,25 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
|
||||
|
||||
if p.TextModel.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
|
||||
}
|
||||
if p.TextModel.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
|
||||
}
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
}
|
||||
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
@@ -105,7 +74,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3CausalModel struct {
|
||||
ModelParameters
|
||||
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
Mscale *float32 `json:"mscale"`
|
||||
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
|
||||
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
|
||||
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
|
||||
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||
|
||||
if p.RopeParameters.Mscale != nil {
|
||||
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
|
||||
}
|
||||
|
||||
if p.RopeParameters.MscaleAllDim != nil {
|
||||
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
|
||||
}
|
||||
|
||||
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
if p.RopeParameters.Llama4ScalingBeta != nil {
|
||||
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.norm", "output_norm",
|
||||
"model.", "",
|
||||
"layers", "blk",
|
||||
"transformer.layers", "blk",
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"embed_tokens", "token_embd",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"multi_modal_projector", "mm",
|
||||
"ffn_norm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||
heads = p.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type nomicbertModel struct {
|
||||
ModelParameters
|
||||
NLayers uint32 `json:"n_layers"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
RopeFreqBase float32 `json:"rope_theta"`
|
||||
normalizeEmbeddings bool
|
||||
PoolingType uint32
|
||||
|
||||
// MoE parameters (only present in v2 models)
|
||||
NumExperts uint32 `json:"num_local_experts"`
|
||||
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
|
||||
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
|
||||
}
|
||||
|
||||
var (
|
||||
_ ModelConverter = (*nomicbertModel)(nil)
|
||||
_ moreParser = (*nomicbertModel)(nil)
|
||||
)
|
||||
|
||||
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "modules.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modules []struct {
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &modules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
if pooling != "" {
|
||||
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pc struct {
|
||||
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
||||
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &pc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pc.PoolingModeMeanTokens {
|
||||
p.PoolingType = 1
|
||||
} else if pc.PoolingModeCLSToken {
|
||||
p.PoolingType = 2
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
arch := "nomic-bert"
|
||||
if p.MoEEveryNLayers > 0 {
|
||||
arch += "-moe"
|
||||
}
|
||||
|
||||
kv["general.architecture"] = arch
|
||||
kv["attention.causal"] = false
|
||||
kv["pooling_type"] = p.PoolingType
|
||||
kv["normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
|
||||
|
||||
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
|
||||
kv["context_length"] = contextLength
|
||||
}
|
||||
|
||||
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
|
||||
kv["embedding_length"] = p.HiddenSize
|
||||
}
|
||||
|
||||
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
|
||||
kv["feed_forward_length"] = p.IntermediateSize
|
||||
}
|
||||
|
||||
if headCount := p.NumAttentionHeads; headCount > 0 {
|
||||
kv["attention.head_count"] = p.NumAttentionHeads
|
||||
}
|
||||
|
||||
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
|
||||
kv["attention.head_count_kv"] = p.NumKeyValueHeads
|
||||
}
|
||||
|
||||
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
|
||||
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||
}
|
||||
|
||||
if p.RopeFreqBase > 0 {
|
||||
kv["rope.freq_base"] = p.RopeFreqBase
|
||||
}
|
||||
|
||||
// MoE specific parameters (only if MoE is enabled)
|
||||
if p.NumExperts > 0 {
|
||||
kv["expert_count"] = p.NumExperts
|
||||
}
|
||||
|
||||
if p.NumExpertsUsed > 0 {
|
||||
kv["expert_used_count"] = p.NumExpertsUsed
|
||||
}
|
||||
|
||||
if p.MoEEveryNLayers > 0 {
|
||||
kv["moe_every_n_layers"] = p.MoEEveryNLayers
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.model"] = "bert"
|
||||
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
||||
|
||||
// convert to phantom space tokens
|
||||
for i, e := range t.Tokens {
|
||||
switch {
|
||||
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
|
||||
// noop - keep special tokens as-is
|
||||
case strings.HasPrefix(e, "##"):
|
||||
t.Tokens[i] = e[2:]
|
||||
default:
|
||||
t.Tokens[i] = "\u2581" + e
|
||||
}
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.tokens"] = t.Tokens
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
"pooler.dense.weight",
|
||||
"pooler.dense.bias",
|
||||
}, t.Name()) {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (nomicbertModel) Replacements() []string {
|
||||
return []string{
|
||||
"encoder.layer", "blk",
|
||||
"encoder.layers", "blk",
|
||||
"embeddings.word_embeddings", "token_embd",
|
||||
"embeddings.token_type_embeddings", "token_types",
|
||||
"embeddings.LayerNorm", "token_embd_norm",
|
||||
|
||||
"attention.self.qkv", "attn_qkv",
|
||||
|
||||
"attention.output.dense", "attn_output",
|
||||
"attention.output.LayerNorm", "attn_output_norm",
|
||||
|
||||
"mlp.up", "ffn_up",
|
||||
"mlp.down", "ffn_down",
|
||||
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.up", "ffn_up_exps",
|
||||
"mlp.experts.down", "ffn_down_exps",
|
||||
|
||||
"intermediate.dense", "ffn_up",
|
||||
"output.dense", "ffn_down",
|
||||
"output.LayerNorm", "layer_output_norm",
|
||||
}
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ropeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||
AttentionFactor float32 `json:"attention_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
RopeType string `json:"rope_type"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
}
|
||||
|
||||
type olmoModel struct {
|
||||
ModelParameters
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||
}
|
||||
|
||||
if p.RopeScaling != nil {
|
||||
if p.RopeScaling.Factor > 0 {
|
||||
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
}
|
||||
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||
}
|
||||
if p.RopeScaling.AttentionFactor > 0 {
|
||||
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||
}
|
||||
if p.RopeScaling.RopeType != "" {
|
||||
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||
}
|
||||
}
|
||||
|
||||
if p.RMSNormEPS > 0 {
|
||||
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
}
|
||||
|
||||
if p.SlidingWindow > 0 {
|
||||
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||
for i, layerType := range p.LayerTypes {
|
||||
slidingPattern[i] = (layerType == "sliding_attention")
|
||||
}
|
||||
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *olmoModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"model.norm", "output_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -44,10 +44,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
t.name == "s.position_embd" ||
|
||||
strings.HasSuffix(t.name, "rel_pos_h") ||
|
||||
strings.HasSuffix(t.name, "rel_pos_w") {
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
// these tensors are always F32
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
@@ -96,10 +96,7 @@ type safetensor struct {
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
kind != tensorKindFP32 {
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,10 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
@@ -96,26 +94,6 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
|
||||
return matched
|
||||
})
|
||||
|
||||
slices.SortStableFunc(matched, func(a, b Tensor) int {
|
||||
x := strings.Split(a.Name(), ".")
|
||||
y := strings.Split(b.Name(), ".")
|
||||
if len(x) != len(y) {
|
||||
return cmp.Compare(len(x), len(y))
|
||||
}
|
||||
|
||||
vals := make([]int, len(x))
|
||||
for i := range x {
|
||||
vals[i] = strings.Compare(x[i], y[i])
|
||||
m, err := strconv.ParseInt(x[i], 0, 0)
|
||||
n, err2 := strconv.ParseInt(y[i], 0, 0)
|
||||
if errors.Join(err, err2) == nil {
|
||||
vals[i] = cmp.Compare(m, n)
|
||||
}
|
||||
}
|
||||
|
||||
return cmp.Or(vals...)
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
|
||||
@@ -3,10 +3,8 @@ package convert
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -953,45 +951,3 @@ func TestMerge(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeOrder(t *testing.T) {
|
||||
for range 8 {
|
||||
t.Run("", func(t *testing.T) {
|
||||
tensors := make([]Tensor, 16)
|
||||
for i := range tensors {
|
||||
tensors[i] = &fakeTensor{
|
||||
name: fmt.Sprintf("layer.%d.weight", i),
|
||||
shape: []uint64{1},
|
||||
data: []float32{float32(i)},
|
||||
}
|
||||
}
|
||||
|
||||
rand.Shuffle(len(tensors), func(i, j int) {
|
||||
tensors[i], tensors[j] = tensors[j], tensors[i]
|
||||
})
|
||||
|
||||
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
|
||||
if len(unmatched) != 0 {
|
||||
t.Error("expected no remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := matched[0].WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var f32s [16]float32
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.IsSorted(f32s[:]) {
|
||||
t.Errorf("merged tensor data is not in order: %+v", f32s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +49,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package discover
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -11,21 +10,12 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
mem, err := getCPUMem()
|
||||
if err != nil {
|
||||
return memInfo{}, err
|
||||
}
|
||||
return getCPUMemByCgroups(mem), nil
|
||||
}
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached, freeSwap uint64
|
||||
f, err := os.Open("/proc/meminfo")
|
||||
@@ -66,32 +56,6 @@ func getCPUMem() (memInfo, error) {
|
||||
return mem, nil
|
||||
}
|
||||
|
||||
func getCPUMemByCgroups(mem memInfo) memInfo {
|
||||
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
|
||||
if err == nil {
|
||||
mem.TotalMemory = total
|
||||
}
|
||||
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
|
||||
if err == nil {
|
||||
mem.FreeMemory = mem.TotalMemory - used
|
||||
}
|
||||
return mem
|
||||
}
|
||||
|
||||
func getUint64ValueFromFile(path string) (uint64, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
return strconv.ParseUint(line, 10, 64)
|
||||
}
|
||||
return 0, errors.New("empty file content")
|
||||
}
|
||||
|
||||
const CpuInfoFilename = "/proc/cpuinfo"
|
||||
|
||||
type linuxCpuInfo struct {
|
||||
@@ -110,41 +74,7 @@ func GetCPUDetails() []CPU {
|
||||
return nil
|
||||
}
|
||||
defer file.Close()
|
||||
cpus := linuxCPUDetails(file)
|
||||
return overwriteThreadCountByLinuxCgroups(cpus)
|
||||
}
|
||||
|
||||
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
|
||||
file, err := os.Open("/sys/fs/cgroup/cpu.max")
|
||||
if err != nil {
|
||||
return cpus
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if sl := strings.Split(line, " "); len(sl) == 2 {
|
||||
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
|
||||
return cpus
|
||||
}
|
||||
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse CPU unit micro secs", "error", err)
|
||||
return cpus
|
||||
}
|
||||
|
||||
threads := int(max(allowdUs/unitUs, 1))
|
||||
|
||||
cpu := cpus[0]
|
||||
cpu.CoreCount = threads
|
||||
cpu.ThreadCount = threads
|
||||
return []CPU{cpu}
|
||||
}
|
||||
}
|
||||
return cpus
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
func linuxCPUDetails(file io.Reader) []CPU {
|
||||
|
||||
@@ -65,11 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
detectIncompatibleLibraries()
|
||||
|
||||
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||
overrideWarnings()
|
||||
|
||||
requested := envconfig.LLMLibrary()
|
||||
jetpack := cudaJetpack()
|
||||
|
||||
@@ -95,13 +90,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
var dirs []string
|
||||
if dir != "" {
|
||||
if requested != "" && filepath.Base(dir) != requested {
|
||||
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir)
|
||||
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
|
||||
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
|
||||
continue
|
||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||
continue
|
||||
@@ -121,7 +113,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
// In the second pass, we more deeply initialize the GPUs to weed out devices that
|
||||
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
|
||||
// Only devices that need verification are included in this pass
|
||||
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices))
|
||||
slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
|
||||
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
@@ -129,25 +121,15 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
supportedMu := sync.Mutex{}
|
||||
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
|
||||
for i := range devices {
|
||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||
if !devices[i].NeedsInitValidation() {
|
||||
// No need to validate, add to the supported map
|
||||
supportedMu.Lock()
|
||||
if _, ok := supported[devices[i].Library]; !ok {
|
||||
supported[devices[i].Library] = make(map[string]map[string]int)
|
||||
}
|
||||
if _, ok := supported[devices[i].Library][libDir]; !ok {
|
||||
supported[devices[i].Library][libDir] = make(map[string]int)
|
||||
}
|
||||
supported[devices[i].Library][libDir][devices[i].ID] = i
|
||||
supportedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
|
||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
|
||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
|
||||
devices[i].AddInitValidation(extraEnvs)
|
||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
slog.Debug("filtering device which didn't fully initialize",
|
||||
@@ -333,8 +315,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
defer cancel()
|
||||
|
||||
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
||||
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
|
||||
devFilter := ml.GetVisibleDevicesEnv(devices, false)
|
||||
devFilter := ml.GetVisibleDevicesEnv(devices)
|
||||
|
||||
for dir := range libDirs {
|
||||
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
||||
@@ -468,37 +449,3 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func overrideWarnings() {
|
||||
anyFound := false
|
||||
m := envconfig.AsMap()
|
||||
for _, k := range []string{
|
||||
"CUDA_VISIBLE_DEVICES",
|
||||
"HIP_VISIBLE_DEVICES",
|
||||
"ROCR_VISIBLE_DEVICES",
|
||||
"GGML_VK_VISIBLE_DEVICES",
|
||||
"GPU_DEVICE_ORDINAL",
|
||||
"HSA_OVERRIDE_GFX_VERSION",
|
||||
} {
|
||||
if e, found := m[k]; found && e.Value != "" {
|
||||
anyFound = true
|
||||
slog.Warn("user overrode visible devices", k, e.Value)
|
||||
}
|
||||
}
|
||||
if anyFound {
|
||||
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||
}
|
||||
}
|
||||
|
||||
func detectIncompatibleLibraries() {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
basePath, err := exec.LookPath("ggml-base.dll")
|
||||
if err != nil || basePath == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
||||
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
* [API Reference](https://docs.ollama.com/api)
|
||||
* [Modelfile Reference](https://docs.ollama.com/modelfile)
|
||||
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
|
||||
* [Anthropic Compatibility](./api/anthropic-compatibility.mdx)
|
||||
|
||||
### Resources
|
||||
|
||||
|
||||
10
docs/api.md
10
docs/api.md
@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
@@ -507,7 +507,7 @@ The `message` object has the following fields:
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
||||
- `template`: (optional) the prompt template for the model
|
||||
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
||||
- `system`: (optional) a string containing the system prompt for the model
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
|
||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
|
||||
- `messages`: (optional) a list of message objects used to create a conversation
|
||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
||||
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
|
||||
Advanced parameters:
|
||||
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `dimensions`: number of dimensions for the embedding
|
||||
|
||||
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Examples
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
---
|
||||
title: Anthropic compatibility
|
||||
---
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python basic.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='llama3.2:3b',
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{'role': 'user', 'content': 'Hello, how are you?'}
|
||||
]
|
||||
)
|
||||
print(message.content[0].text)
|
||||
```
|
||||
|
||||
```javascript basic.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const message = await anthropic.messages.create({
|
||||
model: "llama3.2:3b",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Hello, how are you?" }],
|
||||
});
|
||||
|
||||
console.log(message.content[0].text);
|
||||
```
|
||||
|
||||
```shell basic.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: ollama" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-d '{
|
||||
"model": "llama3.2:3b",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{ "role": "user", "content": "Hello, how are you?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Streaming example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python streaming.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model='llama3.2:3b',
|
||||
max_tokens=1024,
|
||||
messages=[{'role': 'user', 'content': 'Count from 1 to 10'}]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end='', flush=True)
|
||||
```
|
||||
|
||||
```javascript streaming.js
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const anthropic = new Anthropic({
|
||||
baseURL: "http://localhost:11434",
|
||||
apiKey: "ollama",
|
||||
});
|
||||
|
||||
const stream = await anthropic.messages.stream({
|
||||
model: "llama3.2:3b",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "Count from 1 to 10" }],
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (
|
||||
event.type === "content_block_delta" &&
|
||||
event.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(event.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```shell streaming.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama3.2:3b",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [{ "role": "user", "content": "Count from 1 to 10" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tools.py
|
||||
import anthropic
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url='http://localhost:11434',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
message = client.messages.create(
|
||||
model='llama3.2:3b',
|
||||
max_tokens=1024,
|
||||
tools=[
|
||||
{
|
||||
'name': 'get_weather',
|
||||
'description': 'Get the current weather in a location',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The city and state, e.g. San Francisco, CA'
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': "What's the weather in San Francisco?"}]
|
||||
)
|
||||
|
||||
for block in message.content:
|
||||
if block.type == 'tool_use':
|
||||
print(f'Tool: {block.name}')
|
||||
print(f'Input: {block.input}')
|
||||
```
|
||||
|
||||
```shell tools.sh
|
||||
curl -X POST http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama3.2:3b",
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [{ "role": "user", "content": "What is the weather in San Francisco?" }]
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://docs.anthropic.com/en/docs/claude-code) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model llama3.2:3b
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model llama3.2:3b
|
||||
claude --model qwen3:8b
|
||||
claude --model deepseek-r1:14b
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/v1/messages`
|
||||
|
||||
#### Supported features
|
||||
|
||||
- [x] Messages
|
||||
- [x] Streaming
|
||||
- [x] System prompts
|
||||
- [x] Multi-turn conversations
|
||||
- [x] Vision (images)
|
||||
- [x] Tools (function calling)
|
||||
- [x] Tool results
|
||||
- [x] Thinking/extended thinking
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `max_tokens`
|
||||
- [x] `messages`
|
||||
- [x] Text `content`
|
||||
- [x] Image `content` (base64)
|
||||
- [x] Array of content blocks
|
||||
- [x] `tool_use` blocks
|
||||
- [x] `tool_result` blocks
|
||||
- [x] `thinking` blocks
|
||||
- [x] `system` (string or array)
|
||||
- [x] `stream`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `top_k`
|
||||
- [x] `stop_sequences`
|
||||
- [x] `tools`
|
||||
- [x] `thinking`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `metadata`
|
||||
|
||||
#### Supported response fields
|
||||
|
||||
- [x] `id`
|
||||
- [x] `type`
|
||||
- [x] `role`
|
||||
- [x] `model`
|
||||
- [x] `content` (text, tool_use, thinking blocks)
|
||||
- [x] `stop_reason` (end_turn, max_tokens, tool_use)
|
||||
- [x] `usage` (input_tokens, output_tokens)
|
||||
|
||||
#### Streaming events
|
||||
|
||||
- [x] `message_start`
|
||||
- [x] `content_block_start`
|
||||
- [x] `content_block_delta` (text_delta, input_json_delta, thinking_delta)
|
||||
- [x] `content_block_stop`
|
||||
- [x] `message_delta`
|
||||
- [x] `message_stop`
|
||||
- [x] `ping`
|
||||
- [x] `error`
|
||||
|
||||
## Models
|
||||
|
||||
Before using a model, pull it locally with `ollama pull`:
|
||||
|
||||
```shell
|
||||
ollama pull llama3.2:3b
|
||||
```
|
||||
|
||||
### Default model names
|
||||
|
||||
For tooling that relies on default Anthropic model names such as `claude-3-5-sonnet`, use `ollama cp` to copy an existing model name:
|
||||
|
||||
```shell
|
||||
ollama cp llama3.2:3b claude-3-5-sonnet
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified in the `model` field:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/v1/messages \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Differences from the Anthropic API
|
||||
|
||||
### Behavior differences
|
||||
|
||||
- API key is accepted but not validated
|
||||
- `anthropic-version` header is accepted but not used
|
||||
- Token counts are approximations based on the underlying model's tokenizer
|
||||
|
||||
### Not supported
|
||||
|
||||
The following Anthropic API features are not currently supported:
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| `/v1/messages/count_tokens` | Token counting endpoint |
|
||||
| `tool_choice` | Forcing specific tool use or disabling tools |
|
||||
| `metadata` | Request metadata (user_id) |
|
||||
| Prompt caching | `cache_control` blocks for caching prefixes |
|
||||
| Batches API | `/v1/messages/batches` for async batch processing |
|
||||
| Citations | `citations` content blocks |
|
||||
| PDF support | `document` content blocks with PDF files |
|
||||
| Server-sent errors | `error` events during streaming (errors return HTTP status) |
|
||||
|
||||
### Partial support
|
||||
|
||||
| Feature | Status |
|
||||
|---------|--------|
|
||||
| Image content | Base64 images supported; URL images not supported |
|
||||
| Extended thinking | Basic support; `budget_tokens` accepted but not enforced |
|
||||
File diff suppressed because one or more lines are too long
@@ -13,23 +13,9 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
|
||||
|
||||
## Generate embeddings
|
||||
|
||||
Use `/api/embed` with a single string.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="CLI">
|
||||
Generate embeddings directly from the command line:
|
||||
|
||||
```shell
|
||||
ollama run embeddinggemma "Hello world"
|
||||
```
|
||||
|
||||
You can also pipe text to generate embeddings:
|
||||
|
||||
```shell
|
||||
echo "Hello world" | ollama run embeddinggemma
|
||||
```
|
||||
|
||||
Output is a JSON array.
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/embed \
|
||||
|
||||
@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
|
||||
```shell
|
||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen3",
|
||||
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
|
||||
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
|
||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen3",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the temperature in New York?"},
|
||||
{"role": "user", "content": "What's the temperature in New York?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
|
||||
}
|
||||
return temperatures.get(city, "Unknown")
|
||||
|
||||
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
|
||||
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
|
||||
|
||||
# pass functions directly as tools in the tools list or as a JSON schema
|
||||
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
||||
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
|
||||
},
|
||||
]
|
||||
|
||||
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'qwen3',
|
||||
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
|
||||
return temperatures.get(city, 'Unknown')
|
||||
|
||||
|
||||
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
|
||||
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
|
||||
|
||||
while True:
|
||||
stream = chat(
|
||||
@@ -684,7 +684,7 @@ const getTemperatureTool = {
|
||||
}
|
||||
|
||||
async function agentLoop() {
|
||||
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
||||
|
||||
while (true) {
|
||||
const stream = await ollama.chat({
|
||||
|
||||
@@ -9,9 +9,15 @@ sidebarTitle: Cloud
|
||||
|
||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||
|
||||
### Supported models
|
||||
Ollama currently supports the following cloud models, with more coming soon:
|
||||
|
||||
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud).
|
||||
- `deepseek-v3.1:671b-cloud`
|
||||
- `gpt-oss:20b-cloud`
|
||||
- `gpt-oss:120b-cloud`
|
||||
- `kimi-k2:1t-cloud`
|
||||
- `qwen3-coder:480b-cloud`
|
||||
- `glm-4.6:cloud`
|
||||
- `minimax-m2:cloud`
|
||||
|
||||
### Running Cloud models
|
||||
|
||||
|
||||
@@ -49,8 +49,6 @@ Install prerequisites:
|
||||
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||
- (Optional) NVIDIA GPU support
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
|
||||
Then, configure and build the project:
|
||||
|
||||
@@ -59,17 +57,6 @@ cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
> Building for Vulkan requires VULKAN_SDK environment variable:
|
||||
>
|
||||
> PowerShell
|
||||
> ```powershell
|
||||
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
|
||||
> ```
|
||||
> CMD
|
||||
> ```cmd
|
||||
> set VULKAN_SDK=C:\VulkanSDK\<version>
|
||||
> ```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Building for ROCm requires additional flags:
|
||||
> ```
|
||||
@@ -78,7 +65,6 @@ cmake --build build --config Release
|
||||
> ```
|
||||
|
||||
|
||||
|
||||
Lastly, run Ollama:
|
||||
|
||||
```shell
|
||||
@@ -98,9 +84,7 @@ Install prerequisites:
|
||||
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
||||
- (Optional) NVIDIA GPU support
|
||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
||||
- (Optional) VULKAN GPU support
|
||||
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Ensure prerequisites are in `PATH` before running CMake.
|
||||
|
||||
|
||||
@@ -139,8 +139,7 @@
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
"/api/errors",
|
||||
"/api/openai-compatibility",
|
||||
"/api/anthropic-compatibility"
|
||||
"/api/openai-compatibility"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
17
docs/faq.mdx
17
docs/faq.mdx
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
@@ -57,13 +57,8 @@ ollama ps
|
||||
```
|
||||
|
||||
<Info>
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
NAME ID SIZE PROCESSOR UNTIL
|
||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
|
||||
100% GPU 4 minutes from now ```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -228,7 +223,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
||||
|
||||
## How can I use Ollama in Visual Studio Code?
|
||||
|
||||
There is already a large collection of plugins available for VS Code as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
|
||||
There is already a large collection of plugins available for VSCode as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
|
||||
|
||||
## How do I use Ollama with GPU acceleration in Docker?
|
||||
|
||||
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||
|
||||
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
||||
|
||||
### GPU Selection
|
||||
|
||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> **NOTE:**
|
||||
> [!NOTE]
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> **NOTE:**
|
||||
> [!NOTE]
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
@@ -1,34 +1,34 @@
|
||||
---
|
||||
title: VS Code
|
||||
title: VS Code
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [VS Code](https://code.visualstudio.com/download).
|
||||
Install [VSCode](https://code.visualstudio.com/download).
|
||||
|
||||
## Usage with Ollama
|
||||
## Usage with Ollama
|
||||
|
||||
1. Open Copilot side bar found in top right window
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-sidebar.png"
|
||||
alt="VS Code chat Sidebar"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. Select the model dropdown > **Manage models**
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-models.png"
|
||||
alt="VS Code model picker"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-sidebar.png"
|
||||
alt="VSCode chat Sidebar"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. Select the model drowpdown > **Manage models**
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-models.png"
|
||||
alt="VSCode model picker"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
||||
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||
<img
|
||||
src="/images/vscode-model-options.png"
|
||||
alt="VS Code model options dropdown"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/vscode-model-options.png"
|
||||
alt="VSCode model options dropdown"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -41,7 +41,6 @@ INSTRUCTION arguments
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -150,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
@@ -249,16 +251,6 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
### REQUIRES
|
||||
|
||||
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||
|
||||
```
|
||||
REQUIRES <version>
|
||||
```
|
||||
|
||||
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
@@ -111,12 +111,6 @@ components:
|
||||
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
||||
options:
|
||||
$ref: "#/components/schemas/ModelOptions"
|
||||
logprobs:
|
||||
type: boolean
|
||||
description: Whether to return log probabilities of the output tokens
|
||||
top_logprobs:
|
||||
type: integer
|
||||
description: Number of most likely tokens to return at each token position when logprobs are enabled
|
||||
GenerateResponse:
|
||||
type: object
|
||||
properties:
|
||||
@@ -156,11 +150,6 @@ components:
|
||||
eval_duration:
|
||||
type: integer
|
||||
description: Time spent generating tokens in nanoseconds
|
||||
logprobs:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Logprob"
|
||||
description: Log probability information for the generated tokens when logprobs are enabled
|
||||
GenerateStreamEvent:
|
||||
type: object
|
||||
properties:
|
||||
@@ -298,12 +287,6 @@ components:
|
||||
- type: string
|
||||
- type: number
|
||||
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
||||
logprobs:
|
||||
type: boolean
|
||||
description: Whether to return log probabilities of the output tokens
|
||||
top_logprobs:
|
||||
type: integer
|
||||
description: Number of most likely tokens to return at each token position when logprobs are enabled
|
||||
ChatResponse:
|
||||
type: object
|
||||
properties:
|
||||
@@ -361,11 +344,6 @@ components:
|
||||
eval_duration:
|
||||
type: integer
|
||||
description: Time spent generating tokens in nanoseconds
|
||||
logprobs:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Logprob"
|
||||
description: Log probability information for the generated tokens when logprobs are enabled
|
||||
ChatStreamEvent:
|
||||
type: object
|
||||
properties:
|
||||
@@ -728,41 +706,6 @@ components:
|
||||
version:
|
||||
type: string
|
||||
description: Version of Ollama
|
||||
TokenLogprob:
|
||||
type: object
|
||||
description: Log probability information for a single token alternative
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: The text representation of the token
|
||||
logprob:
|
||||
type: number
|
||||
description: The log probability of this token
|
||||
bytes:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
description: The raw byte representation of the token
|
||||
Logprob:
|
||||
type: object
|
||||
description: Log probability information for a generated token
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: The text representation of the token
|
||||
logprob:
|
||||
type: number
|
||||
description: The log probability of this token
|
||||
bytes:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
description: The raw byte representation of the token
|
||||
top_logprobs:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/TokenLogprob"
|
||||
description: Most likely tokens and their log probabilities at this position
|
||||
ErrorResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
# extract-examples
|
||||
|
||||
Extracts code examples from MDX files to a temp directory so you can run them.
|
||||
|
||||
## Usage
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go <mdx-file>
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
```shell
|
||||
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
- 01_basic.py
|
||||
- 01_basic.js
|
||||
- 01_basic.sh
|
||||
- 02_responses.py
|
||||
- 02_responses.js
|
||||
- 02_responses.sh
|
||||
- 03_vision.py
|
||||
- 03_vision.js
|
||||
- 03_vision.sh
|
||||
|
||||
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
|
||||
To run examples:
|
||||
|
||||
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||
npm install # for JS examples
|
||||
|
||||
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
|
||||
```
|
||||
|
||||
## How it works
|
||||
|
||||
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
|
||||
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
|
||||
- Writes all extracted files to a temp directory
|
||||
@@ -1,137 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
mdxFile := os.Args[1]
|
||||
|
||||
f, err := os.Open(mdxFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create temp directory
|
||||
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
|
||||
|
||||
// Patterns
|
||||
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
|
||||
codeGroupStart := regexp.MustCompile("^<CodeGroup")
|
||||
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
inCodeBlock := false
|
||||
inCodeGroup := false
|
||||
var currentFile string
|
||||
var content strings.Builder
|
||||
count := 0
|
||||
codeGroupNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Track CodeGroup boundaries
|
||||
if codeGroupStart.MatchString(line) {
|
||||
inCodeGroup = true
|
||||
codeGroupNum++
|
||||
continue
|
||||
}
|
||||
if codeGroupEnd.MatchString(line) {
|
||||
inCodeGroup = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inCodeBlock {
|
||||
if line == "```" {
|
||||
// End of code block - write file
|
||||
if currentFile != "" {
|
||||
outPath := filepath.Join(tempDir, currentFile)
|
||||
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
|
||||
} else {
|
||||
fmt.Printf(" - %s\n", currentFile)
|
||||
count++
|
||||
}
|
||||
}
|
||||
inCodeBlock = false
|
||||
currentFile = ""
|
||||
content.Reset()
|
||||
} else {
|
||||
content.WriteString(line)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
|
||||
inCodeBlock = true
|
||||
filename := matches[2]
|
||||
// Prefix with CodeGroup number if inside a CodeGroup
|
||||
if inCodeGroup {
|
||||
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
|
||||
} else {
|
||||
currentFile = filename
|
||||
}
|
||||
content.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Write package.json for JavaScript dependencies
|
||||
packageJSON := `{
|
||||
"name": "mdx-examples",
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"openai": "^4",
|
||||
"ollama": "^0.5"
|
||||
}
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
|
||||
}
|
||||
|
||||
// Write pyproject.toml for Python dependencies
|
||||
pyprojectTOML := `[project]
|
||||
name = "mdx-examples"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"openai",
|
||||
"ollama",
|
||||
]
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("To run examples:\n")
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
|
||||
}
|
||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
||||
|
||||
### Linux NVIDIA Troubleshooting
|
||||
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||
|
||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
@@ -241,17 +240,12 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"deepseek2",
|
||||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"nomic-bert",
|
||||
"olmo3",
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
@@ -553,7 +547,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
@@ -794,7 +788,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
@@ -812,14 +806,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
@@ -840,11 +826,8 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
"gemma3",
|
||||
"gptoss", "gpt-oss",
|
||||
"mistral3",
|
||||
"olmo3",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
}, f.KV().String("general.architecture"))
|
||||
|
||||
@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
|
||||
|
||||
a.values[i] = e
|
||||
} else {
|
||||
_ = discardGGUFString(llm, r)
|
||||
discardGGUFString(llm, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,6 +568,7 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||
for _, t := range ts {
|
||||
t := t
|
||||
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||
g.Go(func() error {
|
||||
_, err := t.WriteTo(w)
|
||||
@@ -597,10 +598,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
|
||||
var err error
|
||||
switch v := v.(type) {
|
||||
case int32:
|
||||
err = writeGGUF(ws, ggufTypeInt32, v)
|
||||
case int64:
|
||||
err = writeGGUF(ws, ggufTypeInt64, v)
|
||||
case uint32, FileType:
|
||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||
case uint64:
|
||||
@@ -615,10 +612,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||
case *array[int32]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||
case []int64:
|
||||
err = writeGGUFArray(ws, ggufTypeInt64, v)
|
||||
case *array[int64]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
|
||||
case []uint32:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||
case *array[uint32]:
|
||||
|
||||
@@ -42,10 +42,6 @@ func TestWriteGGUF(t *testing.T) {
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"test.key": "value",
|
||||
"test.int32_key": int32(-42),
|
||||
"test.int64_key": int64(-9223372036854775808),
|
||||
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
|
||||
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
|
||||
"attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
@@ -59,7 +55,7 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, -1)
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -69,19 +65,15 @@ func TestWriteGGUF(t *testing.T) {
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
"test.key": "value",
|
||||
"test.int32_key": int32(-42),
|
||||
"test.int64_key": int64(-9223372036854775808),
|
||||
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
|
||||
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
|
||||
"test.attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 992,
|
||||
Offset: 800,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
|
||||
16
go.mod
16
go.mod
@@ -15,8 +15,9 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.36.0
|
||||
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -29,8 +30,7 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -77,11 +77,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
30
go.sum
30
go.sum
@@ -224,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -255,8 +255,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -269,8 +267,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -280,8 +278,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -297,17 +295,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -321,8 +319,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the handler with tools, optional last message, and think value
|
||||
// Init initializes the handler with tools and optional last message
|
||||
// Implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
// Initialize the harmony parser
|
||||
if h.HarmonyParser == nil {
|
||||
h.HarmonyParser = &HarmonyParser{
|
||||
|
||||
@@ -14,23 +14,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
|
||||
t.Helper()
|
||||
|
||||
raw := []byte(token)
|
||||
if len(ints) != len(raw) {
|
||||
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
|
||||
return
|
||||
}
|
||||
|
||||
for i, b := range raw {
|
||||
if ints[i] != int(b) {
|
||||
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGenerate(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
@@ -483,7 +466,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
|
||||
// Check top_logprobs if requested
|
||||
if test.topLogprobs > 0 {
|
||||
@@ -500,9 +482,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
|
||||
}
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
} else if len(lp.TopLogprobs) > 0 {
|
||||
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
|
||||
}
|
||||
@@ -565,15 +544,11 @@ func TestAPIChatLogprobs(t *testing.T) {
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > 3 {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -206,8 +204,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 8 {
|
||||
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 6 {
|
||||
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,8 +251,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount != 16 {
|
||||
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||
if res.PromptEvalCount != 12 {
|
||||
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,7 +275,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*testing.T, *api.EmbedResponse, error)
|
||||
check func(*api.EmbedResponse, error)
|
||||
}{
|
||||
{
|
||||
name: "target truncation",
|
||||
@@ -285,7 +283,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -302,11 +300,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -320,11 +317,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -338,21 +334,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "the input length exceeds the context length" {
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error with context length of 1",
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
@@ -366,7 +362,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
@@ -379,7 +375,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||
Options: map[string]any{"num_ctx": 16},
|
||||
},
|
||||
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -389,8 +385,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
resp, err := embedTestHelper(ctx, client, t, req.request)
|
||||
req.check(t, resp, err)
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -414,230 +409,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
||||
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
func TestEmbedTruncation(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
// Give each model its own budget to account for first-time pulls/loads
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
t.Run("truncation batch", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 30},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 90 {
|
||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{"test", "test", "test"},
|
||||
}
|
||||
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCount := baseRes.PromptEvalCount * 3
|
||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||
func TestEmbedLargeInput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Test with progressively larger inputs
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputWords int
|
||||
}{
|
||||
{"medium_input_256_words", 256},
|
||||
{"large_input_512_words", 512},
|
||||
{"very_large_input_800_words", 800},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
words := make([]string, tc.inputWords)
|
||||
for i := range words {
|
||||
words[i] = "word"
|
||||
}
|
||||
input := strings.Join(words, " ")
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: input,
|
||||
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(mctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) == 0 {
|
||||
t.Fatal("expected non-empty embedding")
|
||||
}
|
||||
|
||||
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
// where api.StatusError errors should maintain their original status code.
|
||||
func TestEmbedStatusCode(t *testing.T) {
|
||||
// Use test deadline if set, otherwise default to 2 minutes
|
||||
timeout := 2 * time.Minute
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Check if we're running out of time (reserve 20s for current model)
|
||||
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||
t.Skip("skipping remaining tests to avoid timeout")
|
||||
}
|
||||
|
||||
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
// Pull the model if needed
|
||||
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: longInput,
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||
// not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message is meaningful
|
||||
if !strings.Contains(err.Error(), "context length") {
|
||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: []string{
|
||||
"short input",
|
||||
strings.Repeat("very long input ", 100),
|
||||
"another short input",
|
||||
},
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(mctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
|
||||
// Qwen 3 VL mixture of experts
|
||||
model: "qwen3-vl:30b",
|
||||
},
|
||||
{
|
||||
model: "ministral-3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range testCases {
|
||||
|
||||
@@ -30,7 +30,6 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
"mistral": 6,
|
||||
"qwen2.5": 6,
|
||||
"qwen2": 6,
|
||||
"ministral-3": 20,
|
||||
"mistral-nemo": 9,
|
||||
"mistral-small": 16,
|
||||
"mixtral:8x22b": 80,
|
||||
|
||||
@@ -38,7 +38,6 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
@@ -168,7 +167,6 @@ var (
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"ministral-3",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
@@ -272,7 +270,6 @@ var (
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"ministral-3",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
|
||||
@@ -3,6 +3,7 @@ package kvcache
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
@@ -39,18 +40,18 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// starting location for data storage for this batch
|
||||
curLoc int
|
||||
|
||||
// size of the current batch
|
||||
curBatchSize int
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLoc ml.Tensor
|
||||
|
||||
// mask of the cache as used by this batch
|
||||
curMask ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// locations in the cache that are needed for this batch
|
||||
curCellRange cellRange
|
||||
|
||||
@@ -140,6 +141,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
@@ -201,47 +206,45 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var locs []int32
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
locs, err = c.findLocs()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
loc := int(locs[i])
|
||||
|
||||
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
seqRange.min = min(seqRange.min, loc)
|
||||
c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||
|
||||
seqRange.max = max(seqRange.max, loc)
|
||||
c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
locs = make([]int32, c.curBatchSize)
|
||||
for i := range locs {
|
||||
locs[i] = int32(i)
|
||||
}
|
||||
c.curLoc = 0
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
c.curMask = c.buildMask(ctx)
|
||||
|
||||
return nil
|
||||
@@ -254,20 +257,22 @@ func newRange() cellRange {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a slice of locations where each token in the batch should be stored
|
||||
func (c *Causal) findLocs() ([]int32, error) {
|
||||
loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
// Find the first contiguous block of at least curBatchSize
|
||||
func (c *Causal) findStartLoc() (int, error) {
|
||||
var start, count int
|
||||
for i := range c.cells {
|
||||
if len(c.cells[i].sequences) == 0 {
|
||||
loc = append(loc, int32(i))
|
||||
if len(loc) >= c.curBatchSize {
|
||||
return loc, nil
|
||||
count++
|
||||
if count >= c.curBatchSize {
|
||||
return start, nil
|
||||
}
|
||||
} else {
|
||||
start = i + 1
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
@@ -360,12 +365,15 @@ func roundUp(length, pad int) int {
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
mask := make([]float32, c.curBatchSize*length)
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
enabled := !slices.Contains(c.opts.Except, i)
|
||||
@@ -379,7 +387,13 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
@@ -388,6 +402,145 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
return maskTensor
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) defrag() {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
|
||||
// Defrag strategy:
|
||||
// - Search for empty holes at the beginning of the cache,
|
||||
// filling them with active data starting at the end
|
||||
// - If there are contiguous elements that need to be moved,
|
||||
// combine them into a single operation by holding new moves
|
||||
// until we see that the next one is non-contiguous
|
||||
// - Fill up the context with the maximum number of operations it
|
||||
// can hold then compute that and continue with a new context
|
||||
//
|
||||
// We could try to optimize placement by grouping blocks from
|
||||
// the same sequences together but most likely the next forward
|
||||
// pass will disrupt this anyways, so the real world benefit
|
||||
// seems limited as this time.
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v). We also need to refer to the original
|
||||
// k and v cache tensors - once per layer, not per move.
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
src := len(c.cells) - 1
|
||||
|
||||
for dst := 0; dst < src; dst++ {
|
||||
if len(c.cells[dst].sequences) == 0 {
|
||||
for ; src > dst; src-- {
|
||||
if len(c.cells[src].sequences) != 0 {
|
||||
c.cells[dst] = c.cells[src]
|
||||
c.cells[src] = cacheCell{}
|
||||
|
||||
if pendingLen > 0 {
|
||||
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||
pendingSrc = src
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
|
||||
pendingSrc = src
|
||||
pendingDst = dst
|
||||
pendingLen = 1
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if moves >= maxMoves {
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
ctx = c.backend.NewContext()
|
||||
|
||||
moves = 0
|
||||
}
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
if moves > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
// Reset range metadata
|
||||
for seq := range c.cellRanges {
|
||||
seqRange := newRange()
|
||||
|
||||
for i, cell := range c.cells {
|
||||
if slices.Contains(cell.sequences, seq) {
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.updateSlidingWindow()
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
@@ -472,25 +625,18 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||
keyCache := c.keys[c.curLayer]
|
||||
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "ec98e2002";
|
||||
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
@@ -17,17 +17,11 @@ include /tools/mtmd/clip.cpp
|
||||
include /tools/mtmd/mtmd.cpp
|
||||
include /tools/mtmd/mtmd-audio.cpp
|
||||
include /tools/mtmd/mtmd-helper.cpp
|
||||
include /tools/mtmd/models/
|
||||
include /tools/mtmd/models/*.h
|
||||
include /tools/mtmd/models/*.cpp
|
||||
include /src/
|
||||
include /src/llama.*
|
||||
include /src/llama-*.*
|
||||
include /src/unicode-data.*
|
||||
include /src/unicode.*
|
||||
include /src/models/
|
||||
include /src/models/*.h
|
||||
include /src/models/*.cpp
|
||||
include /vendor/
|
||||
include /vendor/miniaudio/
|
||||
include /vendor/miniaudio/*.h
|
||||
|
||||
359
llama/llama.cpp/common/common.cpp
vendored
359
llama/llama.cpp/common/common.cpp
vendored
@@ -8,7 +8,6 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@@ -27,6 +26,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@@ -60,14 +60,6 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
common_time_meas::~common_time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
@@ -363,7 +355,11 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
@@ -694,7 +690,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
|
||||
// Validate if a filename is safe to use
|
||||
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
bool fs_validate_filename(const std::string & filename) {
|
||||
if (!filename.length()) {
|
||||
// Empty filename invalid
|
||||
return false;
|
||||
@@ -754,14 +750,10 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
||||
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||
return false;
|
||||
}
|
||||
if (!allow_subdirs && (c == '/' || c == '\\')) {
|
||||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
@@ -786,29 +778,11 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
#include <iostream>
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
@@ -881,11 +855,6 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
bool fs_is_directory(const std::string & path) {
|
||||
std::filesystem::path dir(path);
|
||||
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
|
||||
}
|
||||
|
||||
std::string fs_get_cache_directory() {
|
||||
std::string cache_directory = "";
|
||||
auto ensure_trailing_slash = [](std::string p) {
|
||||
@@ -920,8 +889,6 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
GGML_ABORT("not implemented on this platform");
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
@@ -941,258 +908,34 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
|
||||
std::vector<common_file_info> files;
|
||||
if (path.empty()) return files;
|
||||
|
||||
std::filesystem::path dir(path);
|
||||
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||
return files;
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||
try {
|
||||
// Only include regular files (skip directories)
|
||||
const auto & p = entry.path();
|
||||
if (std::filesystem::is_regular_file(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.is_dir = false;
|
||||
try {
|
||||
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
info.size = 0;
|
||||
}
|
||||
files.push_back(std::move(info));
|
||||
} else if (include_directories && std::filesystem::is_directory(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
info.size = 0; // Directories have no size
|
||||
info.is_dir = true;
|
||||
files.push_back(std::move(info));
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
// skip entries we cannot inspect
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
bool tty_can_use_colors() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
// TODO: move to common/sampling
|
||||
static void common_init_sampler_from_model(
|
||||
const llama_model * model,
|
||||
common_params_sampling & sparams) {
|
||||
|
||||
const uint64_t config = sparams.user_sampling_config;
|
||||
|
||||
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[64] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
int32_t v = strtol(buf, &end, 10);
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[128] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
float v = strtof(buf, &end);
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Sampling sequence
|
||||
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||
char buf[512] = {0};
|
||||
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||
if (!sampler_names.empty()) {
|
||||
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||
}
|
||||
|
||||
struct common_init_result::impl {
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
pimpl(new impl{}) {
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
return;
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
}
|
||||
|
||||
pimpl->model.reset(model);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// updates params.sampling
|
||||
// TODO: fix naming
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
//if (params.sampling.penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
}
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return;
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
pimpl->context.reset(lctx);
|
||||
}
|
||||
|
||||
llama_model * common_init_result::model() {
|
||||
return pimpl->model.get();
|
||||
}
|
||||
|
||||
llama_context * common_init_result::context() {
|
||||
return pimpl->context.get();
|
||||
}
|
||||
|
||||
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
@@ -1204,7 +947,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
|
||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||
if (cvec.n_embd == -1) {
|
||||
return res;
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
@@ -1215,7 +961,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
params.control_vector_layer_start,
|
||||
params.control_vector_layer_end);
|
||||
if (err) {
|
||||
return res;
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1239,7 +988,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
return res;
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1249,7 +1001,9 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
return res;
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
@@ -1258,13 +1012,43 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
if (!params.lora_init_without_apply) {
|
||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
@@ -1303,10 +1087,11 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
llama_set_warmup(lctx, false);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
iparams.model.reset(model);
|
||||
iparams.context.reset(lctx);
|
||||
|
||||
common_init_result::~common_init_result() = default;
|
||||
return iparams;
|
||||
}
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
@@ -1316,9 +1101,7 @@ std::string get_model_endpoint() {
|
||||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') {
|
||||
model_endpoint += '/';
|
||||
}
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
|
||||
123
llama/llama.cpp/common/common.h
vendored
123
llama/llama.cpp/common/common.h
vendored
@@ -2,19 +2,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <cmath>
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
@@ -30,14 +28,7 @@
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||
} while(0)
|
||||
|
||||
struct common_time_meas {
|
||||
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||
~common_time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
|
||||
struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
@@ -82,8 +73,7 @@ int32_t cpu_get_num_math();
|
||||
enum llama_example {
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_COMPLETION,
|
||||
LLAMA_EXAMPLE_CLI,
|
||||
LLAMA_EXAMPLE_MAIN,
|
||||
LLAMA_EXAMPLE_EMBEDDING,
|
||||
LLAMA_EXAMPLE_PERPLEXITY,
|
||||
LLAMA_EXAMPLE_RETRIEVAL,
|
||||
@@ -99,7 +89,6 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
LLAMA_EXAMPLE_FINETUNE,
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -144,22 +133,6 @@ struct common_grammar_trigger {
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
};
|
||||
|
||||
enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
@@ -192,10 +165,9 @@ struct common_params_sampling {
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
@@ -216,10 +188,6 @@ struct common_params_sampling {
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
@@ -230,7 +198,6 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
@@ -307,8 +274,8 @@ struct lr_opt {
|
||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||
|
||||
struct common_params {
|
||||
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 4096; // context size
|
||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
@@ -329,12 +296,9 @@ struct common_params {
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
@@ -380,7 +344,7 @@ struct common_params {
|
||||
|
||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||
|
||||
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||
int32_t verbosity = 0;
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
@@ -414,7 +378,6 @@ struct common_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool show_timings = true; // show timing information on CLI
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
@@ -443,8 +406,6 @@ struct common_params {
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
|
||||
// finetune
|
||||
struct lr_opt lr;
|
||||
@@ -471,7 +432,7 @@ struct common_params {
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = true; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
@@ -490,22 +451,14 @@ struct common_params {
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
|
||||
bool log_json = false;
|
||||
|
||||
std::string slot_save_path;
|
||||
std::string media_path; // path to directory for loading media files
|
||||
|
||||
float slot_prompt_similarity = 0.1f;
|
||||
|
||||
// batched-bench params
|
||||
bool is_pp_shared = false;
|
||||
bool is_tg_separate = false;
|
||||
bool is_pp_shared = false;
|
||||
|
||||
std::vector<int32_t> n_pp;
|
||||
std::vector<int32_t> n_tg;
|
||||
@@ -552,10 +505,6 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -650,55 +599,25 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||
bool fs_validate_filename(const std::string & filename);
|
||||
bool fs_create_directory_with_parents(const std::string & path);
|
||||
bool fs_is_directory(const std::string & path);
|
||||
|
||||
std::string fs_get_cache_directory();
|
||||
std::string fs_get_cache_file(const std::string & filename);
|
||||
|
||||
struct common_file_info {
|
||||
std::string path;
|
||||
std::string name;
|
||||
size_t size = 0; // in bytes
|
||||
bool is_dir = false;
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
// Auto-detect if colors can be enabled based on terminal and environment
|
||||
bool tty_can_use_colors();
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
struct common_sampler;
|
||||
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
// note: defines object's lifetime
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
~common_init_result();
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
};
|
||||
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
struct common_init_result common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user