Compare commits

..

2 Commits

Author SHA1 Message Date
Daniel Hiltgen
7ccb65cd5e ci: update release flow 2025-10-27 14:46:29 -07:00
Daniel Hiltgen
50f00b7f60 update app - as of 5fba33e 2025-10-23 11:27:41 -07:00
1059 changed files with 57733 additions and 151607 deletions

4
.gitattributes vendored
View File

@@ -15,12 +15,8 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -16,15 +16,13 @@ jobs:
outputs:
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
VERSION: ${{ steps.goflags.outputs.VERSION }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
- name: Set environment
id: goflags
run: |
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
darwin-build:
runs-on: macos-14-xlarge
@@ -55,9 +53,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- run: |
./scripts/build_darwin.sh
- name: Log build results
@@ -68,7 +63,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -110,13 +104,6 @@ jobs:
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: 'rocm'
- os: windows
arch: amd64
preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -126,14 +113,13 @@ jobs:
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
name: Install CUDA ${{ matrix.cuda-version }}
@@ -163,18 +149,6 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -185,20 +159,19 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
@@ -255,9 +228,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- name: Verify gcc is actually clang
run: |
$ErrorActionPreference='Continue'
@@ -294,26 +264,23 @@ jobs:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
- uses: actions/checkout@v4
- uses: google-github-actions/auth@v2
with:
project_id: ollama
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
- run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
# - uses: google-github-actions/auth@v2
# with:
# project_id: ollama
# credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
# - run: |
# $ErrorActionPreference = "Stop"
# Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
# Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
# Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
# Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
# & "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
# echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/download-artifact@v4
with:
pattern: depends-windows*
@@ -345,13 +312,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive
target: archive_novulkan
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive
target: archive_novulkan
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -372,17 +339,12 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -397,13 +359,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -412,12 +374,14 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -430,6 +394,14 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -447,6 +419,7 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.preset }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
@@ -536,7 +509,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -22,7 +22,6 @@ jobs:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.changes.outputs.changed }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
with:
@@ -38,7 +37,6 @@ jobs:
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
needs: [changes]
@@ -85,7 +83,7 @@ jobs:
- uses: actions/cache@v4
with:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel
@@ -174,13 +172,12 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
@@ -203,37 +200,82 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- name: checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: cache restore
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
go-version-file: 'go.mod'
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/setup-node@v4
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
restore-keys: |
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
- name: Setup Go
uses: actions/setup-go@v5
with:
node-version: '20'
- name: Install UI dependencies
working-directory: ./app/ui/app
run: npm ci
- name: Install tscriptify
# The caching strategy of setup-go is less than ideal, and wastes
# time by not saving artifacts due to small failures like the linter
# complaining, etc. This means subsequent have to rebuild their world
# again until all checks pass. For instance, if you mispell a word,
# you're punished until you fix it. This is more hostile than
# helpful.
cache: false
go-version-file: go.mod
# It is tempting to run this in a platform independent way, but the past
# shows this codebase will see introductions of platform specific code
# generation, and so we need to check this per platform to ensure we
# don't abuse go generate on specific platforms.
- name: check that 'go generate' is clean
if: always()
run: |
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
- name: Run UI tests
if: ${{ startsWith(matrix.os, 'ubuntu') }}
working-directory: ./app/ui/app
run: npm test
- name: Run go generate
run: go generate ./...
go generate ./...
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
- name: go test
if: always()
run: go test -count=1 -benchtime=1x ./...
- uses: golangci/golangci-lint-action@v9
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
with:
only-new-issues: true
args: --timeout 10m0s -v
- name: cache save
# Always save the cache, even if the job fails. The artifacts produced
# during the building of test binaries are not all for naught. They can
# be used to speed up subsequent runs.
if: always()
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
patches:
runs-on: ubuntu-latest
@@ -242,4 +284,4 @@ jobs:
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
git diff --compact-summary --exit-code

View File

@@ -1,4 +1,5 @@
version: "2"
run:
timeout: 5m
linters:
enable:
- asasalint
@@ -6,46 +7,35 @@ linters:
- bodyclose
- containedctx
- gocheckcompilerdirectives
- gofmt
- gofumpt
- gosimple
- govet
- ineffassign
- intrange
- makezero
- misspell
- nilerr
- nolintlint
- nosprintfhostport
- staticcheck
- unconvert
- usetesting
- wastedassign
- whitespace
disable:
- errcheck
- usestdlibvars
settings:
govet:
disable:
- unusedresult
staticcheck:
checks:
- all
- -QF* # disable quick fix suggestions
- -SA1019
- -ST1000 # package comment format
- -ST1003 # underscores in package names
- -ST1005 # error strings should not be capitalized
- -ST1012 # error var naming (ErrFoo)
- -ST1016 # receiver name consistency
- -ST1020 # comment on exported function format
- -ST1021 # comment on exported type format
- -ST1022 # comment on exported var format
- -ST1023 # omit type from declaration
- errcheck
linters-settings:
staticcheck:
checks:
- all
- -SA1019 # omit Deprecated check
severity:
default: error
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
severity: info
formatters:
enable:
- gofmt
- gofumpt

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -48,10 +32,9 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if(APPLE)
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -71,13 +54,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
# Define GGML version variables for shared library SOVERSION
# These are required by ggml/src/CMakeLists.txt for proper library versioning
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 0)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -164,56 +140,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
### Issues that may not be accepted
@@ -43,7 +43,7 @@ Tips for proposals:
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the changes were accepted.
see if the change were accepted.
## Pull requests
@@ -66,6 +66,7 @@ Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad
docs: simplify manual installation with shorter curl commands
Bad Examples:

View File

@@ -32,21 +32,21 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache git \
&& dnf install -y clang ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
@@ -57,8 +57,6 @@ ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
@@ -69,8 +67,6 @@ ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
@@ -82,8 +78,6 @@ ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
@@ -93,8 +87,6 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
@@ -126,37 +118,11 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -165,23 +131,18 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -198,9 +159,34 @@ ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@@ -1 +0,0 @@
v0.4.1

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=ec98e2002
FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9
.PHONY: help
help:
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
$(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR)
.PHONY: format-patches
.PHONE: format-patches
format-patches: llama/patches
git -C $(WORKDIR) format-patch \
--no-signature \
@@ -66,11 +66,7 @@ format-patches: llama/patches
-o $(realpath $<) \
$(FETCH_HEAD)
.PHONY: clean
.PHONE: clean
clean: checkout
@git -C $(WORKDIR) am --abort || true
$(RM) llama/patches/.*.patched
.PHONY: print-base
print-base:
@echo $(FETCH_HEAD)

View File

