mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 20:48:32 -05:00
Compare commits
118 Commits
scratch
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a859f037da | ||
|
|
27aa2d4a19 | ||
|
|
b9f91a0b36 | ||
|
|
b538dc3858 | ||
|
|
f0e9496c85 | ||
|
|
09a6f76f4c | ||
|
|
e135167484 | ||
|
|
38296ab352 | ||
|
|
f43dea68d1 | ||
|
|
e1f50377f4 | ||
|
|
7913104527 | ||
|
|
bfbf2f7cf7 | ||
|
|
fe3cbd014f | ||
|
|
3d6f48507a | ||
|
|
f3761405c8 | ||
|
|
e49dc9f3d8 | ||
|
|
d125510b4b | ||
|
|
1ca386aa9e | ||
|
|
fb56988014 | ||
|
|
d046bee790 | ||
|
|
f11bf0740b | ||
|
|
8450bf66e6 | ||
|
|
b4e11be8ef | ||
|
|
a896079705 | ||
|
|
583950c828 | ||
|
|
8ac08a0eec | ||
|
|
60f47be64c | ||
|
|
6e56077ada | ||
|
|
98ae9467bb | ||
|
|
b7a24af083 | ||
|
|
c8b1f2369e | ||
|
|
72b12c3be7 | ||
|
|
0632dff3f8 | ||
|
|
509e2dec8a | ||
|
|
78a48de804 | ||
|
|
e7dbb00331 | ||
|
|
c3f9538636 | ||
|
|
2e06ed01d5 | ||
|
|
4072b5879b | ||
|
|
15562e887d | ||
|
|
f2245c7c77 | ||
|
|
e4b9b72f2a | ||
|
|
311f8e0c3f | ||
|
|
f07f8b7a9e | ||
|
|
4c4c730a0a | ||
|
|
e02ecfb6c8 | ||
|
|
c8059b4dcf | ||
|
|
59d87127f5 | ||
|
|
b5cf31b460 | ||
|
|
cc4915e262 | ||
|
|
667a2ba18a | ||
|
|
e054ebe059 | ||
|
|
9d3dcfd0ec | ||
|
|
6e0ea5ecc8 | ||
|
|
a47d8b2557 | ||
|
|
30c43c285c | ||
|
|
23a7ea593b | ||
|
|
75c44aa319 | ||
|
|
9d7b5d6c91 | ||
|
|
5d9c4a5f5a | ||
|
|
197e420a97 | ||
|
|
a34e1ad3cf | ||
|
|
2ae0556292 | ||
|
|
5be9bdd444 | ||
|
|
b706794905 | ||
|
|
a8c5413d06 | ||
|
|
5580de4571 | ||
|
|
946431d5b0 | ||
|
|
0610126049 | ||
|
|
3ebd6a83fc | ||
|
|
a64570dcae | ||
|
|
7c40a67841 | ||
|
|
e64b5b07a2 | ||
|
|
9e1e295cdc | ||
|
|
6eb3cddcb6 | ||
|
|
a4564232a4 | ||
|
|
a643823f86 | ||
|
|
8e5d359a03 | ||
|
|
a170888dd4 | ||
|
|
cd22855ef8 | ||
|
|
013fd07139 | ||
|
|
f63dc2db5c | ||
|
|
eaa5a396d9 | ||
|
|
8ed22f5d72 | ||
|
|
987c16b2f7 | ||
|
|
950f636d64 | ||
|
|
4458efb73a | ||
|
|
ceea599494 | ||
|
|
3005ec74b3 | ||
|
|
0759d8996e | ||
|
|
0f5b843319 | ||
|
|
ffaf52e1e9 | ||
|
|
940b10b036 | ||
|
|
3bc28736cd | ||
|
|
93a756266c | ||
|
|
a0a829bf7a | ||
|
|
730dcfcc7a | ||
|
|
27a2d5af54 | ||
|
|
5f81a33f43 | ||
|
|
6225fde046 | ||
|
|
069184562b | ||
|
|
5576bb2348 | ||
|
|
2738837786 | ||
|
|
ec3764538d | ||
|
|
df54c723ae | ||
|
|
fa8c990e58 | ||
|
|
da72235ebf | ||
|
|
89c4aee29e | ||
|
|
a447a083f2 | ||
|
|
f32ea81b21 | ||
|
|
681a914990 | ||
|
|
4c54f0ddeb | ||
|
|
c08dfaa23d | ||
|
|
3b76e736ae | ||
|
|
552db98bf1 | ||
|
|
fdcdfef620 | ||
|
|
6a042438af | ||
|
|
27331ae3a8 |
98
.github/workflows/test.yaml
vendored
98
.github/workflows/test.yaml
vendored
@@ -23,29 +23,72 @@ jobs:
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
$path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
|
||||
if ($path) {
|
||||
$path = join-path $path 'Common7\Tools\vsdevcmd.bat'
|
||||
if (test-path $path) {
|
||||
cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
|
||||
echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- run: go get ./...
|
||||
- run: go generate -x ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: |
|
||||
llm/llama.cpp/build/**/lib/*
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
generate-cuda:
|
||||
strategy:
|
||||
matrix:
|
||||
cuda-version:
|
||||
- '11.8.0'
|
||||
runs-on: ubuntu-latest
|
||||
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update && apt-get install -y git build-essential curl
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
|
||||
| tar -zx -C /usr --strip-components 1
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
git config --global --add safe.directory /__w/ollama/ollama
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
generate-rocm:
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '5.7.1'
|
||||
- '6.0'
|
||||
runs-on: ubuntu-latest
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update && apt-get install -y git build-essential curl rocm-libs
|
||||
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
|
||||
| tar -zx -C /usr --strip-components 1
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
git config --global --add safe.directory /__w/ollama/ollama
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
lint:
|
||||
needs: generate
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
@@ -69,10 +112,19 @@ jobs:
|
||||
with:
|
||||
go-version: '1.21'
|
||||
cache: false
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
|
||||
touch llm/llama.cpp/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
test:
|
||||
needs: generate
|
||||
@@ -104,3 +156,7 @@ jobs:
|
||||
path: llm/llama.cpp/build
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-binaries
|
||||
path: ollama
|
||||
|
||||
138
Dockerfile
138
Dockerfile
@@ -1,27 +1,135 @@
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
ARG GOLANG_VERSION=1.21.3
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
COPY .git .git
|
||||
COPY .gitmodules .gitmodules
|
||||
COPY llm llm
|
||||
|
||||
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
RUN apt-get update && apt-get install -y git build-essential cmake
|
||||
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
|
||||
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
|
||||
|
||||
COPY . .
|
||||
ENV GOARCH=$TARGETARCH
|
||||
ENV GOFLAGS=$GOFLAGS
|
||||
RUN /usr/local/go/bin/go generate ./... \
|
||||
&& /usr/local/go/bin/go build .
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build .
|
||||
|
||||
FROM ubuntu:22.04
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
|
||||
ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build .
|
||||
|
||||
# Runtime stages
|
||||
FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
|
||||
RUN apt-get update && apt-get install -y ca-certificates
|
||||
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
|
||||
RUN apt-get update && apt-get install -y ca-certificates
|
||||
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
|
||||
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
|
||||
RUN update-pciids
|
||||
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
|
||||
EXPOSE 11434
|
||||
ENV OLLAMA_HOST 0.0.0.0
|
||||
|
||||
# set some environment variable for better NVIDIA compatibility
|
||||
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENTRYPOINT ["/bin/ollama"]
|
||||
CMD ["serve"]
|
||||
|
||||
FROM runtime-$TARGETARCH
|
||||
EXPOSE 11434
|
||||
ENV OLLAMA_HOST 0.0.0.0
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
ARG GOLANG_VERSION=1.21.3
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
COPY .git .git
|
||||
COPY .gitmodules .gitmodules
|
||||
COPY llm llm
|
||||
|
||||
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
ENV LIBRARY_PATH /opt/amdgpu/lib64
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/amd64 centos:7 AS cpu-build-amd64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
|
||||
RUN sh gen_linux.sh
|
||||
|
||||
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
RUN go build .
|
||||
|
||||
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
|
||||
ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/jmorganca/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
RUN go build .
|
||||
|
||||
FROM build-$TARGETARCH
|
||||
29
README.md
29
README.md
@@ -1,8 +1,5 @@
|
||||
<div align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" height="200px" srcset="https://github.com/jmorganca/ollama/assets/3325447/56ea1849-1284-4645-8970-956de6e51c3c">
|
||||
<img alt="logo" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</picture>
|
||||
<img alt="ollama" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</div>
|
||||
|
||||
# Ollama
|
||||
@@ -31,6 +28,11 @@ curl https://ollama.ai/install.sh | sh
|
||||
|
||||
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
|
||||
|
||||
### Libraries
|
||||
|
||||
- [ollama-python](https://github.com/ollama/ollama-python)
|
||||
- [ollama-js](https://github.com/ollama/ollama-js)
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 2](https://ollama.ai/library/llama2):
|
||||
@@ -198,18 +200,21 @@ brew install cmake go
|
||||
```
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
```
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Then build the binary:
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
||||
More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md)
|
||||
|
||||
|
||||
### Running local builds
|
||||
|
||||
Next, start the server:
|
||||
|
||||
```
|
||||
@@ -248,13 +253,10 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
## Integrations
|
||||
|
||||
- [ollama-python](https://github.com/jmorganca/ollama-python)
|
||||
|
||||
## Community Integrations
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
@@ -267,7 +269,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
||||
|
||||
- [MindMac](https://mindmac.app)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -280,6 +282,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [gptel Emacs client](https://github.com/karthink/gptel)
|
||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||
- [cmdh](https://github.com/pgibler/cmdh)
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||
|
||||
### Database
|
||||
|
||||
@@ -306,7 +309,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
|
||||
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -327,4 +330,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
|
||||
129
api/types.go
129
api/types.go
@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
|
||||
type ImageData []byte
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -126,8 +128,9 @@ type Runner struct {
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -171,6 +174,7 @@ type ShowResponse struct {
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type CopyRequest struct {
|
||||
@@ -236,6 +240,7 @@ type GenerateResponse struct {
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ParentModel string `json:"parent_model"`
|
||||
Format string `json:"format"`
|
||||
Family string `json:"family"`
|
||||
Families []string `json:"families"`
|
||||
@@ -274,85 +279,20 @@ func (m *Metrics) Summary() {
|
||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||
if jsonTag != "" {
|
||||
jsonOpts[jsonTag] = field
|
||||
err = json.Unmarshal(data, opts)
|
||||
if err != nil {
|
||||
// Custom error handling
|
||||
if jsonErr, ok := err.(*json.UnmarshalTypeError); ok {
|
||||
return fmt.Errorf("invalid type for option '%v': expected %v, got %v", jsonErr.Field, jsonErr.Type, jsonErr.Value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
invalidOpts := []string{}
|
||||
for key, val := range m {
|
||||
if opt, ok := jsonOpts[key]; ok {
|
||||
field := valueOpts.FieldByName(opt.Name)
|
||||
if field.IsValid() && field.CanSet() {
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
switch t := val.(type) {
|
||||
case int64:
|
||||
field.SetInt(t)
|
||||
case float64:
|
||||
// when JSON unmarshals numbers, it uses float64, not int
|
||||
field.SetInt(int64(t))
|
||||
default:
|
||||
return fmt.Errorf("option %q must be of type integer", key)
|
||||
}
|
||||
case reflect.Bool:
|
||||
val, ok := val.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type boolean", key)
|
||||
}
|
||||
field.SetBool(val)
|
||||
case reflect.Float32:
|
||||
// JSON unmarshals to float64
|
||||
val, ok := val.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type float32", key)
|
||||
}
|
||||
field.SetFloat(val)
|
||||
case reflect.String:
|
||||
val, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type string", key)
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of an array of strings", key)
|
||||
}
|
||||
slice[i] = str
|
||||
}
|
||||
field.Set(reflect.ValueOf(slice))
|
||||
default:
|
||||
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
invalidOpts = append(invalidOpts, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(invalidOpts) > 0 {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -411,14 +351,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
case float64:
|
||||
if t < 0 {
|
||||
t = math.MaxFloat64
|
||||
d.Duration = time.Duration(t)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
|
||||
d.Duration = time.Duration(t)
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Duration < 0 {
|
||||
mf := math.MaxFloat64
|
||||
d.Duration = time.Duration(mf)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
127
cmd/cmd.go
127
cmd/cmd.go
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
@@ -146,19 +147,68 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// check if the model exists on the server
|
||||
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
return RunGenerate(cmd, args)
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
MultiModal: slices.Contains(show.Details.Families, "clip"),
|
||||
ParentModel: show.Details.ParentModel,
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -410,63 +460,20 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
||||
type runOptions struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -628,10 +635,18 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.MultiModal {
|
||||
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
|
||||
@@ -6,13 +6,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/progress"
|
||||
"github.com/jmorganca/ollama/readline"
|
||||
)
|
||||
|
||||
@@ -25,45 +28,82 @@ const (
|
||||
MultilineTemplate
|
||||
)
|
||||
|
||||
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
|
||||
// get model details
|
||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
req := api.ShowRequest{Name: name}
|
||||
resp, err := client.Show(cmd.Context(), &req)
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
showReq := api.ShowRequest{Name: opts.Model}
|
||||
showResp, err := client.Show(cmd.Context(), &showReq)
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
|
||||
opts.ParentModel = showResp.Details.ParentModel
|
||||
|
||||
if len(showResp.Messages) > 0 {
|
||||
opts.Messages = append(opts.Messages, showResp.Messages...)
|
||||
}
|
||||
|
||||
return slices.Contains(resp.Details.Families, "clip")
|
||||
chatReq := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: []api.Message{},
|
||||
}
|
||||
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
if len(opts.Messages) > 0 {
|
||||
for _, msg := range opts.Messages {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
fmt.Printf(">>> %s\n", msg.Content)
|
||||
case "assistant":
|
||||
state := &displayResponseState{}
|
||||
displayResponse(msg.Content, opts.WordWrap, state)
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
multiModal := modelIsMultiModal(cmd, opts.Model)
|
||||
opts.Messages = make([]api.Message, 0)
|
||||
|
||||
// load the model
|
||||
loadOpts := runOptions{
|
||||
Model: opts.Model,
|
||||
Prompt: "",
|
||||
Messages: []api.Message{},
|
||||
}
|
||||
if _, err := chat(cmd, loadOpts); err != nil {
|
||||
err := loadModel(cmd, &opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, " /set Set session variables")
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@@ -140,7 +180,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
opts.Messages = make([]api.Message, 0)
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -174,6 +213,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
case MultilineTemplate:
|
||||
@@ -193,7 +233,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(&sb)
|
||||
multiline = MultilinePrompt
|
||||
scanner.Prompt.UseAlt = true
|
||||
break
|
||||
}
|
||||
case scanner.Pasting:
|
||||
fmt.Fprintln(&sb, line)
|
||||
@@ -203,6 +242,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
if err := ListHandler(cmd, args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
case strings.HasPrefix(line, "/load"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /load <modelname>")
|
||||
continue
|
||||
}
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
if err := loadModel(cmd, &opts); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/save"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) != 2 {
|
||||
fmt.Println("Usage:\n /save <modelname>")
|
||||
continue
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Name: args[1],
|
||||
Modelfile: buildModelfile(opts),
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't save model")
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
@@ -278,10 +355,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
if args[1] == "system" {
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
@@ -389,7 +469,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
||||
if multiModal {
|
||||
if opts.MultiModal {
|
||||
for _, f := range extractFileNames(line) {
|
||||
if strings.HasPrefix(f, args[0]) {
|
||||
isFile = true
|
||||
@@ -411,34 +491,23 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
|
||||
if multiModal {
|
||||
if opts.MultiModal {
|
||||
msg, images, err := extractFileData(sb.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newMessage.Content = msg
|
||||
|
||||
// reset the context if we find another image
|
||||
// clear all previous images for better responses
|
||||
if len(images) > 0 {
|
||||
newMessage.Images = append(newMessage.Images, images...)
|
||||
// reset the context for the new image
|
||||
opts.Messages = []api.Message{}
|
||||
} else {
|
||||
if len(opts.Messages) > 1 {
|
||||
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
|
||||
for i := range opts.Messages {
|
||||
opts.Messages[i].Images = nil
|
||||
}
|
||||
}
|
||||
if len(newMessage.Images) == 0 {
|
||||
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
|
||||
fmt.Println()
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
|
||||
newMessage.Content = msg
|
||||
newMessage.Images = images
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
}
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
@@ -454,6 +523,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
func buildModelfile(opts runOptions) string {
|
||||
var mf strings.Builder
|
||||
model := opts.ParentModel
|
||||
if model == "" {
|
||||
model = opts.Model
|
||||
}
|
||||
fmt.Fprintf(&mf, "FROM %s\n", model)
|
||||
if opts.System != "" {
|
||||
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
||||
}
|
||||
|
||||
if opts.Template != "" {
|
||||
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
|
||||
}
|
||||
|
||||
keys := make([]string, 0)
|
||||
for k := range opts.Options {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
|
||||
}
|
||||
fmt.Fprintln(&mf)
|
||||
|
||||
for _, msg := range opts.Messages {
|
||||
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
|
||||
}
|
||||
|
||||
return mf.String()
|
||||
}
|
||||
|
||||
func normalizeFilePath(fp string) string {
|
||||
// Define a map of escaped characters and their replacements
|
||||
replacements := map[string]string{
|
||||
@@ -500,10 +601,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Couldn't process image: %q\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Printf("Added image '%s'\n", nfp)
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestExtractFilenames(t *testing.T) {
|
||||
@@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
assert.Contains(t, res[9], "ten.svg")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
}
|
||||
|
||||
func TestModelfileBuilder(t *testing.T) {
|
||||
opts := runOptions{
|
||||
Model: "hork",
|
||||
System: "You are part horse and part shark, but all hork. Do horklike things",
|
||||
Template: "This is a template.",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hey there hork!"},
|
||||
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
||||
},
|
||||
Options: map[string]interface{}{},
|
||||
}
|
||||
|
||||
opts.Options["temperature"] = 0.9
|
||||
opts.Options["seed"] = 42
|
||||
opts.Options["penalize_newline"] = false
|
||||
opts.Options["stop"] = []string{"hi", "there"}
|
||||
|
||||
mf := buildModelfile(opts)
|
||||
expectedModelfile := `FROM {{.Model}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
PARAMETER temperature 0.9
|
||||
|
||||
MESSAGE user """Hey there hork!"""
|
||||
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err := template.New("").Parse(expectedModelfile)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, opts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, buf.String(), mf)
|
||||
|
||||
opts.ParentModel = "horseshark"
|
||||
mf = buildModelfile(opts)
|
||||
expectedModelfile = `FROM {{.ParentModel}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
PARAMETER temperature 0.9
|
||||
|
||||
MESSAGE user """Hey there hork!"""
|
||||
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err = template.New("").Parse(expectedModelfile)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var parentBuf bytes.Buffer
|
||||
err = tmpl.Execute(&parentBuf, opts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, parentBuf.String(), mf)
|
||||
}
|
||||
|
||||
@@ -542,7 +542,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"role": "user",
|
||||
"content": "what is in this image?",
|
||||
"images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"]
|
||||
},
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -50,7 +50,8 @@ development and runtime packages.
|
||||
Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler.
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
@@ -74,7 +75,8 @@ Typically the build scripts will auto-detect ROCm, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `ROCM_PATH` to the location of the ROCm
|
||||
install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
|
||||
CLBlast install (typically `/usr/lib/cmake/CLBlast`).
|
||||
CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
|
||||
the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
|
||||
|
||||
```
|
||||
go generate ./...
|
||||
|
||||
59
docs/faq.md
59
docs/faq.md
@@ -8,35 +8,38 @@ To upgrade Ollama, run the installation process again. On the Mac, click the Oll
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
|
||||
## How do I use Ollama server environment variables on Mac
|
||||
## How do I configure Ollama server?
|
||||
|
||||
On macOS, Ollama runs in the background and is managed by the menubar app. If adding environment variables, Ollama will need to be run manually.
|
||||
Ollama server can be configured with environment variables.
|
||||
|
||||
1. Click the menubar icon for Ollama and choose **Quit Ollama**.
|
||||
2. Open a new terminal window and run the following command (this example uses `OLLAMA_HOST` with an IP address of `123.1.1.1`):
|
||||
### Setting environment variables on Mac
|
||||
|
||||
```bash
|
||||
OLLAMA_HOST=123.1.1.1 ollama serve
|
||||
```
|
||||
If Ollama is run as a macOS application, environment variables should be set using `launchctl`:
|
||||
|
||||
## How do I use Ollama server environment variables on Linux?
|
||||
1. For each environment variable, call `launchctl setenv`.
|
||||
|
||||
If Ollama is installed with the install script, a systemd service was created, running as the Ollama user. To add an environment variable, such as OLLAMA_HOST, follow these steps:
|
||||
```bash
|
||||
launchctl setenv OLLAMA_HOST "0.0.0.0"
|
||||
```
|
||||
|
||||
1. Create a `systemd` drop-in directory and add a config file. This is only needed once.
|
||||
2. Restart Ollama application.
|
||||
|
||||
```bash
|
||||
mkdir -p /etc/systemd/system/ollama.service.d
|
||||
echo '[Service]' >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
### Setting environment variables on Linux
|
||||
|
||||
2. For each environment variable, add it to the config file:
|
||||
If Ollama is run as a systemd service, environment variables should be set using `systemctl`:
|
||||
|
||||
```bash
|
||||
echo 'Environment="OLLAMA_HOST=0.0.0.0:11434"' >>/etc/systemd/system/ollama.service.d/environment.conf
|
||||
```
|
||||
1. Edit the systemd service by calling `systemctl edit ollama.service`. This will open an editor.
|
||||
|
||||
3. Reload `systemd` and restart Ollama:
|
||||
2. For each environment variable, add a line `Environment` under section `[Service]`:
|
||||
|
||||
```ini
|
||||
[Service]
|
||||
Environment="OLLAMA_HOST=0.0.0.0"
|
||||
```
|
||||
|
||||
3. Save and exit.
|
||||
|
||||
4. Reload `systemd` and restart Ollama:
|
||||
|
||||
```bash
|
||||
systemctl daemon-reload
|
||||
@@ -45,26 +48,26 @@ If Ollama is installed with the install script, a systemd service was created, r
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable. Refer to the section above for how to use environment variables on your platform.
|
||||
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## How can I allow additional web origins to access Ollama?
|
||||
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable. For example, to add all ports on 192.168.1.1 and https://example.com, use:
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||
|
||||
```shell
|
||||
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com
|
||||
```
|
||||
|
||||
Refer to the section above for how to use environment variables on your platform.
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Where are models stored?
|
||||
|
||||
- macOS: `~/.ollama/models`.
|
||||
- Linux: `/usr/share/ollama/.ollama/models`
|
||||
|
||||
## How do I set them to a different location?
|
||||
### How do I set them to a different location?
|
||||
|
||||
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory. Refer to the section above for how to use environment variables on your platform.
|
||||
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory.
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Does Ollama send my prompts and answers back to Ollama.ai to use in any way?
|
||||
|
||||
|
||||
112
docs/import.md
112
docs/import.md
@@ -15,7 +15,7 @@ FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
@@ -37,55 +37,69 @@ ollama run example "What is your favourite condiment?"
|
||||
|
||||
## Importing (PyTorch & Safetensors)
|
||||
|
||||
### Supported models
|
||||
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
|
||||
|
||||
Ollama supports a set of model architectures, with support for more coming soon:
|
||||
### Setup
|
||||
|
||||
- Llama & Mistral
|
||||
- Falcon & RW
|
||||
- BigCode
|
||||
First, clone the `ollama/ollama` repo:
|
||||
|
||||
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
|
||||
```
|
||||
git clone git@github.com:ollama/ollama.git ollama
|
||||
cd ollama
|
||||
```
|
||||
|
||||
### Step 1: Clone the HuggingFace repository (optional)
|
||||
and then fetch its `llama.cpp` submodule:
|
||||
|
||||
```shell
|
||||
git submodule init
|
||||
git submodule update llm/llama.cpp
|
||||
```
|
||||
|
||||
Next, install the Python dependencies:
|
||||
|
||||
```
|
||||
python3 -m venv llm/llama.cpp/.venv
|
||||
source llm/llama.cpp/.venv/bin/activate
|
||||
pip install -r llm/llama.cpp/requirements.txt
|
||||
```
|
||||
|
||||
Then build the `quantize` tool:
|
||||
|
||||
```
|
||||
make -C llm/llama.cpp quantize
|
||||
```
|
||||
|
||||
### Clone the HuggingFace repository (optional)
|
||||
|
||||
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
||||
|
||||
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
cd Mistral-7B-Instruct-v0.1
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
|
||||
```
|
||||
|
||||
### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors)
|
||||
### Convert the model
|
||||
|
||||
If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available.
|
||||
|
||||
First, Install [Docker](https://www.docker.com/get-started/).
|
||||
|
||||
Next, to convert and quantize your model, run:
|
||||
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
|
||||
|
||||
```
|
||||
docker run --rm -v .:/model ollama/quantize -q q4_0 /model
|
||||
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
|
||||
```
|
||||
|
||||
This will output two files into the directory:
|
||||
### Quantize the model
|
||||
|
||||
- `f16.bin`: the model converted to GGUF
|
||||
- `q4_0.bin` the model quantized to a 4-bit quantization (Ollama will use this file to create the Ollama model)
|
||||
```
|
||||
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
|
||||
```
|
||||
|
||||
### Step 3: Write a `Modelfile`
|
||||
|
||||
Next, create a `Modelfile` for your model:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
```
|
||||
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
FROM quantized.bin
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
@@ -149,47 +163,3 @@ The quantization options are as follow (from highest highest to lowest levels of
|
||||
- `q6_K`
|
||||
- `q8_0`
|
||||
- `f16`
|
||||
|
||||
## Manually converting & quantizing models
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Start by cloning the `llama.cpp` repo to your machine in another directory:
|
||||
|
||||
```
|
||||
git clone https://github.com/ggerganov/llama.cpp.git
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Next, install the Python dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Finally, build the `quantize` tool:
|
||||
|
||||
```
|
||||
make quantize
|
||||
```
|
||||
|
||||
### Convert the model
|
||||
|
||||
Run the correct conversion script for your model architecture:
|
||||
|
||||
```shell
|
||||
# LlamaForCausalLM or MistralForCausalLM
|
||||
python convert.py <path to model directory>
|
||||
|
||||
# FalconForCausalLM
|
||||
python convert-falcon-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTBigCodeForCausalLM
|
||||
python convert-starcoder-hf-to-gguf.py <path to model directory>
|
||||
```
|
||||
|
||||
### Quantize the model
|
||||
|
||||
```
|
||||
quantize <path to model dir>/ggml-model-f32.bin <path to model dir>/q4_0.bin q4_0
|
||||
```
|
||||
|
||||
@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||
- [SYSTEM](#system)
|
||||
- [ADAPTER](#adapter)
|
||||
- [LICENSE](#license)
|
||||
- [MESSAGE](#message)
|
||||
- [Notes](#notes)
|
||||
|
||||
## Format
|
||||
@@ -38,6 +39,7 @@ INSTRUCTION arguments
|
||||
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
|
||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||
| [`MESSAGE`](#message) | Specify message history. |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -205,6 +207,19 @@ LICENSE """
|
||||
"""
|
||||
```
|
||||
|
||||
### MESSAGE
|
||||
|
||||
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
|
||||
|
||||
```modelfile
|
||||
MESSAGE user Is Toronto in Canada?
|
||||
MESSAGE assistant yes
|
||||
MESSAGE user Is Sacramento in Canada?
|
||||
MESSAGE assistant no
|
||||
MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
@@ -12,6 +12,13 @@ On Linux systems with systemd, the logs can be found with this command:
|
||||
journalctl -u ollama
|
||||
```
|
||||
|
||||
When you run Ollama in a container, the logs go to stdout/stderr in the container:
|
||||
|
||||
```shell
|
||||
docker logs <container-name>
|
||||
```
|
||||
(Use `docker ps` to find the container name)
|
||||
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
68
gpu/gpu.go
68
gpu/gpu.go
@@ -16,6 +16,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
@@ -29,8 +30,8 @@ type handles struct {
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, 5.2 and older will not work properly
|
||||
const CudaComputeMajorMin = 6
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
|
||||
// Possible locations for the nvidia-ml library
|
||||
var CudaLinuxGlobs = []string{
|
||||
@@ -38,12 +39,15 @@ var CudaLinuxGlobs = []string{
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so*",
|
||||
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
|
||||
"/opt/cuda/lib64/libnvidia-ml.so*",
|
||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
||||
"/usr/lib*/libnvidia-ml.so*",
|
||||
"/usr/local/lib*/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
|
||||
|
||||
// TODO: are these stubs ever valid?
|
||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
||||
}
|
||||
|
||||
var CudaWindowsGlobs = []string{
|
||||
@@ -118,33 +122,60 @@ func GetGPUInfo() GpuInfo {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := GpuInfo{}
|
||||
if gpuHandles.cuda != nil {
|
||||
if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
} else if memInfo.count > 0 {
|
||||
// Verify minimum compute capability
|
||||
var cc C.cuda_compute_capability_t
|
||||
C.cuda_compute_capability(*gpuHandles.cuda, &cc)
|
||||
if cc.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
|
||||
C.free(unsafe.Pointer(cc.err))
|
||||
} else if cc.major >= CudaComputeMajorMin {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
}
|
||||
} else if gpuHandles.rocm != nil {
|
||||
} else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
|
||||
// Only one GPU detected and it appears to be an integrated GPU - skip it
|
||||
slog.Info("ROCm unsupported integrated GPU detected")
|
||||
} else if memInfo.count > 0 {
|
||||
if memInfo.igpu_index >= 0 {
|
||||
// We have multiple GPUs reported, and one of them is an integrated GPU
|
||||
// so we have to set the env var to bypass it
|
||||
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
|
||||
val := os.Getenv("ROCR_VISIBLE_DEVICES")
|
||||
if val == "" {
|
||||
devices := []string{}
|
||||
for i := 0; i < int(memInfo.count); i++ {
|
||||
if i == int(memInfo.igpu_index) {
|
||||
continue
|
||||
}
|
||||
devices = append(devices, strconv.Itoa(i))
|
||||
}
|
||||
val = strings.Join(devices, ",")
|
||||
os.Setenv("ROCR_VISIBLE_DEVICES", val)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
|
||||
}
|
||||
resp.Library = "rocm"
|
||||
var version C.rocm_version_resp_t
|
||||
C.rocm_get_version(*gpuHandles.rocm, &version)
|
||||
@@ -160,7 +191,7 @@ func GetGPUInfo() GpuInfo {
|
||||
if resp.Library == "" {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
resp.Library = "cpu"
|
||||
resp.Variant = GetCPUVariant()
|
||||
resp.Variant = cpuVariant
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
||||
@@ -190,13 +221,15 @@ func getCPUMem() (memInfo, error) {
|
||||
func CheckVRAM() (int64, error) {
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
// leave 10% or 512MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
overhead := gpuInfo.FreeMemory / 10
|
||||
gpus := uint64(gpuInfo.DeviceCount)
|
||||
if overhead < gpus*512*1024*1024 {
|
||||
overhead = gpus * 512 * 1024 * 1024
|
||||
if overhead < gpus*1024*1024*1024 {
|
||||
overhead = gpus * 1024 * 1024 * 1024
|
||||
}
|
||||
return int64(gpuInfo.FreeMemory - overhead), nil
|
||||
avail := int64(gpuInfo.FreeMemory - overhead)
|
||||
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
@@ -258,6 +291,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
|
||||
func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
|
||||
var resp C.cuda_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range cudaLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
@@ -274,6 +308,7 @@ func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
|
||||
|
||||
func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
|
||||
var resp C.rocm_init_resp_t
|
||||
resp.rh.verbose = getVerboseState()
|
||||
for _, libPath := range rocmLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
@@ -287,3 +322,10 @@ func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,13 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define LOG(verbose, ...) \
|
||||
do { \
|
||||
if (verbose) { \
|
||||
fprintf(stderr, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -35,6 +42,7 @@ typedef struct mem_info {
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
unsigned int count;
|
||||
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
|
||||
char *err; // If non-nill, caller responsible for freeing
|
||||
} mem_info_t;
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#define CUDA_LOOKUP_SIZE 6
|
||||
|
||||
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
|
||||
nvmlReturn_t ret;
|
||||
resp->err = NULL;
|
||||
@@ -16,18 +14,26 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[CUDA_LOOKUP_SIZE] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.initFn},
|
||||
{"nvmlShutdown", (void *)&resp->ch.shutdownFn},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.getHandle},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.getMemInfo},
|
||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.getCount},
|
||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.getComputeCapability},
|
||||
} l[] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
|
||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
|
||||
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
|
||||
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
|
||||
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
|
||||
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
|
||||
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
|
||||
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", cuda_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
cuda_lib_path, msg);
|
||||
@@ -36,12 +42,19 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < CUDA_LOOKUP_SIZE; i++) { // TODO - fix this to use a null terminated list
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", cuda_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
@@ -50,15 +63,23 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.initFn)();
|
||||
ret = (*resp->ch.nvmlInit_v2)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
return;
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
|
||||
@@ -75,7 +96,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.getCount)(&resp->count);
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
@@ -85,19 +106,57 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp->count; i++) {
|
||||
ret = (*h.getHandle)(i, &device);
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.getMemInfo)(device, &memInfo);
|
||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
if (h.verbose) {
|
||||
nvmlBrandType_t brand = 0;
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover, but don't fail on error
|
||||
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
|
||||
}
|
||||
}
|
||||
|
||||
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
|
||||
LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.free);
|
||||
|
||||
resp->total += memInfo.total;
|
||||
resp->free += memInfo.free;
|
||||
@@ -122,7 +181,7 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
|
||||
}
|
||||
|
||||
unsigned int devices;
|
||||
ret = (*h.getCount)(&devices);
|
||||
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
@@ -130,14 +189,14 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
|
||||
}
|
||||
|
||||
for (i = 0; i < devices; i++) {
|
||||
ret = (*h.getHandle)(i, &device);
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.getComputeCapability)(device, &major, &minor);
|
||||
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
||||
resp->err = strdup(buf);
|
||||
|
||||
@@ -15,14 +15,26 @@ typedef struct nvmlMemory_st {
|
||||
unsigned long long used;
|
||||
} nvmlMemory_t;
|
||||
|
||||
typedef enum nvmlBrandType_enum
|
||||
{
|
||||
NVML_BRAND_UNKNOWN = 0,
|
||||
} nvmlBrandType_t;
|
||||
|
||||
typedef struct cuda_handle {
|
||||
void *handle;
|
||||
nvmlReturn_t (*initFn)(void);
|
||||
nvmlReturn_t (*shutdownFn)(void);
|
||||
nvmlReturn_t (*getHandle)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*getMemInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
nvmlReturn_t (*getCount)(unsigned int *);
|
||||
nvmlReturn_t (*getComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
||||
uint16_t verbose;
|
||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
|
||||
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
||||
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
|
||||
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
|
||||
} cuda_handle_t;
|
||||
|
||||
typedef struct cuda_init_resp {
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#define ROCM_LOOKUP_SIZE 5
|
||||
|
||||
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
rsmi_status_t ret;
|
||||
resp->err = NULL;
|
||||
@@ -15,13 +13,22 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[ROCM_LOOKUP_SIZE] = {
|
||||
{"rsmi_init", (void *)&resp->rh.initFn},
|
||||
{"rsmi_shut_down", (void *)&resp->rh.shutdownFn},
|
||||
{"rsmi_dev_memory_total_get", (void *)&resp->rh.totalMemFn},
|
||||
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.usageMemFn},
|
||||
{"rsmi_version_get", (void *)&resp->rh.versionGetFn},
|
||||
// { "rsmi_dev_id_get", (void*)&resp->rh.getHandle },
|
||||
} l[] = {
|
||||
{"rsmi_init", (void *)&resp->rh.rsmi_init},
|
||||
{"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down},
|
||||
{"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get},
|
||||
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get},
|
||||
{"rsmi_version_get", (void *)&resp->rh.rsmi_version_get},
|
||||
{"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices},
|
||||
{"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get},
|
||||
{"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get},
|
||||
{"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get},
|
||||
{"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get},
|
||||
{"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get},
|
||||
{"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get},
|
||||
{"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get},
|
||||
{"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
|
||||
@@ -35,12 +42,19 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < ROCM_LOOKUP_SIZE; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
resp->rh.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->rh.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
@@ -49,8 +63,9 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->rh.initFn)(0);
|
||||
ret = (*resp->rh.rsmi_init)(0);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->rh.handle);
|
||||
resp->rh.handle = NULL;
|
||||
snprintf(buf, buflen, "rocm vram init failure: %d", ret);
|
||||
@@ -62,8 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
|
||||
|
||||
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
// uint32_t num_devices;
|
||||
// uint16_t device;
|
||||
resp->igpu_index = -1;
|
||||
uint64_t totalMem = 0;
|
||||
uint64_t usedMem = 0;
|
||||
rsmi_status_t ret;
|
||||
@@ -76,47 +90,101 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO - iterate through devices... ret =
|
||||
// rsmi_num_monitor_devices(&num_devices);
|
||||
|
||||
// ret = (*h.getHandle)(0, &device);
|
||||
// if (ret != RSMI_STATUS_SUCCESS) {
|
||||
// printf("rocm vram device lookup failure: %d\n", ret);
|
||||
// return -1;
|
||||
// }
|
||||
|
||||
// Get total memory - used memory for available memory
|
||||
ret = (*h.totalMemFn)(0, RSMI_MEM_TYPE_VRAM, &totalMem);
|
||||
ret = (*h.rsmi_num_monitor_devices)(&resp->count);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.usageMemFn)(0, RSMI_MEM_TYPE_VRAM, &usedMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count);
|
||||
|
||||
// TODO: set this to the actual number of devices
|
||||
resp->count = 1;
|
||||
resp->total = totalMem;
|
||||
resp->free = totalMem - usedMem;
|
||||
return;
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
for (i = 0; i < resp->count; i++) {
|
||||
if (h.verbose) {
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover, but don't fail on error
|
||||
ret = (*h.rsmi_dev_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_brand_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf);
|
||||
}
|
||||
ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret);
|
||||
} else {
|
||||
LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf);
|
||||
}
|
||||
}
|
||||
|
||||
// Get total memory - used memory for available memory
|
||||
ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
|
||||
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
|
||||
if (totalMem < 1024 * 1024 * 1024) {
|
||||
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
|
||||
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
|
||||
resp->igpu_index = i;
|
||||
} else {
|
||||
resp->total += totalMem;
|
||||
resp->free += totalMem - usedMem;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
if (h.handle == NULL) {
|
||||
resp->str = strdup("nvml handle not initialized");
|
||||
resp->str = strdup("rocm handle not initialized");
|
||||
resp->status = 1;
|
||||
return;
|
||||
}
|
||||
rsmi_version_t ver;
|
||||
rsmi_status_t ret;
|
||||
ret = h.versionGetFn(&ver);
|
||||
ret = h.rsmi_version_get(&ver);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
snprintf(buf, buflen, "unexpected response on version lookup %d", ret);
|
||||
resp->status = 1;
|
||||
@@ -127,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
|
||||
resp->str = strdup(buf);
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -24,12 +24,21 @@ typedef enum rsmi_memory_type {
|
||||
|
||||
typedef struct rocm_handle {
|
||||
void *handle;
|
||||
rsmi_status_t (*initFn)(uint64_t);
|
||||
rsmi_status_t (*shutdownFn)(void);
|
||||
rsmi_status_t (*totalMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*usageMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*versionGetFn) (rsmi_version_t *version);
|
||||
// rsmi_status_t (*getHandle)(uint32_t, uint16_t *);
|
||||
uint16_t verbose;
|
||||
rsmi_status_t (*rsmi_init)(uint64_t);
|
||||
rsmi_status_t (*rsmi_shut_down)(void);
|
||||
rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
|
||||
rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version);
|
||||
rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *);
|
||||
rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *);
|
||||
rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t);
|
||||
rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t);
|
||||
rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t);
|
||||
} rocm_handle_t;
|
||||
|
||||
typedef struct rocm_init_resp {
|
||||
|
||||
@@ -59,7 +59,7 @@ void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
};
|
||||
|
||||
printf("loading library %s\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_GLOBAL|RTLD_NOW);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
|
||||
@@ -4,7 +4,7 @@ package llm
|
||||
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
|
||||
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
@@ -136,12 +136,21 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
|
||||
|
||||
sparams.n_threads = C.uint(opts.NumThread)
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
sparams.verbose_logging = C.bool(true)
|
||||
} else {
|
||||
sparams.verbose_logging = C.bool(false)
|
||||
}
|
||||
|
||||
slog.Info("Initializing llama server")
|
||||
initResp := newExtServerResp(128)
|
||||
defer freeExtServerResp(initResp)
|
||||
C.dyn_llama_server_init(llm.s, &sparams, &initResp)
|
||||
if initResp.id < 0 {
|
||||
return nil, extServerResponseToErr(initResp)
|
||||
mutex.Unlock()
|
||||
err := extServerResponseToErr(initResp)
|
||||
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Starting llama main loop")
|
||||
@@ -152,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
|
||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
var imageData []ImageData
|
||||
|
||||
if len(predict.Images) > 0 {
|
||||
for cnt, i := range predict.Images {
|
||||
imageData = append(imageData, ImageData{Data: i, ID: cnt})
|
||||
}
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
@@ -180,7 +186,8 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": imageData,
|
||||
"image_data": predict.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if predict.Format == "json" {
|
||||
|
||||
@@ -3,22 +3,44 @@
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
||||
// Low level API access to verify GPU access
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::atomic<bool> ext_server_running(false);
|
||||
std::thread ext_server_thread;
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
#if SERVER_VERBOSE != 1
|
||||
log_disable();
|
||||
#endif
|
||||
LOG_TEE("system info: %s", llama_print_system_info());
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
server_verbose = true;
|
||||
log_disable();
|
||||
}
|
||||
|
||||
LOG_TEE("system info: %s\n", llama_print_system_info());
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama = new llama_server_context;
|
||||
log_set_target(stdout);
|
||||
gpt_params params;
|
||||
params.n_ctx = sparams->n_ctx;
|
||||
params.n_batch = sparams->n_batch;
|
||||
@@ -60,6 +82,18 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
params.mmproj = std::string(sparams->mmproj);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
|
||||
LOG_TEE("Performing pre-initialization of GPU\n");
|
||||
int id;
|
||||
cudaError_t cudaErr = cudaGetDevice(&id);
|
||||
if (cudaErr != cudaSuccess) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
// load the model
|
||||
@@ -88,18 +122,23 @@ void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
ext_server_running = true;
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
while (ext_server_running.load()) {
|
||||
if (!llama->update_slots()) {
|
||||
LOG_TEE(
|
||||
"unexpected error in llama server update_slots - exiting main "
|
||||
"loop\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
llama->queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_all_tasks_finished(std::bind(
|
||||
&llama_server_context::run_on_all_tasks_finished, llama));
|
||||
llama->queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama->queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama->queue_tasks.start_loop();
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
@@ -112,13 +151,10 @@ void llama_server_start() {
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// TODO - too verbose, remove once things are solid
|
||||
LOG_TEE("requesting llama server shutdown\n");
|
||||
ext_server_running = false;
|
||||
|
||||
// unblocks the update_slots() loop so it can clean up and exit
|
||||
llama->request_cancel(0);
|
||||
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
@@ -131,7 +167,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->request_completion(data, false, false, -1);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
llama->request_completion(resp->id, data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
@@ -149,16 +187,22 @@ void llama_server_completion_next_result(const int task_id,
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
task_result result = llama->next_result(task_id);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
LOG_TEE("next result cancel on error\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (result.stop) {
|
||||
LOG_TEE("next result cancel on stop\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
@@ -189,6 +233,7 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
@@ -273,13 +318,15 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->request_completion(
|
||||
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->next_result(task_id);
|
||||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
|
||||
@@ -45,6 +45,7 @@ typedef struct ext_server_params {
|
||||
bool embedding; // get only sentence embedding
|
||||
ext_server_lora_adapter_t *lora_adapters;
|
||||
char *mmproj;
|
||||
bool verbose_logging; // Enable verbose logging of the server
|
||||
} ext_server_params_t;
|
||||
|
||||
typedef struct ext_server_task_result {
|
||||
|
||||
@@ -39,6 +39,9 @@ init_vars() {
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
|
||||
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
fi
|
||||
}
|
||||
|
||||
git_module_setup() {
|
||||
@@ -61,6 +64,19 @@ apply_patches() {
|
||||
if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
|
||||
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
|
||||
fi
|
||||
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
# apply temporary patches until fix is upstream
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
for patch in ../patches/*.diff; do
|
||||
(cd ${LLAMACPP_DIR} && git apply ${patch})
|
||||
done
|
||||
fi
|
||||
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
|
||||
mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp
|
||||
@@ -83,8 +99,9 @@ build() {
|
||||
compress_libs() {
|
||||
echo "Compressing payloads to reduce overall binary size..."
|
||||
pids=""
|
||||
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
|
||||
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
|
||||
gzip --best ${lib} &
|
||||
gzip --best -f ${lib} &
|
||||
pids+=" $!"
|
||||
done
|
||||
echo
|
||||
@@ -97,4 +114,12 @@ compress_libs() {
|
||||
# Keep the local tree clean after we're done with the build
|
||||
cleanup() {
|
||||
(cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp)
|
||||
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -12,7 +12,13 @@ init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_ACCELERATE=off"
|
||||
sign() {
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp --deep --options=runtime --sign "$APPLE_IDENTITY" --identifier ai.ollama.ollama $1
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
@@ -21,10 +27,11 @@ case "${GOARCH}" in
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
|
||||
compress_libs
|
||||
|
||||
#
|
||||
@@ -32,10 +39,11 @@ case "${GOARCH}" in
|
||||
# Approximately 400% faster than LCD on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
|
||||
compress_libs
|
||||
|
||||
#
|
||||
@@ -43,17 +51,20 @@ case "${GOARCH}" in
|
||||
# Approximately 10% faster than AVX on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
|
||||
compress_libs
|
||||
;;
|
||||
"arm64")
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
|
||||
compress_libs
|
||||
;;
|
||||
*)
|
||||
|
||||
@@ -16,6 +16,10 @@ set -o pipefail
|
||||
|
||||
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
|
||||
amdGPUs() {
|
||||
if [ -n "${AMDGPU_TARGETS}" ]; then
|
||||
echo "${AMDGPU_TARGETS}"
|
||||
return
|
||||
fi
|
||||
GPU_LIST=(
|
||||
"gfx803"
|
||||
"gfx900"
|
||||
@@ -73,36 +77,42 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
|
||||
|
||||
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off"
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
compress_libs
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
# Approximately 400% faster than LCD on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
compress_libs
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx" ]; then
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
# Approximately 400% faster than LCD on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
# Approximately 10% faster than AVX on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
build
|
||||
compress_libs
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
# Approximately 10% faster than AVX on same CPU
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
build
|
||||
compress_libs
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Skipping CPU generation step as requested"
|
||||
@@ -118,6 +128,11 @@ if [ -z "${CUDA_LIB_DIR}" ] && [ -d /opt/cuda/targets/x86_64-linux/lib ]; then
|
||||
CUDA_LIB_DIR=/opt/cuda/targets/x86_64-linux/lib
|
||||
fi
|
||||
|
||||
# Allow override in case libcudart is in the wrong place
|
||||
if [ -z "${CUDART_LIB_DIR}" ]; then
|
||||
CUDART_LIB_DIR="${CUDA_LIB_DIR}"
|
||||
fi
|
||||
|
||||
if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
echo "CUDA libraries detected - building dynamic CUDA library"
|
||||
init_vars
|
||||
@@ -125,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
if [ -n "${CUDA_MAJOR}" ]; then
|
||||
CUDA_VARIANT=_v${CUDA_MAJOR}
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
@@ -141,6 +156,8 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
|
||||
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
|
||||
elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/"
|
||||
else
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
|
||||
fi
|
||||
|
||||
@@ -25,6 +25,11 @@ function init_vars {
|
||||
}
|
||||
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
|
||||
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
|
||||
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
} else {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
|
||||
}
|
||||
}
|
||||
|
||||
function git_module_setup {
|
||||
@@ -40,6 +45,29 @@ function apply_patches {
|
||||
if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
|
||||
Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
|
||||
}
|
||||
|
||||
# Apply temporary patches until fix is upstream
|
||||
$patches = Get-ChildItem "../patches/*.diff"
|
||||
foreach ($patch in $patches) {
|
||||
# Extract file paths from the patch file
|
||||
$filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
|
||||
$parts = $_ -split ' '
|
||||
($parts[1] -split '/', 2)[1]
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
foreach ($file in $filePaths) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git checkout $file
|
||||
}
|
||||
}
|
||||
|
||||
# Apply each patch
|
||||
foreach ($patch in $patches) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git apply $patch.FullName
|
||||
}
|
||||
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
$content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
|
||||
$content = $content -replace 'int main\(', 'int __main('
|
||||
@@ -76,7 +104,7 @@ function compress_libs {
|
||||
write-host "Compressing dlls..."
|
||||
$libs = dir "${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in $libs) {
|
||||
& "$script:GZIP" --best $file
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on")
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
install
|
||||
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"
|
||||
|
||||
115
llm/gguf.go
115
llm/gguf.go
@@ -69,12 +69,65 @@ type tensor struct {
|
||||
name string
|
||||
kind uint32
|
||||
offset uint64
|
||||
size uint64
|
||||
|
||||
// shape is the number of elements in each dimension
|
||||
shape [4]uint64
|
||||
}
|
||||
|
||||
func (t tensor) blockSize() uint64 {
|
||||
switch {
|
||||
case t.kind < 2:
|
||||
return 1
|
||||
case t.kind < 10:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t tensor) typeSize() uint64 {
|
||||
blockSize := t.blockSize()
|
||||
|
||||
switch t.kind {
|
||||
case 0: // FP32
|
||||
return 4
|
||||
case 1: // FP16
|
||||
return 2
|
||||
case 2: // Q4_0
|
||||
return 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
return 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
return 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
return 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
return 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t tensor) parameters() uint64 {
|
||||
return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
|
||||
}
|
||||
|
||||
func (t tensor) size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
type ggufModel struct {
|
||||
*containerGGUF
|
||||
|
||||
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||
shape[i] = llm.readU64(rso)
|
||||
}
|
||||
|
||||
kind := llm.readU32(rso)
|
||||
offset := llm.readU64(rso)
|
||||
|
||||
var blockSize uint64
|
||||
switch {
|
||||
case kind < 2:
|
||||
blockSize = 1
|
||||
case kind < 10:
|
||||
blockSize = 32
|
||||
default:
|
||||
blockSize = 256
|
||||
}
|
||||
|
||||
var typeSize uint64
|
||||
switch kind {
|
||||
case 0: // FP32
|
||||
typeSize = 4
|
||||
case 1: // FP16
|
||||
typeSize = 2
|
||||
case 2: // Q4_0
|
||||
typeSize = 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
typeSize = 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
typeSize = 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
typeSize = 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
typeSize = 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
typeSize = 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
typeSize = blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
typeSize = blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
}
|
||||
|
||||
parameters := shape[0] * shape[1] * shape[2] * shape[3]
|
||||
size := parameters * typeSize / blockSize
|
||||
|
||||
llm.tensors = append(llm.tensors, tensor{
|
||||
tensor := tensor{
|
||||
name: name,
|
||||
kind: kind,
|
||||
offset: offset,
|
||||
size: size,
|
||||
kind: llm.readU32(rso),
|
||||
offset: llm.readU64(rso),
|
||||
shape: shape,
|
||||
})
|
||||
}
|
||||
|
||||
llm.parameters += parameters
|
||||
llm.tensors = append(llm.tensors, tensor)
|
||||
llm.parameters += tensor.parameters()
|
||||
}
|
||||
|
||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
||||
@@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||
|
||||
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
|
||||
for _, tensor := range llm.tensors {
|
||||
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
|
||||
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
|
||||
rso.Seek(padded, io.SeekCurrent)
|
||||
}
|
||||
|
||||
|
||||
Submodule llm/llama.cpp updated: 584d674be6...d2f650cb5b
@@ -62,7 +62,7 @@ const maxRetries = 3
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
|
||||
@@ -70,7 +70,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||
break
|
||||
}
|
||||
|
||||
opts.NumGPU = 1
|
||||
// TODO: implement layer splitting on macOS
|
||||
opts.NumGPU = 999
|
||||
default:
|
||||
if info.Library == "cpu" {
|
||||
slog.Info("GPU not available, falling back to CPU")
|
||||
|
||||
30
llm/patches/01-cache.diff
Normal file
30
llm/patches/01-cache.diff
Normal file
@@ -0,0 +1,30 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index a48582ad..9fffffd8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -1564,12 +1564,6 @@ struct llama_server_context
|
||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
||||
-
|
||||
- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
||||
-
|
||||
- slot.cache_tokens = prompt_tokens;
|
||||
-
|
||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
@@ -1581,6 +1575,12 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
||||
+
|
||||
+ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
||||
+
|
||||
+ slot.cache_tokens = prompt_tokens;
|
||||
+
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
||||
90
llm/patches/02-shutdown.diff
Normal file
90
llm/patches/02-shutdown.diff
Normal file
@@ -0,0 +1,90 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index 11dd82c3..311495a8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
+#include <signal.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
||||
}
|
||||
}
|
||||
|
||||
+std::function<void(int)> shutdown_handler;
|
||||
+inline void signal_handler(int signal) { shutdown_handler(signal); }
|
||||
+
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
@@ -3014,8 +3018,14 @@ int main(int argc, char **argv)
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
- llama.queue_tasks.start_loop();
|
||||
|
||||
+ shutdown_handler = [&](int) {
|
||||
+ llama.queue_tasks.terminate();
|
||||
+ };
|
||||
+ signal(SIGTERM, signal_handler);
|
||||
+ signal(SIGINT, signal_handler);
|
||||
+ llama.queue_tasks.start_loop();
|
||||
+ svr.stop();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
|
||||
index 70cce072..2acb1eab 100644
|
||||
--- a/examples/server/utils.hpp
|
||||
+++ b/examples/server/utils.hpp
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <unordered_map>
|
||||
+#include <atomic>
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector<json> messages)
|
||||
struct llama_server_queue {
|
||||
int id = 0;
|
||||
std::mutex mutex_tasks;
|
||||
+ std::atomic<bool> running;
|
||||
// queues
|
||||
std::vector<task_server> queue_tasks;
|
||||
std::vector<task_server> queue_tasks_deferred;
|
||||
@@ -248,9 +250,15 @@ struct llama_server_queue {
|
||||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
- // Start the main loop. This call is blocking
|
||||
- [[noreturn]]
|
||||
+ // end the start_loop routine
|
||||
+ void terminate() {
|
||||
+ running = false;
|
||||
+ condition_tasks.notify_all();
|
||||
+ }
|
||||
+
|
||||
+ // Start the main loop.
|
||||
void start_loop() {
|
||||
+ running = true;
|
||||
while (true) {
|
||||
// new task arrived
|
||||
LOG_VERBOSE("have new task", {});
|
||||
@@ -294,8 +302,12 @@ struct llama_server_queue {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
+ if (!running.load()) {
|
||||
+ LOG_VERBOSE("ending start_loop", {});
|
||||
+ return;
|
||||
+ }
|
||||
condition_tasks.wait(lock, [&]{
|
||||
- return !queue_tasks.empty();
|
||||
+ return (!queue_tasks.empty() || !running.load());
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
@@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
case "MESSAGE":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
||||
}
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
|
||||
@@ -61,3 +61,38 @@ PARAMETER param1
|
||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
||||
|
||||
}
|
||||
|
||||
func Test_Parser_Messages(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
|
||||
@@ -133,13 +133,6 @@ func (b *Buffer) Size() int {
|
||||
return b.Buf.Size()
|
||||
}
|
||||
|
||||
func min(n, m int) int {
|
||||
if n > m {
|
||||
return m
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *Buffer) Add(r rune) {
|
||||
if b.Pos == b.Buf.Size() {
|
||||
fmt.Printf("%c", r)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
set -e
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
set -eu
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
docker build \
|
||||
@@ -13,3 +13,13 @@ docker build \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama:$VERSION \
|
||||
.
|
||||
|
||||
docker build \
|
||||
--load \
|
||||
--platform=linux/amd64 \
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--target runtime-rocm \
|
||||
-f Dockerfile \
|
||||
-t ollama/ollama:$VERSION-rocm \
|
||||
.
|
||||
|
||||
@@ -2,14 +2,24 @@
|
||||
|
||||
set -eu
|
||||
|
||||
export VERSION=${VERSION:-0.0.0}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
|
||||
|
||||
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
|
||||
export AMDGPU_TARGETS=${AMDGPU_TARGETS:=""}
|
||||
mkdir -p dist
|
||||
|
||||
for TARGETARCH in ${BUILD_ARCH}; do
|
||||
docker build --platform=linux/$TARGETARCH --build-arg=GOFLAGS --build-arg=CGO_CFLAGS --build-arg=OLLAMA_CUSTOM_CPU_DEFS -f Dockerfile.build -t builder:$TARGETARCH .
|
||||
docker build \
|
||||
--platform=linux/$TARGETARCH \
|
||||
--build-arg=GOFLAGS \
|
||||
--build-arg=CGO_CFLAGS \
|
||||
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
||||
--build-arg=AMDGPU_TARGETS \
|
||||
--target build-$TARGETARCH \
|
||||
-f Dockerfile \
|
||||
-t builder:$TARGETARCH \
|
||||
.
|
||||
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
|
||||
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH
|
||||
docker rm builder-$TARGETARCH
|
||||
|
||||
@@ -147,12 +147,7 @@ func (s SignatureData) Bytes() []byte {
|
||||
|
||||
// SignData takes a SignatureData object and signs it with a raw private key
|
||||
func (s SignatureData) Sign(rawKey []byte) (string, error) {
|
||||
privateKey, err := ssh.ParseRawPrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
signer, err := ssh.ParsePrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -25,6 +25,11 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
var errPartStalled = errors.New("part stalled")
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
@@ -44,10 +49,11 @@ type blobDownload struct {
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdated = time.Now()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, part))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
|
||||
return json.NewEncoder(partFile).Encode(part)
|
||||
}
|
||||
|
||||
func (b *blobDownload) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
b.Completed.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,10 +338,6 @@ type downloadOpts struct {
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
|
||||
132
server/images.go
132
server/images.go
@@ -41,7 +41,7 @@ type Model struct {
|
||||
Config ConfigV2
|
||||
ShortName string
|
||||
ModelPath string
|
||||
OriginalModel string
|
||||
ParentModel string
|
||||
AdapterPaths []string
|
||||
ProjectorPaths []string
|
||||
Template string
|
||||
@@ -50,6 +50,12 @@ type Model struct {
|
||||
Digest string
|
||||
Size int64
|
||||
Options map[string]interface{}
|
||||
Messages []Message
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PromptVars struct {
|
||||
@@ -57,6 +63,7 @@ type PromptVars struct {
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
Images []llm.ImageData
|
||||
}
|
||||
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
@@ -113,10 +120,6 @@ func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
||||
|
||||
// PreResponsePrompt returns the prompt before the response tag
|
||||
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
pre, _, err := extractParts(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -144,62 +147,68 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||
return Prompt(post, p)
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
type ChatHistory struct {
|
||||
Prompts []PromptVars
|
||||
LastSystem string
|
||||
}
|
||||
|
||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
lastSystem := m.System
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
System: m.System,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := Prompt(m.Template, currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
currentVars = PromptVars{}
|
||||
return nil
|
||||
}
|
||||
prompts := []PromptVars{}
|
||||
var images []llm.ImageData
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
// if this is the first message it overrides the system prompt in the modelfile
|
||||
if !currentVars.First && currentVars.System != "" {
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
lastSystem = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
|
||||
currentVars.Prompt = msg.Content
|
||||
currentImages = msg.Images
|
||||
for i := range msg.Images {
|
||||
id := len(images) + i
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
||||
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||
ID: id,
|
||||
Data: msg.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
images = append(images, currentVars.Images...)
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
p, err := m.PreResponsePrompt(currentVars)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
prompts = append(prompts, currentVars)
|
||||
}
|
||||
|
||||
return prompt.String(), currentImages, nil
|
||||
return &ChatHistory{
|
||||
Prompts: prompts,
|
||||
LastSystem: lastSystem,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
@@ -333,7 +342,7 @@ func GetModel(name string) (*Model, error) {
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.OriginalModel = layer.From
|
||||
model.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
@@ -374,6 +383,16 @@ func GetModel(name string) (*Model, error) {
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
msgs, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
bts, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
@@ -412,6 +431,13 @@ func realpath(mfDir, from string) string {
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
deleteMap[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
@@ -420,15 +446,13 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
},
|
||||
}
|
||||
|
||||
deleteMap := make(map[string]struct{})
|
||||
|
||||
var layers Layers
|
||||
messages := []string{}
|
||||
|
||||
params := make(map[string][]string)
|
||||
fromParams := make(map[string]any)
|
||||
|
||||
for _, c := range commands {
|
||||
slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
|
||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
|
||||
switch c.Name {
|
||||
@@ -602,11 +626,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
|
||||
layers.Replace(layer)
|
||||
case "message":
|
||||
messages = append(messages, c.Args)
|
||||
default:
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
|
||||
for _, m := range messages {
|
||||
// todo: handle images
|
||||
msg := strings.SplitN(m, ": ", 2)
|
||||
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers.Replace(layer)
|
||||
}
|
||||
|
||||
if len(params) > 0 {
|
||||
fn(api.ProgressResponse{Status: "creating parameters layer"})
|
||||
|
||||
@@ -903,8 +953,8 @@ func ShowModelfile(model *Model) (string, error) {
|
||||
mt.Model = model
|
||||
mt.From = model.ModelPath
|
||||
|
||||
if model.OriginalModel != "" {
|
||||
mt.From = model.OriginalModel
|
||||
if model.ParentModel != "" {
|
||||
mt.From = model.ParentModel
|
||||
}
|
||||
|
||||
modelFile := `# Modelfile generated by "ollama show"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -233,17 +234,58 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func chatHistoryEqual(a, b ChatHistory) bool {
|
||||
if len(a.Prompts) != len(b.Prompts) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a.Prompts {
|
||||
|
||||
if v.First != b.Prompts[i].First {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Response != b.Prompts[i].Response {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Prompt != b.Prompts[i].Prompt {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.System != b.Prompts[i].System {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(v.Images) != len(b.Prompts[i].Images) {
|
||||
return false
|
||||
}
|
||||
|
||||
for j, img := range v.Images {
|
||||
if img.ID != b.Prompts[i].Images[j].ID {
|
||||
return false
|
||||
}
|
||||
|
||||
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return a.LastSystem == b.LastSystem
|
||||
}
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want string
|
||||
wantErr string
|
||||
name string
|
||||
model Model
|
||||
msgs []api.Message
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Single Message",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -254,34 +296,22 @@ func TestChat(t *testing.T) {
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "eye of newt",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Message History",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -300,18 +330,85 @@ func TestChat(t *testing.T) {
|
||||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Assistant Only",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "everything nice",
|
||||
},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from modelfile",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Mojo Jojo.",
|
||||
Prompt: "hi",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Mojo Jojo.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from messages",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Professor Utonium.",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
@@ -326,11 +423,8 @@ func TestChat(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _, err := m.ChatPrompt(tt.msgs)
|
||||
got, err := tt.model.ChatPrompts(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
@@ -338,9 +432,10 @@ func TestChat(t *testing.T) {
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
if !chatHistoryEqual(*got, tt.want) {
|
||||
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
225
server/routes.go
225
server/routes.go
@@ -178,15 +178,17 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -233,6 +235,15 @@ func GenerateHandler(c *gin.Context) {
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
}
|
||||
|
||||
if promptVars.System == "" {
|
||||
promptVars.System = model.System
|
||||
}
|
||||
|
||||
for i := range req.Images {
|
||||
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
||||
}
|
||||
|
||||
p, err := model.PreResponsePrompt(promptVars)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -242,6 +253,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt = rebuild.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
@@ -295,11 +308,19 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i := range req.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: req.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
@@ -371,14 +392,17 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -659,6 +683,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
modelDetails := api.ModelDetails{
|
||||
ParentModel: model.ParentModel,
|
||||
Format: model.Config.ModelFormat,
|
||||
Family: model.Config.ModelFamily,
|
||||
Families: model.Config.ModelFamilies,
|
||||
@@ -674,11 +699,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
model.Template = req.Template
|
||||
}
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
for _, msg := range model.Messages {
|
||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
}
|
||||
|
||||
resp := &api.ShowResponse{
|
||||
License: strings.Join(model.License, "\n"),
|
||||
System: model.System,
|
||||
Template: model.Template,
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
var params []string
|
||||
@@ -911,13 +942,26 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
level := slog.LevelInfo
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
var programLevel = new(slog.LevelVar)
|
||||
h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: programLevel, AddSource: true})
|
||||
slog.SetDefault(slog.New(h))
|
||||
programLevel.Set(slog.LevelDebug)
|
||||
slog.Debug("Debug logging enabled")
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
|
||||
return attr
|
||||
},
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
// clean up unused layers and manifests
|
||||
if err := PruneLayers(); err != nil {
|
||||
@@ -1060,14 +1104,17 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1075,18 +1122,32 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
Done: true,
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
chat, err := model.ChatPrompts(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("chat handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
@@ -1160,3 +1221,115 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||
type promptInfo struct {
|
||||
vars PromptVars
|
||||
tokenLen int
|
||||
}
|
||||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
var images []llm.ImageData
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
prompt := chat.Prompts[i]
|
||||
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
break // reached max context length, stop adding more prompts
|
||||
}
|
||||
|
||||
for j := range prompt.Images {
|
||||
if totalTokenLength+768 > loaded.NumCtx {
|
||||
// this decreases the token length but overestimating is fine
|
||||
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
||||
continue
|
||||
}
|
||||
|
||||
totalTokenLength += 768
|
||||
images = append(images, prompt.Images[j])
|
||||
}
|
||||
|
||||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||
|
||||
// construct the final prompt string from the prompts which fit within the context window
|
||||
var result string
|
||||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
|
||||
return result, images, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||
if isMostRecent {
|
||||
p, err := model.PreResponsePrompt(vars)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
p, err := Prompt(model.Template, vars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||
promptsToAdd[i].vars.System = systemPrompt
|
||||
return promptsToAdd[:i+1], nil
|
||||
}
|
||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||
}
|
||||
|
||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||
recent.vars.System = systemPrompt
|
||||
return []promptInfo{recent}, nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
@@ -239,3 +240,258 @@ func Test_Routes(t *testing.T) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
chat *ChatHistory
|
||||
numCtx int
|
||||
runner MockLLM
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "eye of newt",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 4,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
},
|
||||
{
|
||||
name: "Message History Truncated, No System",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
Response: "spice",
|
||||
},
|
||||
{
|
||||
Prompt: "... and?",
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 2, // only 1 message from history and most recent message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Truncated",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Length Exceeded",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First is Preserved when Truncated",
|
||||
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
// first message omitted for test
|
||||
{
|
||||
Prompt: "Do you have a magic hat?",
|
||||
Response: "Of course.",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 3, // two most recent messages and room for system message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Most recent message is returned when longer than ctxLen",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1, // two most recent messages
|
||||
runner: MockLLM{
|
||||
encoding: []int{1, 2},
|
||||
},
|
||||
want: "[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
tt := testCase
|
||||
m := &Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
loaded.runner = &tt.runner
|
||||
loaded.Options = &api.Options{
|
||||
Runner: api.Runner{
|
||||
NumCtx: tt.numCtx,
|
||||
},
|
||||
}
|
||||
// TODO: add tests for trimming images
|
||||
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user