@@ -22,7 +22,7 @@ Get up and running with large language models.
curl -fsSL https://ollama.com/install.sh | sh
```
[Manual install instructions](https://docs.ollama.com/linux#manual-install)
[Manual install instructions](https://github.com/ollama/ollama/blob/main/docs/linux.md)
### Docker
@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -110,7 +110,7 @@ Ollama supports importing GGUF models in the Modelfile:
### Import from Safetensors
See the [guide](https://docs.ollama.com/import) on importing models for more information.
See the [guide](docs/import.md) on importing models for more information.
### Customize a prompt
@@ -143,7 +143,7 @@ ollama run mario
Hello! It's your friend Mario.
```
For more information on working with a Modelfile, see the [Modelfile](https://docs.ollama.com/modelfile) documentation.
For more information on working with a Modelfile, see the [Modelfile](docs/modelfile.md) documentation.
## CLI Reference
@@ -226,18 +226,6 @@ ollama ps
ollama stop llama3.2
```
### Generate embeddings from the CLI
```shell
ollama run embeddinggemma "Your text to embed"
```
You can also pipe text for scripted workflows:
```shell
echo "Your text to embed" | ollama run embeddinggemma
```
### Start Ollama
`ollama serve` is used when you want to start ollama without running the desktop application.
@@ -260,38 +248,6 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
```shell
go build -tags mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -322,17 +278,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms WebUI (Single user)](https://github.com/ParisNeo/lollms-webui)
- [Lollms (Multi users)](https://github.com/ParisNeo/lollms)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [AI-UI](https://github.com/bajahaw/ai-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -399,8 +352,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -432,7 +384,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
@@ -454,17 +406,15 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
### Cloud
@@ -472,10 +422,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/)
- [Koyeb](https://www.koyeb.com/deploy/ollama)
### Tutorial
- [handy-ollama](https://github.com/datawhalechina/handy-ollama) (Chinese Tutorial for Ollama by [Datawhale ](https://github.com/datawhalechina) - China's Largest Open Source AI Learning Community)
### Terminal
- [oterm](https://github.com/ggozad/oterm)
@@ -515,8 +461,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
- [hle-eval-ollama](https://github.com/mags0ft/hle-eval-ollama) - Runs benchmarks like "Humanity's Last Exam" (HLE) on your favorite local Ollama models and evaluates the quality of their responses
- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter. Ollama integration for running local/cloud models with configurable endpoints.
### Apple Vision Pro
@@ -526,7 +470,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -547,7 +491,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
- [crewAI](https://github.com/crewAIInc/crewAI)
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
- [Strands Agents](https://github.com/strands-agents/sdk-python) (A model-driven approach to building AI agents in just a few lines of code)
- [Spring AI](https://github.com/spring-projects/spring-ai) with [reference](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html) and [example](https://github.com/tzolov/ollama-tools)
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
@@ -558,7 +501,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
@@ -588,7 +531,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
@@ -602,7 +545,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
- [Ollama Bash Lib](https://github.com/attogram/ollama-bash-lib) - A Bash Library for Ollama. Run LLM prompts straight from your shell, and more
### Mobile
@@ -648,10 +590,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
@@ -659,7 +602,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
@@ -669,14 +612,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -14,7 +14,7 @@ Please include the following details in your report:
## Security best practices
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
- Regularly updating to the latest version of Ollama
- Securing access to hosted instances of Ollama

View File

@@ -1,778 +0,0 @@
package anthropic
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// Error types matching Anthropic API
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ErrorResponse struct {
Type string `json:"type"` // always "error"
Error Error `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
// NewError creates a new ErrorResponse with the appropriate error type based on HTTP status code
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusUnauthorized:
etype = "authentication_error"
case http.StatusForbidden:
etype = "permission_error"
case http.StatusNotFound:
etype = "not_found_error"
case http.StatusTooManyRequests:
etype = "rate_limit_error"
case http.StatusServiceUnavailable, 529:
etype = "overloaded_error"
default:
etype = "api_error"
}
return ErrorResponse{
Type: "error",
Error: Error{Type: etype, Message: message},
RequestID: generateID("req"),
}
}
// Request types
// MessagesRequest represents an Anthropic Messages API request
type MessagesRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
Messages []MessageParam `json:"messages"`
System any `json:"system,omitempty"` // string or []ContentBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// MessageParam represents a message in the request
type MessageParam struct {
Role string `json:"role"` // "user" or "assistant"
Content any `json:"content"` // string or []ContentBlock
}
// ContentBlock represents a content block in a message.
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
// only when set, which is required for SDK streaming accumulation.
type ContentBlock struct {
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
Text *string `json:"text,omitempty"`
// For image blocks
Source *ImageSource `json:"source,omitempty"`
// For tool_use blocks
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// For tool_result blocks
ToolUseID string `json:"tool_use_id,omitempty"`
Content any `json:"content,omitempty"` // string or []ContentBlock
IsError bool `json:"is_error,omitempty"`
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
Thinking *string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ImageSource represents the source of an image
type ImageSource struct {
Type string `json:"type"` // "base64" or "url"
MediaType string `json:"media_type,omitempty"`
Data string `json:"data,omitempty"`
URL string `json:"url,omitempty"`
}
// Tool represents a tool definition
type Tool struct {
Type string `json:"type,omitempty"` // "custom" for user-defined tools
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema json.RawMessage `json:"input_schema,omitempty"`
}
// ToolChoice controls how the model uses tools
type ToolChoice struct {
Type string `json:"type"` // "auto", "any", "tool", "none"
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
// ThinkingConfig controls extended thinking
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"`
}
// Metadata for the request
type Metadata struct {
UserID string `json:"user_id,omitempty"`
}
// Response types
// MessagesResponse represents an Anthropic Messages API response
type MessagesResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
Usage Usage `json:"usage"`
}
// Usage contains token usage information
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// Streaming event types
// MessageStartEvent is sent at the start of streaming
type MessageStartEvent struct {
Type string `json:"type"` // "message_start"
Message MessagesResponse `json:"message"`
}
// ContentBlockStartEvent signals the start of a content block
type ContentBlockStartEvent struct {
Type string `json:"type"` // "content_block_start"
Index int `json:"index"`
ContentBlock ContentBlock `json:"content_block"`
}
// ContentBlockDeltaEvent contains incremental content updates
type ContentBlockDeltaEvent struct {
Type string `json:"type"` // "content_block_delta"
Index int `json:"index"`
Delta Delta `json:"delta"`
}
// Delta represents an incremental update
type Delta struct {
Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta"
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
}
// ContentBlockStopEvent signals the end of a content block
type ContentBlockStopEvent struct {
Type string `json:"type"` // "content_block_stop"
Index int `json:"index"`
}
// MessageDeltaEvent contains updates to the message
type MessageDeltaEvent struct {
Type string `json:"type"` // "message_delta"
Delta MessageDelta `json:"delta"`
Usage DeltaUsage `json:"usage"`
}
// MessageDelta contains stop information
type MessageDelta struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence string `json:"stop_sequence,omitempty"`
}
// DeltaUsage contains cumulative token usage
type DeltaUsage struct {
OutputTokens int `json:"output_tokens"`
}
// MessageStopEvent signals the end of the message
type MessageStopEvent struct {
Type string `json:"type"` // "message_stop"
}
// PingEvent is a keepalive event
type PingEvent struct {
Type string `json:"type"` // "ping"
}
// StreamErrorEvent is an error during streaming
type StreamErrorEvent struct {
Type string `json:"type"` // "error"
Error Error `json:"error"`
}
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
var messages []api.Message
if r.System != nil {
switch sys := r.System.(type) {
case string:
if sys != "" {
messages = append(messages, api.Message{Role: "system", Content: sys})
}
case []any:
// System can be an array of content blocks
var content strings.Builder
for _, block := range sys {
if blockMap, ok := block.(map[string]any); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content.WriteString(text)
}
}
}
}
if content.Len() > 0 {
messages = append(messages, api.Message{Role: "system", Content: content.String()})
}
}
}
for _, msg := range r.Messages {
converted, err := convertMessage(msg)
if err != nil {
return nil, err
}
messages = append(messages, converted...)
}
options := make(map[string]any)
options["num_predict"] = r.MaxTokens
if r.Temperature != nil {
options["temperature"] = *r.Temperature
}
if r.TopP != nil {
options["top_p"] = *r.TopP
}
if r.TopK != nil {
options["top_k"] = *r.TopK
}
if len(r.StopSequences) > 0 {
options["stop"] = r.StopSequences
}
var tools api.Tools
for _, t := range r.Tools {
tool, err := convertTool(t)
if err != nil {
return nil, err
}
tools = append(tools, tool)
}
var think *api.ThinkValue
if r.Thinking != nil && r.Thinking.Type == "enabled" {
think = &api.ThinkValue{Value: true}
}
stream := r.Stream
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Options: options,
Stream: &stream,
Tools: tools,
Think: think,
}, nil
}
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
func convertMessage(msg MessageParam) ([]api.Message, error) {
var messages []api.Message
role := strings.ToLower(msg.Role)
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: role, Content: content})
case []any:
var textContent strings.Builder
var images []api.ImageData
var toolCalls []api.ToolCall
var thinking string
var toolResults []api.Message
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
return nil, errors.New("invalid content block format")
}
blockType, _ := blockMap["type"].(string)
switch blockType {
case "text":
if text, ok := blockMap["text"].(string); ok {
textContent.WriteString(text)
}
case "image":
source, ok := blockMap["source"].(map[string]any)
if !ok {
return nil, errors.New("invalid image source")
}
sourceType, _ := source["type"].(string)
if sourceType == "base64" {
data, _ := source["data"].(string)
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("invalid base64 image data: %w", err)
}
images = append(images, decoded)
} else {
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
}
// URL images would need to be fetched - skip for now
case "tool_use":
id, ok := blockMap["id"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'id' field")
}
name, ok := blockMap["name"].(string)
if !ok {
return nil, errors.New("tool_use block missing required 'name' field")
}
tc := api.ToolCall{
ID: id,
Function: api.ToolCallFunction{
Name: name,
},
}
if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = mapToArgs(input)
}
toolCalls = append(toolCalls, tc)
case "tool_result":
toolUseID, _ := blockMap["tool_use_id"].(string)
var resultContent string
switch c := blockMap["content"].(type) {
case string:
resultContent = c
case []any:
for _, cb := range c {
if cbMap, ok := cb.(map[string]any); ok {
if cbMap["type"] == "text" {
if text, ok := cbMap["text"].(string); ok {
resultContent += text
}
}
}
}
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: resultContent,
ToolCallID: toolUseID,
})
case "thinking":
if t, ok := blockMap["thinking"].(string); ok {
thinking = t
}
}
}
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
m := api.Message{
Role: role,
Content: textContent.String(),
Images: images,
ToolCalls: toolCalls,
Thinking: thinking,
}
messages = append(messages, m)
}
// Add tool results as separate messages
messages = append(messages, toolResults...)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
return messages, nil
}
// convertTool converts an Anthropic Tool to an Ollama api.Tool
func convertTool(t Tool) (api.Tool, error) {
var params api.ToolFunctionParameters
if len(t.InputSchema) > 0 {
if err := json.Unmarshal(t.InputSchema, &params); err != nil {
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
}
}
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Parameters: params,
},
}, nil
}
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
func ToMessagesResponse(id string, r api.ChatResponse) MessagesResponse {
var content []ContentBlock
if r.Message.Thinking != "" {
content = append(content, ContentBlock{
Type: "thinking",
Thinking: ptr(r.Message.Thinking),
})
}
if r.Message.Content != "" {
content = append(content, ContentBlock{
Type: "text",
Text: ptr(r.Message.Content),
})
}
for _, tc := range r.Message.ToolCalls {
content = append(content, ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: tc.Function.Arguments,
})
}
stopReason := mapStopReason(r.DoneReason, len(r.Message.ToolCalls) > 0)
return MessagesResponse{
ID: id,
Type: "message",
Role: "assistant",
Model: r.Model,
Content: content,
StopReason: stopReason,
Usage: Usage{
InputTokens: r.Metrics.PromptEvalCount,
OutputTokens: r.Metrics.EvalCount,
},
}
}
// mapStopReason converts Ollama done_reason to Anthropic stop_reason
func mapStopReason(reason string, hasToolCalls bool) string {
if hasToolCalls {
return "tool_use"
}
switch reason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
default:
if reason != "" {
return "stop_sequence"
}
return ""
}
}
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
type StreamConverter struct {
ID string
Model string
firstWrite bool
contentIndex int
inputTokens int
outputTokens int
thinkingStarted bool
thinkingDone bool
textStarted bool
toolCallsSent map[string]bool
}
func NewStreamConverter(id, model string) *StreamConverter {
return &StreamConverter{
ID: id,
Model: model,
firstWrite: true,
toolCallsSent: make(map[string]bool),
}
}
// StreamEvent represents a streaming event to be sent to the client
type StreamEvent struct {
Event string
Data any
}
// Process converts an Ollama ChatResponse to Anthropic streaming events
func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
var events []StreamEvent
if c.firstWrite {
c.firstWrite = false
c.inputTokens = r.Metrics.PromptEvalCount
events = append(events, StreamEvent{
Event: "message_start",
Data: MessageStartEvent{
Type: "message_start",
Message: MessagesResponse{
ID: c.ID,
Type: "message",
Role: "assistant",
Model: c.Model,
Content: []ContentBlock{},
Usage: Usage{
InputTokens: c.inputTokens,
OutputTokens: 0,
},
},
},
})
}
if r.Message.Thinking != "" && !c.thinkingDone {
if !c.thinkingStarted {
c.thinkingStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "thinking_delta",
Thinking: r.Message.Thinking,
},
},
})
}
if r.Message.Content != "" {
if c.thinkingStarted && !c.thinkingDone {
c.thinkingDone = true
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
}
if !c.textStarted {
c.textStarted = true
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "text",
Text: ptr(""),
},
},
})
}
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "text_delta",
Text: r.Message.Content,
},
},
})
}
for _, tc := range r.Message.ToolCalls {
if c.toolCallsSent[tc.ID] {
continue
}
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.contentIndex++
c.textStarted = false
}
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("failed to marshal tool arguments", "error", err, "tool_id", tc.ID)
continue
}
events = append(events, StreamEvent{
Event: "content_block_start",
Data: ContentBlockStartEvent{
Type: "content_block_start",
Index: c.contentIndex,
ContentBlock: ContentBlock{
Type: "tool_use",
ID: tc.ID,
Name: tc.Function.Name,
Input: map[string]any{},
},
},
})
events = append(events, StreamEvent{
Event: "content_block_delta",
Data: ContentBlockDeltaEvent{
Type: "content_block_delta",
Index: c.contentIndex,
Delta: Delta{
Type: "input_json_delta",
PartialJSON: string(argsJSON),
},
},
})
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
c.toolCallsSent[tc.ID] = true
c.contentIndex++
}
if r.Done {
if c.textStarted {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
} else if c.thinkingStarted && !c.thinkingDone {
events = append(events, StreamEvent{
Event: "content_block_stop",
Data: ContentBlockStopEvent{
Type: "content_block_stop",
Index: c.contentIndex,
},
})
}
c.outputTokens = r.Metrics.EvalCount
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
events = append(events, StreamEvent{
Event: "message_delta",
Data: MessageDeltaEvent{
Type: "message_delta",
Delta: MessageDelta{
StopReason: stopReason,
},
Usage: DeltaUsage{
OutputTokens: c.outputTokens,
},
},
})
events = append(events, StreamEvent{
Event: "message_stop",
Data: MessageStopEvent{
Type: "message_stop",
},
})
}
return events
}
// generateID generates a unique ID with the given prefix using crypto/rand
func generateID(prefix string) string {
b := make([]byte, 12)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based ID if crypto/rand fails
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
}
return fmt.Sprintf("%s_%x", prefix, b)
}
// GenerateMessageID generates a unique message ID
func GenerateMessageID() string {
return generateID("msg")
}
// ptr returns a pointer to the given string value
func ptr(s string) *string {
return &s
}
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@@ -1,953 +0,0 @@
package anthropic
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 {
t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"])
}
}
func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: "You are a helpful assistant.",
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." {
t.Errorf("unexpected system message: %+v", result.Messages[0])
}
}
func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
System: []any{
map[string]any{"type": "text", "text": "You are helpful."},
map[string]any{"type": "text", "text": " Be concise."},
},
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "You are helpful. Be concise." {
t.Errorf("unexpected system message content: %q", result.Messages[0].Content)
}
}
func TestFromMessagesRequest_WithOptions(t *testing.T) {
temp := 0.7
topP := 0.9
topK := 40
req := MessagesRequest{
Model: "test-model",
MaxTokens: 2048,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Temperature: &temp,
TopP: &topP,
TopK: &topK,
StopSequences: []string{"\n", "END"},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"])
}
if result.Options["top_p"] != 0.9 {
t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"])
}
if result.Options["top_k"] != 40 {
t.Errorf("expected top_k 40, got %v", result.Options["top_k"])
}
if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" {
t.Errorf("stop sequences mismatch: %s", diff)
}
}
func TestFromMessagesRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(testImage)
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "What's in this image?"},
map[string]any{
"type": "image",
"source": map[string]any{
"type": "base64",
"media_type": "image/png",
"data": testImage,
},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Content != "What's in this image?" {
t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content)
}
if len(result.Messages[0].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images))
}
if string(result.Messages[0].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromMessagesRequest_WithToolUse(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": map[string]any{"location": "Paris"},
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if len(result.Messages[1].ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls))
}
tc := result.Messages[1].ToolCalls[0]
if tc.ID != "call_123" {
t.Errorf("expected tool call ID 'call_123', got %q", tc.ID)
}
if tc.Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name)
}
}
func TestFromMessagesRequest_WithToolResult(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_123",
"content": "The weather in Paris is sunny, 22°C",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
msg := result.Messages[0]
if msg.Role != "tool" {
t.Errorf("expected role 'tool', got %q", msg.Role)
}
if msg.ToolCallID != "call_123" {
t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID)
}
if msg.Content != "The weather in Paris is sunny, 22°C" {
t.Errorf("unexpected content: %q", msg.Content)
}
}
func TestFromMessagesRequest_WithTools(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "get_weather",
Description: "Get current weather",
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result.Tools))
}
tool := result.Tools[0]
if tool.Type != "function" {
t.Errorf("expected type 'function', got %q", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Function.Name)
}
if tool.Function.Description != "Get current weather" {
t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description)
}
}
func TestFromMessagesRequest_WithThinking(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Think == nil {
t.Fatal("expected Think to be set")
}
if v, ok := result.Think.Value.(bool); !ok || !v {
t.Errorf("expected Think.Value to be true, got %v", result.Think.Value)
}
}
// TestFromMessagesRequest_ThinkingOnlyBlock verifies that messages containing only
// a thinking block (no text, images, or tool calls) are preserved and not dropped.
func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{Role: "user", Content: "Hello"},
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "thinking",
"thinking": "Let me think about this...",
},
},
},
},
}
result, err := FromMessagesRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
assistantMsg := result.Messages[1]
if assistantMsg.Thinking != "Let me think about this..." {
t.Errorf("expected thinking content, got %q", assistantMsg.Thinking)
}
}
func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"name": "get_weather",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use id")
}
if err.Error() != "tool_use block missing required 'id' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{
{
Role: "assistant",
Content: []any{
map[string]any{
"type": "tool_use",
"id": "call_123",
},
},
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for missing tool_use name")
}
if err.Error() != "tool_use block missing required 'name' field" {
t.Errorf("unexpected error message: %v", err)
}
}
func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) {
req := MessagesRequest{
Model: "test-model",
MaxTokens: 1024,
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
Tools: []Tool{
{
Name: "bad_tool",
InputSchema: json.RawMessage(`{invalid json`),
},
},
}
_, err := FromMessagesRequest(req)
if err == nil {
t.Fatal("expected error for invalid tool schema")
}
}
func TestToMessagesResponse_Basic(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
result := ToMessagesResponse("msg_123", resp)
if result.ID != "msg_123" {
t.Errorf("expected ID 'msg_123', got %q", result.ID)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "text" || result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("unexpected content: %+v", result.Content[0])
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 {
t.Errorf("unexpected usage: %+v", result.Usage)
}
}
func TestToMessagesResponse_WithToolCalls(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Type != "tool_use" {
t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].ID != "call_123" {
t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID)
}
if result.Content[0].Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name)
}
if result.StopReason != "tool_use" {
t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason)
}
}
func TestToMessagesResponse_WithThinking(t *testing.T) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "The answer is 42.",
Thinking: "Let me think about this...",
},
Done: true,
DoneReason: "stop",
}
result := ToMessagesResponse("msg_123", resp)
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "thinking" {
t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type)
}
if result.Content[0].Thinking == nil || *result.Content[0].Thinking != "Let me think about this..." {
t.Errorf("unexpected thinking content: %v", result.Content[0].Thinking)
}
if result.Content[1].Type != "text" {
t.Errorf("expected second block type 'text', got %q", result.Content[1].Type)
}
}
func TestMapStopReason(t *testing.T) {
tests := []struct {
reason string
hasToolCalls bool
want string
}{
{"stop", false, "end_turn"},
{"length", false, "max_tokens"},
{"stop", true, "tool_use"},
{"other", false, "stop_sequence"},
{"", false, ""},
}
for _, tt := range tests {
got := mapStopReason(tt.reason, tt.hasToolCalls)
if got != tt.want {
t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want)
}
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{401, "authentication_error"},
{403, "permission_error"},
{404, "not_found_error"},
{429, "rate_limit_error"},
{500, "api_error"},
{503, "overloaded_error"},
{529, "overloaded_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Type != "error" {
t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type)
}
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message)
}
if result.RequestID == "" {
t.Errorf("NewError(%d) request_id should not be empty", tt.code)
}
}
}
func TestGenerateMessageID(t *testing.T) {
id1 := GenerateMessageID()
id2 := GenerateMessageID()
if id1 == "" {
t.Error("GenerateMessageID returned empty string")
}
if id1 == id2 {
t.Error("GenerateMessageID returned duplicate IDs")
}
if len(id1) < 10 {
t.Errorf("GenerateMessageID returned short ID: %q", id1)
}
if id1[:4] != "msg_" {
t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4])
}
}
func TestStreamConverter_Basic(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
// First chunk
resp1 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello",
},
Metrics: api.Metrics{PromptEvalCount: 10},
}
events1 := conv.Process(resp1)
if len(events1) < 3 {
t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1))
}
// Should have message_start, content_block_start, content_block_delta
if events1[0].Event != "message_start" {
t.Errorf("expected first event 'message_start', got %q", events1[0].Event)
}
if events1[1].Event != "content_block_start" {
t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event)
}
if events1[2].Event != "content_block_delta" {
t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event)
}
// Final chunk
resp2 := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: " world!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{EvalCount: 5},
}
events2 := conv.Process(resp2)
// Should have content_block_delta, content_block_stop, message_delta, message_stop
hasStop := false
for _, e := range events2 {
if e.Event == "message_stop" {
hasStop = true
}
}
if !hasStop {
t.Error("expected message_stop event in final chunk")
}
}
func TestStreamConverter_WithToolCalls(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
events := conv.Process(resp)
hasToolStart := false
hasToolDelta := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
hasToolDelta = true
}
}
}
}
if !hasToolStart {
t.Error("expected tool_use content_block_start event")
}
if !hasToolDelta {
t.Error("expected input_json_delta event")
}
}
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Test that unmarshalable arguments (like channels) are handled gracefully
// and don't cause a panic or corrupt stream
conv := NewStreamConverter("msg_123", "test-model")
// Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
// Should not panic and should skip the unmarshalable tool call
events := conv.Process(resp)
// Verify no tool_use block was started (since marshal failed before block start)
hasToolStart := false
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
hasToolStart = true
}
}
}
}
if hasToolStart {
t.Error("expected no tool_use block when arguments cannot be marshaled")
}
}
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
// Test that valid tool calls still work when mixed with invalid ones
conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_good",
Function: api.ToolCallFunction{
Name: "good_function",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
{
ID: "call_bad",
Function: api.ToolCallFunction{
Name: "bad_function",
Arguments: badArgs,
},
},
},
},
Done: true,
DoneReason: "stop",
}
events := conv.Process(resp)
// Count tool_use blocks - should only have 1 (the valid one)
toolStartCount := 0
toolDeltaCount := 0
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "tool_use" {
toolStartCount++
if start.ContentBlock.Name != "good_function" {
t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name)
}
}
}
}
if e.Event == "content_block_delta" {
if delta, ok := e.Data.(ContentBlockDeltaEvent); ok {
if delta.Delta.Type == "input_json_delta" {
toolDeltaCount++
}
}
}
}
if toolStartCount != 1 {
t.Errorf("expected 1 tool_use block, got %d", toolStartCount)
}
if toolDeltaCount != 1 {
t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount)
}
}
// TestContentBlockJSON_EmptyFieldsPresent verifies that empty text and thinking fields
// are serialized in JSON output. The Anthropic SDK requires these fields to be present
// (even when empty) in content_block_start events to properly accumulate streaming deltas.
// Without these fields, the SDK throws: "TypeError: unsupported operand type(s) for +=: 'NoneType' and 'str'"
func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
tests := []struct {
name string
block ContentBlock
wantKeys []string
}{
{
name: "text block includes empty text field",
block: ContentBlock{
Type: "text",
Text: ptr(""),
},
wantKeys: []string{"type", "text"},
},
{
name: "thinking block includes empty thinking field",
block: ContentBlock{
Type: "thinking",
Thinking: ptr(""),
},
wantKeys: []string{"type", "thinking"},
},
{
name: "text block with content",
block: ContentBlock{
Type: "text",
Text: ptr("hello"),
},
wantKeys: []string{"type", "text"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.block)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
for _, key := range tt.wantKeys {
if _, ok := result[key]; !ok {
t.Errorf("expected key %q to be present in JSON output, got: %s", key, string(data))
}
}
})
}
}
// TestStreamConverter_ContentBlockStartIncludesEmptyFields verifies that content_block_start
// events include the required empty fields for SDK compatibility.
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
t.Run("text block start includes empty text", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "hello"},
}
events := conv.Process(resp)
var foundTextStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "text" {
foundTextStart = true
// Marshal and verify the text field is present
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["text"]; !ok {
t.Error("content_block_start for text should include 'text' field")
}
}
}
}
}
if !foundTextStart {
t.Error("expected text content_block_start event")
}
})
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model")
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Thinking: "let me think..."},
}
events := conv.Process(resp)
var foundThinkingStart bool
for _, e := range events {
if e.Event == "content_block_start" {
if start, ok := e.Data.(ContentBlockStartEvent); ok {
if start.ContentBlock.Type == "thinking" {
foundThinkingStart = true
data, _ := json.Marshal(start)
var result map[string]any
json.Unmarshal(data, &result)
cb := result["content_block"].(map[string]any)
if _, ok := cb["thinking"]; !ok {
t.Error("content_block_start for thinking should include 'thinking' field")
}
}
}
}
}
if !foundThinkingStart {
t.Error("expected thinking content_block_start event")
}
})
}

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader
@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {
@@ -347,7 +340,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -15,19 +15,19 @@ func main() {
}
messages := []api.Message{
{
api.Message{
Role: "system",
Content: "Provide very brief, concise responses",
},
{
api.Message{
Role: "user",
Content: "Name some unusual animals",
},
{
api.Message{
Role: "assistant",
Content: "Monotreme, platypus, echidna",
},
{
api.Message{
Role: "user",
Content: "which of these is the most dangerous?",
},

View File

@@ -3,7 +3,6 @@ package api
import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"math"
"os"
@@ -15,7 +14,6 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model"
)
@@ -119,28 +117,6 @@ type GenerateRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -183,14 +159,6 @@ type ChatRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
type Tools []Tool
@@ -213,11 +181,10 @@ type Message struct {
Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
@@ -233,7 +200,6 @@ func (m *Message) UnmarshalJSON(b []byte) error {
}
type ToolCall struct {
ID string `json:"id,omitempty"`
Function ToolCallFunction `json:"function"`
}
@@ -243,79 +209,13 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"`
}
// ToolCallFunctionArguments holds tool call arguments in insertion order.
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
if t == nil || t.om == nil {
return "{}"
}
bts, _ := json.Marshal(t.om)
bts, _ := json.Marshal(t)
return string(bts)
}
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
@@ -364,78 +264,12 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt))
}
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"`
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
}
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
@@ -484,11 +318,11 @@ func mapToTypeScriptType(jsonType string) string {
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {
@@ -507,27 +341,6 @@ func (t *ToolFunction) String() string {
return string(bts)
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
// Token is the text representation of the token.
Token string `json:"token"`
// Logprob is the log probability of this token.
Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token
Bytes []int `json:"bytes,omitempty"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
// TopLogprobs contains the most likely tokens and their log probabilities
// at this position, if requested via TopLogprobs parameter.
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
@@ -554,10 +367,6 @@ type ChatResponse struct {
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
Metrics
}
@@ -701,9 +510,6 @@ type CreateRequest struct {
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
@@ -749,12 +555,11 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].
@@ -870,24 +675,6 @@ type GenerateResponse struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Image contains a base64-encoded generated image.
// Only present for image generation models.
Image string `json:"image,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
}
// ModelDetails provides details about a model.
@@ -912,19 +699,6 @@ type UserResponse struct {
Plan string `json:"plan,omitempty"`
}
type UsageResponse struct {
// Start is the time the server started tracking usage (UTC, RFC 3339).
Start time.Time `json:"start"`
Usage []ModelUsageData `json:"usage"`
}
type ModelUsageData struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`

View File

@@ -11,24 +11,6 @@ import (
"github.com/stretchr/testify/require"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
@@ -316,48 +298,10 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) {
}
}
func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input ToolFunctionParameters
expected string
}{
{
name: "simple object with string property",
input: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
},
{
name: "no required",
input: ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]ToolProperty{
"name": {Type: PropertyType{"string"}},
}),
},
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
require.NoError(t, err)
assert.Equal(t, test.expected, string(data))
})
}
}
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: testArgs(map[string]any{"message": "hi"}),
Arguments: ToolCallFunctionArguments{"message": "hi"},
}
data, err := json.Marshal(fn)
@@ -522,116 +466,6 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
}
}
func TestToolPropertyNestedProperties(t *testing.T) {
tests := []struct {
name string
input string
expected ToolProperty
}{
{
name: "nested object properties",
input: `{
"type": "object",
"description": "Location details",
"properties": {
"address": {
"type": "string",
"description": "Street address"
},
"city": {
"type": "string",
"description": "City name"
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{
"address": {
Type: PropertyType{"string"},
Description: "Street address",
},
"city": {
Type: PropertyType{"string"},
Description: "City name",
},
}),
},
},
{
name: "deeply nested properties",
input: `{
"type": "object",
"description": "Event",
"properties": {
"location": {
"type": "object",
"description": "Location",
"properties": {
"coordinates": {
"type": "object",
"description": "GPS coordinates",
"properties": {
"lat": {"type": "number", "description": "Latitude"},
"lng": {"type": "number", "description": "Longitude"}
}
}
}
}
}
}`,
expected: ToolProperty{
Type: PropertyType{"object"},
Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{
"location": {
Type: PropertyType{"object"},
Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{
"coordinates": {
Type: PropertyType{"object"},
Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}),
},
}),
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
// Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop)
require.NoError(t, err)
var prop2 ToolProperty
err = json.Unmarshal(data, &prop2)
require.NoError(t, err)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
})
}
}
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
@@ -643,12 +477,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
}),
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
@@ -665,7 +499,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s
return s
}(),
Properties: testPropsMap(map[string]ToolProperty{}),
Properties: map[string]ToolProperty{},
},
expected: "",
},
@@ -678,235 +512,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@@ -1,97 +1,22 @@
# Ollama for macOS and Windows
# Ollama App
## Download
## Linux
- [macOS](https://github.com/ollama/app/releases/download/latest/Ollama.dmg)
- [Windows](https://github.com/ollama/app/releases/download/latest/OllamaSetup.exe)
TODO
## Development
## MacOS
### Desktop App
TODO
```bash
go generate ./... &&
go run ./cmd/app
```
### UI Development
#### Setup
Install required tools:
```bash
go install github.com/tkrajina/typescriptify-golang-structs/tscriptify@latest
```
#### Develop UI (Development Mode)
1. Start the React development server (with hot-reload):
```bash
cd ui/app
npm install
npm run dev
```
2. In a separate terminal, run the Ollama app with the `-dev` flag:
```bash
go generate ./... &&
OLLAMA_DEBUG=1 go run ./cmd/app -dev
```
The `-dev` flag enables:
- Loading the UI from the Vite dev server at http://localhost:5173
- Fixed UI server port at http://127.0.0.1:3001 for API requests
- CORS headers for cross-origin requests
- Hot-reload support for UI development
## Build
### Windows
## Windows
If you want to build the installer, youll need to install
- https://jrsoftware.org/isinfo.php
**Dependencies** - either build a local copy of ollama, or use a github release
In the top directory of this repo, run the following powershell script
to build the ollama CLI, ollama app, and ollama installer.
```powershell
# Local dependencies
.\scripts\deps_local.ps1
# Release dependencies
.\scripts\deps_release.ps1 0.6.8
```
**Build**
```powershell
.\scripts\build_windows.ps1
```
### macOS
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
```
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
```
**Dependencies** - either build a local copy of Ollama, or use a GitHub release:
```sh
# Local dependencies
./scripts/deps_local.sh
# Release dependencies
./scripts/deps_release.sh 0.6.8
```
**Build**
```sh
./scripts/build_darwin.sh
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
```

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package assets
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package auth
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package main
import (
@@ -24,6 +22,7 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/logrotate"
"github.com/ollama/ollama/app/network"
"github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools"
@@ -32,17 +31,13 @@ import (
"github.com/ollama/ollama/app/version"
)
var (
wv = &Webview{}
uiServerPort int
)
var wv = &Webview{}
var uiServerPort int
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
var (
fastStartup = false
devMode = false
)
var fastStartup = false
var devMode = false
type appMove int
@@ -75,7 +70,7 @@ func main() {
fmt.Println(version.Version)
os.Exit(0)
case "background":
// When running the process in this "background" mode, we spawn a
// When running the process in this "backgroud" mode, we spawn a
// child process for the main app. This is necessary so the
// "Allow in the Background" setting in MacOS can be unchecked
// without breaking the main app. Two copies of the app are
@@ -107,7 +102,7 @@ func main() {
logrotate.Rotate(appLogPath)
if _, err := os.Stat(filepath.Dir(appLogPath)); errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(filepath.Dir(appLogPath), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(appLogPath), 0755); err != nil {
slog.Error(fmt.Sprintf("failed to create server log dir %v", err))
return
}
@@ -183,7 +178,7 @@ func main() {
// Check if another instance is already running
// On Windows, focus the existing instance; on other platforms, kill it
handleExistingInstance(startHidden)
handleExistingInstance()
// on macOS, offer the user to create a symlink
// from /usr/local/bin/ollama to the app bundle
@@ -263,16 +258,23 @@ func main() {
done <- osrv.Run(octx)
}()
},
Store: st,
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
Store: st,
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
NetworkMonitor: network.NewMonitor(),
}
uiServer.NetworkMonitor.Start(ctx)
srv := &http.Server{
Handler: uiServer.Handler(),
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
// Start the UI server
slog.Info("starting ui server", "port", port)
go func() {
@@ -316,17 +318,6 @@ func main() {
slog.Debug("no URL scheme request to handle")
}
go func() {
slog.Debug("waiting for ollama server to be ready")
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
slog.Warn("ollama server not ready, continuing anyway", "error", err)
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
}()
osRun(cancel, hasCompletedFirstRun, startHidden)
slog.Info("shutting down desktop server")
@@ -368,7 +359,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
return false
}
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
if err != nil {
slog.Debug("failed to call local auth endpoint", "error", err)
return false
@@ -404,8 +395,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
return
}
@@ -441,30 +432,37 @@ func openInBrowser(url string) {
}
}
// parseURLScheme parses an ollama:// URL and validates it
// Supports: ollama:// (open app) and ollama://connect (OAuth)
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
parsedURL, err := url.Parse(urlSchemeRequest)
if err != nil {
return false, fmt.Errorf("invalid URL: %w", err)
return false, "", err
}
// Check if this is a connect URL
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
return true, nil
return true, "", nil
}
// Allow bare ollama:// or ollama:/// to open the app
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
return false, nil
// Extract the UI path
path := "/"
if parsedURL.Path != "" && parsedURL.Path != "/" {
// For URLs like ollama:///settings, use the path directly
path = parsedURL.Path
} else if parsedURL.Host != "" {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
// This also handles ollama://settings/ where Windows adds a trailing slash
path = "/" + parsedURL.Host
}
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
return false, path, nil
}
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
isConnect, err := parseURLScheme(urlSchemeRequest)
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
return
@@ -473,8 +471,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage(uiPath)
}
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package main
// #cgo CFLAGS: -x objective-c
@@ -8,7 +6,6 @@ package main
// #include "../../updater/updater_darwin.h"
// typedef const char cchar_t;
import "C"
import (
"log/slog"
"os"
@@ -35,11 +32,9 @@ var ollamaPath = func() string {
return filepath.Join(pwd, "ollama")
}()
var (
isApp = updater.BundlePath != ""
appLogPath = filepath.Join(os.Getenv("HOME"), ".ollama", "logs", "app.log")
launchAgentPath = filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "com.ollama.ollama.plist")
)
var isApp = updater.BundlePath != ""
var appLogPath = filepath.Join(os.Getenv("HOME"), ".ollama", "logs", "app.log")
var launchAgentPath = filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "com.ollama.ollama.plist")
// TODO(jmorganca): pre-create the window and pass
// it to the webview instead of using the internal one
@@ -128,7 +123,7 @@ func maybeMoveAndRestart() appMove {
}
// handleExistingInstance handles existing instances on macOS
func handleExistingInstance(_ bool) {
func handleExistingInstance() {
C.killOtherInstances()
}
@@ -191,6 +186,13 @@ func LaunchNewApp() {
C.launchApp(appName)
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem()

View File

@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end
@implementation AppDelegate
@@ -25,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -41,13 +53,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) {
@@ -255,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
[[NSWorkspace sharedWorkspace] openURL:url];
}
@@ -286,18 +291,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES];
}
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel;

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package main
import (
@@ -26,6 +24,7 @@ import (
var (
u32 = windows.NewLazySystemDLL("User32.dll")
pBringWindowToTop = u32.NewProc("BringWindowToTop")
pSetWindowLong = u32.NewProc("SetWindowLongA")
pShowWindow = u32.NewProc("ShowWindow")
pSendMessage = u32.NewProc("SendMessageA")
pGetSystemMetrics = u32.NewProc("GetSystemMetrics")
@@ -36,6 +35,7 @@ var (
pIsIconic = u32.NewProc("IsIconic")
appPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "Programs", "Ollama")
appDataPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama")
appLogPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "app.log")
startupShortcut = filepath.Join(os.Getenv("APPDATA"), "Microsoft", "Windows", "Start Menu", "Programs", "Startup", "Ollama.lnk")
ollamaPath string
@@ -73,10 +73,10 @@ func maybeMoveAndRestart() appMove {
return 0
}
// handleExistingInstance checks for existing instances and optionally focuses them
func handleExistingInstance(startHidden bool) {
if wintray.CheckAndFocusExistingInstance(!startHidden) {
slog.Info("existing instance found, exiting")
// handleExistingInstance checks for existing instances and focuses them
func handleExistingInstance() {
if wintray.CheckAndFocusExistingInstance() {
slog.Info("existing instance found and focused, exiting")
os.Exit(0)
}
}
@@ -93,7 +93,6 @@ var app = &appCallbacks{}
func (ac *appCallbacks) UIRun(path string) {
wv.Run(path)
}
func (*appCallbacks) UIShow() {
if wv.webview != nil {
showWindow(wv.webview.Window())
@@ -101,21 +100,18 @@ func (*appCallbacks) UIShow() {
wv.Run("/")
}
}
func (*appCallbacks) UITerminate() {
wv.Terminate()
}
func (*appCallbacks) UIRunning() bool {
return wv.IsRunning()
}
func (app *appCallbacks) Quit() {
app.t.Quit()
wv.Terminate()
}
// TODO - reconcile with above for consistency between mac/windows
// TODO - reconcile with above for consitency between mac/windows
func quit() {
wv.Terminate()
}
@@ -138,7 +134,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
// handleURLSchemeRequest processes URL scheme requests from other instances
func handleURLSchemeRequest(urlScheme string) {
isConnect, err := parseURLScheme(urlScheme)
isConnect, uiPath, err := parseURLScheme(urlScheme)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
return
@@ -147,9 +143,7 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage(uiPath)
}
}
@@ -196,23 +190,25 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
if startHidden {
startHiddenTasks()
} else {
ptr := wv.Run("/")
if !startHidden {
ptr := wv.Run("/")
// Set the window icon using the tray icon
if ptr != nil {
iconHandle := app.t.GetIconHandle()
if iconHandle != 0 {
hwnd := uintptr(ptr)
const ICON_SMALL = 0
const ICON_BIG = 1
const WM_SETICON = 0x0080
// Set the window icon using the tray icon
if ptr != nil {
iconHandle := app.t.GetIconHandle()
if iconHandle != 0 {
hwnd := uintptr(ptr)
const ICON_SMALL = 0
const ICON_BIG = 1
const WM_SETICON = 0x0080
pSendMessage.Call(hwnd, uintptr(WM_SETICON), uintptr(ICON_SMALL), uintptr(iconHandle))
pSendMessage.Call(hwnd, uintptr(WM_SETICON), uintptr(ICON_BIG), uintptr(iconHandle))
pSendMessage.Call(hwnd, uintptr(WM_SETICON), uintptr(ICON_SMALL), uintptr(iconHandle))
pSendMessage.Call(hwnd, uintptr(WM_SETICON), uintptr(ICON_BIG), uintptr(iconHandle))
}
}
}
centerWindow(ptr)
centerWindow(ptr)
}
}
if !hasCompletedFirstRun {
@@ -263,7 +259,13 @@ func createLoginShortcut() error {
return nil
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() {
}
func logStartup() {

View File

@@ -1,20 +1,15 @@
//go:build windows || darwin
package main
// #include "menu.h"
import "C"
import (
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"unsafe"
@@ -22,6 +17,8 @@ import (
"github.com/ollama/ollama/app/dialog"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/webview"
"log/slog"
)
type Webview struct {
@@ -264,96 +261,42 @@ func (w *Webview) Run(path string) unsafe.Pointer {
}()
})
// Bind selectFiles function for selecting multiple files at once
wv.Bind("selectFiles", func() {
// Bind selectFile function for React UI
// Uses callback pattern since webview bindings can't directly return Promises
// The HTML wrapper creates a Promise that resolves when this callback is called
wv.Bind("selectFile", func() {
go func() {
// Helper function to call the JavaScript callback with data or null
callCallback := func(data interface{}) {
dataJSON, _ := json.Marshal(data)
wv.Dispatch(func() {
wv.Eval(fmt.Sprintf("window.__selectFilesCallback && window.__selectFilesCallback(%s)", dataJSON))
wv.Eval(fmt.Sprintf("window.__selectFileCallback && window.__selectFileCallback(%s)", dataJSON))
})
}
// Define allowed extensions for native dialog filtering
textExts := []string{
"pdf", "docx", "txt", "md", "csv", "json", "xml", "html", "htm",
"js", "jsx", "ts", "tsx", "py", "java", "cpp", "c", "cc", "h", "cs", "php", "rb",
"go", "rs", "swift", "kt", "scala", "sh", "bat", "yaml", "yml", "toml", "ini",
"cfg", "conf", "log", "rtf",
}
imageExts := []string{"png", "jpg", "jpeg", "webp"}
allowedExts := append(textExts, imageExts...)
// Use native multiple file selection with extension filtering
filenames, err := dialog.File().
Filter("Supported Files", allowedExts...).
Title("Select Files").
LoadMultiple()
filename, err := dialog.File().Load()
if err != nil {
slog.Debug("Multiple file selection cancelled or failed", "error", err)
callCallback(nil)
return
}
if len(filenames) == 0 {
fileData, err := os.ReadFile(filename)
if err != nil {
slog.Error("failed to read file", "error", err)
callCallback(nil)
return
}
var files []map[string]string
maxFileSize := int64(10 * 1024 * 1024) // 10MB
mimeType := http.DetectContentType(fileData)
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64.StdEncoding.EncodeToString(fileData))
for _, filename := range filenames {
// Check file extension (double-check after native dialog filtering)
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(filename), "."))
validExt := false
for _, allowedExt := range allowedExts {
if ext == allowedExt {
validExt = true
break
}
}
if !validExt {
slog.Warn("file extension not allowed, skipping", "filename", filepath.Base(filename), "extension", ext)
continue
}
// Check file size before reading (pre-filter large files)
fileStat, err := os.Stat(filename)
if err != nil {
slog.Error("failed to get file info", "error", err, "filename", filename)
continue
}
if fileStat.Size() > maxFileSize {
slog.Warn("file too large, skipping", "filename", filepath.Base(filename), "size", fileStat.Size())
continue
}
fileBytes, err := os.ReadFile(filename)
if err != nil {
slog.Error("failed to read file", "error", err, "filename", filename)
continue
}
mimeType := http.DetectContentType(fileBytes)
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64.StdEncoding.EncodeToString(fileBytes))
fileResult := map[string]string{
"filename": filepath.Base(filename),
"path": filename,
"dataURL": dataURL,
}
files = append(files, fileResult)
data := map[string]string{
"filename": filepath.Base(filename),
"path": filename,
"dataURL": dataURL,
}
if len(files) == 0 {
callCallback(nil)
} else {
callCallback(files)
}
callCallback(data)
}()
})
@@ -495,11 +438,9 @@ func (w *Webview) IsRunning() bool {
return w.webview != nil
}
var (
menuItems []C.menuItem
menuMutex sync.RWMutex
pinner runtime.Pinner
)
var menuItems []C.menuItem
var menuMutex sync.RWMutex
var pinner runtime.Pinner
//export menu_get_item_count
func menu_get_item_count() C.int {

View File

@@ -27,7 +27,6 @@ typedef struct {
char* startDir; /* directory to start in (can be nil) */
char* filename; /* default filename for dialog box (can be nil) */
int showHidden; /* show hidden files? */
int allowMultiple; /* allow multiple file selection? */
} FileDlgParams;
typedef enum {

View File

@@ -1,12 +1,5 @@
#import <Cocoa/Cocoa.h>
#include "dlg.h"
#include <string.h>
#include <sys/syslimits.h>
// Import UniformTypeIdentifiers for macOS 11+
#if __MAC_OS_X_VERSION_MAX_ALLOWED >= 110000
#import <UniformTypeIdentifiers/UniformTypeIdentifiers.h>
#endif
void* NSStr(void* buf, int len) {
return (void*)[[NSString alloc] initWithBytes:buf length:len encoding:NSUTF8StringEncoding];
@@ -114,20 +107,12 @@ DlgResult fileDlg(FileDlgParams* params) {
if(self->params->title != nil) {
[panel setTitle:[[NSString alloc] initWithUTF8String:self->params->title]];
}
// Use modern allowedContentTypes API for better file type support (especially video files)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
if(self->params->numext > 0) {
NSMutableArray *utTypes = [NSMutableArray arrayWithCapacity:self->params->numext];
NSString** exts = (NSString**)self->params->exts;
for(int i = 0; i < self->params->numext; i++) {
UTType *type = [UTType typeWithFilenameExtension:exts[i]];
if(type) {
[utTypes addObject:type];
}
}
if([utTypes count] > 0) {
[panel setAllowedContentTypes:utTypes];
}
[panel setAllowedFileTypes:[NSArray arrayWithObjects:(NSString**)self->params->exts count:self->params->numext]];
}
#pragma clang diagnostic pop
if(self->params->relaxext) {
[panel setAllowsOtherFileTypes:YES];
}
@@ -159,59 +144,13 @@ DlgResult fileDlg(FileDlgParams* params) {
[panel setCanChooseDirectories:YES];
[panel setCanChooseFiles:NO];
}
if(self->params->allowMultiple) {
[panel setAllowsMultipleSelection:YES];
}
if(![self runPanel:panel]) {
return DLG_CANCEL;
}
NSArray* urls = [panel URLs];
if([urls count] == 0) {
return DLG_CANCEL;
NSURL* url = [[panel URLs] objectAtIndex:0];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
if(self->params->allowMultiple) {
// For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf;
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
}
return DLG_OK;
}

View File

@@ -1,6 +1,6 @@
package cocoa
// #cgo darwin LDFLAGS: -framework Cocoa -framework UniformTypeIdentifiers
// #cgo darwin LDFLAGS: -framework Cocoa
// #include <stdlib.h>
// #include <sys/syslimits.h>
// #include "dlg.h"
@@ -57,67 +57,31 @@ func ErrorDlg(msg, title string) {
a.run()
}
const (
BUFSIZE = C.PATH_MAX
MULTI_FILE_BUF_SIZE = 32768
)
const BUFSIZE = C.PATH_MAX
// MultiFileDlg opens a file dialog that allows multiple file selection
func MultiFileDlg(title string, exts []string, relaxExt bool, startDir string, showHidden bool) ([]string, error) {
return fileDlgWithOptions(C.LOADDLG, title, exts, relaxExt, startDir, "", showHidden, true)
}
// FileDlg opens a file dialog for single file selection (kept for compatibility)
func FileDlg(save bool, title string, exts []string, relaxExt bool, startDir string, filename string, showHidden bool) (string, error) {
mode := C.LOADDLG
if save {
mode = C.SAVEDLG
}
files, err := fileDlgWithOptions(mode, title, exts, relaxExt, startDir, filename, showHidden, false)
if err != nil {
return "", err
}
if len(files) == 0 {
return "", nil
}
return files[0], nil
return fileDlg(mode, title, exts, relaxExt, startDir, filename, showHidden)
}
func DirDlg(title string, startDir string, showHidden bool) (string, error) {
files, err := fileDlgWithOptions(C.DIRDLG, title, nil, false, startDir, "", showHidden, false)
if err != nil {
return "", err
}
if len(files) == 0 {
return "", nil
}
return files[0], nil
return fileDlg(C.DIRDLG, title, nil, false, startDir, "", showHidden)
}
// fileDlgWithOptions is the unified file dialog function that handles both single and multiple selection
func fileDlgWithOptions(mode int, title string, exts []string, relaxExt bool, startDir, filename string, showHidden, allowMultiple bool) ([]string, error) {
// Use larger buffer for multiple files, smaller for single
bufSize := BUFSIZE
if allowMultiple {
bufSize = MULTI_FILE_BUF_SIZE
}
func fileDlg(mode int, title string, exts []string, relaxExt bool, startDir, filename string, showHidden bool) (string, error) {
p := C.FileDlgParams{
mode: C.int(mode),
nbuf: C.int(bufSize),
}
if allowMultiple {
p.allowMultiple = C.int(1) // Enable multiple selection //nolint:structcheck
nbuf: BUFSIZE,
}
if showHidden {
p.showHidden = 1
}
p.buf = (*C.char)(C.malloc(C.size_t(bufSize)))
p.buf = (*C.char)(C.malloc(BUFSIZE))
defer C.free(unsafe.Pointer(p.buf))
buf := (*(*[MULTI_FILE_BUF_SIZE]byte)(unsafe.Pointer(p.buf)))[:bufSize]
buf := (*(*[BUFSIZE]byte)(unsafe.Pointer(p.buf)))[:]
if title != "" {
p.title = C.CString(title)
defer C.free(unsafe.Pointer(p.title))
@@ -130,7 +94,6 @@ func fileDlgWithOptions(mode int, title string, exts []string, relaxExt bool, st
p.filename = C.CString(filename)
defer C.free(unsafe.Pointer(p.filename))
}
if len(exts) > 0 {
if len(exts) > 999 {
panic("more than 999 extensions not supported")
@@ -140,6 +103,7 @@ func fileDlgWithOptions(mode int, title string, exts []string, relaxExt bool, st
defer C.free(unsafe.Pointer(p.exts))
cext := (*(*[999]unsafe.Pointer)(unsafe.Pointer(p.exts)))[:]
for i, ext := range exts {
i := i
cext[i] = nsStr(ext)
defer C.NSRelease(cext[i])
}
@@ -148,36 +112,14 @@ func fileDlgWithOptions(mode int, title string, exts []string, relaxExt bool, st
p.relaxext = 1
}
}
// Execute dialog and parse results
switch C.fileDlg(&p) {
case C.DLG_OK:
if allowMultiple {
// Parse multiple null-terminated strings from buffer
var files []string
start := 0
for i := range len(buf) - 1 {
if buf[i] == 0 {
if i > start {
files = append(files, string(buf[start:i]))
}
start = i + 1
// Check for double null (end of list)
if i+1 < len(buf) && buf[i+1] == 0 {
break
}
}
}
return files, nil
} else {
// Single file - return as array for consistency
filename := string(buf[:bytes.Index(buf, []byte{0})])
return []string{filename}, nil
}
// casting to string copies the [about-to-be-freed] bytes
return string(buf[:bytes.Index(buf, []byte{0})]), nil
case C.DLG_CANCEL:
return nil, nil
return "", nil
case C.DLG_URLFAIL:
return nil, errors.New("failed to get file-system representation for selected URL")
return "", errors.New("failed to get file-system representation for selected URL")
}
panic("unhandled case")
}

View File

@@ -1,11 +1,9 @@
//go:build windows || darwin
// Package dialog provides a simple cross-platform common dialog API.
// Eg. to prompt the user with a yes/no dialog:
//
// if dialog.MsgDlg("%s", "Do you want to continue?").YesNo() {
// // user pressed Yes
// }
// if dialog.MsgDlg("%s", "Do you want to continue?").YesNo() {
// // user pressed Yes
// }
//
// The general usage pattern is to call one of the toplevel *Dlg functions
// which return a *Builder structure. From here you can optionally call
@@ -128,13 +126,6 @@ func (b *FileBuilder) Load() (string, error) {
return b.load()
}
// LoadMultiple spawns the file selection dialog using the configured settings,
// asking the user to select multiple files. Returns ErrCancelled as the error
// if the user cancels or closes the dialog.
func (b *FileBuilder) LoadMultiple() ([]string, error) {
return b.loadMultiple()
}
// Save spawns the file selection dialog using the configured settings,
// asking the user for a filename to save as. If the chosen file exists, the
// user is prompted whether they want to overwrite the file. Returns

View File

@@ -20,10 +20,6 @@ func (b *FileBuilder) load() (string, error) {
return b.run(false)
}
func (b *FileBuilder) loadMultiple() ([]string, error) {
return b.runMultiple()
}
func (b *FileBuilder) save() (string, error) {
return b.run(true)
}
@@ -53,26 +49,6 @@ func (b *FileBuilder) run(save bool) (string, error) {
return f, err
}
func (b *FileBuilder) runMultiple() ([]string, error) {
star := false
var exts []string
for _, filt := range b.Filters {
for _, ext := range filt.Extensions {
if ext == "*" {
star = true
} else {
exts = append(exts, ext)
}
}
}
files, err := cocoa.MultiFileDlg(b.Dlg.Title, exts, star, b.StartDir, b.ShowHiddenFiles)
if len(files) == 0 && err == nil {
return nil, ErrCancelled
}
return files, err
}
func (b *DirectoryBuilder) browse() (string, error) {
f, err := cocoa.DirDlg(b.Dlg.Title, b.StartDir, b.ShowHiddenFiles)
if f == "" && err == nil {

124
app/dialog/dlgs_linux.go Normal file
View File

@@ -0,0 +1,124 @@
package dialog
// #cgo pkg-config: gtk+-3.0
// #cgo LDFLAGS: -lX11
// #include <X11/Xlib.h>
// #include <gtk/gtk.h>
// #include <stdlib.h>
// static GtkWidget* msgdlg(GtkWindow *parent, GtkDialogFlags flags, GtkMessageType type, GtkButtonsType buttons, char *msg) {
// return gtk_message_dialog_new(parent, flags, type, buttons, "%s", msg);
// }
// static GtkWidget* filedlg(char *title, GtkWindow *parent, GtkFileChooserAction action, char* acceptText) {
// return gtk_file_chooser_dialog_new(title, parent, action, "Cancel", GTK_RESPONSE_CANCEL, acceptText, GTK_RESPONSE_ACCEPT, NULL);
// }
import "C"
import "unsafe"
var initSuccess bool
func init() {
C.XInitThreads()
initSuccess = (C.gtk_init_check(nil, nil) == C.TRUE)
}
func checkStatus() {
if !initSuccess {
panic("gtk initialisation failed; presumably no X server is available")
}
}
func closeDialog(dlg *C.GtkWidget) {
C.gtk_widget_destroy(dlg)
/* The Destroy call itself isn't enough to remove the dialog from the screen; apparently
** that happens once the GTK main loop processes some further events. But if we're
** in a non-GTK app the main loop isn't running, so we empty the event queue before
** returning from the dialog functions.
** Not sure how this interacts with an actual GTK app... */
for C.gtk_events_pending() != 0 {
C.gtk_main_iteration()
}
}
func runMsgDlg(defaultTitle string, flags C.GtkDialogFlags, msgtype C.GtkMessageType, buttons C.GtkButtonsType, b *MsgBuilder) C.gint {
checkStatus()
cmsg := C.CString(b.Msg)
defer C.free(unsafe.Pointer(cmsg))
dlg := C.msgdlg(nil, flags, msgtype, buttons, cmsg)
ctitle := C.CString(firstOf(b.Dlg.Title, defaultTitle))
defer C.free(unsafe.Pointer(ctitle))
C.gtk_window_set_title((*C.GtkWindow)(unsafe.Pointer(dlg)), ctitle)
defer closeDialog(dlg)
return C.gtk_dialog_run((*C.GtkDialog)(unsafe.Pointer(dlg)))
}
func (b *MsgBuilder) yesNo() bool {
return runMsgDlg("Confirm?", 0, C.GTK_MESSAGE_QUESTION, C.GTK_BUTTONS_YES_NO, b) == C.GTK_RESPONSE_YES
}
func (b *MsgBuilder) info() {
runMsgDlg("Information", 0, C.GTK_MESSAGE_INFO, C.GTK_BUTTONS_OK, b)
}
func (b *MsgBuilder) error() {
runMsgDlg("Error", 0, C.GTK_MESSAGE_ERROR, C.GTK_BUTTONS_OK, b)
}
func (b *FileBuilder) load() (string, error) {
return chooseFile("Open File", "Open", C.GTK_FILE_CHOOSER_ACTION_OPEN, b)
}
func (b *FileBuilder) save() (string, error) {
f, err := chooseFile("Save File", "Save", C.GTK_FILE_CHOOSER_ACTION_SAVE, b)
if err != nil {
return "", err
}
return f, nil
}
func chooseFile(title string, buttonText string, action C.GtkFileChooserAction, b *FileBuilder) (string, error) {
checkStatus()
ctitle := C.CString(title)
defer C.free(unsafe.Pointer(ctitle))
cbuttonText := C.CString(buttonText)
defer C.free(unsafe.Pointer(cbuttonText))
dlg := C.filedlg(ctitle, nil, action, cbuttonText)
fdlg := (*C.GtkFileChooser)(unsafe.Pointer(dlg))
for _, filt := range b.Filters {
filter := C.gtk_file_filter_new()
cdesc := C.CString(filt.Desc)
defer C.free(unsafe.Pointer(cdesc))
C.gtk_file_filter_set_name(filter, cdesc)
for _, ext := range filt.Extensions {
cpattern := C.CString("*." + ext)
defer C.free(unsafe.Pointer(cpattern))
C.gtk_file_filter_add_pattern(filter, cpattern)
}
C.gtk_file_chooser_add_filter(fdlg, filter)
}
if b.StartDir != "" {
cdir := C.CString(b.StartDir)
defer C.free(unsafe.Pointer(cdir))
C.gtk_file_chooser_set_current_folder(fdlg, cdir)
}
if b.StartFile != "" {
cfile := C.CString(b.StartFile)
defer C.free(unsafe.Pointer(cfile))
C.gtk_file_chooser_set_current_name(fdlg, cfile)
}
if b.ShowHiddenFiles {
C.gtk_file_chooser_set_show_hidden(fdlg, C.TRUE)
}
C.gtk_file_chooser_set_do_overwrite_confirmation(fdlg, C.TRUE)
r := C.gtk_dialog_run((*C.GtkDialog)(unsafe.Pointer(dlg)))
defer closeDialog(dlg)
if r == C.GTK_RESPONSE_ACCEPT {
return C.GoString(C.gtk_file_chooser_get_filename(fdlg)), nil
}
return "", ErrCancelled
}
func (b *DirectoryBuilder) browse() (string, error) {
return chooseFile("Open Folder", "Open", C.GTK_FILE_CHOOSER_ACTION_SELECT_FOLDER, &FileBuilder{Dlg: b.Dlg, ShowHiddenFiles: b.ShowHiddenFiles})
}

View File

@@ -10,12 +10,10 @@ import (
"github.com/TheTitanrain/w32"
)
const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int
func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
}
func err() error {
@@ -53,57 +51,6 @@ func (d filedlg) Filename() string {
return string(utf16.Decode(d.buf[:i]))
}
func (d filedlg) parseMultipleFilenames() []string {
var files []string
i := 0
// Find first null terminator (directory path)
for i < len(d.buf) && d.buf[i] != 0 {
i++
}
if i >= len(d.buf) {
return files
}
// Get directory path
dirPath := string(utf16.Decode(d.buf[:i]))
i++ // Skip null terminator
// Check if there are more files (multiple selection)
if i < len(d.buf) && d.buf[i] != 0 {
// Multiple files selected - parse filenames
for i < len(d.buf) {
start := i
// Find next null terminator
for i < len(d.buf) && d.buf[i] != 0 {
i++
}
if i >= len(d.buf) {
break
}
if start < i {
filename := string(utf16.Decode(d.buf[start:i]))
if dirPath != "" {
files = append(files, dirPath+"\\"+filename)
} else {
files = append(files, filename)
}
}
i++ // Skip null terminator
if i >= len(d.buf) || d.buf[i] == 0 {
break // End of list
}
}
} else {
// Single file selected
files = append(files, dirPath)
}
return files
}
func (b *FileBuilder) load() (string, error) {
d := openfile(w32.OFN_FILEMUSTEXIST|w32.OFN_NOCHANGEDIR, b)
if w32.GetOpenFileName(d.opf) {
@@ -112,18 +59,6 @@ func (b *FileBuilder) load() (string, error) {
return "", err()
}
func (b *FileBuilder) loadMultiple() ([]string, error) {
d := openfile(w32.OFN_FILEMUSTEXIST|w32.OFN_NOCHANGEDIR|w32.OFN_ALLOWMULTISELECT|w32.OFN_EXPLORER, b)
d.buf = make([]uint16, multiFileBufferSize)
d.opf.File = utf16ptr(d.buf)
d.opf.MaxFile = uint32(len(d.buf))
if w32.GetOpenFileName(d.opf) {
return d.parseMultipleFilenames(), nil
}
return nil, err()
}
func (b *FileBuilder) save() (string, error) {
d := openfile(w32.OFN_OVERWRITEPROMPT|w32.OFN_NOCHANGEDIR, b)
if w32.GetSaveFileName(d.opf) {
@@ -141,15 +76,15 @@ func utf16ptr(utf16 []uint16) *uint16 {
return (*uint16)(unsafe.Pointer(h.Data))
}
func utf16slice(ptr *uint16) []uint16 { //nolint:unused
func utf16slice(ptr *uint16) []uint16 {
hdr := reflect.SliceHeader{Data: uintptr(unsafe.Pointer(ptr)), Len: 1, Cap: 1}
slice := *((*[]uint16)(unsafe.Pointer(&hdr))) //nolint:govet
slice := *((*[]uint16)(unsafe.Pointer(&hdr)))
i := 0
for slice[len(slice)-1] != 0 {
i++
}
hdr.Len = i
slice = *((*[]uint16)(unsafe.Pointer(&hdr))) //nolint:govet
slice = *((*[]uint16)(unsafe.Pointer(&hdr)))
return slice
}

View File

@@ -1,5 +1,3 @@
//go:build windows
package dialog
func firstOf(args ...string) string {

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package format
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package format
import "testing"

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
// package logrotate provides utilities for rotating logs
// TODO (jmorgan): this most likely doesn't need it's own
// package and can be moved to app where log files are created

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package logrotate
import (

91
app/network/monitor.go Normal file
View File

@@ -0,0 +1,91 @@
package network
import (
"context"
"sync"
)
type ConnectivityStatus int
const (
StatusUnknown ConnectivityStatus = iota
StatusOnline
StatusOffline
)
type ConnectivityChangeHandler func(status ConnectivityStatus)
type Monitor struct {
mu sync.RWMutex
status ConnectivityStatus
handlers []ConnectivityChangeHandler
stopChan chan struct{}
}
func NewMonitor() *Monitor {
return &Monitor{
status: StatusUnknown,
handlers: make([]ConnectivityChangeHandler, 0),
}
}
func (m *Monitor) Start(ctx context.Context) {
m.mu.Lock()
if m.stopChan != nil {
m.mu.Unlock()
return
}
m.stopChan = make(chan struct{})
m.mu.Unlock()
m.startPlatformMonitor(ctx)
}
func (m *Monitor) checkConnectivity() {
online := m.checkPlatformConnectivity()
m.mu.Lock()
oldStatus := m.status
if online {
m.status = StatusOnline
} else {
m.status = StatusOffline
}
handlers := m.handlers
m.mu.Unlock()
if oldStatus != m.status {
for _, handler := range handlers {
handler(m.status)
}
}
}
func (m *Monitor) OnConnectivityChange(handler ConnectivityChangeHandler) {
m.mu.Lock()
defer m.mu.Unlock()
m.handlers = append(m.handlers, handler)
}
func (m *Monitor) IsOnline() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.status == StatusOnline
}
// Disconnected returns a channel that receives a signal when the network goes offline
func (m *Monitor) Disconnected() <-chan struct{} {
ch := make(chan struct{})
m.OnConnectivityChange(func(status ConnectivityStatus) {
if status == StatusOffline {
select {
case ch <- struct{}{}:
default:
// Don't block if already signaled
}
}
})
return ch
}

View File

@@ -0,0 +1,96 @@
//go:build darwin
package network
import (
"bufio"
"context"
"os/exec"
"strings"
"time"
)
func (m *Monitor) startPlatformMonitor(ctx context.Context) {
go m.watchNetworkChanges(ctx)
}
func (m *Monitor) checkPlatformConnectivity() bool {
// Check if we have active network interfaces
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--nwi")
output, err := cmd.Output()
if err != nil {
return false
}
outputStr := string(output)
// Check for active interfaces with IP addresses
hasIPv4 := strings.Contains(outputStr, "IPv4") &&
!strings.Contains(outputStr, "IPv4 : No addresses")
hasIPv6 := strings.Contains(outputStr, "IPv6") &&
!strings.Contains(outputStr, "IPv6 : No addresses")
if !hasIPv4 && !hasIPv6 {
return false
}
// Check for active network interfaces
lines := strings.Split(outputStr, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Look for active ethernet (en) or VPN (utun) interfaces
if strings.HasPrefix(line, "en") || strings.HasPrefix(line, "utun") {
if strings.Contains(line, "flags") && !strings.Contains(line, "inactive") {
return true
}
}
}
return false
}
func (m *Monitor) watchNetworkChanges(ctx context.Context) {
// Use scutil to watch for network changes
cmd := exec.CommandContext(ctx, "scutil")
stdin, err := cmd.StdinPipe()
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return
}
if err := cmd.Start(); err != nil {
return
}
defer cmd.Wait()
// Watch for network state changes
stdin.Write([]byte("n.add State:/Network/Global/IPv4\n"))
stdin.Write([]byte("n.add State:/Network/Global/IPv6\n"))
stdin.Write([]byte("n.add State:/Network/Interface\n"))
stdin.Write([]byte("n.watch\n"))
// Trigger initial check
m.checkConnectivity()
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
select {
case <-ctx.Done():
return
case <-m.stopChan:
return
default:
// Any output from scutil indicates a network change
// Trigger connectivity check
m.checkConnectivity()
}
}
}

View File

@@ -0,0 +1,93 @@
//go:build windows
package network
import (
"context"
"os/exec"
"strings"
"syscall"
"time"
"unsafe"
)
var (
wininet = syscall.NewLazyDLL("wininet.dll")
internetGetConnectedState = wininet.NewProc("InternetGetConnectedState")
)
const INTERNET_CONNECTION_OFFLINE = 0x20
func (m *Monitor) startPlatformMonitor(ctx context.Context) {
go m.watchNetworkChanges(ctx)
}
func (m *Monitor) checkPlatformConnectivity() bool {
// First check Windows Internet API
if internetGetConnectedState.Find() == nil {
var flags uint32
r, _, _ := internetGetConnectedState.Call(
uintptr(unsafe.Pointer(&flags)),
0,
)
if r == 1 && (flags&INTERNET_CONNECTION_OFFLINE) == 0 {
// Also verify with netsh that interfaces are actually connected
return m.checkWindowsInterfaces()
}
}
return false
}
func (m *Monitor) checkWindowsInterfaces() bool {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "netsh", "interface", "show", "interface")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
return false
}
for line := range strings.SplitSeq(string(output), "\n") {
line = strings.ToLower(strings.TrimSpace(line))
// Look for a “connected” interface that isnt “disconnected” or “loopback”
if strings.Contains(line, "connected") &&
!strings.Contains(line, "disconnected") &&
!strings.Contains(line, "loopback") {
return true
}
}
return false
}
func (m *Monitor) watchNetworkChanges(ctx context.Context) {
// Windows doesn't have a simple built-in tool like scutil,
// so poll frequently to detect changes
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
// Initial check
m.checkConnectivity()
var lastState bool = m.checkPlatformConnectivity()
for {
select {
case <-ctx.Done():
return
case <-m.stopChan:
return
case <-ticker.C:
currentState := m.checkPlatformConnectivity()
if currentState != lastState {
lastState = currentState
m.checkConnectivity()
}
}
}
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package server
import (
@@ -224,7 +222,9 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models
} else {
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
}
}
if settings.ContextLength > 0 {
@@ -260,7 +260,7 @@ func openRotatingLog() (io.WriteCloser, error) {
return f, nil
}
// Attempt to retrieve inference compute information from the server
// Attempt to retrive inference compute information from the server
// log. Set ctx to timeout to control how long to wait for the logs to appear
func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
inference := []InferenceCompute{}
@@ -326,7 +326,6 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
time.Sleep(time.Second)
continue
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package server
import (
@@ -15,7 +13,12 @@ import (
)
func TestNew(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-server-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
st := &store.Store{DBPath: filepath.Join(tmpDir, "db.sqlite")}
defer st.Close() // Ensure database is closed before cleanup
s := New(st, false)
@@ -37,10 +40,14 @@ func TestServerCmd(t *testing.T) {
home, err := os.UserHomeDir()
if err == nil {
defaultModels = filepath.Join(home, ".ollama", "models")
os.MkdirAll(defaultModels, 0o755)
os.MkdirAll(defaultModels, 0755)
}
tmpModels := t.TempDir()
tmpModels, err := os.MkdirTemp("", "models")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpModels)
tests := []struct {
name string
settings store.Settings
@@ -95,7 +102,12 @@ func TestServerCmd(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-server-cmd-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
st := &store.Store{DBPath: filepath.Join(tmpDir, "db.sqlite")}
defer st.Close() // Ensure database is closed before cleanup
st.SetSettings(tt.settings)
@@ -103,7 +115,7 @@ func TestServerCmd(t *testing.T) {
store: st,
}
cmd, err := s.cmd(t.Context())
cmd, err := s.cmd(context.Background())
if err != nil {
t.Fatalf("s.cmd() error = %v", err)
}
@@ -211,13 +223,17 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", tt.name)
if err != nil {
t.Fatal(err)
}
serverLogPath = filepath.Join(tmpDir, "server.log")
err := os.WriteFile(serverLogPath, []byte(tt.log), 0o644)
defer os.RemoveAll(tmpDir)
err = os.WriteFile(serverLogPath, []byte(tt.log), 0644)
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
ics, err := GetInferenceComputer(ctx)
if err != nil {
@@ -227,15 +243,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp
t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp)
}
})
}
}
func TestGetInferenceComputerTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "timeouttest")
if err != nil {
t.Fatal(err)
}
serverLogPath = filepath.Join(tmpDir, "server.log")
err := os.WriteFile(serverLogPath, []byte("foo\nbar\nbaz\n"), 0o644)
defer os.RemoveAll(tmpDir)
err = os.WriteFile(serverLogPath, []byte("foo\nbar\nbaz\n"), 0644)
if err != nil {
t.Fatalf("failed to write log file %s: %s", serverLogPath, err)
}

View File

@@ -1,4 +1,4 @@
//go:build darwin
//go:build !windows
package server
@@ -15,10 +15,8 @@ import (
"syscall"
)
var (
pidFile = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "Ollama", "ollama.pid")
serverLogPath = filepath.Join(os.Getenv("HOME"), ".ollama", "logs", "server.log")
)
var pidFile = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "Ollama", "ollama.pid")
var serverLogPath = filepath.Join(os.Getenv("HOME"), ".ollama", "logs", "server.log")
func commandContext(ctx context.Context, name string, arg ...string) *exec.Cmd {
return exec.CommandContext(ctx, name, arg...)
@@ -59,7 +57,7 @@ func reapServers() error {
if err != nil {
// No ollama processes found
slog.Debug("no ollama processes found")
return nil //nolint:nilerr
return nil
}
pidsStr := strings.TrimSpace(string(output))

View File

@@ -14,10 +14,8 @@ import (
"golang.org/x/sys/windows"
)
var (
pidFile = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "ollama.pid")
serverLogPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "server.log")
)
var pidFile = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "ollama.pid")
var serverLogPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "server.log")
func commandContext(ctx context.Context, name string, arg ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, arg...)
@@ -113,7 +111,7 @@ func reapServers() error {
if err != nil {
// No ollama processes found
slog.Debug("no ollama processes found")
return nil //nolint:nilerr
return nil
}
lines := strings.Split(string(output), "\n")

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package store
import (
@@ -256,6 +254,7 @@ func (db *database) migrate() error {
// migrateV1ToV2 adds the context_length column to the settings table
func (db *database) migrateV1ToV2() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN context_length INTEGER NOT NULL DEFAULT 4096;`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add context_length column: %w", err)
@@ -295,7 +294,6 @@ func (db *database) migrateV2ToV3() error {
return nil
}
func (db *database) migrateV3ToV4() error {
_, err := db.conn.Exec(`ALTER TABLE messages ADD COLUMN tool_result TEXT;`)
if err != nil && !duplicateColumnError(err) {
@@ -415,6 +413,7 @@ func (db *database) migrateV9ToV10() error {
);
UPDATE settings SET schema_version = 10;
`)
if err != nil {
return fmt.Errorf("create users table: %w", err)
}
@@ -1111,6 +1110,7 @@ func (db *database) getSettings() (Settings, error) {
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1187,6 +1187,7 @@ func (db *database) getUser() (*User, error) {
FROM users
LIMIT 1
`).Scan(&user.Name, &user.Email, &user.Plan, &user.CachedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // No user cached yet
@@ -1206,6 +1207,7 @@ func (db *database) setUser(user User) error {
INSERT INTO users (name, email, plan, cached_at)
VALUES (?, ?, ?, ?)
`, user.Name, user.Email, user.Plan, user.CachedAt)
if err != nil {
return fmt.Errorf("set user: %w", err)
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package store
import (
@@ -18,7 +16,12 @@ import (
func TestSchemaMigrations(t *testing.T) {
t.Run("schema comparison after migration", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "migration-schema-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
migratedDBPath := filepath.Join(tmpDir, "migrated.db")
migratedDB := loadV2Schema(t, migratedDBPath)
defer migratedDB.Close()
@@ -52,7 +55,12 @@ func TestSchemaMigrations(t *testing.T) {
})
t.Run("idempotent migrations", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "migration-idempotent-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
db := loadV2Schema(t, dbPath)
defer db.Close()
@@ -77,7 +85,12 @@ func TestSchemaMigrations(t *testing.T) {
})
t.Run("init database has correct schema version", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "schema-version-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
@@ -100,7 +113,12 @@ func TestSchemaMigrations(t *testing.T) {
func TestChatDeletionWithCascade(t *testing.T) {
t.Run("chat deletion cascades to related messages", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "cascade-delete-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
@@ -196,7 +214,12 @@ func TestChatDeletionWithCascade(t *testing.T) {
})
t.Run("foreign keys are enabled", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "foreign-keys-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {
@@ -218,7 +241,12 @@ func TestChatDeletionWithCascade(t *testing.T) {
// This test is only relevant for v8 migrations, but we keep it here for now
// since it's a useful test to ensure that we don't introduce any new orphaned data
t.Run("cleanup orphaned data", func(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "orphaned-data-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
db, err := newDatabase(dbPath)
if err != nil {

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package store
import (
@@ -50,7 +48,7 @@ func (s *Store) ImgDir() string {
// ImgToFile saves image data to disk and returns ImageData reference
func (s *Store) ImgToFile(chatID string, imageBytes []byte, filename, mimeType string) (Image, error) {
baseImageDir := s.ImgDir()
if err := os.MkdirAll(baseImageDir, 0o755); err != nil {
if err := os.MkdirAll(baseImageDir, 0755); err != nil {
return Image{}, fmt.Errorf("create base image directory: %w", err)
}
@@ -63,7 +61,7 @@ func (s *Store) ImgToFile(chatID string, imageBytes []byte, filename, mimeType s
// Create chat-specific subdirectory within the root
chatDir := sanitize(chatID)
if err := root.Mkdir(chatDir, 0o755); err != nil && !os.IsExist(err) {
if err := root.Mkdir(chatDir, 0755); err != nil && !os.IsExist(err) {
return Image{}, fmt.Errorf("create chat directory: %w", err)
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package store
import (
@@ -11,7 +9,12 @@ import (
)
func TestConfigMigration(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-migration-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
// Create a legacy config.json
legacyConfig := legacyData{
ID: "test-device-id-12345",
@@ -24,7 +27,7 @@ func TestConfigMigration(t *testing.T) {
}
configPath := filepath.Join(tmpDir, "config.json")
if err := os.WriteFile(configPath, configData, 0o644); err != nil {
if err := os.WriteFile(configPath, configData, 0644); err != nil {
t.Fatal(err)
}
@@ -86,7 +89,12 @@ func TestConfigMigration(t *testing.T) {
}
func TestNoConfigToMigrate(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-no-migration-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
// Override the legacy config path for testing
oldLegacyConfigPath := legacyConfigPath
legacyConfigPath = filepath.Join(tmpDir, "config.json")
@@ -189,7 +197,11 @@ const (
)
func TestMigrationFromEpoc(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-migration-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
s := Store{DBPath: filepath.Join(tmpDir, "db.sqlite")}
defer s.Close()
// Open database connection

View File

@@ -1,14 +1,18 @@
//go:build windows || darwin
package store
import (
"os"
"path/filepath"
"testing"
)
func TestSchemaVersioning(t *testing.T) {
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-schema-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
// Override legacy config path to avoid migration logs
oldLegacyConfigPath := legacyConfigPath
legacyConfigPath = filepath.Join(tmpDir, "config.json")

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
// Package store provides a simple JSON file store for the desktop application
// to save and load data such as ollama server configuration, messages,
// login information and more.
@@ -155,7 +153,7 @@ type Settings struct {
// TurboEnabled indicates if Ollama Turbo features are enabled
TurboEnabled bool
// Maps gpt-oss specific frontend name' BrowserToolEnabled' to db field 'websearch_enabled'
// Maps gpt-oss specfic frontend name' BrowserToolEnabled' to db field 'websearch_enabled'
WebSearchEnabled bool
// ThinkEnabled indicates if thinking is enabled
@@ -230,7 +228,7 @@ func (s *Store) ensureDB() error {
}
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return fmt.Errorf("create db directory: %w", err)
}

View File

@@ -1,8 +1,7 @@
//go:build windows || darwin
package store
import (
"os"
"path/filepath"
"testing"
)
@@ -175,7 +174,10 @@ func TestStore(t *testing.T) {
func setupTestStore(t *testing.T) (*Store, func()) {
t.Helper()
tmpDir := t.TempDir()
tmpDir, err := os.MkdirTemp("", "ollama-store-test")
if err != nil {
t.Fatal(err)
}
// Override legacy config path to ensure no migration happens
oldLegacyConfigPath := legacyConfigPath
@@ -186,6 +188,7 @@ func setupTestStore(t *testing.T) (*Store, func()) {
cleanup := func() {
s.Close()
legacyConfigPath = oldLegacyConfigPath
os.RemoveAll(tmpDir)
}
return s, cleanup

137
app/tools/bash.go Normal file
View File

@@ -0,0 +1,137 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
"time"
)
// BashCommand executes non-destructive bash commands
type BashCommand struct{}
func (b *BashCommand) Name() string {
return "bash_command"
}
func (b *BashCommand) Description() string {
return "Execute non-destructive bash commands safely"
}
func (b *BashCommand) Prompt() string {
return `For bash commands:
1. Only use safe, non-destructive commands like: ls, pwd, echo, cat, grep, ps, df, du, find, which, whoami, date, uptime, uname, wc, head, tail, sort, uniq
2. For searching files and content:
- Use grep -r "keyword" . to recursively search for keywords in files
- Use find . -name "*keyword*" to search for files by name
- Use find . -type f -exec grep "keyword" {} \; to search file contents
3. Never use dangerous flags like --delete, --remove, -rf, -fr, --modify, --write, --exec
4. Commands will timeout after 30 seconds by default
5. Always check command output for errors and handle them appropriately
6. Before running any commands:
- Use ls to understand directory structure
- Use cat/head/tail to inspect file contents
- Plan your search strategy based on the context`
}
func (b *BashCommand) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute"
},
"timeout_seconds": {
"type": "integer",
"description": "Maximum execution time in seconds (default: 30)",
"default": 30
}
},
"required": ["command"]
}`)
var schema map[string]any
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
return nil
}
return schema
}
func (b *BashCommand) Execute(ctx context.Context, args map[string]any) (any, error) {
// Extract command
cmd, ok := args["command"].(string)
if !ok {
return nil, fmt.Errorf("command parameter is required and must be a string")
}
// Get optional timeout
timeoutSeconds := 30
if t, ok := args["timeout_seconds"].(float64); ok {
timeoutSeconds = int(t)
}
// List of allowed commands (exact matches or prefixes)
allowedCommands := []string{
"ls", "pwd", "echo", "cat", "grep",
"ps", "df", "du", "find", "which",
"whoami", "date", "uptime", "uname",
"wc", "head", "tail", "sort", "uniq",
}
// Split the command to get the base command
cmdParts := strings.Fields(cmd)
if len(cmdParts) == 0 {
return nil, fmt.Errorf("empty command")
}
baseCmd := cmdParts[0]
// Check if the command is allowed
allowed := false
for _, allowedCmd := range allowedCommands {
if baseCmd == allowedCmd {
allowed = true
break
}
}
if !allowed {
return nil, fmt.Errorf("command not in allowed list: %s", baseCmd)
}
// Additional safety checks for arguments
dangerousFlags := []string{
"--delete", "--remove", "-rf", "-fr",
"--modify", "--write", "--exec",
}
cmdLower := strings.ToLower(cmd)
for _, flag := range dangerousFlags {
if strings.Contains(cmdLower, flag) {
return nil, fmt.Errorf("command contains dangerous flag: %s", flag)
}
}
// Create command with timeout
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
defer cancel()
// Execute command
execCmd := exec.CommandContext(ctx, "bash", "-c", cmd)
output, err := execCmd.CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("command timed out after %d seconds", timeoutSeconds)
}
if err != nil {
return nil, fmt.Errorf("command execution failed: %w", err)
}
// Return result directly as a map
return map[string]any{
"command": cmd,
"output": string(output),
"success": true,
}, nil
}

185
app/tools/bash_test.go Normal file
View File

@@ -0,0 +1,185 @@
package tools
import (
"context"
"strings"
"testing"
)
func TestBashCommand_Name(t *testing.T) {
cmd := &BashCommand{}
if name := cmd.Name(); name != "bash_command" {
t.Errorf("Expected name 'bash_command', got %s", name)
}
}
func TestBashCommand_Execute(t *testing.T) {
cmd := &BashCommand{}
ctx := context.Background()
tests := []struct {
name string
input map[string]any
wantErr bool
errContains string
wantOutput string
}{
{
name: "valid echo command",
input: map[string]any{
"command": "echo 'hello world'",
},
wantErr: false,
wantOutput: "hello world\n",
},
{
name: "valid ls command",
input: map[string]any{
"command": "ls -l",
},
wantErr: false,
},
{
name: "invalid command",
input: map[string]any{
"command": "rm -rf /",
},
wantErr: true,
errContains: "command not in allowed list",
},
{
name: "dangerous flag",
input: map[string]any{
"command": "find . --delete",
},
wantErr: true,
errContains: "dangerous flag",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cmd.Execute(ctx, tt.input)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing '%s', got '%s'", tt.errContains, err.Error())
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
// Check result type and fields
response, ok := result.(map[string]any)
if !ok {
t.Fatal("Expected result to be map[string]any")
}
// Check required fields
success, ok := response["success"].(bool)
if !ok || !success {
t.Error("Expected success to be true")
}
command, ok := response["command"].(string)
if !ok || command == "" {
t.Error("Expected command to be non-empty string")
}
output, ok := response["output"].(string)
if !ok {
t.Error("Expected output to be string")
} else if tt.wantOutput != "" && output != tt.wantOutput {
t.Errorf("Expected output '%s', got '%s'", tt.wantOutput, output)
}
})
}
}
func TestBashCommand_InvalidInput(t *testing.T) {
cmd := &BashCommand{}
ctx := context.Background()
tests := []struct {
name string
input map[string]any
errContains string
}{
{
name: "missing command",
input: map[string]any{},
errContains: "command parameter is required",
},
{
name: "empty command",
input: map[string]any{
"command": "",
},
errContains: "empty command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := cmd.Execute(ctx, tt.input)
if err == nil {
t.Error("Expected error but got none")
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing '%s', got '%s'", tt.errContains, err.Error())
}
})
}
}
func TestBashCommand_OutputFormat(t *testing.T) {
cmd := &BashCommand{}
ctx := context.Background()
// Test with a simple echo command
input := map[string]any{
"command": "echo 'test output'",
}
result, err := cmd.Execute(ctx, input)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Verify the result is a map[string]any
response, ok := result.(map[string]any)
if !ok {
t.Fatal("Result is not a map[string]any")
}
// Check all expected fields exist
requiredFields := []string{"command", "output", "success"}
for _, field := range requiredFields {
if _, ok := response[field]; !ok {
t.Errorf("Missing required field: %s", field)
}
}
// Verify output is plain text
output, ok := response["output"].(string)
if !ok {
t.Error("Output field is not a string")
} else {
// Output should contain 'test output' and a newline
expectedOutput := "test output\n"
if output != expectedOutput {
t.Errorf("Expected output '%s', got '%s'", expectedOutput, output)
}
// Verify output is not base64 encoded
if strings.Contains(output, "base64") ||
(len(output) > 0 && output[0] == 'e' && strings.ContainsAny(output, "+/=")) {
t.Error("Output appears to be base64 encoded")
}
}
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package tools
import (
@@ -178,8 +176,8 @@ func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any,
func (b *Browser) buildSearchResultsPageCollection(query string, results *WebSearchResponse) *responses.Page {
page := &responses.Page{
URL: "search_results_" + query,
Title: query,
URL: fmt.Sprintf("%s", "search_results_"+query),
Title: fmt.Sprintf("%s", query),
Links: make(map[int]string),
FetchedAt: time.Now(),
}
@@ -501,6 +499,7 @@ func (b *BrowserOpen) Schema() map[string]any {
}
func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, string, error) {
// Get cursor parameter first
cursor := -1
if c, ok := args["cursor"].(float64); ok {

View File

@@ -1,11 +1,16 @@
//go:build windows || darwin
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/ollama/ollama/auth"
)
// CrawlContent represents the content of a crawled page
@@ -49,7 +54,6 @@ func (g *BrowserCrawler) Name() string {
func (g *BrowserCrawler) Description() string {
return "Crawl and extract text content from web pages"
}
func (g *BrowserCrawler) Prompt() string {
return `When you need to read content from web pages, use the get_webpage tool. Simply provide the URLs you want to read and I'll fetch their content for you.
@@ -73,6 +77,11 @@ func (g *BrowserCrawler) Schema() map[string]any {
"type": "string"
},
"description": "List of URLs to crawl and extract content from"
},
"latest": {
"type": "boolean",
"description": " Needs up to date and latest information (default: false)",
"default": false
}
},
"required": ["urls"]
@@ -85,6 +94,7 @@ func (g *BrowserCrawler) Schema() map[string]any {
}
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
// Extract and validate URLs
urlsRaw, ok := args["urls"].([]any)
if !ok {
return nil, fmt.Errorf("urls parameter is required and must be an array of strings")
@@ -101,36 +111,86 @@ func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*Cra
return nil, fmt.Errorf("at least one URL is required")
}
return g.performWebCrawl(ctx, urls)
latest, _ := args["latest"].(bool)
// Perform the web crawling
return g.performWebCrawl(ctx, urls, latest)
}
// performWebCrawl handles the actual HTTP request to ollama.com crawl API
func (g *BrowserCrawler) performWebCrawl(ctx context.Context, urls []string) (*CrawlResponse, error) {
result := &CrawlResponse{Results: make(map[string][]CrawlResult, len(urls))}
for _, targetURL := range urls {
fetchResp, err := performWebFetch(ctx, targetURL)
if err != nil {
return nil, fmt.Errorf("web_fetch failed for %q: %w", targetURL, err)
}
links := make([]CrawlLink, 0, len(fetchResp.Links))
for _, link := range fetchResp.Links {
links = append(links, CrawlLink{URL: link, Href: link})
}
snippet := truncateString(fetchResp.Content, 400)
result.Results[targetURL] = []CrawlResult{{
Title: fetchResp.Title,
URL: targetURL,
Content: CrawlContent{
Snippet: snippet,
FullText: fetchResp.Content,
},
Extras: CrawlExtras{Links: links},
}}
func (g *BrowserCrawler) performWebCrawl(ctx context.Context, urls []string, latest bool) (*CrawlResponse, error) {
// Prepare the request body matching the API format
reqBody := map[string]any{
"urls": urls,
"text": true,
"extras": map[string]any{
"links": 1,
},
"livecrawl": "fallback",
}
return result, nil
if latest {
reqBody["livecrawl"] = "always"
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
crawlURL, err := url.Parse("https://ollama.com/api/tools/webcrawl")
if err != nil {
return nil, fmt.Errorf("failed to parse crawl URL: %w", err)
}
// Add timestamp for signing
query := crawlURL.Query()
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
var signature string
crawlURL.RawQuery = query.Encode()
// Sign the request data (method + URI)
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, crawlURL.RequestURI())
signature, err = auth.Sign(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
}
// Create the request
req, err := http.NewRequestWithContext(ctx, "POST", crawlURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
// Make the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute crawl request: %w", err)
}
defer resp.Body.Close()
// Read and parse response
var result CrawlResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Check for error response
if resp.StatusCode != http.StatusOK {
errMsg := "unknown error"
if resp.StatusCode == http.StatusServiceUnavailable {
errMsg = "crawl service unavailable - API key may not be configured"
}
return nil, fmt.Errorf("crawl API error (status %d): %s", resp.StatusCode, errMsg)
}
return &result, nil
}

View File

@@ -1,8 +1,7 @@
//go:build windows || darwin
package tools
import (
"context"
"strings"
"testing"
"time"
@@ -30,7 +29,7 @@ func TestBrowser_Scroll_AppendsOnlyPageStack(t *testing.T) {
bo := NewBrowserOpen(b)
// Scroll without id — should push only to PageStack
_, _, err := bo.Execute(t.Context(), map[string]any{"loc": float64(1), "num_lines": float64(1)})
_, _, err := bo.Execute(context.TODO(), map[string]any{"loc": float64(1), "num_lines": float64(1)})
if err != nil {
t.Fatalf("scroll execute failed: %v", err)
}
@@ -52,7 +51,7 @@ func TestBrowserOpen_UseCacheByURL(t *testing.T) {
initialStackLen := len(b.state.Data.PageStack)
initialMapLen := len(b.state.Data.URLToPage)
_, _, err := bo.Execute(t.Context(), map[string]any{"id": p.URL})
_, _, err := bo.Execute(context.TODO(), map[string]any{"id": p.URL})
if err != nil {
t.Fatalf("open cached execute failed: %v", err)
}
@@ -91,7 +90,7 @@ func TestBrowserOpen_LinkId_UsesCacheAndAppends(t *testing.T) {
initialMapLen := len(b.state.Data.URLToPage)
bo := NewBrowserOpen(b)
_, _, err := bo.Execute(t.Context(), map[string]any{"id": float64(0)})
_, _, err := bo.Execute(context.TODO(), map[string]any{"id": float64(0)})
if err != nil {
t.Fatalf("open by link id failed: %v", err)
}

View File

@@ -1,13 +1,16 @@
//go:build windows || darwin
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/ollama/ollama/auth"
)
// WebSearchContent represents the content of a search result
@@ -82,6 +85,7 @@ func (w *BrowserWebSearch) Schema() map[string]any {
}
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
// Extract and validate queries
queriesRaw, ok := args["queries"].([]any)
if !ok {
return nil, fmt.Errorf("queries parameter is required and must be an array of strings")
@@ -98,46 +102,83 @@ func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (an
return nil, fmt.Errorf("at least one query is required")
}
// Get optional parameters
maxResults := 5
if mr, ok := args["max_results"].(int); ok {
maxResults = mr
}
// Perform the web search
return w.performWebSearch(ctx, queries, maxResults)
}
// performWebSearch handles the actual HTTP request to ollama.com search API
func (w *BrowserWebSearch) performWebSearch(ctx context.Context, queries []string, maxResults int) (*WebSearchResponse, error) {
response := &WebSearchResponse{Results: make(map[string][]WebSearchResult, len(queries))}
for _, query := range queries {
searchResp, err := performWebSearch(ctx, query, maxResults)
if err != nil {
return nil, fmt.Errorf("web_search failed for %q: %w", query, err)
}
converted := make([]WebSearchResult, 0, len(searchResp.Results))
for _, item := range searchResp.Results {
converted = append(converted, WebSearchResult{
Title: item.Title,
URL: item.URL,
Content: WebSearchContent{
Snippet: truncateString(item.Content, 400),
FullText: item.Content,
},
Metadata: WebSearchMetadata{},
})
}
response.Results[query] = converted
// Prepare the request body
reqBody := map[string]any{
"queries": queries,
"max_results": maxResults,
}
return response, nil
}
func truncateString(input string, limit int) string {
if limit <= 0 || len(input) <= limit {
return input
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
return input[:limit]
searchURL, err := url.Parse("https://ollama.com/api/tools/websearch")
if err != nil {
return nil, fmt.Errorf("failed to parse search URL: %w", err)
}
// Add timestamp for signing
query := searchURL.Query()
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
var signature string
searchURL.RawQuery = query.Encode()
// Sign the request data (method + URI)
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, searchURL.RequestURI())
signature, err = auth.Sign(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
}
// Create the request
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", signature)
}
// Make the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute search request: %w", err)
}
defer resp.Body.Close()
// Read and parse response
var result WebSearchResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Check for error response
if resp.StatusCode != http.StatusOK {
errMsg := "unknown error"
if resp.StatusCode == http.StatusServiceUnavailable {
errMsg = "search service unavailable - API key may not be configured"
}
return nil, fmt.Errorf("search API error (status %d): %s", resp.StatusCode, errMsg)
}
// Return the results directly without caching
return &result, nil
}

105
app/tools/crawl_test.go Normal file
View File

@@ -0,0 +1,105 @@
package tools
import (
"context"
"strings"
"testing"
)
func TestGetWebpage_Name(t *testing.T) {
tool := &BrowserCrawler{}
if name := tool.Name(); name != "get_webpage" {
t.Errorf("Expected name 'get_webpage', got %s", name)
}
}
func TestGetWebpage_Description(t *testing.T) {
tool := &BrowserCrawler{}
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestGetWebpage_Schema(t *testing.T) {
tool := &BrowserCrawler{}
schema := tool.Schema()
if schema == nil {
t.Error("Schema should not be nil")
}
// Check if schema has required properties
if schema["type"] != "object" {
t.Error("Schema type should be 'object'")
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
t.Error("Schema should have properties")
}
// Check if urls property exists
if _, ok := properties["urls"]; !ok {
t.Error("Schema should have 'urls' property")
}
// Check if required field exists
required, ok := schema["required"].([]any)
if !ok {
t.Error("Schema should have 'required' field")
}
// Check if urls is in required
foundUrls := false
for _, req := range required {
if req == "urls" {
foundUrls = true
break
}
}
if !foundUrls {
t.Error("'urls' should be in required fields")
}
}
func TestGetWebpage_Execute_InvalidInput(t *testing.T) {
tool := &BrowserCrawler{}
ctx := context.Background()
tests := []struct {
name string
input map[string]any
errContains string
}{
{
name: "missing urls",
input: map[string]any{},
errContains: "urls parameter is required",
},
{
name: "empty urls array",
input: map[string]any{
"urls": []any{},
},
errContains: "at least one URL is required",
},
{
name: "invalid urls type",
input: map[string]any{
"urls": "not an array",
},
errContains: "urls parameter is required and must be an array",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := tool.Execute(ctx, tt.input)
if err == nil {
t.Error("Expected error but got none")
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing '%s', got '%s'", tt.errContains, err.Error())
}
})
}
}

624
app/tools/filesystem.go Normal file
View File

@@ -0,0 +1,624 @@
package tools
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/ledongthuc/pdf"
)
// FileInfo represents information about a single file or directory
type FileInfo struct {
// BasePath string `json:"base_path"`
RelPath string `json:"rel_path"`
IsDir bool `json:"is_dir"`
}
// FileListResult represents the result of a directory listing operation
type FileListResult struct {
BasePath string `json:"base_path"`
Files []FileInfo `json:"files"`
Count int `json:"count"`
}
// FileReadResult represents the result of a file read operation
type FileReadResult struct {
Path string `json:"path"`
TotalLines int `json:"total_lines"`
LinesRead int `json:"lines_read"`
Content string `json:"content"`
}
// FileWriteResult represents the result of a file write operation
type FileWriteResult struct {
Path string `json:"path"`
Size int64 `json:"size,omitempty"`
Written int `json:"written"`
Mode string `json:"mode,omitempty"`
Modified int64 `json:"modified,omitempty"`
}
// FileReader implements the file reading functionality
type FileReader struct {
workingDir string
}
func (f *FileReader) SetWorkingDir(dir string) {
f.workingDir = dir
}
func (f *FileReader) Name() string {
return "file_read"
}
func (f *FileReader) Description() string {
return "Read the contents of a file from the file system"
}
func (f *FileReader) Prompt() string {
// TODO: read iteratively in agent mode, full in single shot - control with prompt?
return `Use the file_read tool to read the contents of a file using the path parameter. read_full is false by default and will return the first 100 lines of the file, if the user requires more information about the file, set read_full to true`
}
func (f *FileReader) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the file to read"
},
"read_full": {
"type": "boolean",
"description": "returns the first 100 lines of the file when set to false (default: false)",
"default": false
}
},
"required": ["path"]
}`)
var schema map[string]any
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
return nil
}
return schema
}
func (f *FileReader) Execute(ctx context.Context, args map[string]any) (any, error) {
fmt.Println("file_read tool called", args)
path, ok := args["path"].(string)
if !ok {
return nil, fmt.Errorf("path parameter is required and must be a string")
}
// If path is not absolute and working directory is set, make it relative to working directory
if !filepath.IsAbs(path) && f.workingDir != "" {
path = filepath.Join(f.workingDir, path)
}
// Security: Clean and validate the path
cleanPath := filepath.Clean(path)
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("path traversal not allowed")
}
// Get max size limit
maxSize := int64(1024 * 1024) // 1MB default
if ms, ok := args["max_size"]; ok {
switch v := ms.(type) {
case float64:
maxSize = int64(v)
case int:
maxSize = int64(v)
case int64:
maxSize = v
}
}
// Check if file exists and get info
info, err := os.Stat(cleanPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("file does not exist: %s", cleanPath)
}
return nil, fmt.Errorf("error accessing file: %w", err)
}
// Check if it's a directory
if info.IsDir() {
return nil, fmt.Errorf("path is a directory, not a file: %s", cleanPath)
}
// Check file size
if info.Size() > maxSize {
return nil, fmt.Errorf("file too large (%d bytes), maximum allowed: %d bytes", info.Size(), maxSize)
}
if strings.HasSuffix(strings.ToLower(cleanPath), ".pdf") {
return f.readPDFFile(cleanPath, args)
}
// Check read_full parameter
readFull := false // default to false
if rf, ok := args["read_full"]; ok {
readFull, _ = rf.(bool)
}
// Open and read the file
file, err := os.Open(cleanPath)
if err != nil {
return nil, fmt.Errorf("error opening file: %w", err)
}
defer file.Close()
// Read file content
scanner := bufio.NewScanner(file)
var lines []string
totalLines := 0
// Read content, keeping track of total lines but only storing up to 100 if !readFull
for scanner.Scan() {
totalLines++
if readFull || totalLines <= 100 {
lines = append(lines, scanner.Text())
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading file: %w", err)
}
content := strings.Join(lines, "\n")
return &FileReadResult{
Path: cleanPath,
LinesRead: len(lines),
TotalLines: totalLines,
Content: content,
}, nil
}
// readPDFFile extracts text from a PDF file
func (f *FileReader) readPDFFile(cleanPath string, args map[string]any) (any, error) {
// Open the PDF file
pdfFile, r, err := pdf.Open(cleanPath)
if err != nil {
return nil, fmt.Errorf("error opening PDF: %w", err)
}
defer pdfFile.Close()
// Get total number of pages
totalPages := r.NumPage()
// Check read_full parameter - for PDFs, this controls whether to read all pages
readFull := false
if rf, ok := args["read_full"]; ok {
readFull, _ = rf.(bool)
}
// Extract text from pages
var allText strings.Builder
maxPages := 10 // Default to first 10 pages if not read_full
if readFull {
maxPages = totalPages
}
linesExtracted := 0
for pageNum := 1; pageNum <= totalPages && pageNum <= maxPages; pageNum++ {
// Get page
page := r.Page(pageNum)
if page.V.IsNull() {
continue
}
// Use the built-in GetPlainText method which handles text extraction better
pageText, err := page.GetPlainText(nil)
if err != nil {
// If GetPlainText fails, fall back to manual extraction
pageText = f.extractTextFromPage(page)
}
pageText = strings.TrimSpace(pageText)
if pageText != "" {
if allText.Len() > 0 {
allText.WriteString("\n\n")
}
allText.WriteString(fmt.Sprintf("--- Page %d ---\n", pageNum))
allText.WriteString(pageText)
// Count lines for reporting
linesExtracted += strings.Count(pageText, "\n") + 1
}
}
content := strings.TrimSpace(allText.String())
// If no text was extracted, return a helpful message
if content == "" {
content = "[PDF file contains no extractable text - it may contain only images or use complex encoding]"
linesExtracted = 1
}
return &FileReadResult{
Path: cleanPath,
LinesRead: linesExtracted,
TotalLines: totalPages, // For PDFs, we report pages as "lines"
Content: content,
}, nil
}
// extractTextFromPage extracts text from a single PDF page
func (f *FileReader) extractTextFromPage(page pdf.Page) string {
var buf bytes.Buffer
// Get page contents
contents := page.Content()
// Group text elements that appear to be part of the same word/line
var currentLine strings.Builder
lastX := -1.0
for i, t := range contents.Text {
// Skip empty text
if t.S == "" {
continue
}
// Check if this text element is on a new line or far from the previous one
// If X position is significantly different or we've reset to the beginning, it's likely a new word
if lastX >= 0 && (t.X < lastX-10 || t.X > lastX+50) {
// Add the accumulated line to buffer with a space
if currentLine.Len() > 0 {
buf.WriteString(currentLine.String())
buf.WriteString(" ")
currentLine.Reset()
}
}
// Add the text without extra spaces
currentLine.WriteString(t.S)
lastX = t.X
// Check if next element exists and has significantly different Y position (new line)
if i+1 < len(contents.Text) && contents.Text[i+1].Y > t.Y+5 {
if currentLine.Len() > 0 {
buf.WriteString(currentLine.String())
buf.WriteString("\n")
currentLine.Reset()
lastX = -1.0
}
}
}
// Add any remaining text
if currentLine.Len() > 0 {
buf.WriteString(currentLine.String())
}
return strings.TrimSpace(buf.String())
}
// FileList implements the directory listing functionality
type FileList struct {
workingDir string
}
func (f *FileList) SetWorkingDir(dir string) {
f.workingDir = dir
}
func (f *FileList) Name() string {
return "file_list"
}
func (f *FileList) Description() string {
return "List the contents of a directory"
}
func (f *FileList) Prompt() string {
return `Use the file_list tool to list the contents of a directory using the path parameter`
}
func (f *FileList) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the directory to list (default: current directory)",
"default": "."
},
"show_hidden": {
"type": "boolean",
"description": "Whether to show hidden files (starting with .)",
"default": false
},
"depth": {
"type": "integer",
"description": "How many directory levels deep to list (default: 1)",
"default": 1
}
},
"required": []
}`)
var schema map[string]any
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
return nil
}
return schema
}
func (f *FileList) Execute(ctx context.Context, args map[string]any) (any, error) {
path := "."
if p, ok := args["path"].(string); ok {
path = p
}
// If path is not absolute and working directory is set, make it relative to working directory
if !filepath.IsAbs(path) && f.workingDir != "" {
path = filepath.Join(f.workingDir, path)
}
// Security: Clean and validate the path
cleanPath := filepath.Clean(path)
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("path traversal not allowed")
}
// Get optional parameters
showHidden := false
if sh, ok := args["show_hidden"].(bool); ok {
showHidden = sh
}
maxDepth := 1
if md, ok := args["depth"].(float64); ok {
maxDepth = int(md)
}
// Check if directory exists
info, err := os.Stat(cleanPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("directory does not exist: %s", cleanPath)
}
return nil, fmt.Errorf("error accessing directory: %w", err)
}
if !info.IsDir() {
return nil, fmt.Errorf("path is not a directory: %s", cleanPath)
}
var files []FileInfo
files, err = f.listRecursive(cleanPath, showHidden, maxDepth, 0)
if err != nil {
return nil, err
}
return &FileListResult{
BasePath: cleanPath,
Files: files,
Count: len(files),
}, nil
}
func (f *FileList) listDirectory(path string, showHidden bool) ([]FileInfo, error) {
entries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("error reading directory: %w", err)
}
var files []FileInfo
for _, entry := range entries {
name := entry.Name()
// Skip hidden files if not requested
if !showHidden && strings.HasPrefix(name, ".") {
continue
}
fileInfo := FileInfo{
RelPath: name,
IsDir: entry.IsDir(),
}
files = append(files, fileInfo)
}
return files, nil
}
func (f *FileList) listRecursive(path string, showHidden bool, maxDepth, currentDepth int) ([]FileInfo, error) {
if currentDepth >= maxDepth {
return nil, nil
}
files, err := f.listDirectory(path, showHidden)
if err != nil {
return nil, err
}
var allFiles []FileInfo
for _, file := range files {
// For the first level, use the file name as is
// For deeper levels, join with parent directory
if currentDepth != 0 {
// Get the relative part of the path by removing the base path
rel, err := filepath.Rel(filepath.Dir(path), path)
if err == nil {
file.RelPath = filepath.Join(rel, file.RelPath)
}
}
allFiles = append(allFiles, file)
if file.IsDir {
subFiles, err := f.listRecursive(filepath.Join(path, file.RelPath), showHidden, maxDepth, currentDepth+1)
if err != nil {
continue // Skip directories we can't read
}
allFiles = append(allFiles, subFiles...)
}
}
return allFiles, nil
}
// FileWriter implements the file writing functionality
// TODO(parthsareen): max file size limit
type FileWriter struct {
workingDir string
}
func (f *FileWriter) SetWorkingDir(dir string) {
f.workingDir = dir
}
func (f *FileWriter) Name() string {
return "file_write"
}
func (f *FileWriter) Description() string {
return "Write content to a file on the file system"
}
func (f *FileWriter) Prompt() string {
return `Use the file_write tool to write content to a file using the path parameter`
}
func (f *FileWriter) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
},
"append": {
"type": "boolean",
"description": "Whether to append to the file instead of overwriting (default: false)",
"default": false
},
"create_dirs": {
"type": "boolean",
"description": "Whether to create parent directories if they don't exist (default: false)",
"default": false
},
"max_size": {
"type": "integer",
"description": "Maximum content size to write in bytes (default: 1MB)",
"default": 1024 * 1024
}
},
"required": ["path", "content"]
}`)
var schema map[string]any
if err := json.Unmarshal(schemaBytes, &schema); err != nil {
return nil
}
return schema
}
func (f *FileWriter) Execute(ctx context.Context, args map[string]any) (any, error) {
path, ok := args["path"].(string)
if !ok {
return nil, fmt.Errorf("path parameter is required and must be a string")
}
// If path is not absolute and working directory is set, make it relative to working directory
if !filepath.IsAbs(path) && f.workingDir != "" {
path = filepath.Join(f.workingDir, path)
}
// Extract required parameters
content, ok := args["content"].(string)
if !ok {
return nil, fmt.Errorf("content parameter is required and must be a string")
}
// Get optional parameters with defaults
append := true // Always append by default
if a, ok := args["append"].(bool); ok && !a {
return nil, fmt.Errorf("overwriting existing files is not allowed - must use append mode")
}
createDirs := false
if cd, ok := args["create_dirs"].(bool); ok {
createDirs = cd
}
maxSize := int64(1024 * 1024) // 1MB default
if ms, ok := args["max_size"].(float64); ok {
maxSize = int64(ms)
}
// Security: Clean and validate the path
cleanPath := filepath.Clean(path)
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("path traversal not allowed")
}
// Check content size
if int64(len(content)) > maxSize {
return nil, fmt.Errorf("content too large (%d bytes), maximum allowed: %d bytes", len(content), maxSize)
}
// Create parent directories if requested
if createDirs {
dir := filepath.Dir(cleanPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create parent directories: %w", err)
}
}
// Check if file exists - if it does, we must append
fileInfo, err := os.Stat(cleanPath)
if err == nil && fileInfo.Size() > 0 {
// File exists and has content
if !append {
return nil, fmt.Errorf("file %s already exists - cannot overwrite, must use append mode", cleanPath)
}
}
// Open file in append mode
flag := os.O_WRONLY | os.O_CREATE | os.O_APPEND
file, err := os.OpenFile(cleanPath, flag, 0644)
if err != nil {
return nil, fmt.Errorf("error opening file for writing: %w", err)
}
defer file.Close()
// Write content
n, err := file.WriteString(content)
if err != nil {
return nil, fmt.Errorf("error writing to file: %w", err)
}
// Get file info for response
info, err := file.Stat()
if err != nil {
// Return basic success info if we can't get file stats
return &FileWriteResult{
Path: cleanPath,
Written: n,
}, nil
}
return &FileWriteResult{
Path: cleanPath,
Size: info.Size(),
Written: n,
Mode: info.Mode().String(),
Modified: info.ModTime().Unix(),
}, nil
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package tools
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package tools
import (
@@ -68,24 +66,15 @@ func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, strin
return nil, "", fmt.Errorf("url must be a non-empty string")
}
result, err := performWebFetch(ctx, urlStr)
if err != nil {
return nil, "", err
}
return result, "", nil
}
func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, error) {
reqBody := FetchRequest{URL: targetURL}
reqBody := FetchRequest{URL: urlStr}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
return nil, "", fmt.Errorf("failed to marshal request body: %w", err)
}
crawlURL, err := url.Parse("https://ollama.com/api/web_fetch")
if err != nil {
return nil, fmt.Errorf("failed to parse fetch URL: %w", err)
return nil, "", fmt.Errorf("failed to parse fetch URL: %w", err)
}
query := crawlURL.Query()
@@ -95,12 +84,12 @@ func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, err
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, crawlURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
return nil, "", fmt.Errorf("failed to sign request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, crawlURL.String(), bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", crawlURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
@@ -111,18 +100,18 @@ func performWebFetch(ctx context.Context, targetURL string) (*FetchResponse, err
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute fetch request: %w", err)
return nil, "", fmt.Errorf("failed to execute fetch request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("fetch API error (status %d)", resp.StatusCode)
return nil, "", fmt.Errorf("fetch API error (status %d)", resp.StatusCode)
}
var result FetchResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
return &result, "", nil
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package tools
import (
@@ -69,6 +67,7 @@ func (g *WebSearch) Schema() map[string]any {
}
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
rawQuery, ok := args["query"]
if !ok {
return nil, "", fmt.Errorf("query parameter is required")
@@ -84,25 +83,15 @@ func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, stri
maxResults = int(v)
}
result, err := performWebSearch(ctx, queryStr, maxResults)
if err != nil {
return nil, "", err
}
return result, "", nil
}
func performWebSearch(ctx context.Context, query string, maxResults int) (*SearchResponse, error) {
reqBody := SearchRequest{Query: query, MaxResults: maxResults}
reqBody := SearchRequest{Query: queryStr, MaxResults: maxResults}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
return nil, "", fmt.Errorf("failed to marshal request body: %w", err)
}
searchURL, err := url.Parse("https://ollama.com/api/web_search")
if err != nil {
return nil, fmt.Errorf("failed to parse search URL: %w", err)
return nil, "", fmt.Errorf("failed to parse search URL: %w", err)
}
q := searchURL.Query()
@@ -112,15 +101,14 @@ func performWebSearch(ctx context.Context, query string, maxResults int) (*Searc
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, searchURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
return nil, "", fmt.Errorf("failed to sign request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, searchURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
@@ -128,18 +116,18 @@ func performWebSearch(ctx context.Context, query string, maxResults int) (*Searc
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute search request: %w", err)
return nil, "", fmt.Errorf("failed to execute search request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("search API error (status %d)", resp.StatusCode)
return nil, "", fmt.Errorf("search API error (status %d)", resp.StatusCode)
}
var result SearchResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
return &result, "", nil
}

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package not
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package not
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package not_test
import (

View File

@@ -1,5 +1,3 @@
//go:build windows || darwin
package ui
import (
@@ -39,6 +37,7 @@ func (s *Server) appHandler() http.Handler {
}
return
}
http.ServeContent(w, r, "index.html", time.Time{}, bytes.NewReader(data))
})
}

View File

@@ -0,0 +1,26 @@
import type { StorybookConfig } from "@storybook/react-vite";
const config: StorybookConfig = {
stories: ["../src/**/*.mdx", "../src/**/*.stories.@(js|jsx|mjs|ts|tsx)"],
addons: [
"@chromatic-com/storybook",
"@storybook/addon-docs",
"@storybook/addon-onboarding",
"@storybook/addon-a11y",
"@storybook/addon-vitest",
],
framework: {
name: "@storybook/react-vite",
options: {},
},
typescript: {
reactDocgen: "react-docgen-typescript",
reactDocgenTypescriptOptions: {
tsconfigPath: "../tsconfig.stories.json",
},
},
core: {
disableTelemetry: true,
},
};
export default config;

View File

@@ -0,0 +1,22 @@
import type { Preview } from "@storybook/react-vite";
import "../src/index.css";
const preview: Preview = {
parameters: {
controls: {
matchers: {
color: /(background|color)$/i,
date: /Date$/i,
},
},
a11y: {
// 'todo' - show a11y violations in the test UI only
// 'error' - fail CI on a11y violations
// 'off' - skip a11y checks entirely
test: "todo",
},
},
};
export default preview;

View File

@@ -0,0 +1,7 @@
import * as a11yAddonAnnotations from "@storybook/addon-a11y/preview";
import { setProjectAnnotations } from "@storybook/react-vite";
import * as projectAnnotations from "./preview";
// This is an important step to apply the right configuration when testing your stories.
// More info at: https://storybook.js.org/docs/api/portable-stories/portable-stories-vitest#setprojectannotations
setProjectAnnotations([a11yAddonAnnotations, projectAnnotations]);

View File

@@ -217,12 +217,14 @@ export class Model {
model: string;
digest?: string;
modified_at?: Time;
needs_download?: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.model = source["model"];
this.digest = source["digest"];
this.modified_at = this.convertValues(source["modified_at"], Time);
this.needs_download = source["needs_download"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
@@ -469,24 +471,26 @@ export class HealthResponse {
}
export class User {
id: string;
email: string;
name: string;
bio?: string;
avatarurl?: string;
firstname?: string;
lastname?: string;
plan?: string;
email: string;
avatarURL: string;
plan: string;
bio: string;
firstName: string;
lastName: string;
overThreshold: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"];
this.email = source["email"];
this.name = source["name"];
this.bio = source["bio"];
this.avatarurl = source["avatarurl"];
this.firstname = source["firstname"];
this.lastname = source["lastname"];
this.email = source["email"];
this.avatarURL = source["avatarURL"];
this.plan = source["plan"];
this.bio = source["bio"];
this.firstName = source["firstName"];
this.lastName = source["lastName"];
this.overThreshold = source["overThreshold"];
}
}
export class Attachment {

View File

@@ -11,31 +11,18 @@
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
<script>
// Add selectFiles method if available
if (typeof window.selectFiles === "function") {
window.webview = window.webview || {};
// Single file selection (returns first file or null)
window.webview.selectFile = function () {
return new Promise((resolve) => {
window.__selectFilesCallback = (data) => {
window.__selectFilesCallback = null;
// For single file, return first file or null
resolve(data && data.length > 0 ? data[0] : null);
};
window.selectFiles();
});
};
// Multiple file selection (returns array or null)
window.webview.selectMultipleFiles = function () {
return new Promise((resolve) => {
window.__selectFilesCallback = (data) => {
window.__selectFilesCallback = null;
resolve(data); // Returns array of files or null if cancelled
};
window.selectFiles();
});
// Initialize webview API object if individual functions are available
if (typeof window.selectFile === "function") {
window.webview = {
selectFile: function () {
return new Promise((resolve) => {
window.__selectFileCallback = (data) => {
window.__selectFileCallback = null;
resolve(data); // Returns file data or null if cancelled
};
window.selectFile();
});
},
};
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,6 @@
"framer-motion": "^12.17.0",
"katex": "^0.16.22",
"micromark-extension-llm-math": "^3.1.0",
"ollama": "^0.6.0",
"react": "^19.1.0",
"react-dom": "^19.1.0",
"rehype-katex": "^7.0.1",
@@ -34,7 +33,6 @@
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-math": "^6.0.0",
"streamdown": "^1.4.0",
"unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0"
},

View File

@@ -4,6 +4,7 @@ import {
ChatEvent,
DownloadEvent,
ErrorEvent,
ModelsResponse,
InferenceCompute,
InferenceComputeResponse,
ModelCapabilitiesResponse,
@@ -13,9 +14,6 @@ import {
User,
} from "@/gotypes";
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,6 +25,9 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -41,50 +42,44 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
}
export async function fetchUser(): Promise<User | null> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
try {
const response = await fetch(`${API_BASE}/api/v1/me`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const userData: User = await response.json();
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
if (response.ok) {
const userData: User = await response.json();
return userData;
}
return userData;
}
if (response.status === 401 || response.status === 403) {
return null;
} catch (error) {
console.error("Error fetching user:", error);
return null;
}
throw new Error(`Failed to fetch user: ${response.status}`);
}
export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
const response = await fetch(`${API_BASE}/api/v1/connect`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.status === 401) {
const data = await response.json();
if (data.signin_url) {
return data.signin_url;
}
if (!response.ok) {
throw new Error("Failed to fetch connect URL");
}
throw new Error("Failed to fetch connect URL");
const data = await response.json();
return data.connect_url;
}
export async function disconnectUser(): Promise<void> {
const response = await fetch(`${API_BASE}/api/signout`, {
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
method: "POST",
headers: {
"Content-Type": "application/json",
@@ -109,86 +104,38 @@ export async function getChat(chatId: string): Promise<ChatResponse> {
}
export async function getModels(query?: string): Promise<Model[]> {
try {
const { models: modelsResponse } = await ollama.list();
let models: Model[] = modelsResponse
.filter((m: ModelResponse) => {
const families = m.details?.families;
if (!families || families.length === 0) {
return true;
}
const isBertOnly = families.every((family: string) =>
family.toLowerCase().includes("bert"),
);
return !isBertOnly;
})
.map((m: ModelResponse) => {
// Remove the latest tag from the returned model
const modelName = m.name.replace(/:latest$/, "");
return new Model({
model: modelName,
digest: m.digest,
modified_at: m.modified_at ? new Date(m.modified_at) : undefined,
});
});
// Filter by query if provided
if (query) {
const normalizedQuery = query.toLowerCase().trim();
const filteredModels = models.filter((m: Model) => {
return m.model.toLowerCase().startsWith(normalizedQuery);
});
let exactMatch = false;
for (const m of filteredModels) {
if (m.model.toLowerCase() === normalizedQuery) {
exactMatch = true;
break;
}
}
// Add query if it's in the registry and not already in the list
if (!exactMatch) {
const result = await getModelUpstreamInfo(new Model({ model: query }));
const existsUpstream = !!result.digest && !result.error;
if (existsUpstream) {
filteredModels.push(new Model({ model: query }));
}
}
models = filteredModels;
}
return models;
} catch (err) {
throw new Error(`Failed to fetch models: ${err}`);
const params = new URLSearchParams();
if (query) {
params.append("q", query);
}
const response = await fetch(
`${API_BASE}/api/v1/models?${params.toString()}`,
);
if (!response.ok) {
throw new Error(`Failed to fetch models: ${response.statusText}`);
}
const data = await response.json();
const modelsResponse = new ModelsResponse(data);
return modelsResponse.models || [];
}
export async function getModelCapabilities(
modelName: string,
): Promise<ModelCapabilitiesResponse> {
try {
const showResponse = await ollama.show({ model: modelName });
return new ModelCapabilitiesResponse({
capabilities: Array.isArray(showResponse.capabilities)
? showResponse.capabilities
: [],
});
} catch (error) {
// Model might not be downloaded yet, return empty capabilities
console.error(`Failed to get capabilities for ${modelName}:`, error);
return new ModelCapabilitiesResponse({ capabilities: [] });
const response = await fetch(
`${API_BASE}/api/v1/model/${encodeURIComponent(modelName)}/capabilities`,
);
if (!response.ok) {
throw new Error(
`Failed to fetch model capabilities: ${response.statusText}`,
);
}
}
const data = await response.json();
return new ModelCapabilitiesResponse(data);
}
export type ChatEventUnion = ChatEvent | DownloadEvent | ErrorEvent;
export async function* sendMessage(
@@ -209,11 +156,6 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data),
}));
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
const shouldSendThink =
think !== undefined &&
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST",
headers: {
@@ -231,7 +173,7 @@ export async function* sendMessage(
web_search: webSearch ?? false,
file_tools: fileTools ?? false,
...(forceUpdate !== undefined ? { forceUpdate } : {}),
...(shouldSendThink ? { think } : {}),
...(think !== undefined ? { think } : {}),
}),
),
signal,
@@ -394,8 +336,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function fetchHealth(): Promise<boolean> {
try {
// Use the /api/version endpoint as a health check
const response = await fetch(`${API_BASE}/api/version`, {
const response = await fetch(`${API_BASE}/api/v1/health`, {
method: "GET",
headers: {
"Content-Type": "application/json",
@@ -404,8 +345,7 @@ export async function fetchHealth(): Promise<boolean> {
if (response.ok) {
const data = await response.json();
// If we get a version back, the server is healthy
return !!data.version;
return data.healthy || false;
}
return false;

View File

@@ -17,16 +17,11 @@ import {
} from "@/hooks/useChats";
import { useNavigate } from "@tanstack/react-router";
import { useSelectedModel } from "@/hooks/useSelectedModel";
import { useHasVisionCapability } from "@/hooks/useModelCapabilities";
import { useUser } from "@/hooks/useUser";
import { DisplayLogin } from "@/components/DisplayLogin";
import { ErrorEvent, Message } from "@/gotypes";
import { useSettings } from "@/hooks/useSettings";
import { ThinkButton } from "./ThinkButton";
import { ErrorMessage } from "./ErrorMessage";
import { processFiles } from "@/utils/fileValidation";
import type { ImageData } from "@/types/webview";
import { PlusIcon } from "@heroicons/react/24/outline";
export type ThinkingLevel = "low" | "medium" | "high";
@@ -109,14 +104,10 @@ function ChatForm({
const cancelMessage = useCancelMessage();
const isDownloading = isDownloadingModel;
const { selectedModel } = useSelectedModel();
const hasVisionCapability = useHasVisionCapability(selectedModel?.model);
const { isAuthenticated, isLoading: isLoadingUser } = useUser();
const [loginPromptFeature, setLoginPromptFeature] = useState<
"webSearch" | "turbo" | null
>(null);
const [fileUploadError, setFileUploadError] = useState<ErrorEvent | null>(
null,
);
const handleThinkingLevelDropdownToggle = (isOpen: boolean) => {
if (
@@ -168,18 +159,6 @@ function ChatForm({
const supportsThinkToggling =
selectedModel?.model.toLowerCase().startsWith("deepseek-v3.1") || false;
useEffect(() => {
if (supportsThinkToggling && thinkEnabled && webSearchEnabled) {
setSettings({ WebSearchEnabled: false });
}
}, [
selectedModel?.model,
supportsThinkToggling,
thinkEnabled,
webSearchEnabled,
setSettings,
]);
const removeFile = (index: number) => {
setMessage((prev) => ({
...prev,
@@ -200,9 +179,8 @@ function ChatForm({
files: Array<{ filename: string; data: Uint8Array; type?: string }>,
errors: Array<{ filename: string; error: string }> = [],
) => {
// Add valid files to form state
if (files.length > 0) {
setFileUploadError(null);
const newAttachments = files.map((file) => ({
id: crypto.randomUUID(),
filename: file.filename,
@@ -479,11 +457,15 @@ function ChatForm({
);
const useWebSearch = supportsWebSearch && webSearchEnabled && !airplaneMode;
const useThink = modelSupportsThinkingLevels
? thinkLevel
: supportsThinkToggling
? thinkEnabled
: undefined;
const useThink = (() => {
if (modelSupportsThinkingLevels) {
return thinkLevel;
} else if (supportsThinkToggling && thinkEnabled) {
return true;
}
return undefined;
})();
if (onSubmit) {
onSubmit(message.content, {
@@ -621,62 +603,6 @@ function ChatForm({
e.target.style.height = Math.min(e.target.scrollHeight, 24 * 8) + "px";
};
const handleFilesUpload = async () => {
try {
setFileUploadError(null);
const results = await window.webview?.selectMultipleFiles();
if (results && results.length > 0) {
// Convert native dialog results to File objects
const files = results
.map((result: ImageData) => {
if (result.dataURL) {
// Convert dataURL back to File object
const base64Data = result.dataURL.split(",")[1];
const mimeType = result.dataURL.split(";")[0].split(":")[1];
const binaryString = atob(base64Data);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
const blob = new Blob([bytes], { type: mimeType });
const file = new File([blob], result.filename, {
type: mimeType,
});
return file;
}
return null;
})
.filter(Boolean) as File[];
if (files.length > 0) {
const { validFiles, errors } = await processFiles(files, {
selectedModel,
hasVisionCapability,
});
// Send processed files and errors to the same handler as FileUpload
if (validFiles.length > 0 || errors.length > 0) {
handleFilesReceived(validFiles, errors);
}
}
}
} catch (error) {
console.error("Error selecting multiple files:", error);
const errorEvent = new ErrorEvent({
eventName: "error" as const,
error:
error instanceof Error ? error.message : "Failed to select files",
code: "file_selection_error",
details:
"An error occurred while trying to open the file selection dialog. Please try again.",
});
setFileUploadError(errorEvent);
}
};
return (
<div className={`pb-3 px-3 ${hasMessages ? "mt-auto" : "my-auto"}`}>
{chatId === "new" && <Logo />}
@@ -707,8 +633,6 @@ function ChatForm({
/>
)}
{/* File upload error message */}
{fileUploadError && <ErrorMessage error={fileUploadError} />}
<div
className={`relative mx-auto flex bg-neutral-100 w-full max-w-[768px] flex-col items-center rounded-3xl pb-2 pt-4 dark:bg-neutral-800 dark:border-neutral-700 min-h-[88px] transition-opacity duration-200 ${isDisabled ? "opacity-50" : "opacity-100"}`}
>
@@ -847,78 +771,67 @@ function ChatForm({
{/* Controls */}
<div className="flex w-full items-center justify-end gap-2 px-3 pt-2">
{/* Tool buttons - animate from underneath model picker */}
{!isDisabled && (
<div className="flex-1 flex justify-end items-center gap-2">
<div className={`flex gap-2`}>
{/* File Upload Buttons */}
<button
type="button"
onClick={handleFilesUpload}
className="flex h-9 w-9 items-center justify-center rounded-full bg-white dark:bg-neutral-700 focus:outline-none focus:ring-2 focus:ring-blue-500 cursor-pointer border border-transparent"
title="Upload multiple files"
>
<PlusIcon className="w-4.5 h-4.5 stroke-2 text-neutral-500 dark:text-neutral-400" />
</button>
{/* Thinking Level Button */}
{modelSupportsThinkingLevels && (
<>
<ThinkButton
mode="thinkingLevel"
ref={thinkingLevelButtonRef}
isVisible={modelSupportsThinkingLevels}
currentLevel={thinkLevel}
onLevelChange={setThinkingLevel}
onDropdownToggle={handleThinkingLevelDropdownToggle}
/>
</>
)}
{/* Think Button turn on and off */}
{supportsThinkToggling && !modelSupportsThinkingLevels && (
<>
<ThinkButton
mode="think"
ref={thinkButtonRef}
isVisible={
supportsThinkToggling && !modelSupportsThinkingLevels
<div className="flex-1 flex justify-end items-center gap-2">
<div className={`flex gap-2`}>
{/* Thinking Level Button */}
{modelSupportsThinkingLevels && (
<>
<ThinkButton
mode="thinkingLevel"
ref={thinkingLevelButtonRef}
isVisible={modelSupportsThinkingLevels}
currentLevel={thinkLevel}
onLevelChange={setThinkingLevel}
onDropdownToggle={handleThinkingLevelDropdownToggle}
/>
</>
)}
{/* Think Button turn on and off */}
{supportsThinkToggling && !modelSupportsThinkingLevels && (
<>
<ThinkButton
mode="think"
ref={thinkButtonRef}
isVisible={
supportsThinkToggling && !modelSupportsThinkingLevels
}
isActive={thinkEnabled}
onToggle={() => {
// DeepSeek-v3 specific - thinking and web search are mutually exclusive
if (supportsThinkToggling) {
const enable = !thinkEnabled;
setSettings({
ThinkEnabled: enable,
...(enable ? { WebSearchEnabled: false } : {}),
});
return;
}
isActive={thinkEnabled}
onToggle={() => {
// DeepSeek-v3 specific - thinking and web search are mutually exclusive
if (supportsThinkToggling) {
const enable = !thinkEnabled;
setSettings({
ThinkEnabled: enable,
...(enable ? { WebSearchEnabled: false } : {}),
});
return;
}
setSettings({ ThinkEnabled: !thinkEnabled });
}}
/>
</>
)}
<WebSearchButton
ref={webSearchButtonRef}
isVisible={supportsWebSearch && airplaneMode === false}
isActive={webSearchEnabled}
onToggle={() => {
if (!webSearchEnabled && !isAuthenticated) {
setLoginPromptFeature("webSearch");
}
const enable = !webSearchEnabled;
if (supportsThinkToggling && enable) {
setSettings({
WebSearchEnabled: true,
ThinkEnabled: false,
});
return;
}
setSettings({ WebSearchEnabled: enable });
}}
/>
</div>
setSettings({ ThinkEnabled: !thinkEnabled });
}}
/>
</>
)}
<WebSearchButton
ref={webSearchButtonRef}
isVisible={supportsWebSearch && airplaneMode === false}
isActive={webSearchEnabled}
onToggle={() => {
if (!webSearchEnabled && !isAuthenticated) {
setLoginPromptFeature("webSearch");
}
const enable = !webSearchEnabled;
if (supportsThinkToggling && enable) {
setSettings({
WebSearchEnabled: true,
ThinkEnabled: false,
});
return;
}
setSettings({ WebSearchEnabled: enable });
}}
/>
</div>
)}
</div>
{/* Model picker and submit button */}
<div className="flex items-center gap-2 relative z-20">

View File

@@ -0,0 +1,42 @@
import { type JSX } from "react";
interface FileToolsButtonProps {
enabled: boolean;
active: boolean;
onToggle: (active: boolean) => void;
}
export default function FileToolsButton({
enabled,
active,
onToggle,
}: FileToolsButtonProps): JSX.Element | null {
if (!enabled) return null;
return (
<button
type="button"
onClick={() => onToggle(!active)}
title="Toggle File Tools"
className={`flex h-9 w-9 items-center justify-center rounded-full bg-white dark:bg-neutral-700 focus:outline-none transition-all cursor-pointer border border-transparent ${
active
? "text-[rgba(0,115,255,1)]"
: "text-neutral-800 dark:text-neutral-100"
}`}
>
<svg
className="h-4 w-4"
fill="none"
stroke="currentColor"
strokeWidth="2"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"
/>
</svg>
</button>
);
}

View File

@@ -7,7 +7,48 @@ import {
} from "react";
import { DocumentPlusIcon } from "@heroicons/react/24/outline";
import type { Model } from "@/gotypes";
import { processFiles as processFilesUtil } from "@/utils/fileValidation";
const TEXT_FILE_EXTENSIONS = [
"pdf",
"docx",
"txt",
"md",
"csv",
"json",
"xml",
"html",
"htm",
"js",
"jsx",
"ts",
"tsx",
"py",
"java",
"cpp",
"c",
"cc",
"h",
"cs",
"php",
"rb",
"go",
"rs",
"swift",
"kt",
"scala",
"sh",
"bat",
"yaml",
"yml",
"toml",
"ini",
"cfg",
"conf",
"log",
"rtf",
];
const IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"];
interface FileUploadProps {
children: ReactNode;
@@ -36,11 +77,30 @@ export function FileUpload({
// Prevents flickering when dragging over child elements within the component
const dragCounter = useRef(0);
const MAX_FILE_SIZE = maxFileSize * 1024 * 1024; // Convert MB to bytes
const ALLOWED_EXTENSIONS = allowedExtensions || [
...TEXT_FILE_EXTENSIONS,
...IMAGE_EXTENSIONS,
];
// Helper function to check if dragging files
const hasFiles = useCallback((dataTransfer: DataTransfer) => {
return dataTransfer.types.includes("Files");
}, []);
// Helper function to read file as Uint8Array
const readFileAsBytes = useCallback((file: File): Promise<Uint8Array> => {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => {
const arrayBuffer = reader.result as ArrayBuffer;
resolve(new Uint8Array(arrayBuffer));
};
reader.onerror = () => reject(reader.error);
reader.readAsArrayBuffer(file);
});
}, []);
// Helper function to read directory contents
const readDirectory = useCallback(
async (entry: FileSystemDirectoryEntry): Promise<File[]> => {
@@ -84,6 +144,12 @@ export function FileUpload({
// Main file processing function
const processFiles = useCallback(
async (dataTransfer: DataTransfer) => {
const attachments: Array<{
filename: string;
data: Uint8Array;
type?: string;
}> = [];
const errors: Array<{ filename: string; error: string }> = [];
const allFiles: File[] = [];
// Extract files from DataTransfer
@@ -105,26 +171,83 @@ export function FileUpload({
allFiles.push(...Array.from(dataTransfer.files));
}
// Use shared validation utility
const { validFiles, errors } = await processFilesUtil(allFiles, {
maxFileSize,
allowedExtensions,
hasVisionCapability,
selectedModel,
customValidator: validateFile,
});
// First pass: Check file sizes and types
const validFiles: File[] = [];
for (const file of allFiles) {
const fileExtension = file.name.toLowerCase().split(".").pop();
// Custom validation first
if (validateFile) {
const validation = validateFile(file);
if (!validation.valid) {
errors.push({
filename: file.name,
error: validation.error || "File validation failed",
});
continue;
}
}
// Default validation
if (!fileExtension) {
errors.push({
filename: file.name,
error: "File type not supported",
});
} else if (
IMAGE_EXTENSIONS.includes(fileExtension) &&
!hasVisionCapability
) {
errors.push({
filename: file.name,
error: "This model does not support images",
});
} else if (!ALLOWED_EXTENSIONS.includes(fileExtension)) {
errors.push({
filename: file.name,
error: "File type not supported",
});
} else if (file.size > MAX_FILE_SIZE) {
errors.push({
filename: file.name,
error: "File too large",
});
} else {
validFiles.push(file);
}
}
// Second pass: Process only valid files
for (const file of validFiles) {
try {
const fileBytes = await readFileAsBytes(file);
attachments.push({
filename: file.name,
data: fileBytes,
type: file.type || undefined,
});
} catch (error) {
console.error(`Error reading file ${file.name}:`, error);
errors.push({
filename: file.name,
error: "Error reading file",
});
}
}
// Send processed files and errors back to parent
if (validFiles.length > 0 || errors.length > 0) {
onFilesAdded(validFiles, errors);
if (attachments.length > 0 || errors.length > 0) {
onFilesAdded(attachments, errors);
}
},
[
readFileAsBytes,
readDirectory,
selectedModel,
hasVisionCapability,
allowedExtensions,
maxFileSize,
ALLOWED_EXTENSIONS,
MAX_FILE_SIZE,
validateFile,
onFilesAdded,
],

View File

@@ -613,7 +613,7 @@ function ToolCallDisplay({
return (
<div className="text-neutral-600 dark:text-neutral-400 relative select-text">
<svg
className="h-4 w-4 absolute top-1.5"
className="h-4 w-4 absolute top-1 left-5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"

View File

@@ -44,13 +44,7 @@ export const ModelPicker = forwardRef<
}>(null);
const checkModelStaleness = async (model: Model) => {
if (
!model ||
!model.model ||
model.digest === undefined ||
model.digest === ""
)
return;
if (!model || !model.model || model.needs_download) return;
// Check cache - only check staleness every 5 minutes per model
const now = Date.now();
@@ -323,7 +317,9 @@ export const ModelList = forwardRef(function ModelList(
) : (
models.map((model, index) => {
return (
<div key={`${model.model}-${model.digest || "no-digest"}-${index}`}>
<div
key={`${model.model}-${model.digest || "no-digest"}-${model.needs_download ? "download" : "local"}-${index}`}
>
<button
onClick={() => onModelSelect(model)}
onMouseEnter={() => setHighlightedIndex(index)}
@@ -347,7 +343,7 @@ export const ModelList = forwardRef(function ModelList(
<path d="M4.01511 14.5861H14.2304C16.9183 14.5861 19.0002 12.5509 19.0002 9.9403C19.0002 7.30491 16.8911 5.3046 14.0203 5.3046C12.9691 3.23016 11.0602 2 8.69505 2C5.62816 2 3.04822 4.32758 2.72935 7.47455C1.12954 7.95356 0.0766602 9.29431 0.0766602 10.9757C0.0766602 12.9913 1.55776 14.5861 4.01511 14.5861ZM4.02056 13.1261C2.46452 13.1261 1.53673 12.2938 1.53673 11.0161C1.53673 9.91553 2.24207 9.12934 3.51367 8.79302C3.95684 8.68258 4.11901 8.48427 4.16138 8.00729C4.39317 5.3613 6.29581 3.46007 8.69505 3.46007C10.5231 3.46007 11.955 4.48273 12.8385 6.26013C13.0338 6.65439 13.2626 6.7882 13.7488 6.7882C16.1671 6.7882 17.5337 8.19719 17.5337 9.97707C17.5337 11.7526 16.1242 13.1261 14.2852 13.1261H4.02056Z" />
</svg>
)}
{model.digest === undefined &&
{(model.needs_download || model.digest === undefined) &&
(airplaneMode || !model.isCloud()) && (
<ArrowDownTrayIcon
className="h-4 w-4 text-neutral-500 dark:text-neutral-400"

View File

@@ -299,9 +299,9 @@ export default function Settings() {
</Button>
</div>
</div>
{user?.avatarurl && (
{user?.avatarURL && (
<img
src={user.avatarurl}
src={user.avatarURL}
alt={user?.name}
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
onError={(e) => {

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,522 @@
import { expect, test, suite } from "vitest";
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
suite("common llm outputs that cause issues", () => {
test("prefix of bolded list item shouldn't make a horizontal line", () => {
// we're going to go in order of incrementally adding characters. This
// happens really commonly with LLMs that like to make lists like so:
//
// * **point 1**: explanatory text
// * **point 2**: more explanatory text
//
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
// (C) is a total mess. (A) renders as a single bullet point in an
// otherwise empty list, (B) renders as two nested lists (and therefore
// two bullet points, styled differently by default in html), and (C)
// renders as a horizontal line because in markdown apparently `***` or `*
// * *` horizontal rules don't have as strict whitespace rules as I
// expected them to
// these are alone (i.e., they would be the first list item)
expect(processStreamingMarkdown("*")).toBe("");
expect(processStreamingMarkdown("* *")).toBe("");
expect(processStreamingMarkdown("* **")).toBe("");
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
// with a list item before them
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"*"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* *"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* **"
].join("\n"),
),
).toBe("* abc");
});
test("bolded list items with text should be rendered properly", () => {
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
});
test("partially bolded list items should be autoclosed", () => {
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
});
suite(
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
() => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`*"),
).toBe("* **Asynchronous Function `async`**");
});
},
);
});
suite("autoclosing bold", () => {
suite("endings with no asterisks", () => {
test("should autoclose bold", () => {
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
});
suite("should autoclose, even if the last node isn't a text node", () => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`"),
).toBe("* **Asynchronous Function `async`**");
});
test("opening ** is at the end of the text", () => {
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
"abc **`def` jhk [lmn](opq)**",
);
});
test("if there's a space after the **, it should NOT be autoclosed", () => {
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
"abc \\*\\* `def` jhk [lmn](opq)",
);
});
});
test("should autoclose bold, even if the last node isn't a text node", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
).toBe("* **Asynchronous Function ( `async`**");
});
test("whitespace fakeouts should not be modified", () => {
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
});
// TODO(drifkin): arguably this should just be removed entirely, but empty
// isn't so bad
test("should handle empty bolded items", () => {
expect(processStreamingMarkdown("**")).toBe("");
});
});
suite("partially closed bolded items", () => {
test("simple partial", () => {
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
});
test("partial with non-text node at end", () => {
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
});
test("partial with multiply nested ending nodes", () => {
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
"**abc[abc](`def`)**",
);
});
test("normal emphasis should not be affected", () => {
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
});
test("normal emphasis with nested code should not be affected", () => {
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
});
});
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
});
// skipping for now because this requires partial link completion as well
suite.skip("nested blocks that each need autoclosing", () => {
test("emph nested in link nested in strong nested in list item", () => {
expect(processStreamingMarkdown("* **[abc **def")).toBe(
"* **[abc **def**]()**",
);
});
test("* **[ab *`def`", () => {
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
"* **[ab *`def`*]()**",
);
});
});
});
suite("numbered list items", () => {
test("should remove trailing numbers", () => {
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
});
test("should remove trailing numbers with breaks before", () => {
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
});
test("should remove trailing numbers that form a new paragraph", () => {
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
});
test("but should leave list items separated by two newlines", () => {
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
"1. First\n\n2. S",
);
});
});
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
// through them as I implement
/*
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
describe("Ambiguous endings removal", () => {
it("should remove list markers at the end", () => {
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
});
it("should remove heading markers at the end", () => {
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
});
it("should remove ambiguous bold markers at the end", () => {
expect(processStreamingMarkdown("Text **")).toBe("Text ");
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
});
it("should remove code block markers at the end", () => {
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
expect(processStreamingMarkdown("```")).toBe("");
});
it("should remove single backtick at the end", () => {
expect(processStreamingMarkdown("Text `")).toBe("Text ");
expect(processStreamingMarkdown("`")).toBe("");
});
it("should remove single asterisk at the end", () => {
expect(processStreamingMarkdown("Text *")).toBe("Text ");
expect(processStreamingMarkdown("*")).toBe("");
});
it("should handle empty content", () => {
expect(processStreamingMarkdown("")).toBe("");
});
it("should handle single line removals correctly", () => {
expect(processStreamingMarkdown("* ")).toBe("");
expect(processStreamingMarkdown("# ")).toBe("");
expect(processStreamingMarkdown("**")).toBe("");
expect(processStreamingMarkdown("`")).toBe("");
});
it("shouldn't have this regexp capture group bug", () => {
expect(
processStreamingMarkdown("Here's a shopping list:\n*"),
).not.toContain("0*");
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
"Here's a shopping list:",
);
});
});
describe("List markers", () => {
it("should preserve complete list items", () => {
expect(processStreamingMarkdown("* Complete item")).toBe(
"* Complete item",
);
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
expect(processStreamingMarkdown("1. Numbered item")).toBe(
"1. Numbered item",
);
});
it("should handle indented list markers", () => {
expect(processStreamingMarkdown(" * ")).toBe(" ");
expect(processStreamingMarkdown(" - ")).toBe(" ");
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
});
});
describe("Heading markers", () => {
it("should preserve complete headings", () => {
expect(processStreamingMarkdown("# Complete Heading")).toBe(
"# Complete Heading",
);
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
});
it("should not affect # in other contexts", () => {
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
});
});
describe("Bold text", () => {
it("should close incomplete bold text", () => {
expect(processStreamingMarkdown("This is **bold text")).toBe(
"This is **bold text**",
);
expect(processStreamingMarkdown("Start **bold and more")).toBe(
"Start **bold and more**",
);
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
});
it("should not affect complete bold text", () => {
expect(processStreamingMarkdown("**complete bold**")).toBe(
"**complete bold**",
);
expect(processStreamingMarkdown("Text **bold** more")).toBe(
"Text **bold** more",
);
});
it("should handle nested bold correctly", () => {
expect(processStreamingMarkdown("**bold** and **another")).toBe(
"**bold** and **another**",
);
});
});
describe("Italic text", () => {
it("should close incomplete italic text", () => {
expect(processStreamingMarkdown("This is *italic text")).toBe(
"This is *italic text*",
);
expect(processStreamingMarkdown("Start *italic and more")).toBe(
"Start *italic and more*",
);
});
it("should differentiate between list markers and italic", () => {
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
expect(processStreamingMarkdown("Some *italic text")).toBe(
"Some *italic text*",
);
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
});
it("should not affect complete italic text", () => {
expect(processStreamingMarkdown("*complete italic*")).toBe(
"*complete italic*",
);
expect(processStreamingMarkdown("Text *italic* more")).toBe(
"Text *italic* more",
);
});
});
describe("Code blocks", () => {
it("should close incomplete code blocks", () => {
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
"```javascript\nconst x = 42;\n```",
);
expect(processStreamingMarkdown("```\ncode here")).toBe(
"```\ncode here\n```",
);
});
it("should not affect complete code blocks", () => {
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
"```js\nconst x = 1;\n```",
);
});
it("should handle nested code blocks correctly", () => {
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
"```\ncode\n```\n```python\n```",
);
});
it("should not process markdown inside code blocks", () => {
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
"```\n* not a list\n**not bold**\n```",
);
});
});
describe("Inline code", () => {
it("should close incomplete inline code", () => {
expect(processStreamingMarkdown("This is `inline code")).toBe(
"This is `inline code`",
);
expect(processStreamingMarkdown("Use `console.log")).toBe(
"Use `console.log`",
);
});
it("should not affect complete inline code", () => {
expect(processStreamingMarkdown("`complete code`")).toBe(
"`complete code`",
);
expect(processStreamingMarkdown("Use `code` here")).toBe(
"Use `code` here",
);
});
it("should handle multiple inline codes correctly", () => {
expect(processStreamingMarkdown("`code` and `more")).toBe(
"`code` and `more`",
);
});
it("should not confuse inline code with code blocks", () => {
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
"```\nblock\n```\n`inline`",
);
});
});
describe("Complex streaming scenarios", () => {
it("should handle progressive streaming of a heading", () => {
const steps = [
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
{ input: "# ", expected: "" },
{ input: "# H", expected: "# H" },
{ input: "# Hello", expected: "# Hello" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle progressive streaming of bold text", () => {
const steps = [
{ input: "*", expected: "" },
{ input: "**", expected: "" },
{ input: "**b", expected: "**b**" },
{ input: "**bold", expected: "**bold**" },
{ input: "**bold**", expected: "**bold**" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle multiline content with various patterns", () => {
const multiline = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2
* `;
const expected = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2`;
expect(processStreamingMarkdown(multiline)).toBe(expected);
});
it("should only fix the last line", () => {
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
"# Complete\n# Another",
);
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
"* Item 1\n* Item 2",
);
});
it("should handle mixed content correctly", () => {
const input = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold`;
const expected = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold**`;
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
describe("Edge cases with escaping", () => {
it("should handle escaped asterisks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
"Text \\*not italic*",
);
});
it("should handle escaped backticks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\`not code")).toBe(
"Text \\`not code`",
);
});
});
describe("Code block edge cases", () => {
it("should handle triple backticks in the middle of lines", () => {
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
"Text ``` in middle\n```",
);
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
"```\nText ``` in code\nmore\n```",
);
});
it("should properly close code blocks with language specifiers", () => {
expect(processStreamingMarkdown("```typescript")).toBe(
"```typescript\n```",
);
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
"```typescript\nconst x = 1\n```",
);
});
it("should remove a completely empty partial code block", () => {
expect(processStreamingMarkdown("```\n")).toBe("");
});
});
});
*/

View File

@@ -1,123 +1,66 @@
import React from "react";
import { Streamdown, defaultRemarkPlugins } from "streamdown";
import Markdown from "react-markdown";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeRaw from "rehype-raw";
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
import rehypePrismPlus from "rehype-prism-plus";
import rehypeKatex from "rehype-katex";
import remarkStreamingMarkdown, {
type LastNodeInfo,
} from "@/utils/remarkStreamingMarkdown";
import type { PluggableList } from "unified";
import remarkCitationParser from "@/utils/remarkCitationParser";
import CopyButton from "./CopyButton";
import type { BundledLanguage } from "shiki";
import { highlighter } from "@/lib/highlighter";
interface StreamingMarkdownContentProps {
content: string;
isStreaming?: boolean;
size?: "sm" | "md" | "lg";
onLastNode?: (info: LastNodeInfo) => void;
browserToolResult?: any; // TODO: proper type
}
// Helper to extract text from React nodes
const extractText = (node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
if (React.isValidElement(node)) {
const props = node.props as any;
if (props?.children) {
return extractText(props.children as React.ReactNode);
}
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
};
const CodeBlock = React.memo(
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
// Extract code and language from children
const codeElement = children as React.ReactElement<{
className?: string;
children: React.ReactNode;
}>;
const language =
codeElement.props.className?.replace(/language-/, "") || "";
const codeText = extractText(codeElement.props.children);
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
const extractText = React.useCallback((node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
// Synchronously highlight code using the pre-loaded highlighter
const tokens = React.useMemo(() => {
if (!highlighter) return null;
try {
return {
light: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-light" as any,
}),
dark: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-dark" as any,
}),
};
} catch (error) {
console.error("Failed to highlight code:", error);
return null;
if (React.isValidElement(node)) {
if (
node.props &&
typeof node.props === "object" &&
"children" in node.props
) {
return extractText(node.props.children as React.ReactNode);
}
}
}, [codeText, language]);
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
}, []);
const language = className?.replace(/language-/, "") || "";
return (
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
<div className="flex select-none">
{language && (
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
)}
<div className="flex justify-between select-none">
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
<CopyButton
content={codeText}
content={extractText(children)}
showLabels={true}
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
/>
</div>
{/* Light mode */}
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.light
? tokens.light.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.light.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
{/* Dark mode */}
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.dark
? tokens.dark.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.dark.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
<pre className={className} {...props}>
{children}
</pre>
</div>
);
@@ -125,19 +68,65 @@ const CodeBlock = React.memo(
);
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
// Build the remark plugins array - keep default GFM and Math, add citations
const remarkPlugins = React.useMemo(() => {
return [
defaultRemarkPlugins.gfm,
defaultRemarkPlugins.math,
remarkCitationParser,
];
}, []);
React.memo(
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
// Build the remark plugins array
const remarkPlugins = React.useMemo(() => {
const plugins: PluggableList = [
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
remarkCitationParser,
];
return (
<div
className={`
// Add streaming plugin when in streaming mode
if (isStreaming) {
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
}
return plugins;
}, [isStreaming, onLastNode]);
// Create a custom sanitization schema that allows math elements
const sanitizeSchema = React.useMemo(() => {
return {
...defaultSchema,
attributes: {
...defaultSchema.attributes,
span: [
...(defaultSchema.attributes?.span || []),
["className", /^katex/],
],
div: [
...(defaultSchema.attributes?.div || []),
["className", /^katex/],
],
"ol-citation": ["cursor", "start", "end"],
},
tagNames: [
...(defaultSchema.tagNames || []),
"math",
"mrow",
"mi",
"mo",
"mn",
"msup",
"msub",
"mfrac",
"mover",
"munder",
"msqrt",
"mroot",
"merror",
"mspace",
"mpadded",
"ol-citation",
],
};
}, []);
return (
<div
className={`
max-w-full
${size === "sm" ? "prose-sm" : size === "lg" ? "prose-lg" : ""}
prose
@@ -155,27 +144,7 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
prose-pre:my-0
prose-pre:max-w-full
prose-pre:pt-1
[&_table]:border-collapse
[&_table]:w-full
[&_table]:border
[&_table]:border-neutral-200
[&_table]:rounded-lg
[&_table]:overflow-hidden
[&_th]:px-3
[&_th]:py-2
[&_th]:text-left
[&_th]:font-semibold
[&_th]:border-b
[&_th]:border-r
[&_th]:border-neutral-200
[&_th:last-child]:border-r-0
[&_td]:px-3
[&_td]:py-2
[&_td]:border-r
[&_td]:border-neutral-200
[&_td:last-child]:border-r-0
[&_tbody_tr:not(:last-child)_td]:border-b
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:bg-neutral-100
[&_code:not(pre_code)]:font-normal
[&_code:not(pre_code)]:px-1.5
@@ -191,10 +160,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-strong:text-neutral-200
dark:prose-pre:text-neutral-200
dark:prose:pre:text-neutral-200
dark:[&_table]:border-neutral-700
dark:[&_thead]:bg-neutral-800
dark:[&_th]:border-neutral-700
dark:[&_td]:border-neutral-700
dark:[&_code:not(pre_code)]:text-neutral-200
dark:[&_code:not(pre_code)]:bg-neutral-800
dark:[&_code:not(pre_code)]:font-normal
@@ -202,86 +167,104 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-li:marker:text-neutral-300
break-words
`}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<Streamdown
parseIncompleteMarkdown={isStreaming}
isAnimating={isStreaming}
remarkPlugins={remarkPlugins}
controls={false}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table
{...props}
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
>
{children}
</table>
</div>
),
// @ts-expect-error: custom citation type
"ol-citation": ({
cursor,
}: {
cursor: number;
start: number;
end: number;
}) => {
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring("search_results_".length);
return `Search: ${searchTerm}`;
}
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
return citationElement;
},
}}
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
{content}
</Streamdown>
</StreamingMarkdownErrorBoundary>
</div>
);
});
<Markdown
remarkPlugins={remarkPlugins}
rehypePlugins={
[
[rehypeRaw, { allowDangerousHtml: true }],
[rehypeSanitize, sanitizeSchema],
[rehypePrismPlus, { ignoreMissing: true }],
[
rehypeKatex,
{
errorColor: "#000000", // Black instead of red for errors
strict: false, // Be more lenient with parsing
throwOnError: false,
},
],
] as PluggableList
}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table {...props}>{children}</table>
</div>
),
// @ts-expect-error: custom type
"ol-citation": ({
cursor,
// start,
// end,
}: {
cursor: number;
start: number;
end: number;
}) => {
// Check if we have a page_stack and if the cursor is valid
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
// Extract a readable title from the URL if possible
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring(
"search_results_".length,
);
return `Search: ${searchTerm}`;
}
// For regular URLs, try to extract domain or use full URL
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
// If not a valid URL, return as is
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
// If we have a valid page URL, wrap in a link
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
// Otherwise, just return the citation without a link
return citationElement;
},
}}
>
{content}
</Markdown>
</StreamingMarkdownErrorBoundary>
</div>
);
},
);
interface StreamingMarkdownErrorBoundaryProps {
content: string;

View File

@@ -50,33 +50,21 @@ export default function Thinking({
// Position content to show bottom when collapsed
useEffect(() => {
if (isCollapsed && contentRef.current && wrapperRef.current) {
requestAnimationFrame(() => {
if (!contentRef.current || !wrapperRef.current) return;
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
});
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
setHasOverflow(false);
}
} else if (contentRef.current) {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
}, [thinking, isCollapsed]);
useEffect(() => {
if (activelyThinking && wrapperRef.current && !isCollapsed) {
// When expanded and actively thinking, scroll to bottom
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
}
}, [thinking, activelyThinking, isCollapsed]);
const handleToggle = () => {
setIsCollapsed(!isCollapsed);
setHasUserInteracted(true);
@@ -85,9 +73,8 @@ export default function Thinking({
// Calculate max height for smooth animations
const getMaxHeight = () => {
if (isCollapsed) {
return finishedThinking ? "0px" : "12rem";
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
}
// When expanded, use the content height or grow naturally
return contentHeight ? `${contentHeight}px` : "none";
};
@@ -144,11 +131,10 @@ export default function Thinking({
</div>
<div
ref={wrapperRef}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
style={{
maxHeight: isCollapsed ? getMaxHeight() : undefined,
maxHeight: getMaxHeight(),
opacity: isCollapsed && finishedThinking ? 0 : 1,
}}
>

View File

@@ -0,0 +1,108 @@
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import type React from "react";
import { Text } from "./text";
const sizes = {
xs: "sm:max-w-xs",
sm: "sm:max-w-sm",
md: "sm:max-w-md",
lg: "sm:max-w-lg",
xl: "sm:max-w-xl",
"2xl": "sm:max-w-2xl",
"3xl": "sm:max-w-3xl",
"4xl": "sm:max-w-4xl",
"5xl": "sm:max-w-5xl",
};
export function Alert({
size = "md",
className,
children,
...props
}: {
size?: keyof typeof sizes;
className?: string;
children: React.ReactNode;
} & Omit<Headless.DialogProps, "as" | "className">) {
return (
<Headless.Dialog {...props}>
<Headless.DialogBackdrop
transition
className="fixed inset-0 flex w-screen justify-center overflow-y-auto bg-zinc-950/15 px-2 py-2 transition duration-100 focus:outline-0 data-closed:opacity-0 data-enter:ease-out data-leave:ease-in sm:px-6 sm:py-8 lg:px-8 lg:py-16 dark:bg-zinc-950/50"
/>
<div className="fixed inset-0 w-screen overflow-y-auto pt-6 sm:pt-0">
<div className="grid min-h-full grid-rows-[1fr_auto_1fr] justify-items-center p-8 sm:grid-rows-[1fr_auto_3fr] sm:p-4">
<Headless.DialogPanel
transition
className={clsx(
className,
sizes[size],
"row-start-2 w-full rounded-2xl bg-white p-8 shadow-lg ring-1 ring-zinc-950/10 sm:rounded-2xl sm:p-6 dark:bg-zinc-900 dark:ring-white/10 forced-colors:outline",
"transition duration-100 will-change-transform data-closed:opacity-0 data-enter:ease-out data-closed:data-enter:scale-95 data-leave:ease-in",
)}
>
{children}
</Headless.DialogPanel>
</div>
</div>
</Headless.Dialog>
);
}
export function AlertTitle({
className,
...props
}: { className?: string } & Omit<
Headless.DialogTitleProps,
"as" | "className"
>) {
return (
<Headless.DialogTitle
{...props}
className={clsx(
className,
"text-center text-base/6 font-semibold text-balance text-zinc-950 sm:text-left sm:text-sm/6 sm:text-wrap dark:text-white",
)}
/>
);
}
export function AlertDescription({
className,
...props
}: { className?: string } & Omit<
Headless.DescriptionProps<typeof Text>,
"as" | "className"
>) {
return (
<Headless.Description
as={Text}
{...props}
className={clsx(className, "mt-2 text-center text-pretty sm:text-left")}
/>
);
}
export function AlertBody({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return <div {...props} className={clsx(className, "mt-4")} />;
}
export function AlertActions({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
{...props}
className={clsx(
className,
"mt-6 flex flex-col-reverse items-center justify-end gap-3 *:w-full sm:mt-4 sm:flex-row sm:*:w-auto",
)}
/>
);
}

View File

@@ -0,0 +1,107 @@
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import React, { forwardRef } from "react";
import { TouchTarget } from "./button";
import { Link } from "./link";
type AvatarProps = {
src?: string | null;
square?: boolean;
initials?: string;
alt?: string;
className?: string;
};
export function Avatar({
src = null,
square = false,
initials,
alt = "",
className,
...props
}: AvatarProps & React.ComponentPropsWithoutRef<"span">) {
return (
<span
data-slot="avatar"
{...props}
className={clsx(
className,
// Basic layout
"inline-grid shrink-0 align-middle [--avatar-radius:20%] *:col-start-1 *:row-start-1",
"outline -outline-offset-1 outline-black/10 dark:outline-white/10",
// Border radius
square
? "rounded-(--avatar-radius) *:rounded-(--avatar-radius)"
: "rounded-full *:rounded-full",
)}
>
{initials && (
<svg
className="size-full fill-current p-[5%] text-[48px] font-medium uppercase select-none"
viewBox="0 0 100 100"
aria-hidden={alt ? undefined : "true"}
>
{alt && <title>{alt}</title>}
<text
x="50%"
y="50%"
alignmentBaseline="middle"
dominantBaseline="middle"
textAnchor="middle"
dy=".125em"
>
{initials}
</text>
</svg>
)}
{src && <img className="size-full" src={src} alt={alt} />}
</span>
);
}
export const AvatarButton = forwardRef(function AvatarButton(
{
src,
square = false,
initials,
alt,
className,
...props
}: AvatarProps &
(
| Omit<Headless.ButtonProps, "as" | "className">
| Omit<React.ComponentPropsWithoutRef<typeof Link>, "className">
),
ref: React.ForwardedRef<HTMLElement>,
) {
const classes = clsx(
className,
square ? "rounded-[20%]" : "rounded-full",
"relative inline-grid focus:not-data-focus:outline-hidden data-focus:outline-2 data-focus:outline-offset-2 data-focus:outline-blue-500",
);
return "href" in props ? (
<Link
{...(props as Omit<
React.ComponentPropsWithoutRef<typeof Link>,
"className"
>)}
className={classes}
ref={ref as React.ForwardedRef<HTMLAnchorElement>}
>
<TouchTarget>
<Avatar src={src} square={square} initials={initials} alt={alt} />
</TouchTarget>
</Link>
) : (
<Headless.Button
{...(props as Omit<Headless.ButtonProps, "as" | "className">)}
className={classes}
ref={ref as React.ForwardedRef<HTMLButtonElement>}
>
<TouchTarget>
<Avatar src={src} square={square} initials={initials} alt={alt} />
</TouchTarget>
</Headless.Button>
);
});

View File

@@ -0,0 +1,160 @@
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import type React from "react";
export function CheckboxGroup({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
data-slot="control"
{...props}
className={clsx(
className,
// Basic groups
"space-y-3",
// With descriptions
"has-data-[slot=description]:space-y-6 has-data-[slot=description]:**:data-[slot=label]:font-medium",
)}
/>
);
}
export function CheckboxField({
className,
...props
}: { className?: string } & Omit<Headless.FieldProps, "as" | "className">) {
return (
<Headless.Field
data-slot="field"
{...props}
className={clsx(
className,
// Base layout
"grid grid-cols-[1.125rem_1fr] gap-x-4 gap-y-1 sm:grid-cols-[1rem_1fr]",
// Control layout
"*:data-[slot=control]:col-start-1 *:data-[slot=control]:row-start-1 *:data-[slot=control]:mt-0.75 sm:*:data-[slot=control]:mt-1",
// Label layout
"*:data-[slot=label]:col-start-2 *:data-[slot=label]:row-start-1",
// Description layout
"*:data-[slot=description]:col-start-2 *:data-[slot=description]:row-start-2",
// With description
"has-data-[slot=description]:**:data-[slot=label]:font-medium",
)}
/>
);
}
const base = [
// Basic layout
"relative isolate flex size-4.5 items-center justify-center rounded-[0.3125rem] sm:size-4",
// Background color + shadow applied to inset pseudo element, so shadow blends with border in light mode
"before:absolute before:inset-0 before:-z-10 before:rounded-[calc(0.3125rem-1px)] before:bg-white before:shadow-sm",
// Background color when checked
"group-data-checked:before:bg-(--checkbox-checked-bg)",
// Background color is moved to control and shadow is removed in dark mode so hide `before` pseudo
"dark:before:hidden",
// Background color applied to control in dark mode
"dark:bg-white/5 dark:group-data-checked:bg-(--checkbox-checked-bg)",
// Border
"border border-zinc-950/15 group-data-checked:border-transparent group-data-hover:group-data-checked:border-transparent group-data-hover:border-zinc-950/30 group-data-checked:bg-(--checkbox-checked-border)",
"dark:border-white/15 dark:group-data-checked:border-white/5 dark:group-data-hover:group-data-checked:border-white/5 dark:group-data-hover:border-white/30",
// Inner highlight shadow
"after:absolute after:inset-0 after:rounded-[calc(0.3125rem-1px)] after:shadow-[inset_0_1px_--theme(--color-white/15%)]",
"dark:after:-inset-px dark:after:hidden dark:after:rounded-[0.3125rem] dark:group-data-checked:after:block",
// Focus ring
"group-data-focus:outline-2 group-data-focus:outline-offset-2 group-data-focus:outline-blue-500",
// Disabled state
"group-data-disabled:opacity-50",
"group-data-disabled:border-zinc-950/25 group-data-disabled:bg-zinc-950/5 group-data-disabled:[--checkbox-check:var(--color-zinc-950)]/50 group-data-disabled:before:bg-transparent",
"dark:group-data-disabled:border-white/20 dark:group-data-disabled:bg-white/2.5 dark:group-data-disabled:[--checkbox-check:var(--color-white)]/50 dark:group-data-checked:group-data-disabled:after:hidden",
// Forced colors mode
"forced-colors:[--checkbox-check:HighlightText] forced-colors:[--checkbox-checked-bg:Highlight] forced-colors:group-data-disabled:[--checkbox-check:Highlight]",
"dark:forced-colors:[--checkbox-check:HighlightText] dark:forced-colors:[--checkbox-checked-bg:Highlight] dark:forced-colors:group-data-disabled:[--checkbox-check:Highlight]",
];
const colors = {
"dark/zinc": [
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-zinc-900)] [--checkbox-checked-border:var(--color-zinc-950)]/90",
"dark:[--checkbox-checked-bg:var(--color-zinc-600)]",
],
"dark/white": [
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-zinc-900)] [--checkbox-checked-border:var(--color-zinc-950)]/90",
"dark:[--checkbox-check:var(--color-zinc-900)] dark:[--checkbox-checked-bg:var(--color-white)] dark:[--checkbox-checked-border:var(--color-zinc-950)]/15",
],
white:
"[--checkbox-check:var(--color-zinc-900)] [--checkbox-checked-bg:var(--color-white)] [--checkbox-checked-border:var(--color-zinc-950)]/15",
dark: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-zinc-900)] [--checkbox-checked-border:var(--color-zinc-950)]/90",
zinc: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-zinc-600)] [--checkbox-checked-border:var(--color-zinc-700)]/90",
red: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-red-600)] [--checkbox-checked-border:var(--color-red-700)]/90",
orange:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-orange-500)] [--checkbox-checked-border:var(--color-orange-600)]/90",
amber:
"[--checkbox-check:var(--color-amber-950)] [--checkbox-checked-bg:var(--color-amber-400)] [--checkbox-checked-border:var(--color-amber-500)]/80",
yellow:
"[--checkbox-check:var(--color-yellow-950)] [--checkbox-checked-bg:var(--color-yellow-300)] [--checkbox-checked-border:var(--color-yellow-400)]/80",
lime: "[--checkbox-check:var(--color-lime-950)] [--checkbox-checked-bg:var(--color-lime-300)] [--checkbox-checked-border:var(--color-lime-400)]/80",
green:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-green-600)] [--checkbox-checked-border:var(--color-green-700)]/90",
emerald:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-emerald-600)] [--checkbox-checked-border:var(--color-emerald-700)]/90",
teal: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-teal-600)] [--checkbox-checked-border:var(--color-teal-700)]/90",
cyan: "[--checkbox-check:var(--color-cyan-950)] [--checkbox-checked-bg:var(--color-cyan-300)] [--checkbox-checked-border:var(--color-cyan-400)]/80",
sky: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-sky-500)] [--checkbox-checked-border:var(--color-sky-600)]/80",
blue: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-blue-600)] [--checkbox-checked-border:var(--color-blue-700)]/90",
indigo:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-indigo-500)] [--checkbox-checked-border:var(--color-indigo-600)]/90",
violet:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-violet-500)] [--checkbox-checked-border:var(--color-violet-600)]/90",
purple:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-purple-500)] [--checkbox-checked-border:var(--color-purple-600)]/90",
fuchsia:
"[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-fuchsia-500)] [--checkbox-checked-border:var(--color-fuchsia-600)]/90",
pink: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-pink-500)] [--checkbox-checked-border:var(--color-pink-600)]/90",
rose: "[--checkbox-check:var(--color-white)] [--checkbox-checked-bg:var(--color-rose-500)] [--checkbox-checked-border:var(--color-rose-600)]/90",
};
type Color = keyof typeof colors;
export function Checkbox({
color = "dark/zinc",
className,
...props
}: {
color?: Color;
className?: string;
} & Omit<Headless.CheckboxProps, "as" | "className">) {
return (
<Headless.Checkbox
data-slot="control"
{...props}
className={clsx(className, "group inline-flex focus:outline-hidden")}
>
<span className={clsx([base, colors[color]])}>
<svg
className="size-4 stroke-(--checkbox-check) opacity-0 group-data-checked:opacity-100 sm:h-3.5 sm:w-3.5"
viewBox="0 0 14 14"
fill="none"
>
{/* Checkmark icon */}
<path
className="opacity-100 group-data-indeterminate:opacity-0"
d="M3 8L6 11L11 3.5"
strokeWidth={2}
strokeLinecap="round"
strokeLinejoin="round"
/>
{/* Indeterminate icon */}
<path
className="opacity-0 group-data-indeterminate:opacity-100"
d="M3 7H11"
strokeWidth={2}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</span>
</Headless.Checkbox>
);
}

View File

@@ -0,0 +1,227 @@
"use client";
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import { useState } from "react";
export function Combobox<T>({
options,
displayValue,
filter,
anchor = "bottom",
className,
placeholder,
autoFocus,
"aria-label": ariaLabel,
children,
...props
}: {
options: T[];
displayValue: (value: T | null) => string | undefined;
filter?: (value: T, query: string) => boolean;
className?: string;
placeholder?: string;
autoFocus?: boolean;
"aria-label"?: string;
children: (value: NonNullable<T>) => React.ReactElement;
} & Omit<Headless.ComboboxProps<T, false>, "as" | "multiple" | "children"> & {
anchor?: "top" | "bottom";
}) {
const [query, setQuery] = useState("");
const filteredOptions =
query === ""
? options
: options.filter((option) =>
filter
? filter(option, query)
: displayValue(option)?.toLowerCase().includes(query.toLowerCase()),
);
return (
<Headless.Combobox
{...props}
multiple={false}
virtual={{ options: filteredOptions }}
onClose={() => setQuery("")}
>
<span
data-slot="control"
className={clsx([
className,
// Basic layout
"relative block w-full",
// Background color + shadow applied to inset pseudo element, so shadow blends with border in light mode
"before:absolute before:inset-px before:rounded-[calc(var(--radius-lg)-1px)] before:bg-white before:shadow-sm",
// Background color is moved to control and shadow is removed in dark mode so hide `before` pseudo
"dark:before:hidden",
// Focus ring
"after:pointer-events-none after:absolute after:inset-0 after:rounded-lg after:ring-transparent after:ring-inset sm:focus-within:after:ring-2 sm:focus-within:after:ring-blue-500",
// Disabled state
"has-data-disabled:opacity-50 has-data-disabled:before:bg-zinc-950/5 has-data-disabled:before:shadow-none",
// Invalid state
"has-data-invalid:before:shadow-red-500/10",
])}
>
<Headless.ComboboxInput
autoFocus={autoFocus}
data-slot="control"
aria-label={ariaLabel}
displayValue={(option: T) => displayValue(option) ?? ""}
onChange={(event) => setQuery(event.target.value)}
placeholder={placeholder}
className={clsx([
className,
// Basic layout
"relative block w-full appearance-none rounded-lg py-[calc(--spacing(2.5)-1px)] sm:py-[calc(--spacing(1.5)-1px)]",
// Horizontal padding
"pr-[calc(--spacing(10)-1px)] pl-[calc(--spacing(3.5)-1px)] sm:pr-[calc(--spacing(9)-1px)] sm:pl-[calc(--spacing(3)-1px)]",
// Typography
"text-base/6 text-zinc-950 placeholder:text-zinc-500 sm:text-sm/6 dark:text-white",
// Border
"border border-zinc-950/10 data-hover:border-zinc-950/20 dark:border-white/10 dark:data-hover:border-white/20",
// Background color
"bg-transparent dark:bg-white/5",
// Hide default focus styles
"focus:outline-hidden",
// Invalid state
"data-invalid:border-red-500 data-invalid:data-hover:border-red-500 dark:data-invalid:border-red-500 dark:data-invalid:data-hover:border-red-500",
// Disabled state
"data-disabled:border-zinc-950/20 dark:data-disabled:border-white/15 dark:data-disabled:bg-white/2.5 dark:data-hover:data-disabled:border-white/15",
// System icons
"dark:scheme-dark",
])}
/>
<Headless.ComboboxButton className="group absolute inset-y-0 right-0 flex items-center px-2">
<svg
className="size-5 stroke-zinc-500 group-data-disabled:stroke-zinc-600 group-data-hover:stroke-zinc-700 sm:size-4 dark:stroke-zinc-400 dark:group-data-hover:stroke-zinc-300 forced-colors:stroke-[CanvasText]"
viewBox="0 0 16 16"
aria-hidden="true"
fill="none"
>
<path
d="M5.75 10.75L8 13L10.25 10.75"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M10.25 5.25L8 3L5.75 5.25"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</Headless.ComboboxButton>
</span>
<Headless.ComboboxOptions
transition
anchor={anchor}
className={clsx(
// Anchor positioning
"[--anchor-gap:--spacing(2)] [--anchor-padding:--spacing(4)] sm:data-[anchor~=start]:[--anchor-offset:-4px]",
// Base styles,
"isolate min-w-[calc(var(--input-width)+8px)] scroll-py-1 rounded-xl p-1 select-none empty:invisible",
// Invisible border that is only visible in `forced-colors` mode for accessibility purposes
"outline outline-transparent focus:outline-hidden",
// Handle scrolling when menu won't fit in viewport
"overflow-y-scroll overscroll-contain",
// Popover background
"bg-white/75 backdrop-blur-xl dark:bg-zinc-800/75",
// Shadows
"shadow-lg ring-1 ring-zinc-950/10 dark:ring-white/10 dark:ring-inset",
// Transitions
"transition-opacity duration-100 ease-in data-closed:data-leave:opacity-0 data-transition:pointer-events-none",
)}
>
{({ option }) => children(option)}
</Headless.ComboboxOptions>
</Headless.Combobox>
);
}
export function ComboboxOption<T>({
children,
className,
...props
}: { className?: string; children?: React.ReactNode } & Omit<
Headless.ComboboxOptionProps<"div", T>,
"as" | "className"
>) {
let sharedClasses = clsx(
// Base
"flex min-w-0 items-center",
// Icons
"*:data-[slot=icon]:size-5 *:data-[slot=icon]:shrink-0 sm:*:data-[slot=icon]:size-4",
"*:data-[slot=icon]:text-zinc-500 group-data-focus/option:*:data-[slot=icon]:text-white dark:*:data-[slot=icon]:text-zinc-400",
"forced-colors:*:data-[slot=icon]:text-[CanvasText] forced-colors:group-data-focus/option:*:data-[slot=icon]:text-[Canvas]",
// Avatars
"*:data-[slot=avatar]:-mx-0.5 *:data-[slot=avatar]:size-6 sm:*:data-[slot=avatar]:size-5",
);
return (
<Headless.ComboboxOption
{...props}
className={clsx(
// Basic layout
"group/option grid w-full cursor-default grid-cols-[1fr_--spacing(5)] items-baseline gap-x-2 rounded-lg py-2.5 pr-2 pl-3.5 sm:grid-cols-[1fr_--spacing(4)] sm:py-1.5 sm:pr-2 sm:pl-3",
// Typography
"text-base/6 text-zinc-950 sm:text-sm/6 dark:text-white forced-colors:text-[CanvasText]",
// Focus
"outline-hidden data-focus:bg-blue-500 data-focus:text-white",
// Forced colors mode
"forced-color-adjust-none forced-colors:data-focus:bg-[Highlight] forced-colors:data-focus:text-[HighlightText]",
// Disabled
"data-disabled:opacity-50",
)}
>
<span className={clsx(className, sharedClasses)}>{children}</span>
<svg
className="relative col-start-2 hidden size-5 self-center stroke-current group-data-selected/option:inline sm:size-4"
viewBox="0 0 16 16"
fill="none"
aria-hidden="true"
>
<path
d="M4 8.5l3 3L12 4"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</Headless.ComboboxOption>
);
}
export function ComboboxLabel({
className,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
{...props}
className={clsx(
className,
"ml-2.5 truncate first:ml-0 sm:ml-2 sm:first:ml-0",
)}
/>
);
}
export function ComboboxDescription({
className,
children,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
{...props}
className={clsx(
className,
"flex flex-1 overflow-hidden text-zinc-500 group-data-focus/option:text-white before:w-2 before:min-w-0 before:shrink dark:text-zinc-400",
)}
>
<span className="flex-1 truncate">{children}</span>
</span>
);
}

View File

@@ -0,0 +1,46 @@
import clsx from "clsx";
export function DescriptionList({
className,
...props
}: React.ComponentPropsWithoutRef<"dl">) {
return (
<dl
{...props}
className={clsx(
className,
"grid grid-cols-1 text-base/6 sm:grid-cols-[min(50%,--spacing(80))_auto] sm:text-sm/6",
)}
/>
);
}
export function DescriptionTerm({
className,
...props
}: React.ComponentPropsWithoutRef<"dt">) {
return (
<dt
{...props}
className={clsx(
className,
"col-start-1 border-t border-zinc-950/5 pt-3 text-zinc-500 first:border-none sm:border-t sm:border-zinc-950/5 sm:py-3 dark:border-white/5 dark:text-zinc-400 sm:dark:border-white/5",
)}
/>
);
}
export function DescriptionDetails({
className,
...props
}: React.ComponentPropsWithoutRef<"dd">) {
return (
<dd
{...props}
className={clsx(
className,
"pt-1 pb-3 text-zinc-950 sm:border-t sm:border-zinc-950/5 sm:py-3 sm:nth-2:border-none dark:text-white dark:sm:border-white/5",
)}
/>
);
}

View File

@@ -0,0 +1,108 @@
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import type React from "react";
import { Text } from "./text";
const sizes = {
xs: "sm:max-w-xs",
sm: "sm:max-w-sm",
md: "sm:max-w-md",
lg: "sm:max-w-lg",
xl: "sm:max-w-xl",
"2xl": "sm:max-w-2xl",
"3xl": "sm:max-w-3xl",
"4xl": "sm:max-w-4xl",
"5xl": "sm:max-w-5xl",
};
export function Dialog({
size = "lg",
className,
children,
...props
}: {
size?: keyof typeof sizes;
className?: string;
children: React.ReactNode;
} & Omit<Headless.DialogProps, "as" | "className">) {
return (
<Headless.Dialog {...props}>
<Headless.DialogBackdrop
transition
className="fixed inset-0 flex w-screen justify-center overflow-y-auto bg-zinc-950/25 px-2 py-2 transition duration-100 focus:outline-0 data-closed:opacity-0 data-enter:ease-out data-leave:ease-in sm:px-6 sm:py-8 lg:px-8 lg:py-16 dark:bg-zinc-950/50"
/>
<div className="fixed inset-0 w-screen overflow-y-auto pt-6 sm:pt-0">
<div className="grid min-h-full grid-rows-[1fr_auto] justify-items-center sm:grid-rows-[1fr_auto_3fr] sm:p-4">
<Headless.DialogPanel
transition
className={clsx(
className,
sizes[size],
"row-start-2 w-full min-w-0 rounded-t-3xl bg-white p-(--gutter) shadow-lg ring-1 ring-zinc-950/10 [--gutter:--spacing(8)] sm:mb-auto sm:rounded-2xl dark:bg-zinc-900 dark:ring-white/10 forced-colors:outline",
"transition duration-100 will-change-transform data-closed:translate-y-12 data-closed:opacity-0 data-enter:ease-out data-leave:ease-in sm:data-closed:translate-y-0 sm:data-closed:data-enter:scale-95",
)}
>
{children}
</Headless.DialogPanel>
</div>
</div>
</Headless.Dialog>
);
}
export function DialogTitle({
className,
...props
}: { className?: string } & Omit<
Headless.DialogTitleProps,
"as" | "className"
>) {
return (
<Headless.DialogTitle
{...props}
className={clsx(
className,
"text-lg/6 font-semibold text-balance text-zinc-950 sm:text-base/6 dark:text-white",
)}
/>
);
}
export function DialogDescription({
className,
...props
}: { className?: string } & Omit<
Headless.DescriptionProps<typeof Text>,
"as" | "className"
>) {
return (
<Headless.Description
as={Text}
{...props}
className={clsx(className, "mt-2 text-pretty")}
/>
);
}
export function DialogBody({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return <div {...props} className={clsx(className, "mt-6")} />;
}
export function DialogActions({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
{...props}
className={clsx(
className,
"mt-8 flex flex-col-reverse items-center justify-end gap-3 *:w-full sm:flex-row sm:*:w-auto",
)}
/>
);
}

View File

@@ -0,0 +1,20 @@
import clsx from "clsx";
export function Divider({
soft = false,
className,
...props
}: { soft?: boolean } & React.ComponentPropsWithoutRef<"hr">) {
return (
<hr
role="presentation"
{...props}
className={clsx(
className,
"w-full border-t",
soft && "border-zinc-950/5 dark:border-white/5",
!soft && "border-zinc-950/10 dark:border-white/10",
)}
/>
);
}

View File

@@ -0,0 +1,230 @@
"use client";
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import type React from "react";
import { Button } from "./button";
import { Link } from "./link";
export function Dropdown(props: Headless.MenuProps) {
return <Headless.Menu {...props} />;
}
export function DropdownButton<T extends React.ElementType = typeof Button>({
as = Button,
...props
}: { className?: string } & Omit<Headless.MenuButtonProps<T>, "className">) {
return <Headless.MenuButton as={as} {...props} />;
}
export function DropdownMenu({
anchor = "bottom",
className,
...props
}: { className?: string } & Omit<Headless.MenuItemsProps, "as" | "className">) {
return (
<Headless.MenuItems
{...props}
transition
anchor={anchor}
className={clsx(
className,
// Anchor positioning
"[--anchor-gap:--spacing(2)] [--anchor-padding:--spacing(1)] data-[anchor~=end]:[--anchor-offset:6px] data-[anchor~=start]:[--anchor-offset:-6px] sm:data-[anchor~=end]:[--anchor-offset:4px] sm:data-[anchor~=start]:[--anchor-offset:-4px]",
// Base styles
"isolate w-max rounded-xl p-1",
// Invisible border that is only visible in `forced-colors` mode for accessibility purposes
"outline outline-transparent focus:outline-hidden",
// Handle scrolling when menu won't fit in viewport
"overflow-y-auto",
// Popover background
"bg-white/75 backdrop-blur-xl dark:bg-zinc-800/75",
// Shadows
"shadow-lg ring-1 ring-zinc-950/10 dark:ring-white/10 dark:ring-inset",
// Define grid at the menu level if subgrid is supported
"supports-[grid-template-columns:subgrid]:grid supports-[grid-template-columns:subgrid]:grid-cols-[auto_1fr_1.5rem_0.5rem_auto]",
// Transitions
"transition data-leave:duration-100 data-leave:ease-in data-closed:data-leave:opacity-0",
)}
/>
);
}
export function DropdownItem({
className,
...props
}: { className?: string } & (
| Omit<Headless.MenuItemProps<"button">, "as" | "className">
| Omit<Headless.MenuItemProps<typeof Link>, "as" | "className">
)) {
let classes = clsx(
className,
// Base styles
"group cursor-default rounded-lg px-3.5 py-2.5 focus:outline-hidden sm:px-3 sm:py-1.5",
// Text styles
"text-left text-base/6 text-zinc-950 sm:text-sm/6 dark:text-white forced-colors:text-[CanvasText]",
// Focus
"data-focus:bg-blue-500 data-focus:text-white",
// Disabled state
"data-disabled:opacity-50",
// Forced colors mode
"forced-color-adjust-none forced-colors:data-focus:bg-[Highlight] forced-colors:data-focus:text-[HighlightText] forced-colors:data-focus:*:data-[slot=icon]:text-[HighlightText]",
// Use subgrid when available but fallback to an explicit grid layout if not
"col-span-full grid grid-cols-[auto_1fr_1.5rem_0.5rem_auto] items-center supports-[grid-template-columns:subgrid]:grid-cols-subgrid",
// Icons
"*:data-[slot=icon]:col-start-1 *:data-[slot=icon]:row-start-1 *:data-[slot=icon]:mr-2.5 *:data-[slot=icon]:-ml-0.5 *:data-[slot=icon]:size-5 sm:*:data-[slot=icon]:mr-2 sm:*:data-[slot=icon]:size-4",
"*:data-[slot=icon]:text-zinc-500 data-focus:*:data-[slot=icon]:text-white dark:*:data-[slot=icon]:text-zinc-400 dark:data-focus:*:data-[slot=icon]:text-white",
// Avatar
"*:data-[slot=avatar]:mr-2.5 *:data-[slot=avatar]:-ml-1 *:data-[slot=avatar]:size-6 sm:*:data-[slot=avatar]:mr-2 sm:*:data-[slot=avatar]:size-5",
);
return "href" in props ? (
<Headless.MenuItem
as={Link}
{...(props as Omit<
Headless.MenuItemProps<typeof Link>,
"as" | "className"
>)}
className={classes}
/>
) : (
<Headless.MenuItem
as="button"
{...(props as Omit<Headless.MenuItemProps<"button">, "as" | "className">)}
className={classes}
/>
);
}
export function DropdownHeader({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
{...props}
className={clsx(className, "col-span-5 px-3.5 pt-2.5 pb-1 sm:px-3")}
/>
);
}
export function DropdownSection({
className,
...props
}: { className?: string } & Omit<
Headless.MenuSectionProps,
"as" | "className"
>) {
return (
<Headless.MenuSection
{...props}
className={clsx(
className,
// Define grid at the section level instead of the item level if subgrid is supported
"col-span-full supports-[grid-template-columns:subgrid]:grid supports-[grid-template-columns:subgrid]:grid-cols-[auto_1fr_1.5rem_0.5rem_auto]",
)}
/>
);
}
export function DropdownHeading({
className,
...props
}: { className?: string } & Omit<
Headless.MenuHeadingProps,
"as" | "className"
>) {
return (
<Headless.MenuHeading
{...props}
className={clsx(
className,
"col-span-full grid grid-cols-[1fr_auto] gap-x-12 px-3.5 pt-2 pb-1 text-sm/5 font-medium text-zinc-500 sm:px-3 sm:text-xs/5 dark:text-zinc-400",
)}
/>
);
}
export function DropdownDivider({
className,
...props
}: { className?: string } & Omit<
Headless.MenuSeparatorProps,
"as" | "className"
>) {
return (
<Headless.MenuSeparator
{...props}
className={clsx(
className,
"col-span-full mx-3.5 my-1 h-px border-0 bg-zinc-950/5 sm:mx-3 dark:bg-white/10 forced-colors:bg-[CanvasText]",
)}
/>
);
}
export function DropdownLabel({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
{...props}
data-slot="label"
className={clsx(className, "col-start-2 row-start-1")}
{...props}
/>
);
}
export function DropdownDescription({
className,
...props
}: { className?: string } & Omit<
Headless.DescriptionProps,
"as" | "className"
>) {
return (
<Headless.Description
data-slot="description"
{...props}
className={clsx(
className,
"col-span-2 col-start-2 row-start-2 text-sm/5 text-zinc-500 group-data-focus:text-white sm:text-xs/5 dark:text-zinc-400 forced-colors:group-data-focus:text-[HighlightText]",
)}
/>
);
}
export function DropdownShortcut({
keys,
className,
...props
}: { keys: string | string[]; className?: string } & Omit<
Headless.DescriptionProps<"kbd">,
"as" | "className"
>) {
return (
<Headless.Description
as="kbd"
{...props}
className={clsx(
className,
"col-start-5 row-start-1 flex justify-self-end",
)}
>
{(Array.isArray(keys) ? keys : keys.split("")).map((char, index) => (
<kbd
key={index}
className={clsx([
"min-w-[2ch] text-center font-sans text-zinc-400 capitalize group-data-focus:text-white forced-colors:group-data-focus:text-[HighlightText]",
// Make sure key names that are longer than one character (like "Tab") have extra space
index > 0 && char.length > 1 && "pl-1",
])}
>
{char}
</kbd>
))}
</Headless.Description>
);
}

View File

@@ -0,0 +1,33 @@
import clsx from "clsx";
type HeadingProps = {
level?: 1 | 2 | 3 | 4 | 5 | 6;
} & React.ComponentPropsWithoutRef<"h1" | "h2" | "h3" | "h4" | "h5" | "h6">;
export function Heading({ className, level = 1, ...props }: HeadingProps) {
let Element: `h${typeof level}` = `h${level}`;
return (
<Element
{...props}
className={clsx(
className,
"text-2xl/8 font-semibold text-zinc-950 sm:text-xl/8 dark:text-white",
)}
/>
);
}
export function Subheading({ className, level = 2, ...props }: HeadingProps) {
let Element: `h${typeof level}` = `h${level}`;
return (
<Element
{...props}
className={clsx(
className,
"text-base/7 font-semibold text-zinc-950 sm:text-sm/6 dark:text-white",
)}
/>
);
}

View File

@@ -0,0 +1,217 @@
"use client";
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import { Fragment } from "react";
export function Listbox<T>({
className,
placeholder,
autoFocus,
"aria-label": ariaLabel,
children: options,
...props
}: {
className?: string;
placeholder?: React.ReactNode;
autoFocus?: boolean;
"aria-label"?: string;
children?: React.ReactNode;
} & Omit<Headless.ListboxProps<typeof Fragment, T>, "as" | "multiple">) {
return (
<Headless.Listbox {...props} multiple={false}>
<Headless.ListboxButton
autoFocus={autoFocus}
data-slot="control"
aria-label={ariaLabel}
className={clsx([
className,
// Basic layout
"group relative block w-full",
// Background color + shadow applied to inset pseudo element, so shadow blends with border in light mode
"before:absolute before:inset-px before:rounded-[calc(var(--radius-lg)-1px)] before:bg-white before:shadow-sm",
// Background color is moved to control and shadow is removed in dark mode so hide `before` pseudo
"dark:before:hidden",
// Hide default focus styles
"focus:outline-hidden",
// Focus ring
"after:pointer-events-none after:absolute after:inset-0 after:rounded-lg after:ring-transparent after:ring-inset data-focus:after:ring-2 data-focus:after:ring-blue-500",
// Disabled state
"data-disabled:opacity-50 data-disabled:before:bg-zinc-950/5 data-disabled:before:shadow-none",
])}
>
<Headless.ListboxSelectedOption
as="span"
options={options}
placeholder={
placeholder && (
<span className="block truncate text-zinc-500">
{placeholder}
</span>
)
}
className={clsx([
// Basic layout
"relative block w-full appearance-none rounded-lg py-[calc(--spacing(2.5)-1px)] sm:py-[calc(--spacing(1.5)-1px)]",
// Set minimum height for when no value is selected
"min-h-11 sm:min-h-9",
// Horizontal padding
"pr-[calc(--spacing(7)-1px)] pl-[calc(--spacing(3.5)-1px)] sm:pl-[calc(--spacing(3)-1px)]",
// Typography
"text-left text-base/6 text-zinc-950 placeholder:text-zinc-500 sm:text-sm/6 dark:text-white forced-colors:text-[CanvasText]",
// Border
"border border-zinc-950/10 group-data-active:border-zinc-950/20 group-data-hover:border-zinc-950/20 dark:border-white/10 dark:group-data-active:border-white/20 dark:group-data-hover:border-white/20",
// Background color
"bg-transparent dark:bg-white/5",
// Invalid state
"group-data-invalid:border-red-500 group-data-hover:group-data-invalid:border-red-500 dark:group-data-invalid:border-red-600 dark:data-hover:group-data-invalid:border-red-600",
// Disabled state
"group-data-disabled:border-zinc-950/20 group-data-disabled:opacity-100 dark:group-data-disabled:border-white/15 dark:group-data-disabled:bg-white/2.5 dark:group-data-disabled:data-hover:border-white/15",
])}
/>
<span className="pointer-events-none absolute inset-y-0 right-0 flex items-center pr-2">
<svg
className="size-5 stroke-zinc-500 group-data-disabled:stroke-zinc-600 sm:size-4 dark:stroke-zinc-400 forced-colors:stroke-[CanvasText]"
viewBox="0 0 16 16"
aria-hidden="true"
fill="none"
>
<path
d="M5.75 10.75L8 13L10.25 10.75"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M10.25 5.25L8 3L5.75 5.25"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</span>
</Headless.ListboxButton>
<Headless.ListboxOptions
transition
anchor="selection start"
className={clsx(
// Anchor positioning
"[--anchor-offset:-1.625rem] [--anchor-padding:--spacing(4)] sm:[--anchor-offset:-1.375rem]",
// Base styles
"isolate w-max min-w-[calc(var(--button-width)+1.75rem)] scroll-py-1 rounded-xl p-1 select-none",
// Invisible border that is only visible in `forced-colors` mode for accessibility purposes
"outline outline-transparent focus:outline-hidden",
// Handle scrolling when menu won't fit in viewport
"overflow-y-scroll overscroll-contain",
// Popover background
"bg-white/75 backdrop-blur-xl dark:bg-zinc-800/75",
// Shadows
"shadow-lg ring-1 ring-zinc-950/10 dark:ring-white/10 dark:ring-inset",
// Transitions
"transition-opacity duration-100 ease-in data-closed:data-leave:opacity-0 data-transition:pointer-events-none",
)}
>
{options}
</Headless.ListboxOptions>
</Headless.Listbox>
);
}
export function ListboxOption<T>({
children,
className,
...props
}: { className?: string; children?: React.ReactNode } & Omit<
Headless.ListboxOptionProps<"div", T>,
"as" | "className"
>) {
let sharedClasses = clsx(
// Base
"flex min-w-0 items-center",
// Icons
"*:data-[slot=icon]:size-5 *:data-[slot=icon]:shrink-0 sm:*:data-[slot=icon]:size-4",
"*:data-[slot=icon]:text-zinc-500 group-data-focus/option:*:data-[slot=icon]:text-white dark:*:data-[slot=icon]:text-zinc-400",
"forced-colors:*:data-[slot=icon]:text-[CanvasText] forced-colors:group-data-focus/option:*:data-[slot=icon]:text-[Canvas]",
// Avatars
"*:data-[slot=avatar]:-mx-0.5 *:data-[slot=avatar]:size-6 sm:*:data-[slot=avatar]:size-5",
);
return (
<Headless.ListboxOption as={Fragment} {...props}>
{({ selectedOption }) => {
if (selectedOption) {
return (
<div className={clsx(className, sharedClasses)}>{children}</div>
);
}
return (
<div
className={clsx(
// Basic layout
"group/option grid cursor-default grid-cols-[--spacing(5)_1fr] items-baseline gap-x-2 rounded-lg py-2.5 pr-3.5 pl-2 sm:grid-cols-[--spacing(4)_1fr] sm:py-1.5 sm:pr-3 sm:pl-1.5",
// Typography
"text-base/6 text-zinc-950 sm:text-sm/6 dark:text-white forced-colors:text-[CanvasText]",
// Focus
"outline-hidden data-focus:bg-blue-500 data-focus:text-white",
// Forced colors mode
"forced-color-adjust-none forced-colors:data-focus:bg-[Highlight] forced-colors:data-focus:text-[HighlightText]",
// Disabled
"data-disabled:opacity-50",
)}
>
<svg
className="relative hidden size-5 self-center stroke-current group-data-selected/option:inline sm:size-4"
viewBox="0 0 16 16"
fill="none"
aria-hidden="true"
>
<path
d="M4 8.5l3 3L12 4"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
<span className={clsx(className, sharedClasses, "col-start-2")}>
{children}
</span>
</div>
);
}}
</Headless.ListboxOption>
);
}
export function ListboxLabel({
className,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
{...props}
className={clsx(
className,
"ml-2.5 truncate first:ml-0 sm:ml-2 sm:first:ml-0",
)}
/>
);
}
export function ListboxDescription({
className,
children,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
{...props}
className={clsx(
className,
"flex flex-1 overflow-hidden text-zinc-500 group-data-focus/option:text-white before:w-2 before:min-w-0 before:shrink dark:text-zinc-400",
)}
>
<span className="flex-1 truncate">{children}</span>
</span>
);
}

View File

@@ -0,0 +1,131 @@
"use client";
import * as Headless from "@headlessui/react";
import clsx from "clsx";
import { LayoutGroup, motion } from "framer-motion";
import React, { forwardRef, useId } from "react";
import { TouchTarget } from "./button";
import { Link } from "./link";
export function Navbar({
className,
...props
}: React.ComponentPropsWithoutRef<"nav">) {
return (
<nav
{...props}
className={clsx(className, "flex flex-1 items-center gap-4 py-2.5")}
/>
);
}
export function NavbarDivider({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
aria-hidden="true"
{...props}
className={clsx(className, "h-6 w-px bg-zinc-950/10 dark:bg-white/10")}
/>
);
}
export function NavbarSection({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
let id = useId();
return (
<LayoutGroup id={id}>
<div {...props} className={clsx(className, "flex items-center gap-3")} />
</LayoutGroup>
);
}
export function NavbarSpacer({
className,
...props
}: React.ComponentPropsWithoutRef<"div">) {
return (
<div
aria-hidden="true"
{...props}
className={clsx(className, "-ml-4 flex-1")}
/>
);
}
export const NavbarItem = forwardRef(function NavbarItem(
{
current,
className,
children,
...props
}: { current?: boolean; className?: string; children: React.ReactNode } & (
| Omit<Headless.ButtonProps, "as" | "className">
| Omit<React.ComponentPropsWithoutRef<typeof Link>, "className">
),
ref: React.ForwardedRef<HTMLAnchorElement | HTMLButtonElement>,
) {
let classes = clsx(
// Base
"relative flex min-w-0 items-center gap-3 rounded-lg p-2 text-left text-base/6 font-medium text-zinc-950 sm:text-sm/5",
// Leading icon/icon-only
"*:data-[slot=icon]:size-6 *:data-[slot=icon]:shrink-0 *:data-[slot=icon]:fill-zinc-500 sm:*:data-[slot=icon]:size-5",
// Trailing icon (down chevron or similar)
"*:not-nth-2:last:data-[slot=icon]:ml-auto *:not-nth-2:last:data-[slot=icon]:size-5 sm:*:not-nth-2:last:data-[slot=icon]:size-4",
// Avatar
"*:data-[slot=avatar]:-m-0.5 *:data-[slot=avatar]:size-7 *:data-[slot=avatar]:[--avatar-radius:var(--radius-md)] sm:*:data-[slot=avatar]:size-6",
// Hover
"data-hover:bg-zinc-950/5 data-hover:*:data-[slot=icon]:fill-zinc-950",
// Active
"data-active:bg-zinc-950/5 data-active:*:data-[slot=icon]:fill-zinc-950",
// Dark mode
"dark:text-white dark:*:data-[slot=icon]:fill-zinc-400",
"dark:data-hover:bg-white/5 dark:data-hover:*:data-[slot=icon]:fill-white",
"dark:data-active:bg-white/5 dark:data-active:*:data-[slot=icon]:fill-white",
);
return (
<span className={clsx(className, "relative")}>
{current && (
<motion.span
layoutId="current-indicator"
className="absolute inset-x-2 -bottom-2.5 h-0.5 rounded-full bg-zinc-950 dark:bg-white"
/>
)}
{"href" in props ? (
<Link
{...(props as Omit<
React.ComponentPropsWithoutRef<typeof Link>,
"className"
>)}
className={classes}
data-current={current ? "true" : undefined}
ref={ref as React.ForwardedRef<HTMLAnchorElement>}
>
<TouchTarget>{children}</TouchTarget>
</Link>
) : (
<Headless.Button
{...(props as Omit<Headless.ButtonProps, "as" | "className">)}
className={clsx("cursor-default", classes)}
data-current={current ? "true" : undefined}
ref={ref as React.ForwardedRef<HTMLButtonElement>}
>
<TouchTarget>{children}</TouchTarget>
</Headless.Button>
)}
</span>
);
});
export function NavbarLabel({
className,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return <span {...props} className={clsx(className, "truncate")} />;
}

View File

@@ -0,0 +1,139 @@
import clsx from "clsx";
import type React from "react";
import { Button } from "./button";
export function Pagination({
"aria-label": ariaLabel = "Page navigation",
className,
...props
}: React.ComponentPropsWithoutRef<"nav">) {
return (
<nav
aria-label={ariaLabel}
{...props}
className={clsx(className, "flex gap-x-2")}
/>
);
}
export function PaginationPrevious({
href = null,
className,
children = "Previous",
}: React.PropsWithChildren<{ href?: string | null; className?: string }>) {
return (
<span className={clsx(className, "grow basis-0")}>
<Button
{...(href === null ? { disabled: true } : { href })}
plain
aria-label="Previous page"
>
<svg
className="stroke-current"
data-slot="icon"
viewBox="0 0 16 16"
fill="none"
aria-hidden="true"
>
<path
d="M2.75 8H13.25M2.75 8L5.25 5.5M2.75 8L5.25 10.5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
{children}
</Button>
</span>
);
}
export function PaginationNext({
href = null,
className,
children = "Next",
}: React.PropsWithChildren<{ href?: string | null; className?: string }>) {
return (
<span className={clsx(className, "flex grow basis-0 justify-end")}>
<Button
{...(href === null ? { disabled: true } : { href })}
plain
aria-label="Next page"
>
{children}
<svg
className="stroke-current"
data-slot="icon"
viewBox="0 0 16 16"
fill="none"
aria-hidden="true"
>
<path
d="M13.25 8L2.75 8M13.25 8L10.75 10.5M13.25 8L10.75 5.5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</Button>
</span>
);
}
export function PaginationList({
className,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
{...props}
className={clsx(className, "hidden items-baseline gap-x-2 sm:flex")}
/>
);
}
export function PaginationPage({
href,
className,
current = false,
children,
}: React.PropsWithChildren<{
href: string;
className?: string;
current?: boolean;
}>) {
return (
<Button
href={href}
plain
aria-label={`Page ${children}`}
aria-current={current ? "page" : undefined}
className={clsx(
className,
"min-w-9 before:absolute before:-inset-px before:rounded-lg",
current && "before:bg-zinc-950/5 dark:before:bg-white/10",
)}
>
<span className="-mx-0.5">{children}</span>
</Button>
);
}
export function PaginationGap({
className,
children = <>&hellip;</>,
...props
}: React.ComponentPropsWithoutRef<"span">) {
return (
<span
aria-hidden="true"
{...props}
className={clsx(
className,
"w-9 text-center text-sm/6 font-semibold text-zinc-950 select-none dark:text-white",
)}
>
{children}
</span>
);
}

View File

@@ -0,0 +1,148 @@
import * as Headless from "@headlessui/react";
import clsx from "clsx";
export function RadioGroup({
className,
...props
}: { className?: string } & Omit<
Headless.RadioGroupProps,
"as" | "className"
>) {
return (
<Headless.RadioGroup
data-slot="control"
{...props}
className={clsx(
className,
// Basic groups
"space-y-3 **:data-[slot=label]:font-normal",
// With descriptions
"has-data-[slot=description]:space-y-6 has-data-[slot=description]:**:data-[slot=label]:font-medium",
)}
/>
);
}
export function RadioField({
className,
...props
}: { className?: string } & Omit<Headless.FieldProps, "as" | "className">) {
return (
<Headless.Field
data-slot="field"
{...props}
className={clsx(
className,
// Base layout
"grid grid-cols-[1.125rem_1fr] gap-x-4 gap-y-1 sm:grid-cols-[1rem_1fr]",
// Control layout
"*:data-[slot=control]:col-start-1 *:data-[slot=control]:row-start-1 *:data-[slot=control]:mt-0.75 sm:*:data-[slot=control]:mt-1",
// Label layout
"*:data-[slot=label]:col-start-2 *:data-[slot=label]:row-start-1",
// Description layout
"*:data-[slot=description]:col-start-2 *:data-[slot=description]:row-start-2",
// With description
"has-data-[slot=description]:**:data-[slot=label]:font-medium",
)}
/>
);
}
const base = [
// Basic layout
"relative isolate flex size-4.75 shrink-0 rounded-full sm:size-4.25",
// Background color + shadow applied to inset pseudo element, so shadow blends with border in light mode
"before:absolute before:inset-0 before:-z-10 before:rounded-full before:bg-white before:shadow-sm",
// Background color when checked
"group-data-checked:before:bg-(--radio-checked-bg)",
// Background color is moved to control and shadow is removed in dark mode so hide `before` pseudo
"dark:before:hidden",
// Background color applied to control in dark mode
"dark:bg-white/5 dark:group-data-checked:bg-(--radio-checked-bg)",
// Border
"border border-zinc-950/15 group-data-checked:border-transparent group-data-hover:group-data-checked:border-transparent group-data-hover:border-zinc-950/30 group-data-checked:bg-(--radio-checked-border)",
"dark:border-white/15 dark:group-data-checked:border-white/5 dark:group-data-hover:group-data-checked:border-white/5 dark:group-data-hover:border-white/30",
// Inner highlight shadow
"after:absolute after:inset-0 after:rounded-full after:shadow-[inset_0_1px_--theme(--color-white/15%)]",
"dark:after:-inset-px dark:after:hidden dark:after:rounded-full dark:group-data-checked:after:block",
// Indicator color (light mode)
"[--radio-indicator:transparent] group-data-checked:[--radio-indicator:var(--radio-checked-indicator)] group-data-hover:group-data-checked:[--radio-indicator:var(--radio-checked-indicator)] group-data-hover:[--radio-indicator:var(--color-zinc-900)]/10",
// Indicator color (dark mode)
"dark:group-data-hover:group-data-checked:[--radio-indicator:var(--radio-checked-indicator)] dark:group-data-hover:[--radio-indicator:var(--color-zinc-700)]",
// Focus ring
"group-data-focus:outline group-data-focus:outline-2 group-data-focus:outline-offset-2 group-data-focus:outline-blue-500",
// Disabled state
"group-data-disabled:opacity-50",
"group-data-disabled:border-zinc-950/25 group-data-disabled:bg-zinc-950/5 group-data-disabled:[--radio-checked-indicator:var(--color-zinc-950)]/50 group-data-disabled:before:bg-transparent",
"dark:group-data-disabled:border-white/20 dark:group-data-disabled:bg-white/2.5 dark:group-data-disabled:[--radio-checked-indicator:var(--color-white)]/50 dark:group-data-checked:group-data-disabled:after:hidden",
];
const colors = {
"dark/zinc": [
"[--radio-checked-bg:var(--color-zinc-900)] [--radio-checked-border:var(--color-zinc-950)]/90 [--radio-checked-indicator:var(--color-white)]",
"dark:[--radio-checked-bg:var(--color-zinc-600)]",
],
"dark/white": [
"[--radio-checked-bg:var(--color-zinc-900)] [--radio-checked-border:var(--color-zinc-950)]/90 [--radio-checked-indicator:var(--color-white)]",
"dark:[--radio-checked-bg:var(--color-white)] dark:[--radio-checked-border:var(--color-zinc-950)]/15 dark:[--radio-checked-indicator:var(--color-zinc-900)]",
],
white:
"[--radio-checked-bg:var(--color-white)] [--radio-checked-border:var(--color-zinc-950)]/15 [--radio-checked-indicator:var(--color-zinc-900)]",
dark: "[--radio-checked-bg:var(--color-zinc-900)] [--radio-checked-border:var(--color-zinc-950)]/90 [--radio-checked-indicator:var(--color-white)]",
zinc: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-zinc-600)] [--radio-checked-border:var(--color-zinc-700)]/90",
red: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-red-600)] [--radio-checked-border:var(--color-red-700)]/90",
orange:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-orange-500)] [--radio-checked-border:var(--color-orange-600)]/90",
amber:
"[--radio-checked-bg:var(--color-amber-400)] [--radio-checked-border:var(--color-amber-500)]/80 [--radio-checked-indicator:var(--color-amber-950)]",
yellow:
"[--radio-checked-bg:var(--color-yellow-300)] [--radio-checked-border:var(--color-yellow-400)]/80 [--radio-checked-indicator:var(--color-yellow-950)]",
lime: "[--radio-checked-bg:var(--color-lime-300)] [--radio-checked-border:var(--color-lime-400)]/80 [--radio-checked-indicator:var(--color-lime-950)]",
green:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-green-600)] [--radio-checked-border:var(--color-green-700)]/90",
emerald:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-emerald-600)] [--radio-checked-border:var(--color-emerald-700)]/90",
teal: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-teal-600)] [--radio-checked-border:var(--color-teal-700)]/90",
cyan: "[--radio-checked-bg:var(--color-cyan-300)] [--radio-checked-border:var(--color-cyan-400)]/80 [--radio-checked-indicator:var(--color-cyan-950)]",
sky: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-sky-500)] [--radio-checked-border:var(--color-sky-600)]/80",
blue: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-blue-600)] [--radio-checked-border:var(--color-blue-700)]/90",
indigo:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-indigo-500)] [--radio-checked-border:var(--color-indigo-600)]/90",
violet:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-violet-500)] [--radio-checked-border:var(--color-violet-600)]/90",
purple:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-purple-500)] [--radio-checked-border:var(--color-purple-600)]/90",
fuchsia:
"[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-fuchsia-500)] [--radio-checked-border:var(--color-fuchsia-600)]/90",
pink: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-pink-500)] [--radio-checked-border:var(--color-pink-600)]/90",
rose: "[--radio-checked-indicator:var(--color-white)] [--radio-checked-bg:var(--color-rose-500)] [--radio-checked-border:var(--color-rose-600)]/90",
};
type Color = keyof typeof colors;
export function Radio({
color = "dark/zinc",
className,
...props
}: { color?: Color; className?: string } & Omit<
Headless.RadioProps,
"as" | "className" | "children"
>) {
return (
<Headless.Radio
data-slot="control"
{...props}
className={clsx(className, "group inline-flex focus:outline-hidden")}
>
<span className={clsx([base, colors[color]])}>
<span
className={clsx(
"size-full rounded-full border-[4.5px] border-transparent bg-(--radio-indicator) bg-clip-padding",
// Forced colors mode
"forced-colors:border-[Canvas] forced-colors:group-data-checked:border-[Highlight]",
)}
/>
</span>
</Headless.Radio>
);
}

Some files were not shown because too many files have changed in this diff Show More