mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-23 08:10:48 -04:00
Compare commits
79 Commits
v2.8.1
...
docs_updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b8d6a31e2 | ||
|
|
f0752be4aa | ||
|
|
bafc9effad | ||
|
|
d2934dd69f | ||
|
|
a6b540737f | ||
|
|
f82065703d | ||
|
|
b423af001d | ||
|
|
b9e77d394b | ||
|
|
57222497ec | ||
|
|
5c5f07c1e7 | ||
|
|
f895d06605 | ||
|
|
bc8f648a91 | ||
|
|
8e57f4df31 | ||
|
|
a08cc5adbb | ||
|
|
595a73fce4 | ||
|
|
dc919e08e8 | ||
|
|
5d1018495f | ||
|
|
ad6fd7a991 | ||
|
|
e022b5959e | ||
|
|
db7f4955a1 | ||
|
|
5c69dd155f | ||
|
|
504f2e8bf4 | ||
|
|
e586dc2924 | ||
|
|
333f918005 | ||
|
|
c8e29033c2 | ||
|
|
d0bd961bde | ||
|
|
006511ee25 | ||
|
|
4ab72146cd | ||
|
|
b60a3fc879 | ||
|
|
a0eeb74957 | ||
|
|
daa0b8741c | ||
|
|
939411300a | ||
|
|
1c312685aa | ||
|
|
316de82f51 | ||
|
|
9068bc5271 | ||
|
|
31a4c9c9d3 | ||
|
|
c1966af2cf | ||
|
|
c665898652 | ||
|
|
f651a660aa | ||
|
|
ba672b51da | ||
|
|
be498c5dd9 | ||
|
|
6e95beccb9 | ||
|
|
c8be839481 | ||
|
|
c7e08813a5 | ||
|
|
d21a6b33ab | ||
|
|
9112cf153e | ||
|
|
3868ac8402 | ||
|
|
3f09010227 | ||
|
|
d6cf82aba3 | ||
|
|
dfe54639b1 | ||
|
|
bc5f5aa538 | ||
|
|
05818e0425 | ||
|
|
7f72a61104 | ||
|
|
8e45d47740 | ||
|
|
71771d1e9b | ||
|
|
aa098e4d0b | ||
|
|
0135e1e3b9 | ||
|
|
ff88c390bb | ||
|
|
d825821a22 | ||
|
|
cbed6ab1bb | ||
|
|
6fc122fa1a | ||
|
|
feba38be36 | ||
|
|
ba85d0bcad | ||
|
|
ad3623dd8d | ||
|
|
8292781045 | ||
|
|
54ec6348fa | ||
|
|
255748bcba | ||
|
|
594eb468df | ||
|
|
960d314e4f | ||
|
|
ed3b50622b | ||
|
|
9f2235c208 | ||
|
|
4ec50bfc41 | ||
|
|
51b67a247a | ||
|
|
01205fd4c0 | ||
|
|
c72808f18b | ||
|
|
6b539a2972 | ||
|
|
2151d21862 | ||
|
|
fb0a4c5d9a | ||
|
|
e690bf387a |
@@ -3,3 +3,4 @@ models
|
|||||||
examples/chatbot-ui/models
|
examples/chatbot-ui/models
|
||||||
examples/rwkv/models
|
examples/rwkv/models
|
||||||
examples/**/models
|
examples/**/models
|
||||||
|
Dockerfile
|
||||||
2
.env
2
.env
@@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
## Default path for models
|
## Default path for models
|
||||||
#
|
#
|
||||||
MODELS_PATH=/models
|
# MODELS_PATH=/models
|
||||||
|
|
||||||
## Enable debug mode
|
## Enable debug mode
|
||||||
# DEBUG=true
|
# DEBUG=true
|
||||||
|
|||||||
18
.github/workflows/image-pr.yml
vendored
18
.github/workflows/image-pr.yml
vendored
@@ -51,6 +51,22 @@ jobs:
|
|||||||
image-type: 'extras'
|
image-type: 'extras'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas'
|
||||||
|
ffmpeg: 'false'
|
||||||
|
image-type: 'extras'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
|
tag-suffix: 'sycl-f16-ffmpeg'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'extras'
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
core-image-build:
|
core-image-build:
|
||||||
uses: ./.github/workflows/image_build.yml
|
uses: ./.github/workflows/image_build.yml
|
||||||
with:
|
with:
|
||||||
@@ -97,4 +113,4 @@ jobs:
|
|||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'core'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
105
.github/workflows/image.yml
vendored
105
.github/workflows/image.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
extras-image-build:
|
self-hosted-jobs:
|
||||||
uses: ./.github/workflows/image_build.yml
|
uses: ./.github/workflows/image_build.yml
|
||||||
with:
|
with:
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
tag-latest: ${{ matrix.tag-latest }}
|
||||||
@@ -37,6 +37,7 @@ jobs:
|
|||||||
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
|
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
# Extra images
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
#platforms: 'linux/amd64,linux/arm64'
|
#platforms: 'linux/amd64,linux/arm64'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
@@ -103,35 +104,39 @@ jobs:
|
|||||||
image-type: 'extras'
|
image-type: 'extras'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
core-image-build:
|
- build-type: 'hipblas'
|
||||||
uses: ./.github/workflows/image_build.yml
|
|
||||||
with:
|
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
|
||||||
tag-suffix: ${{ matrix.tag-suffix }}
|
|
||||||
ffmpeg: ${{ matrix.ffmpeg }}
|
|
||||||
image-type: ${{ matrix.image-type }}
|
|
||||||
build-type: ${{ matrix.build-type }}
|
|
||||||
cuda-major-version: ${{ matrix.cuda-major-version }}
|
|
||||||
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
runs-on: ${{ matrix.runs-on }}
|
|
||||||
base-image: ${{ matrix.base-image }}
|
|
||||||
secrets:
|
|
||||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
|
||||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
|
||||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- build-type: ''
|
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-ffmpeg-core'
|
tag-suffix: '-hipblas-ffmpeg'
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'extras'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas'
|
||||||
|
ffmpeg: 'false'
|
||||||
|
image-type: 'extras'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
|
tag-suffix: '-sycl-f16-ffmpeg'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'extras'
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'sycl_f32'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
|
tag-suffix: '-sycl-f32-ffmpeg'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'extras'
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
# Core images
|
||||||
- build-type: 'sycl_f16'
|
- build-type: 'sycl_f16'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
@@ -164,6 +169,52 @@ jobs:
|
|||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'core'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas-ffmpeg-core'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas-core'
|
||||||
|
ffmpeg: 'false'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
|
||||||
|
core-image-build:
|
||||||
|
uses: ./.github/workflows/image_build.yml
|
||||||
|
with:
|
||||||
|
tag-latest: ${{ matrix.tag-latest }}
|
||||||
|
tag-suffix: ${{ matrix.tag-suffix }}
|
||||||
|
ffmpeg: ${{ matrix.ffmpeg }}
|
||||||
|
image-type: ${{ matrix.image-type }}
|
||||||
|
build-type: ${{ matrix.build-type }}
|
||||||
|
cuda-major-version: ${{ matrix.cuda-major-version }}
|
||||||
|
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
|
runs-on: ${{ matrix.runs-on }}
|
||||||
|
base-image: ${{ matrix.base-image }}
|
||||||
|
secrets:
|
||||||
|
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||||
|
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- build-type: ''
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-ffmpeg-core'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "ubuntu:22.04"
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
cuda-minor-version: "7"
|
cuda-minor-version: "7"
|
||||||
|
|||||||
29
.github/workflows/release.yaml
vendored
29
.github/workflows/release.yaml
vendored
@@ -89,6 +89,35 @@ jobs:
|
|||||||
files: |
|
files: |
|
||||||
release/*
|
release/*
|
||||||
|
|
||||||
|
build-stablediffusion:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '>=1.21.0'
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y --no-install-recommends libopencv-dev
|
||||||
|
sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||||
|
- name: Build stablediffusion
|
||||||
|
run: |
|
||||||
|
make backend-assets/grpc/stablediffusion
|
||||||
|
mkdir -p release && cp backend-assets/grpc/stablediffusion release
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: stablediffusion
|
||||||
|
path: release/
|
||||||
|
- name: Release
|
||||||
|
uses: softprops/action-gh-release@v1
|
||||||
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
release/*
|
||||||
|
|
||||||
build-macOS:
|
build-macOS:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,6 +21,7 @@ local-ai
|
|||||||
!charts/*
|
!charts/*
|
||||||
# prevent above rules from omitting the api/localai folder
|
# prevent above rules from omitting the api/localai folder
|
||||||
!api/localai
|
!api/localai
|
||||||
|
!core/**/localai
|
||||||
|
|
||||||
# Ignore models
|
# Ignore models
|
||||||
models/*
|
models/*
|
||||||
@@ -34,6 +35,7 @@ release/
|
|||||||
.idea
|
.idea
|
||||||
|
|
||||||
# Generated during build
|
# Generated during build
|
||||||
backend-assets/
|
backend-assets/*
|
||||||
|
!backend-assets/.keep
|
||||||
prepare
|
prepare
|
||||||
/ggml-metal.metal
|
/ggml-metal.metal
|
||||||
|
|||||||
55
Dockerfile
55
Dockerfile
@@ -1,10 +1,11 @@
|
|||||||
ARG GO_VERSION=1.21
|
|
||||||
ARG IMAGE_TYPE=extras
|
ARG IMAGE_TYPE=extras
|
||||||
ARG BASE_IMAGE=ubuntu:22.04
|
ARG BASE_IMAGE=ubuntu:22.04
|
||||||
|
|
||||||
# extras or core
|
# extras or core
|
||||||
FROM ${BASE_IMAGE} as requirements-core
|
FROM ${BASE_IMAGE} as requirements-core
|
||||||
|
|
||||||
|
USER root
|
||||||
|
|
||||||
ARG GO_VERSION=1.21.7
|
ARG GO_VERSION=1.21.7
|
||||||
ARG BUILD_TYPE
|
ARG BUILD_TYPE
|
||||||
ARG CUDA_MAJOR_VERSION=11
|
ARG CUDA_MAJOR_VERSION=11
|
||||||
@@ -22,7 +23,7 @@ RUN apt-get update && \
|
|||||||
apt-get install -y ca-certificates curl patch pip cmake git && apt-get clean
|
apt-get install -y ca-certificates curl patch pip cmake git && apt-get clean
|
||||||
|
|
||||||
# Install Go
|
# Install Go
|
||||||
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -v -C /usr/local -xz
|
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -C /usr/local -xz
|
||||||
ENV PATH $PATH:/usr/local/go/bin
|
ENV PATH $PATH:/usr/local/go/bin
|
||||||
|
|
||||||
COPY --chmod=644 custom-ca-certs/* /usr/local/share/ca-certificates/
|
COPY --chmod=644 custom-ca-certs/* /usr/local/share/ca-certificates/
|
||||||
@@ -42,8 +43,12 @@ RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
|
|||||||
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && apt-get clean \
|
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && apt-get clean \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
|
# Cuda
|
||||||
ENV PATH /usr/local/cuda/bin:${PATH}
|
ENV PATH /usr/local/cuda/bin:${PATH}
|
||||||
|
|
||||||
|
# HipBLAS requirements
|
||||||
|
ENV PATH /opt/rocm/bin:${PATH}
|
||||||
|
|
||||||
# OpenBLAS requirements and stable diffusion
|
# OpenBLAS requirements and stable diffusion
|
||||||
RUN apt-get install -y \
|
RUN apt-get install -y \
|
||||||
libopenblas-dev \
|
libopenblas-dev \
|
||||||
@@ -70,10 +75,16 @@ RUN curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmo
|
|||||||
apt-get install -y conda && apt-get clean
|
apt-get install -y conda && apt-get clean
|
||||||
|
|
||||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
RUN apt-get install -y python3-pip && apt-get clean
|
||||||
RUN pip install --upgrade pip
|
RUN pip install --upgrade pip
|
||||||
|
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
RUN apt-get install -y espeak-ng espeak && apt-get clean
|
RUN apt-get install -y espeak-ng espeak && apt-get clean
|
||||||
|
|
||||||
|
RUN if [ ! -e /usr/bin/python ]; then \
|
||||||
|
ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
; fi
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
###################################
|
###################################
|
||||||
|
|
||||||
@@ -94,6 +105,13 @@ COPY . .
|
|||||||
COPY .git .
|
COPY .git .
|
||||||
RUN make prepare
|
RUN make prepare
|
||||||
|
|
||||||
|
# If we are building with clblas support, we need the libraries for the builds
|
||||||
|
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y libclblast-dev && \
|
||||||
|
apt-get clean \
|
||||||
|
; fi
|
||||||
|
|
||||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
# stablediffusion does not tolerate a newer version of abseil, build it first
|
||||||
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
||||||
|
|
||||||
@@ -137,6 +155,13 @@ RUN if [ "${FFMPEG}" = "true" ]; then \
|
|||||||
apt-get install -y ffmpeg && apt-get clean \
|
apt-get install -y ffmpeg && apt-get clean \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
|
# Add OpenCL
|
||||||
|
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y libclblast1 && \
|
||||||
|
apt-get clean \
|
||||||
|
; fi
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
|
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
|
||||||
@@ -161,43 +186,43 @@ COPY --from=builder /build/backend-assets/grpc/stablediffusion ./backend-assets/
|
|||||||
|
|
||||||
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/autogptq \
|
make -C backend/python/autogptq \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/bark \
|
make -C backend/python/bark \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers \
|
make -C backend/python/diffusers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
|
make -C backend/python/vllm \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
|
make -C backend/python/mamba \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
|
make -C backend/python/sentencetransformers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers \
|
make -C backend/python/transformers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
|
make -C backend/python/vall-e-x \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama \
|
make -C backend/python/exllama \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama2 \
|
make -C backend/python/exllama2 \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/petals \
|
make -C backend/python/petals \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers-musicgen \
|
make -C backend/python/transformers-musicgen \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/coqui \
|
make -C backend/python/coqui \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
# Make sure the models directory exists
|
# Make sure the models directory exists
|
||||||
|
|||||||
39
Makefile
39
Makefile
@@ -8,7 +8,7 @@ GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0
|
|||||||
|
|
||||||
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
||||||
|
|
||||||
CPPLLAMA_VERSION?=f026f8120f97090d34a52b3dc023c82e0ede3f7d
|
CPPLLAMA_VERSION?=19885d205e768579ab090d1e99281cae58c21b54
|
||||||
|
|
||||||
# gpt4all version
|
# gpt4all version
|
||||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||||
@@ -16,7 +16,7 @@ GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8
|
|||||||
|
|
||||||
# go-rwkv version
|
# go-rwkv version
|
||||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||||
RWKV_VERSION?=633c5a3485c403cb2520693dc0991a25dace9f0f
|
RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
|
||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
|
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
|
||||||
@@ -28,7 +28,7 @@ BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
|||||||
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
|
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
|
||||||
|
|
||||||
# stablediffusion version
|
# stablediffusion version
|
||||||
STABLEDIFFUSION_VERSION?=d5d2be8e7e395c2d73ceef61e6fe8d240f2cd831
|
STABLEDIFFUSION_VERSION?=362df9da29f882dbf09ade61972d16a1f53c3485
|
||||||
|
|
||||||
# tinydream version
|
# tinydream version
|
||||||
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
|
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
|
||||||
@@ -44,6 +44,8 @@ BUILD_ID?=git
|
|||||||
|
|
||||||
TEST_DIR=/tmp/test
|
TEST_DIR=/tmp/test
|
||||||
|
|
||||||
|
TEST_FLAKES?=5
|
||||||
|
|
||||||
RANDOM := $(shell bash -c 'echo $$RANDOM')
|
RANDOM := $(shell bash -c 'echo $$RANDOM')
|
||||||
|
|
||||||
VERSION?=$(shell git describe --always --tags || echo "dev" )
|
VERSION?=$(shell git describe --always --tags || echo "dev" )
|
||||||
@@ -97,6 +99,8 @@ endif
|
|||||||
|
|
||||||
ifeq ($(BUILD_TYPE),hipblas)
|
ifeq ($(BUILD_TYPE),hipblas)
|
||||||
ROCM_HOME ?= /opt/rocm
|
ROCM_HOME ?= /opt/rocm
|
||||||
|
ROCM_PATH ?= /opt/rocm
|
||||||
|
LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib
|
||||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||||
# llama-ggml has no hipblas support, so override it here.
|
# llama-ggml has no hipblas support, so override it here.
|
||||||
@@ -105,7 +109,7 @@ ifeq ($(BUILD_TYPE),hipblas)
|
|||||||
GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100
|
GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100
|
||||||
AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
|
AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
|
||||||
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
|
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
|
||||||
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link
|
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE),metal)
|
ifeq ($(BUILD_TYPE),metal)
|
||||||
@@ -153,6 +157,7 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
|||||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||||
|
|
||||||
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
||||||
|
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||||
|
|
||||||
# If empty, then we build all
|
# If empty, then we build all
|
||||||
ifeq ($(GRPC_BACKENDS),)
|
ifeq ($(GRPC_BACKENDS),)
|
||||||
@@ -248,7 +253,7 @@ sources/go-piper/libpiper_binding.a: sources/go-piper
|
|||||||
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
|
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
|
||||||
|
|
||||||
backend/cpp/llama/llama.cpp:
|
backend/cpp/llama/llama.cpp:
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
|
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
|
||||||
|
|
||||||
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
|
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
|
||||||
touch $@
|
touch $@
|
||||||
@@ -326,7 +331,7 @@ test-models/testmodel:
|
|||||||
cp tests/models_fixtures/* test-models
|
cp tests/models_fixtures/* test-models
|
||||||
|
|
||||||
prepare-test: grpcs
|
prepare-test: grpcs
|
||||||
cp -rf backend-assets api
|
cp -rf backend-assets core/http
|
||||||
cp tests/models_fixtures/* test-models
|
cp tests/models_fixtures/* test-models
|
||||||
|
|
||||||
test: prepare test-models/testmodel grpcs
|
test: prepare test-models/testmodel grpcs
|
||||||
@@ -334,7 +339,7 @@ test: prepare test-models/testmodel grpcs
|
|||||||
export GO_TAGS="tts stablediffusion"
|
export GO_TAGS="tts stablediffusion"
|
||||||
$(MAKE) prepare-test
|
$(MAKE) prepare-test
|
||||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||||
$(MAKE) test-gpt4all
|
$(MAKE) test-gpt4all
|
||||||
$(MAKE) test-llama
|
$(MAKE) test-llama
|
||||||
$(MAKE) test-llama-gguf
|
$(MAKE) test-llama-gguf
|
||||||
@@ -363,23 +368,23 @@ teardown-e2e:
|
|||||||
|
|
||||||
test-gpt4all: prepare-test
|
test-gpt4all: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-llama: prepare-test
|
test-llama: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-llama-gguf: prepare-test
|
test-llama-gguf: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-tts: prepare-test
|
test-tts: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-stablediffusion: prepare-test
|
test-stablediffusion: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-container:
|
test-container:
|
||||||
docker build --target requirements -t local-ai-test-container .
|
docker build --target requirements -t local-ai-test-container .
|
||||||
@@ -480,7 +485,7 @@ ifdef BUILD_GRPC_FOR_BACKEND_LLAMA
|
|||||||
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
else
|
else
|
||||||
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
endif
|
endif
|
||||||
## BACKEND CPP LLAMA END
|
## BACKEND CPP LLAMA END
|
||||||
|
|
||||||
@@ -514,6 +519,7 @@ backend-assets/grpc/langchain-huggingface: backend-assets/grpc
|
|||||||
|
|
||||||
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
||||||
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
||||||
|
$(MAKE) sources/go-stable-diffusion; \
|
||||||
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
|
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-stable-diffusion/ LIBRARY_PATH=$(CURDIR)/sources/go-stable-diffusion/ \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-stable-diffusion/ LIBRARY_PATH=$(CURDIR)/sources/go-stable-diffusion/ \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
|
||||||
@@ -551,3 +557,10 @@ docker-image-intel:
|
|||||||
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
--build-arg GO_TAGS="none" \
|
--build-arg GO_TAGS="none" \
|
||||||
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
|
|
||||||
|
docker-image-intel-xpu:
|
||||||
|
docker build \
|
||||||
|
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04 \
|
||||||
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
|
--build-arg GO_TAGS="none" \
|
||||||
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
22
README.md
22
README.md
@@ -43,19 +43,24 @@
|
|||||||
|
|
||||||
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||||
|
|
||||||
- Intel GPU support (sycl): https://github.com/mudler/LocalAI/issues/1653
|
- Parallel function calling: https://github.com/mudler/LocalAI/pull/1726
|
||||||
|
- Upload file API: https://github.com/mudler/LocalAI/pull/1703
|
||||||
|
- Tools API support: https://github.com/mudler/LocalAI/pull/1715
|
||||||
|
- LLaVa 1.6: https://github.com/mudler/LocalAI/pull/1714
|
||||||
|
- ROCm container images: https://github.com/mudler/LocalAI/pull/1595
|
||||||
|
- Intel GPU support (sycl, transformers, diffusers): https://github.com/mudler/LocalAI/issues/1653
|
||||||
- Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651
|
- Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651
|
||||||
- Mamba support: https://github.com/mudler/LocalAI/pull/1589
|
- Mamba support: https://github.com/mudler/LocalAI/pull/1589
|
||||||
- Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522
|
- Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522
|
||||||
- 🐸 Coqui: https://github.com/mudler/LocalAI/pull/1489
|
- 🐸 Coqui: https://github.com/mudler/LocalAI/pull/1489
|
||||||
- Inline templates: https://github.com/mudler/LocalAI/pull/1452
|
|
||||||
- Mixtral: https://github.com/mudler/LocalAI/pull/1449
|
|
||||||
- Img2vid https://github.com/mudler/LocalAI/pull/1442
|
- Img2vid https://github.com/mudler/LocalAI/pull/1442
|
||||||
- Musicgen https://github.com/mudler/LocalAI/pull/1387
|
|
||||||
|
|
||||||
Hot topics (looking for contributors):
|
Hot topics (looking for contributors):
|
||||||
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
|
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
|
||||||
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
|
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
|
||||||
|
- Assistant API: https://github.com/mudler/LocalAI/issues/1273
|
||||||
|
- Moderation endpoint: https://github.com/mudler/LocalAI/issues/999
|
||||||
|
- Vulkan: https://github.com/mudler/LocalAI/issues/1647
|
||||||
|
|
||||||
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
|
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
|
||||||
|
|
||||||
@@ -64,7 +69,7 @@ If you want to help and contribute, issues up for grabs: https://github.com/mudl
|
|||||||
For a detailed step-by-step introduction, refer to the [Getting Started](https://localai.io/basics/getting_started/index.html) guide. For those in a hurry, here's a straightforward one-liner to launch a LocalAI instance with [phi-2](https://huggingface.co/microsoft/phi-2) using `docker`:
|
For a detailed step-by-step introduction, refer to the [Getting Started](https://localai.io/basics/getting_started/index.html) guide. For those in a hurry, here's a straightforward one-liner to launch a LocalAI instance with [phi-2](https://huggingface.co/microsoft/phi-2) using `docker`:
|
||||||
|
|
||||||
```
|
```
|
||||||
docker run -ti -p 8080:8080 localai/localai:v2.7.0-ffmpeg-core phi-2
|
docker run -ti -p 8080:8080 localai/localai:v2.9.0-ffmpeg-core phi-2
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🚀 [Features](https://localai.io/features/)
|
## 🚀 [Features](https://localai.io/features/)
|
||||||
@@ -94,10 +99,6 @@ WebUIs:
|
|||||||
|
|
||||||
Model galleries
|
Model galleries
|
||||||
- https://github.com/go-skynet/model-gallery
|
- https://github.com/go-skynet/model-gallery
|
||||||
|
|
||||||
Auto Docker / Model setup
|
|
||||||
- https://io.midori-ai.xyz/howtos/easy-localai-installer/
|
|
||||||
- https://io.midori-ai.xyz/howtos/easy-model-installer/
|
|
||||||
|
|
||||||
Other:
|
Other:
|
||||||
- Helm chart https://github.com/go-skynet/helm-charts
|
- Helm chart https://github.com/go-skynet/helm-charts
|
||||||
@@ -108,6 +109,7 @@ Other:
|
|||||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
||||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
||||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
||||||
|
|
||||||
|
|
||||||
### 🔗 Resources
|
### 🔗 Resources
|
||||||
|
|
||||||
@@ -119,6 +121,8 @@ Other:
|
|||||||
|
|
||||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||||
|
|
||||||
|
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/ai/answers/tiZMDoZzZV6TLxgDXNBnFE/deploying-helm-charts-on-aws-eks)
|
||||||
|
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
|
||||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
||||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
||||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||||
|
|||||||
42
SECURITY.md
Normal file
42
SECURITY.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Security Policy
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
At LocalAI, we take the security of our software seriously. We understand the importance of protecting our community from vulnerabilities and are committed to ensuring the safety and security of our users.
|
||||||
|
|
||||||
|
## Supported Versions
|
||||||
|
|
||||||
|
We provide support and updates for certain versions of our software. The following table outlines which versions are currently supported with security updates:
|
||||||
|
|
||||||
|
| Version | Supported |
|
||||||
|
| ------- | ------------------ |
|
||||||
|
| > 2.0 | :white_check_mark: |
|
||||||
|
| < 2.0 | :x: |
|
||||||
|
|
||||||
|
Please ensure that you are using a supported version to receive the latest security updates.
|
||||||
|
|
||||||
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
We encourage the responsible disclosure of any security vulnerabilities. If you believe you've found a security issue in our software, we kindly ask you to follow the steps below to report it to us:
|
||||||
|
|
||||||
|
1. **Email Us:** Send an email to [security@localai.io](mailto:security@localai.io) with a detailed report. Please do not disclose the vulnerability publicly or to any third parties before it has been addressed by us.
|
||||||
|
|
||||||
|
2. **Expect a Response:** We aim to acknowledge receipt of vulnerability reports within 48 hours. Our security team will review your report and work closely with you to understand the impact and ensure a thorough investigation.
|
||||||
|
|
||||||
|
3. **Collaboration:** If the vulnerability is accepted, we will work with you and our community to address the issue promptly. We'll keep you informed throughout the resolution process and may request additional information or collaboration.
|
||||||
|
|
||||||
|
4. **Disclosure:** Once the vulnerability has been resolved, we encourage a coordinated disclosure. We believe in transparency and will work with you to ensure that our community is informed in a responsible manner.
|
||||||
|
|
||||||
|
## Use of Third-Party Platforms
|
||||||
|
|
||||||
|
As a Free and Open Source Software (FOSS) organization, we do not offer monetary bounties. However, researchers who wish to report vulnerabilities can also do so via [Huntr](https://huntr.dev/bounties), a platform that recognizes contributions to open source security.
|
||||||
|
|
||||||
|
## Contact
|
||||||
|
|
||||||
|
For any security-related inquiries beyond vulnerability reporting, please contact us at [security@localai.io](mailto:security@localai.io).
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
We appreciate the efforts of those who contribute to the security of our project. Your responsible disclosure is invaluable to the safety and integrity of LocalAI.
|
||||||
|
|
||||||
|
Thank you for helping us keep LocalAI secure.
|
||||||
288
api/api.go
288
api/api.go
@@ -1,288 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/localai"
|
|
||||||
"github.com/go-skynet/LocalAI/api/openai"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/startup"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
|
|
||||||
options := options.NewOptions(opts...)
|
|
||||||
|
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
|
||||||
if options.Debug {
|
|
||||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
|
||||||
|
|
||||||
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
|
|
||||||
|
|
||||||
cl := config.NewConfigLoader()
|
|
||||||
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
|
||||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
|
||||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
|
||||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
|
||||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.Debug {
|
|
||||||
for _, v := range cl.ListConfigs() {
|
|
||||||
cfg, _ := cl.GetConfig(v)
|
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.AssetsDestination != "" {
|
|
||||||
// Extract files from the embedded FS
|
|
||||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
|
||||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// turn off any process that was started by GRPC if the context is canceled
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
options.Loader.StopAllGRPC()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if options.WatchDog {
|
|
||||||
wd := model.NewWatchDog(
|
|
||||||
options.Loader,
|
|
||||||
options.WatchDogBusyTimeout,
|
|
||||||
options.WatchDogIdleTimeout,
|
|
||||||
options.WatchDogBusy,
|
|
||||||
options.WatchDogIdle)
|
|
||||||
options.Loader.SetWatchDog(wd)
|
|
||||||
go wd.Run()
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
wd.Shutdown()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return options, cl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
|
||||||
|
|
||||||
options, cl, err := Startup(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return errors as JSON responses
|
|
||||||
app := fiber.New(fiber.Config{
|
|
||||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
|
||||||
DisableStartupMessage: options.DisableMessage,
|
|
||||||
// Override default error handler
|
|
||||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
|
||||||
// Status code defaults to 500
|
|
||||||
code := fiber.StatusInternalServerError
|
|
||||||
|
|
||||||
// Retrieve the custom status code if it's a *fiber.Error
|
|
||||||
var e *fiber.Error
|
|
||||||
if errors.As(err, &e) {
|
|
||||||
code = e.Code
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send custom error page
|
|
||||||
return ctx.Status(code).JSON(
|
|
||||||
schema.ErrorResponse{
|
|
||||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if options.Debug {
|
|
||||||
app.Use(logger.New(logger.Config{
|
|
||||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default middleware config
|
|
||||||
app.Use(recover.New())
|
|
||||||
if options.Metrics != nil {
|
|
||||||
app.Use(metrics.APIMiddleware(options.Metrics))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
|
||||||
auth := func(c *fiber.Ctx) error {
|
|
||||||
if len(options.ApiKeys) == 0 {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for api_keys.json file
|
|
||||||
fileContent, err := os.ReadFile("api_keys.json")
|
|
||||||
if err == nil {
|
|
||||||
// Parse JSON content from the file
|
|
||||||
var fileKeys []string
|
|
||||||
err := json.Unmarshal(fileContent, &fileKeys)
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add file keys to options.ApiKeys
|
|
||||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(options.ApiKeys) == 0 {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
authHeader := c.Get("Authorization")
|
|
||||||
if authHeader == "" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
|
||||||
}
|
|
||||||
authHeaderParts := strings.Split(authHeader, " ")
|
|
||||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := authHeaderParts[1]
|
|
||||||
for _, key := range options.ApiKeys {
|
|
||||||
if apiKey == key {
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.CORS {
|
|
||||||
var c func(ctx *fiber.Ctx) error
|
|
||||||
if options.CORSAllowOrigins == "" {
|
|
||||||
c = cors.New()
|
|
||||||
} else {
|
|
||||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
|
||||||
}
|
|
||||||
|
|
||||||
app.Use(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAI API endpoints
|
|
||||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
|
||||||
galleryService.Start(options.Context, cl)
|
|
||||||
|
|
||||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
|
||||||
return c.JSON(struct {
|
|
||||||
Version string `json:"version"`
|
|
||||||
}{Version: internal.PrintableVersion()})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Make sure directories exists
|
|
||||||
os.MkdirAll(options.ImageDir, 0755)
|
|
||||||
os.MkdirAll(options.AudioDir, 0755)
|
|
||||||
os.MkdirAll(options.Loader.ModelPath, 0755)
|
|
||||||
|
|
||||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
|
||||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
|
||||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
|
||||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
|
||||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
|
||||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
|
||||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
|
||||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
|
||||||
|
|
||||||
// chat
|
|
||||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
|
||||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
|
||||||
|
|
||||||
// edit
|
|
||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
|
||||||
|
|
||||||
// completion
|
|
||||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
|
|
||||||
|
|
||||||
// embeddings
|
|
||||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
|
||||||
|
|
||||||
// audio
|
|
||||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
|
|
||||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
|
|
||||||
|
|
||||||
// images
|
|
||||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
|
|
||||||
|
|
||||||
if options.ImageDir != "" {
|
|
||||||
app.Static("/generated-images", options.ImageDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.AudioDir != "" {
|
|
||||||
app.Static("/generated-audio", options.AudioDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok := func(c *fiber.Ctx) error {
|
|
||||||
return c.SendStatus(200)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kubernetes health checks
|
|
||||||
app.Get("/healthz", ok)
|
|
||||||
app.Get("/readyz", ok)
|
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
|
||||||
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
|
|
||||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
|
||||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
|
||||||
|
|
||||||
// models
|
|
||||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
|
||||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
|
||||||
|
|
||||||
app.Get("/metrics", metrics.MetricsHandler())
|
|
||||||
|
|
||||||
return app, nil
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
|
||||||
model.WithBackendString(c.Backend),
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithModel(c.Model),
|
|
||||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
|
||||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
|
||||||
SchedulerType: c.Diffusers.SchedulerType,
|
|
||||||
PipelineType: c.Diffusers.PipelineType,
|
|
||||||
CFGScale: c.Diffusers.CFGScale,
|
|
||||||
LoraAdapter: c.LoraAdapter,
|
|
||||||
LoraScale: c.LoraScale,
|
|
||||||
LoraBase: c.LoraBase,
|
|
||||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
|
||||||
CLIPModel: c.Diffusers.ClipModel,
|
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
|
||||||
ControlNet: c.Diffusers.ControlNet,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader(
|
|
||||||
opts...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fn := func() error {
|
|
||||||
_, err := inferenceModel.GenerateImage(
|
|
||||||
o.Context,
|
|
||||||
&proto.GenerateImageRequest{
|
|
||||||
Height: int32(height),
|
|
||||||
Width: int32(width),
|
|
||||||
Mode: int32(mode),
|
|
||||||
Step: int32(step),
|
|
||||||
Seed: int32(seed),
|
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
|
||||||
PositivePrompt: positive_prompt,
|
|
||||||
NegativePrompt: negative_prompt,
|
|
||||||
Dst: dst,
|
|
||||||
Src: src,
|
|
||||||
EnableParameters: c.Diffusers.EnableParameters,
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fn, nil
|
|
||||||
}
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
)
|
|
||||||
|
|
||||||
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
|
|
||||||
if o.SingleBackend {
|
|
||||||
opts = append(opts, model.WithSingleActiveBackend())
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.ParallelBackendRequests {
|
|
||||||
opts = append(opts, model.EnableParallelRequests)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.GRPC.Attempts != 0 {
|
|
||||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.GRPC.AttemptsSleepTime != 0 {
|
|
||||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range o.ExternalGRPCBackends {
|
|
||||||
opts = append(opts, model.WithExternalBackend(k, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
return opts
|
|
||||||
}
|
|
||||||
|
|
||||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
|
||||||
b := 512
|
|
||||||
if c.Batch != 0 {
|
|
||||||
b = c.Batch
|
|
||||||
}
|
|
||||||
|
|
||||||
return &pb.ModelOptions{
|
|
||||||
ContextSize: int32(c.ContextSize),
|
|
||||||
Seed: int32(c.Seed),
|
|
||||||
NBatch: int32(b),
|
|
||||||
NoMulMatQ: c.NoMulMatQ,
|
|
||||||
CUDA: c.CUDA, // diffusers, transformers
|
|
||||||
DraftModel: c.DraftModel,
|
|
||||||
AudioPath: c.VallE.AudioPath,
|
|
||||||
Quantization: c.Quantization,
|
|
||||||
MMProj: c.MMProj,
|
|
||||||
YarnExtFactor: c.YarnExtFactor,
|
|
||||||
YarnAttnFactor: c.YarnAttnFactor,
|
|
||||||
YarnBetaFast: c.YarnBetaFast,
|
|
||||||
YarnBetaSlow: c.YarnBetaSlow,
|
|
||||||
LoraAdapter: c.LoraAdapter,
|
|
||||||
LoraBase: c.LoraBase,
|
|
||||||
LoraScale: c.LoraScale,
|
|
||||||
NGQA: c.NGQA,
|
|
||||||
RMSNormEps: c.RMSNormEps,
|
|
||||||
F16Memory: c.F16,
|
|
||||||
MLock: c.MMlock,
|
|
||||||
RopeFreqBase: c.RopeFreqBase,
|
|
||||||
RopeScaling: c.RopeScaling,
|
|
||||||
Type: c.ModelType,
|
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
|
||||||
NUMA: c.NUMA,
|
|
||||||
Embeddings: c.Embeddings,
|
|
||||||
LowVRAM: c.LowVRAM,
|
|
||||||
NGPULayers: int32(c.NGPULayers),
|
|
||||||
MMap: c.MMap,
|
|
||||||
MainGPU: c.MainGPU,
|
|
||||||
Threads: int32(c.Threads),
|
|
||||||
TensorSplit: c.TensorSplit,
|
|
||||||
// AutoGPTQ
|
|
||||||
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
|
||||||
Device: c.AutoGPTQ.Device,
|
|
||||||
UseTriton: c.AutoGPTQ.Triton,
|
|
||||||
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
|
|
||||||
// RWKV
|
|
||||||
Tokenizer: c.Tokenizer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
|
||||||
promptCachePath := ""
|
|
||||||
if c.PromptCachePath != "" {
|
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
|
||||||
os.MkdirAll(filepath.Dir(p), 0755)
|
|
||||||
promptCachePath = p
|
|
||||||
}
|
|
||||||
return &pb.PredictOptions{
|
|
||||||
Temperature: float32(c.Temperature),
|
|
||||||
TopP: float32(c.TopP),
|
|
||||||
NDraft: c.NDraft,
|
|
||||||
TopK: int32(c.TopK),
|
|
||||||
Tokens: int32(c.Maxtokens),
|
|
||||||
Threads: int32(c.Threads),
|
|
||||||
PromptCacheAll: c.PromptCacheAll,
|
|
||||||
PromptCacheRO: c.PromptCacheRO,
|
|
||||||
PromptCachePath: promptCachePath,
|
|
||||||
F16KV: c.F16,
|
|
||||||
DebugMode: c.Debug,
|
|
||||||
Grammar: c.Grammar,
|
|
||||||
NegativePromptScale: c.NegativePromptScale,
|
|
||||||
RopeFreqBase: c.RopeFreqBase,
|
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
|
||||||
NegativePrompt: c.NegativePrompt,
|
|
||||||
Mirostat: int32(c.LLMConfig.Mirostat),
|
|
||||||
MirostatETA: float32(c.LLMConfig.MirostatETA),
|
|
||||||
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
|
|
||||||
Debug: c.Debug,
|
|
||||||
StopPrompts: c.StopWords,
|
|
||||||
Repeat: int32(c.RepeatPenalty),
|
|
||||||
NKeep: int32(c.Keep),
|
|
||||||
Batch: int32(c.Batch),
|
|
||||||
IgnoreEOS: c.IgnoreEOS,
|
|
||||||
Seed: int32(c.Seed),
|
|
||||||
FrequencyPenalty: float32(c.FrequencyPenalty),
|
|
||||||
MLock: c.MMlock,
|
|
||||||
MMap: c.MMap,
|
|
||||||
MainGPU: c.MainGPU,
|
|
||||||
TensorSplit: c.TensorSplit,
|
|
||||||
TailFreeSamplingZ: float32(c.TFZ),
|
|
||||||
TypicalP: float32(c.TypicalP),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
|
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
|
||||||
model.WithBackendString(model.WhisperBackend),
|
|
||||||
model.WithModel(c.Model),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithThreads(uint32(c.Threads)),
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
})
|
|
||||||
|
|
||||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if whisperModel == nil {
|
|
||||||
return nil, fmt.Errorf("could not load whisper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
|
||||||
Dst: audio,
|
|
||||||
Language: language,
|
|
||||||
Threads: uint32(c.Threads),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
api_config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
|
||||||
counter := 1
|
|
||||||
fileName := baseName + ext
|
|
||||||
|
|
||||||
for {
|
|
||||||
filePath := filepath.Join(dir, fileName)
|
|
||||||
_, err := os.Stat(filePath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fileName
|
|
||||||
}
|
|
||||||
|
|
||||||
counter++
|
|
||||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) {
|
|
||||||
bb := backend
|
|
||||||
if bb == "" {
|
|
||||||
bb = model.PiperBackend
|
|
||||||
}
|
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
|
||||||
|
|
||||||
opts := modelOpts(api_config.Config{}, o, []model.Option{
|
|
||||||
model.WithBackendString(bb),
|
|
||||||
model.WithModel(modelFile),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
|
||||||
})
|
|
||||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if piperModel == nil {
|
|
||||||
return "", nil, fmt.Errorf("could not load piper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
|
||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
|
||||||
filePath := filepath.Join(o.AudioDir, fileName)
|
|
||||||
|
|
||||||
// If the model file is not empty, we pass it joined with the model path
|
|
||||||
modelPath := ""
|
|
||||||
if modelFile != "" {
|
|
||||||
if bb != model.TransformersMusicGen {
|
|
||||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
|
||||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
modelPath = modelFile
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
|
||||||
Text: text,
|
|
||||||
Model: modelPath,
|
|
||||||
Dst: filePath,
|
|
||||||
})
|
|
||||||
|
|
||||||
return filePath, res, err
|
|
||||||
}
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BackendMonitorRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitorResponse struct {
|
|
||||||
MemoryInfo *gopsutil.MemoryInfoStat
|
|
||||||
MemoryPercent float32
|
|
||||||
CPUPercent float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitor struct {
|
|
||||||
configLoader *config.ConfigLoader
|
|
||||||
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
|
|
||||||
return BackendMonitor{
|
|
||||||
configLoader: configLoader,
|
|
||||||
options: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
|
|
||||||
config, exists := bm.configLoader.GetConfig(model)
|
|
||||||
var backend string
|
|
||||||
if exists {
|
|
||||||
backend = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backend = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backend, ".bin") {
|
|
||||||
backend = fmt.Sprintf("%s.bin", backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
pid, err := bm.options.Loader.GetGRPCPID(backend)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
|
||||||
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memInfo, err := backendProcess.MemoryInfo()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memPercent, err := backendProcess.MemoryPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cpuPercent, err := backendProcess.CPUPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &BackendMonitorResponse{
|
|
||||||
MemoryInfo: memInfo,
|
|
||||||
MemoryPercent: memPercent,
|
|
||||||
CPUPercent: cpuPercent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
|
|
||||||
input := new(BackendMonitorRequest)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
config, exists := bm.configLoader.GetConfig(input.Model)
|
|
||||||
var backendId string
|
|
||||||
if exists {
|
|
||||||
backendId = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backendId = input.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backendId, ".bin") {
|
|
||||||
backendId = fmt.Sprintf("%s.bin", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return backendId, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
model := bm.options.Loader.CheckIsLoaded(backendId)
|
|
||||||
if model == "" {
|
|
||||||
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
|
|
||||||
if rpcErr != nil {
|
|
||||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
|
||||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
|
||||||
if slbErr != nil {
|
|
||||||
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
|
||||||
}
|
|
||||||
return c.JSON(proto.StatusResponse{
|
|
||||||
State: proto.StatusResponse_ERROR,
|
|
||||||
Memory: &proto.MemoryUsageData{
|
|
||||||
Total: val.MemoryInfo.VMS,
|
|
||||||
Breakdown: map[string]uint64{
|
|
||||||
"gopsutil-RSS": val.MemoryInfo.RSS,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bm.options.Loader.ShutdownModel(backendId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
json "github.com/json-iterator/go"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type galleryOp struct {
|
|
||||||
req gallery.GalleryModel
|
|
||||||
id string
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
galleryName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryOpStatus struct {
|
|
||||||
FileName string `json:"file_name"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
Processed bool `json:"processed"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
TotalFileSize string `json:"file_size"`
|
|
||||||
DownloadedFileSize string `json:"downloaded_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryApplier struct {
|
|
||||||
modelPath string
|
|
||||||
sync.Mutex
|
|
||||||
C chan galleryOp
|
|
||||||
statuses map[string]*galleryOpStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGalleryService(modelPath string) *galleryApplier {
|
|
||||||
return &galleryApplier{
|
|
||||||
modelPath: modelPath,
|
|
||||||
C: make(chan galleryOp),
|
|
||||||
statuses: make(map[string]*galleryOpStatus),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
|
||||||
|
|
||||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
||||||
|
|
||||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
g.statuses[s] = op
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses[s]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Done():
|
|
||||||
return
|
|
||||||
case op := <-g.C:
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
|
||||||
|
|
||||||
// updates the status with an error
|
|
||||||
updateError := func(e error) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
|
||||||
}
|
|
||||||
|
|
||||||
// displayDownload displays the download progress
|
|
||||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
|
||||||
if op.galleryName != "" {
|
|
||||||
if strings.Contains(op.galleryName, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload models
|
|
||||||
err = cm.LoadConfigs(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cm.Preload(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryModel struct {
|
|
||||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
|
||||||
var err error
|
|
||||||
for _, r := range requests {
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
if r.ID == "" {
|
|
||||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
if strings.Contains(r.ID, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
dat, err := os.ReadFile(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var requests []galleryModel
|
|
||||||
|
|
||||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
var requests []galleryModel
|
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Endpoint Service
|
|
||||||
|
|
||||||
type ModelGalleryService struct {
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
modelPath string
|
|
||||||
galleryApplier *galleryApplier
|
|
||||||
}
|
|
||||||
|
|
||||||
type GalleryModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
gallery.GalleryModel
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
|
|
||||||
return ModelGalleryService{
|
|
||||||
galleries: galleries,
|
|
||||||
modelPath: modelPath,
|
|
||||||
galleryApplier: galleryApplier,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
|
|
||||||
if status == nil {
|
|
||||||
return fmt.Errorf("could not find any status for ID")
|
|
||||||
}
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
return c.JSON(mgs.galleryApplier.getAllStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(GalleryModel)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, err := uuid.NewUUID()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
mgs.galleryApplier.C <- galleryOp{
|
|
||||||
req: input.GalleryModel,
|
|
||||||
id: uuid.String(),
|
|
||||||
galleryName: input.ID,
|
|
||||||
galleries: mgs.galleries,
|
|
||||||
}
|
|
||||||
return c.JSON(struct {
|
|
||||||
ID string `json:"uuid"`
|
|
||||||
StatusURL string `json:"status"`
|
|
||||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
|
||||||
|
|
||||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
|
||||||
for _, m := range models {
|
|
||||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(models)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
|
||||||
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s already exists", input.Name)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
|
||||||
mgs.galleries = append(mgs.galleries, *input)
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
|
||||||
}
|
|
||||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
})
|
|
||||||
return c.Send(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TTSRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
Input string `json:"input" yaml:"input"`
|
|
||||||
Backend string `json:"backend" yaml:"backend"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
input := new(TTSRequest)
|
|
||||||
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false)
|
|
||||||
if err != nil {
|
|
||||||
modelFile = input.Model
|
|
||||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
|
||||||
}
|
|
||||||
cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false)
|
|
||||||
if err != nil {
|
|
||||||
modelFile = input.Model
|
|
||||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
|
||||||
} else {
|
|
||||||
modelFile = cfg.Model
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
|
||||||
|
|
||||||
if input.Backend != "" {
|
|
||||||
cfg.Backend = input.Input
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Download(filePath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,399 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
emptyMessage := ""
|
|
||||||
id := uuid.New().String()
|
|
||||||
created := int(time.Now().Unix())
|
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
|
||||||
initialMessage := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
}
|
|
||||||
responses <- initialMessage
|
|
||||||
|
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
|
||||||
resp := schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: usage.Prompt,
|
|
||||||
CompletionTokens: usage.Completion,
|
|
||||||
TotalTokens: usage.Prompt + usage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
responses <- resp
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
close(responses)
|
|
||||||
}
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
processFunctions := false
|
|
||||||
funcs := grammar.Functions{}
|
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Configuration read: %+v", config)
|
|
||||||
|
|
||||||
// Allow the user to set custom actions via config file
|
|
||||||
// to be "embedded" in each model
|
|
||||||
noActionName := "answer"
|
|
||||||
noActionDescription := "use this action to answer without performing any action"
|
|
||||||
|
|
||||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
|
||||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
|
||||||
}
|
|
||||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
|
||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
|
||||||
input.Grammar = grammar.JSONBNF
|
|
||||||
}
|
|
||||||
|
|
||||||
// process functions if we have any defined or if we have a function call string
|
|
||||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
|
||||||
log.Debug().Msgf("Response needs to process functions")
|
|
||||||
|
|
||||||
processFunctions = true
|
|
||||||
|
|
||||||
noActionGrammar := grammar.Function{
|
|
||||||
Name: noActionName,
|
|
||||||
Description: noActionDescription,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"type": "string",
|
|
||||||
"description": "The message to reply the user with",
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append the no action function
|
|
||||||
funcs = append(funcs, input.Functions...)
|
|
||||||
if !config.FunctionsConfig.DisableNoAction {
|
|
||||||
funcs = append(funcs, noActionGrammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force picking one of the functions by the request
|
|
||||||
if config.FunctionToCall() != "" {
|
|
||||||
funcs = funcs.Select(config.FunctionToCall())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update input grammar
|
|
||||||
jsStruct := funcs.ToJSONStructure()
|
|
||||||
config.Grammar = jsStruct.Grammar("")
|
|
||||||
} else if input.JSONFunctionGrammarObject != nil {
|
|
||||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
|
||||||
}
|
|
||||||
|
|
||||||
// functions are not supported in stream mode (yet?)
|
|
||||||
toStream := input.Stream && !processFunctions
|
|
||||||
|
|
||||||
log.Debug().Msgf("Parameters: %+v", config)
|
|
||||||
|
|
||||||
var predInput string
|
|
||||||
|
|
||||||
suppressConfigSystemPrompt := false
|
|
||||||
mess := []string{}
|
|
||||||
for messageIndex, i := range input.Messages {
|
|
||||||
var content string
|
|
||||||
role := i.Role
|
|
||||||
|
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
|
||||||
roleFn := "assistant_function_call"
|
|
||||||
r := config.Roles[roleFn]
|
|
||||||
if r != "" {
|
|
||||||
role = roleFn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r := config.Roles[role]
|
|
||||||
contentExists := i.Content != nil && i.StringContent != ""
|
|
||||||
// First attempt to populate content via a chat message specific template
|
|
||||||
if config.TemplateConfig.ChatMessage != "" {
|
|
||||||
chatMessageData := model.ChatMessageTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
Role: r,
|
|
||||||
RoleName: role,
|
|
||||||
Content: i.StringContent,
|
|
||||||
MessageIndex: messageIndex,
|
|
||||||
}
|
|
||||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
|
||||||
} else {
|
|
||||||
if templatedChatMessage == "" {
|
|
||||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
|
||||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
|
||||||
content = templatedChatMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
|
||||||
if content == "" {
|
|
||||||
if r != "" {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(r, i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
|
||||||
} else {
|
|
||||||
content = fmt.Sprint(r, " ", string(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if contentExists {
|
|
||||||
content = fmt.Sprint(i.StringContent)
|
|
||||||
}
|
|
||||||
if i.FunctionCall != nil {
|
|
||||||
j, err := json.Marshal(i.FunctionCall)
|
|
||||||
if err == nil {
|
|
||||||
if contentExists {
|
|
||||||
content += "\n" + string(j)
|
|
||||||
} else {
|
|
||||||
content = string(j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
|
||||||
if contentExists && role == "system" {
|
|
||||||
suppressConfigSystemPrompt = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mess = append(mess, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
predInput = strings.Join(mess, "\n")
|
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
log.Debug().Msgf("Stream request received")
|
|
||||||
c.Context().SetContentType("text/event-stream")
|
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
// c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache")
|
|
||||||
c.Set("Connection", "keep-alive")
|
|
||||||
c.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
templateFile := ""
|
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Chat
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
|
||||||
templateFile = config.TemplateConfig.Functions
|
|
||||||
}
|
|
||||||
|
|
||||||
if templateFile != "" {
|
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
|
||||||
SystemPrompt: config.SystemPrompt,
|
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
|
||||||
Input: predInput,
|
|
||||||
Functions: funcs,
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
predInput = templatedInput
|
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
|
||||||
if processFunctions {
|
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toStream {
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
|
||||||
|
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
||||||
|
|
||||||
usage := &schema.OpenAIUsage{}
|
|
||||||
|
|
||||||
for ev := range responses {
|
|
||||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
|
||||||
var buf bytes.Buffer
|
|
||||||
enc := json.NewEncoder(&buf)
|
|
||||||
enc.Encode(ev)
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
||||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
|
||||||
input.Cancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
w.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []schema.Choice{
|
|
||||||
{
|
|
||||||
FinishReason: "stop",
|
|
||||||
Index: 0,
|
|
||||||
Delta: &schema.Message{Content: &emptyMessage},
|
|
||||||
}},
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Usage: *usage,
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
||||||
w.WriteString("data: [DONE]\n\n")
|
|
||||||
w.Flush()
|
|
||||||
}))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
|
||||||
if processFunctions {
|
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
|
||||||
ss := map[string]interface{}{}
|
|
||||||
// This prevent newlines to break JSON parsing for clients
|
|
||||||
s = utils.EscapeNewLines(s)
|
|
||||||
json.Unmarshal([]byte(s), &ss)
|
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
|
||||||
func_name := ss["function"]
|
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
d, _ := json.Marshal(args)
|
|
||||||
|
|
||||||
ss["arguments"] = string(d)
|
|
||||||
ss["name"] = func_name
|
|
||||||
|
|
||||||
// if do nothing, reply with a message
|
|
||||||
if func_name == noActionName {
|
|
||||||
log.Debug().Msgf("nothing to do, computing a reply")
|
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
|
||||||
arguments := map[string]interface{}{}
|
|
||||||
json.Unmarshal([]byte(d), &arguments)
|
|
||||||
m, exists := arguments["message"]
|
|
||||||
if exists {
|
|
||||||
switch message := m.(type) {
|
|
||||||
case string:
|
|
||||||
if message != "" {
|
|
||||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
|
||||||
message = backend.Finetune(*config, predInput, message)
|
|
||||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
|
||||||
|
|
||||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
|
||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
|
||||||
// Note: This costs (in term of CPU) another computation
|
|
||||||
config.Grammar = ""
|
|
||||||
images := []string{}
|
|
||||||
for _, m := range input.Messages {
|
|
||||||
images = append(images, m.StringImages...)
|
|
||||||
}
|
|
||||||
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prediction, err := predFunc()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
|
||||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
|
|
||||||
} else {
|
|
||||||
// otherwise reply with the function call
|
|
||||||
*c = append(*c, schema.Choice{
|
|
||||||
FinishReason: "function_call",
|
|
||||||
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &schema.OpenAIResponse{
|
|
||||||
ID: id,
|
|
||||||
Created: created,
|
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Usage: schema.OpenAIUsage{
|
|
||||||
PromptTokens: tokenUsage.Prompt,
|
|
||||||
CompletionTokens: tokenUsage.Completion,
|
|
||||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
respData, _ := json.Marshal(resp)
|
|
||||||
log.Debug().Msgf("Response: %s", respData)
|
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -126,6 +126,11 @@ message ModelOptions {
|
|||||||
|
|
||||||
// vllm
|
// vllm
|
||||||
string Quantization = 40;
|
string Quantization = 40;
|
||||||
|
float GPUMemoryUtilization = 50;
|
||||||
|
bool TrustRemoteCode = 51;
|
||||||
|
bool EnforceEager = 52;
|
||||||
|
int32 SwapSpace = 53;
|
||||||
|
int32 MaxModelLen = 54;
|
||||||
|
|
||||||
string MMProj = 41;
|
string MMProj = 41;
|
||||||
|
|
||||||
@@ -186,6 +191,7 @@ message TTSRequest {
|
|||||||
string text = 1;
|
string text = 1;
|
||||||
string model = 2;
|
string model = 2;
|
||||||
string dst = 3;
|
string dst = 3;
|
||||||
|
string voice = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizationResponse {
|
message TokenizationResponse {
|
||||||
|
|||||||
@@ -2,16 +2,20 @@
|
|||||||
## XXX: In some versions of CMake clip wasn't being built before llama.
|
## XXX: In some versions of CMake clip wasn't being built before llama.
|
||||||
## This is an hack for now, but it should be fixed in the future.
|
## This is an hack for now, but it should be fixed in the future.
|
||||||
set(TARGET myclip)
|
set(TARGET myclip)
|
||||||
add_library(${TARGET} clip.cpp clip.h)
|
add_library(${TARGET} clip.cpp clip.h llava.cpp llava.h)
|
||||||
install(TARGETS ${TARGET} LIBRARY)
|
install(TARGETS ${TARGET} LIBRARY)
|
||||||
target_link_libraries(${TARGET} PRIVATE common ggml ${CMAKE_THREAD_LIBS_INIT})
|
target_include_directories(myclip PUBLIC .)
|
||||||
|
target_include_directories(myclip PUBLIC ../..)
|
||||||
|
target_include_directories(myclip PUBLIC ../../common)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
|
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
endif()
|
endif()
|
||||||
|
# END CLIP hack
|
||||||
|
|
||||||
|
|
||||||
set(TARGET grpc-server)
|
set(TARGET grpc-server)
|
||||||
# END CLIP hack
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
cmake_minimum_required(VERSION 3.15)
|
cmake_minimum_required(VERSION 3.15)
|
||||||
set(TARGET grpc-server)
|
set(TARGET grpc-server)
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ ifeq ($(BUILD_TYPE),cublas)
|
|||||||
# to CMAKE_ARGS automatically
|
# to CMAKE_ARGS automatically
|
||||||
else ifeq ($(BUILD_TYPE),openblas)
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
|
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
|
||||||
# If build type is clblast (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
# If build type is clblas (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
else ifeq ($(BUILD_TYPE),clblast)
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||||
else ifeq ($(BUILD_TYPE),hipblas)
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
@@ -45,6 +45,9 @@ llama.cpp/examples/grpc-server:
|
|||||||
## XXX: In some versions of CMake clip wasn't being built before llama.
|
## XXX: In some versions of CMake clip wasn't being built before llama.
|
||||||
## This is an hack for now, but it should be fixed in the future.
|
## This is an hack for now, but it should be fixed in the future.
|
||||||
cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h
|
cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h
|
||||||
|
cp -rfv llama.cpp/examples/llava/llava.cpp llama.cpp/examples/grpc-server/llava.cpp
|
||||||
|
echo '#include "llama.h"' > llama.cpp/examples/grpc-server/llava.h
|
||||||
|
cat llama.cpp/examples/llava/llava.h >> llama.cpp/examples/grpc-server/llava.h
|
||||||
cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp
|
cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp
|
||||||
|
|
||||||
rebuild:
|
rebuild:
|
||||||
|
|||||||
@@ -11,7 +11,8 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <getopt.h>
|
#include <getopt.h>
|
||||||
#include "../llava/clip.h"
|
#include "clip.h"
|
||||||
|
#include "llava.h"
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
@@ -32,6 +33,7 @@
|
|||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
#include <grpcpp/health_check_service_interface.h>
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <signal.h>
|
||||||
|
|
||||||
using grpc::Server;
|
using grpc::Server;
|
||||||
using grpc::ServerBuilder;
|
using grpc::ServerBuilder;
|
||||||
@@ -51,12 +53,16 @@ struct server_params
|
|||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
|
std::string chat_template = "";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
|
bool slots_endpoint = true;
|
||||||
|
bool metrics_endpoint = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool server_verbose = false;
|
bool server_verbose = false;
|
||||||
|
bool server_log_json = true;
|
||||||
|
|
||||||
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
||||||
{
|
{
|
||||||
@@ -172,6 +178,7 @@ struct llama_client_slot
|
|||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t n_remaining = -1;
|
int32_t n_remaining = -1;
|
||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
|
int32_t n_predict = -1;
|
||||||
|
|
||||||
int32_t num_prompt_tokens = 0;
|
int32_t num_prompt_tokens = 0;
|
||||||
int32_t num_prompt_tokens_processed = 0;
|
int32_t num_prompt_tokens_processed = 0;
|
||||||
@@ -311,12 +318,76 @@ struct llama_client_slot
|
|||||||
}
|
}
|
||||||
|
|
||||||
void print_timings() const {
|
void print_timings() const {
|
||||||
LOG_TEE("\n");
|
char buffer[512];
|
||||||
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
double t_token = t_prompt_processing / num_prompt_tokens_processed;
|
||||||
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
|
double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed;
|
||||||
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||||
__func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded);
|
t_prompt_processing, num_prompt_tokens_processed,
|
||||||
LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation);
|
t_token, n_tokens_second);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_prompt_processing", t_prompt_processing},
|
||||||
|
{"num_prompt_tokens_processed", num_prompt_tokens_processed},
|
||||||
|
{"t_token", t_token},
|
||||||
|
{"n_tokens_second", n_tokens_second},
|
||||||
|
});
|
||||||
|
|
||||||
|
t_token = t_token_generation / n_decoded;
|
||||||
|
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
||||||
|
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||||
|
t_token_generation, n_decoded,
|
||||||
|
t_token, n_tokens_second);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_token_generation", t_token_generation},
|
||||||
|
{"n_decoded", n_decoded},
|
||||||
|
{"t_token", t_token},
|
||||||
|
{"n_tokens_second", n_tokens_second},
|
||||||
|
});
|
||||||
|
|
||||||
|
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_prompt_processing", t_prompt_processing},
|
||||||
|
{"t_token_generation", t_token_generation},
|
||||||
|
{"t_total", t_prompt_processing + t_token_generation},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_metrics {
|
||||||
|
uint64_t n_prompt_tokens_processed_total = 0;
|
||||||
|
uint64_t n_tokens_predicted_total = 0;
|
||||||
|
|
||||||
|
uint64_t n_prompt_tokens_processed = 0;
|
||||||
|
uint64_t t_prompt_processing = 0;
|
||||||
|
|
||||||
|
uint64_t n_tokens_predicted = 0;
|
||||||
|
uint64_t t_tokens_generation = 0;
|
||||||
|
|
||||||
|
|
||||||
|
void on_prompt_eval(const llama_client_slot &slot) {
|
||||||
|
n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed;
|
||||||
|
|
||||||
|
n_prompt_tokens_processed += slot.num_prompt_tokens_processed;
|
||||||
|
t_prompt_processing += slot.t_prompt_processing;
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_prediction(const llama_client_slot &slot) {
|
||||||
|
n_tokens_predicted_total += slot.n_decoded;
|
||||||
|
|
||||||
|
n_tokens_predicted += slot.n_decoded;
|
||||||
|
t_tokens_generation += slot.t_token_generation;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_bucket() {
|
||||||
|
n_prompt_tokens_processed = 0;
|
||||||
|
t_prompt_processing = 0;
|
||||||
|
n_tokens_predicted = 0;
|
||||||
|
t_tokens_generation = 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -349,10 +420,13 @@ struct llama_server_context
|
|||||||
|
|
||||||
// slots / clients
|
// slots / clients
|
||||||
std::vector<llama_client_slot> slots;
|
std::vector<llama_client_slot> slots;
|
||||||
|
json default_generation_settings_for_props;
|
||||||
|
|
||||||
llama_server_queue queue_tasks;
|
llama_server_queue queue_tasks;
|
||||||
llama_server_response queue_results;
|
llama_server_response queue_results;
|
||||||
|
|
||||||
|
llama_metrics metrics;
|
||||||
|
|
||||||
~llama_server_context()
|
~llama_server_context()
|
||||||
{
|
{
|
||||||
if (ctx)
|
if (ctx)
|
||||||
@@ -372,7 +446,7 @@ struct llama_server_context
|
|||||||
params = params_;
|
params = params_;
|
||||||
if (!params.mmproj.empty()) {
|
if (!params.mmproj.empty()) {
|
||||||
multimodal = true;
|
multimodal = true;
|
||||||
LOG_TEE("Multi Modal Mode Enabled");
|
LOG_INFO("Multi Modal Mode Enabled", {});
|
||||||
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
||||||
if(clp_ctx == nullptr) {
|
if(clp_ctx == nullptr) {
|
||||||
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
||||||
@@ -409,21 +483,35 @@ struct llama_server_context
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void validate_model_chat_template(server_params & sparams) {
|
||||||
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
|
std::vector<char> buf(1);
|
||||||
|
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
||||||
|
if (res < 0) {
|
||||||
|
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||||
|
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void initialize() {
|
void initialize() {
|
||||||
// create slots
|
// create slots
|
||||||
all_slots_are_idle = true;
|
all_slots_are_idle = true;
|
||||||
|
|
||||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||||
|
|
||||||
LOG_TEE("Available slots:\n");
|
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
|
||||||
for (int i = 0; i < params.n_parallel; i++)
|
for (int i = 0; i < params.n_parallel; i++)
|
||||||
{
|
{
|
||||||
llama_client_slot slot;
|
llama_client_slot slot;
|
||||||
|
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.n_ctx = n_ctx_slot;
|
slot.n_ctx = n_ctx_slot;
|
||||||
|
slot.n_predict = params.n_predict;
|
||||||
|
|
||||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
LOG_INFO("new slot", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"n_ctx_slot", slot.n_ctx}
|
||||||
|
});
|
||||||
|
|
||||||
const int ga_n = params.grp_attn_n;
|
const int ga_n = params.grp_attn_n;
|
||||||
const int ga_w = params.grp_attn_w;
|
const int ga_w = params.grp_attn_w;
|
||||||
@@ -433,7 +521,12 @@ struct llama_server_context
|
|||||||
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
||||||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
||||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
||||||
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
|
|
||||||
|
LOG_INFO("slot self-extend", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"ga_n", ga_n},
|
||||||
|
{"ga_w", ga_w}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.ga_i = 0;
|
slot.ga_i = 0;
|
||||||
@@ -445,11 +538,10 @@ struct llama_server_context
|
|||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
||||||
|
default_generation_settings_for_props["seed"] = -1;
|
||||||
|
|
||||||
// empty system prompt
|
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
||||||
system_prompt = "";
|
|
||||||
system_tokens.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
|
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
|
||||||
@@ -526,28 +618,40 @@ struct llama_server_context
|
|||||||
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
llama_sampling_params default_sparams;
|
llama_sampling_params default_sparams;
|
||||||
|
|
||||||
|
slot->params.stream = json_value(data, "stream", false);
|
||||||
|
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||||
|
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||||
|
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
|
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
|
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
|
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
|
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||||
|
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
|
slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
|
slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
|
slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
|
slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
|
slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
|
slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||||
|
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
|
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
|
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
|
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
|
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||||
|
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||||
|
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
|
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
|
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
|
||||||
slot->params.stream = json_value(data, "stream", false);
|
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
|
||||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
// Might be better to reject the request with a 400 ?
|
||||||
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
||||||
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
{"params.n_predict", slot->params.n_predict},
|
||||||
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
{"slot.n_predict", slot->n_predict},
|
||||||
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
});
|
||||||
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot->params.n_predict = slot->n_predict;
|
||||||
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
}
|
||||||
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
||||||
slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
||||||
slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
||||||
slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
||||||
slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
||||||
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
||||||
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
||||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
||||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
||||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
|
||||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
|
||||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
||||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
||||||
|
|
||||||
// infill
|
// infill
|
||||||
if (data.count("input_prefix") != 0)
|
if (data.count("input_prefix") != 0)
|
||||||
@@ -626,18 +730,36 @@ struct llama_server_context
|
|||||||
const int n_vocab = llama_n_vocab(model);
|
const int n_vocab = llama_n_vocab(model);
|
||||||
for (const auto &el : *logit_bias)
|
for (const auto &el : *logit_bias)
|
||||||
{
|
{
|
||||||
if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
|
if (el.is_array() && el.size() == 2)
|
||||||
{
|
{
|
||||||
llama_token tok = el[0].get<llama_token>();
|
float bias;
|
||||||
if (tok >= 0 && tok < n_vocab)
|
if (el[1].is_number())
|
||||||
{
|
{
|
||||||
if (el[1].is_number())
|
bias = el[1].get<float>();
|
||||||
|
}
|
||||||
|
else if (el[1].is_boolean() && !el[1].get<bool>())
|
||||||
|
{
|
||||||
|
bias = -INFINITY;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (el[0].is_number_integer())
|
||||||
|
{
|
||||||
|
llama_token tok = el[0].get<llama_token>();
|
||||||
|
if (tok >= 0 && tok < n_vocab)
|
||||||
{
|
{
|
||||||
slot->sparams.logit_bias[tok] = el[1].get<float>();
|
slot->sparams.logit_bias[tok] = bias;
|
||||||
}
|
}
|
||||||
else if (el[1].is_boolean() && !el[1].get<bool>())
|
}
|
||||||
|
else if (el[0].is_string())
|
||||||
|
{
|
||||||
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||||
|
for (auto tok : toks)
|
||||||
{
|
{
|
||||||
slot->sparams.logit_bias[tok] = -INFINITY;
|
slot->sparams.logit_bias[tok] = bias;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -658,6 +780,24 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto &samplers_sequence = data.find("samplers");
|
||||||
|
if (samplers_sequence != data.end() && samplers_sequence->is_array())
|
||||||
|
{
|
||||||
|
std::vector<std::string> sampler_names;
|
||||||
|
for (const auto &sampler_name : *samplers_sequence)
|
||||||
|
{
|
||||||
|
if (sampler_name.is_string())
|
||||||
|
{
|
||||||
|
sampler_names.emplace_back(sampler_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
slot->sparams.samplers_sequence = default_sparams.samplers_sequence;
|
||||||
|
}
|
||||||
|
|
||||||
if (multimodal)
|
if (multimodal)
|
||||||
{
|
{
|
||||||
const auto &images_data = data.find("image_data");
|
const auto &images_data = data.find("image_data");
|
||||||
@@ -672,10 +812,16 @@ struct llama_server_context
|
|||||||
img_sl.img_data = clip_image_u8_init();
|
img_sl.img_data = clip_image_u8_init();
|
||||||
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
||||||
{
|
{
|
||||||
LOG_TEE("slot %i - failed to load image [id: %i]\n", slot->id, img_sl.id);
|
LOG_ERROR("failed to load image", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"img_sl_id", img_sl.id}
|
||||||
|
});
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
LOG_TEE("slot %i - loaded image\n", slot->id);
|
LOG_VERBOSE("image loaded", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"img_sl_id", img_sl.id}
|
||||||
|
});
|
||||||
img_sl.request_encode_image = true;
|
img_sl.request_encode_image = true;
|
||||||
slot->images.push_back(img_sl);
|
slot->images.push_back(img_sl);
|
||||||
}
|
}
|
||||||
@@ -735,7 +881,10 @@ struct llama_server_context
|
|||||||
|
|
||||||
all_slots_are_idle = false;
|
all_slots_are_idle = false;
|
||||||
|
|
||||||
LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id);
|
LOG_INFO("slot is processing task", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"task_id", slot->task_id},
|
||||||
|
});
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -747,27 +896,44 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
void update_system_prompt() {
|
void update_system_prompt() {
|
||||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
|
||||||
|
|
||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
|
system_tokens.clear();
|
||||||
|
|
||||||
for (int i = 0; i < (int) system_tokens.size(); ++i)
|
if (!system_prompt.empty()) {
|
||||||
{
|
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
||||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0)
|
llama_batch_clear(batch);
|
||||||
{
|
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
for (int i = 0; i < (int)system_tokens.size(); ++i)
|
||||||
for (int32_t i = 1; i < params.n_parallel; ++i)
|
{
|
||||||
{
|
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
}
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
|
||||||
|
{
|
||||||
|
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
llama_batch batch_view = {
|
||||||
|
n_tokens,
|
||||||
|
batch.token + i,
|
||||||
|
nullptr,
|
||||||
|
batch.pos + i,
|
||||||
|
batch.n_seq_id + i,
|
||||||
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
if (llama_decode(ctx, batch_view) != 0)
|
||||||
|
{
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign the system KV cache to all parallel sequences
|
||||||
|
for (int32_t i = 1; i < params.n_parallel; ++i)
|
||||||
|
{
|
||||||
|
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("system prompt updated\n");
|
LOG_TEE("system prompt updated\n");
|
||||||
@@ -789,10 +955,8 @@ struct llama_server_context
|
|||||||
name_user = sys_props.value("anti_prompt", "");
|
name_user = sys_props.value("anti_prompt", "");
|
||||||
name_assistant = sys_props.value("assistant_name", "");
|
name_assistant = sys_props.value("assistant_name", "");
|
||||||
|
|
||||||
if (slots.size() > 0)
|
|
||||||
{
|
notify_system_prompt_changed();
|
||||||
notify_system_prompt_changed();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
|
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
|
||||||
@@ -950,28 +1114,12 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
clip_image_f32 * img_res = clip_image_f32_init();
|
|
||||||
if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true))
|
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
|
||||||
{
|
|
||||||
LOG_TEE("Error processing the given image");
|
LOG_TEE("Error processing the given image");
|
||||||
clip_free(clp_ctx);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
img.image_tokens = clip_n_patches(clp_ctx);
|
|
||||||
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
|
|
||||||
if (!img.image_embedding)
|
|
||||||
{
|
|
||||||
LOG_TEE("Unable to allocate memory for image embeddings\n");
|
|
||||||
clip_free(clp_ctx);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
LOG_TEE("slot %i - encoding image [id: %i]\n", slot.id, img.id);
|
|
||||||
if (!clip_image_encode(clp_ctx, params.n_threads, img_res, img.image_embedding))
|
|
||||||
{
|
|
||||||
LOG_TEE("Unable to encode image\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
clip_image_f32_free(img_res);
|
|
||||||
img.request_encode_image = false;
|
img.request_encode_image = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -990,21 +1138,25 @@ struct llama_server_context
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
json get_model_props()
|
|
||||||
{
|
|
||||||
return get_formated_generation(slots[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
json get_formated_generation(llama_client_slot &slot)
|
json get_formated_generation(llama_client_slot &slot)
|
||||||
{
|
{
|
||||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
|
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
|
||||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||||
|
std::vector<std::string> samplers_sequence;
|
||||||
|
for (const auto &sampler_type : slot.sparams.samplers_sequence)
|
||||||
|
{
|
||||||
|
samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
|
||||||
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
{"n_ctx", slot.n_ctx},
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_predict", slot.n_predict},
|
||||||
{"model", params.model_alias},
|
{"model", params.model_alias},
|
||||||
{"seed", slot.params.seed},
|
{"seed", slot.params.seed},
|
||||||
{"temperature", slot.sparams.temp},
|
{"temperature", slot.sparams.temp},
|
||||||
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||||
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
||||||
{"top_k", slot.sparams.top_k},
|
{"top_k", slot.sparams.top_k},
|
||||||
{"top_p", slot.sparams.top_p},
|
{"top_p", slot.sparams.top_p},
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
@@ -1027,7 +1179,9 @@ struct llama_server_context
|
|||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
{"logit_bias", slot.sparams.logit_bias},
|
{"logit_bias", slot.sparams.logit_bias},
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
|
{"samplers", samplers_sequence}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1166,13 +1320,30 @@ struct llama_server_context
|
|||||||
task.multitask_id = multitask_id;
|
task.multitask_id = multitask_id;
|
||||||
|
|
||||||
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||||
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
|
|
||||||
{
|
|
||||||
split_multiprompt_task(task_id, task);
|
|
||||||
}
|
|
||||||
|
|
||||||
// otherwise, it's a single-prompt task, we actually queue it
|
// otherwise, it's a single-prompt task, we actually queue it
|
||||||
queue_tasks.post(task);
|
// if there's numbers in the prompt array it will be treated as an array of tokens
|
||||||
|
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
||||||
|
bool numbers = false;
|
||||||
|
for (const auto& e : task.data.at("prompt")) {
|
||||||
|
if (e.is_number()) {
|
||||||
|
numbers = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
|
||||||
|
// it will completely stall the server. I don't know where the bug for this is.
|
||||||
|
//
|
||||||
|
// if there are numbers, it needs to be treated like a single prompt,
|
||||||
|
// queue_tasks handles a mix of strings and numbers just fine.
|
||||||
|
if (numbers) {
|
||||||
|
queue_tasks.post(task);
|
||||||
|
} else {
|
||||||
|
split_multiprompt_task(task_id, task);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
queue_tasks.post(task);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for multiple images processing
|
// for multiple images processing
|
||||||
@@ -1254,7 +1425,10 @@ struct llama_server_context
|
|||||||
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
|
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
|
||||||
{
|
{
|
||||||
int prompt_count = multiprompt_task.data.at("prompt").size();
|
int prompt_count = multiprompt_task.data.at("prompt").size();
|
||||||
assert(prompt_count > 1);
|
if (prompt_count <= 1) {
|
||||||
|
send_error(multiprompt_task, "error while handling multiple prompts");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// generate all the ID for subtask
|
// generate all the ID for subtask
|
||||||
std::vector<int> subtask_ids(prompt_count);
|
std::vector<int> subtask_ids(prompt_count);
|
||||||
@@ -1286,7 +1460,7 @@ struct llama_server_context
|
|||||||
if (slot == nullptr)
|
if (slot == nullptr)
|
||||||
{
|
{
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
LOG_VERBOSE("no slot is available", {});
|
LOG_VERBOSE("no slot is available", {{"task_id", task.id}});
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(task);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1360,7 +1534,7 @@ struct llama_server_context
|
|||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
if (system_need_update)
|
if (system_need_update)
|
||||||
{
|
{
|
||||||
LOG_TEE("updating system prompt\n");
|
LOG_INFO("updating system prompt", {});
|
||||||
update_system_prompt();
|
update_system_prompt();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1370,12 +1544,13 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
if (system_prompt.empty() && clean_kv_cache)
|
if (system_prompt.empty() && clean_kv_cache)
|
||||||
{
|
{
|
||||||
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
|
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
|
||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
||||||
task_server task;
|
task_server task;
|
||||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||||
task.target_id = -1;
|
task.target_id = -1;
|
||||||
@@ -1406,6 +1581,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
|
LOG_VERBOSE("decoding ongoing sequences", {});
|
||||||
for (auto & slot : slots)
|
for (auto & slot : slots)
|
||||||
{
|
{
|
||||||
// release the slot
|
// release the slot
|
||||||
@@ -1415,7 +1591,15 @@ struct llama_server_context
|
|||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
slot.t_last_used = ggml_time_us();
|
slot.t_last_used = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
|
LOG_INFO("slot released", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"task_id", slot.task_id},
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_past", slot.n_past},
|
||||||
|
{"n_system_tokens", system_tokens.size()},
|
||||||
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||||
|
{"truncated", slot.truncated}
|
||||||
|
});
|
||||||
queue_tasks.notify_slot_changed();
|
queue_tasks.notify_slot_changed();
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@@ -1542,6 +1726,14 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
|
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
||||||
|
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
||||||
|
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
|
||||||
|
{
|
||||||
|
slot.n_past -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||||
|
|
||||||
if (slot.ga_n != 1)
|
if (slot.ga_n != 1)
|
||||||
@@ -1563,19 +1755,23 @@ struct llama_server_context
|
|||||||
slot.ga_i = ga_i;
|
slot.ga_i = ga_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
LOG_INFO("slot progression", {
|
||||||
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id },
|
||||||
|
{ "n_past", slot.n_past },
|
||||||
|
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
|
||||||
|
|
||||||
slot.cache_tokens = prompt_tokens;
|
slot.cache_tokens = prompt_tokens;
|
||||||
|
|
||||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||||
{
|
{
|
||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
||||||
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id }
|
||||||
|
});
|
||||||
slot.n_past--;
|
slot.n_past--;
|
||||||
if (slot.ga_i > 0)
|
if (slot.ga_i > 0)
|
||||||
{
|
{
|
||||||
@@ -1583,6 +1779,14 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
|
LOG_INFO("kv cache rm [p0, end)", {
|
||||||
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id },
|
||||||
|
{ "p0", p0 }
|
||||||
|
});
|
||||||
|
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||||
|
|
||||||
LOG_VERBOSE("prompt ingested", {
|
LOG_VERBOSE("prompt ingested", {
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
||||||
@@ -1616,7 +1820,13 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
{
|
{
|
||||||
LOG_TEE("failed processing images\n");
|
LOG_ERROR("failed processing images", {
|
||||||
|
"slot_id", slot.id,
|
||||||
|
"task_id", slot.task_id,
|
||||||
|
});
|
||||||
|
// FIXME @phymbert: to be properly tested
|
||||||
|
// early returning without changing the slot state will block the slot for ever
|
||||||
|
// no one at the moment is checking the return value
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1658,9 +1868,9 @@ struct llama_server_context
|
|||||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||||
|
|
||||||
slot.n_past_se -= bd;
|
slot.n_past_se -= bd;
|
||||||
|
|
||||||
@@ -1716,7 +1926,7 @@ struct llama_server_context
|
|||||||
send_embedding(slot);
|
send_embedding(slot);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
return true;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
@@ -1729,6 +1939,7 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
slot.t_start_genereration = ggml_time_us();
|
slot.t_start_genereration = ggml_time_us();
|
||||||
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
||||||
|
metrics.on_prompt_eval(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
||||||
@@ -1751,11 +1962,14 @@ struct llama_server_context
|
|||||||
slot.release();
|
slot.release();
|
||||||
slot.print_timings();
|
slot.print_timings();
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
|
metrics.on_prediction(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("slots updated", {});
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1784,18 +1998,6 @@ static json format_partial_response(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_tokenizer_response(const std::vector<llama_token> &tokens)
|
|
||||||
{
|
|
||||||
return json{
|
|
||||||
{"tokens", tokens}};
|
|
||||||
}
|
|
||||||
|
|
||||||
static json format_detokenized_response(std::string content)
|
|
||||||
{
|
|
||||||
return json{
|
|
||||||
{"content", content}};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct token_translator
|
struct token_translator
|
||||||
{
|
{
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
@@ -1819,6 +2021,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::function<void(int)> shutdown_handler;
|
||||||
|
inline void signal_handler(int signal) { shutdown_handler(signal); }
|
||||||
|
|
||||||
/////////////////////////////////
|
/////////////////////////////////
|
||||||
////////////////////////////////
|
////////////////////////////////
|
||||||
//////// LOCALAI code starts below here
|
//////// LOCALAI code starts below here
|
||||||
@@ -2051,9 +2256,9 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
params.use_mmap = request->mmap();
|
params.use_mmap = request->mmap();
|
||||||
params.embedding = request->embeddings();
|
params.embedding = request->embeddings();
|
||||||
|
|
||||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||||
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
||||||
if ( request->yarnextfactor() != 0.0f ) {
|
if ( request->yarnextfactor() != 0.0f ) {
|
||||||
params.yarn_ext_factor = request->yarnextfactor();
|
params.yarn_ext_factor = request->yarnextfactor();
|
||||||
}
|
}
|
||||||
@@ -2089,7 +2294,8 @@ public:
|
|||||||
gpt_params params;
|
gpt_params params;
|
||||||
params_parse(request, params);
|
params_parse(request, params);
|
||||||
|
|
||||||
llama_backend_init(params.numa);
|
llama_backend_init();
|
||||||
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
if (!llama.load_model(params))
|
if (!llama.load_model(params))
|
||||||
|
|||||||
@@ -8,24 +8,24 @@ import (
|
|||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
"github.com/go-audio/wav"
|
"github.com/go-audio/wav"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func sh(c string) (string, error) {
|
func runCommand(command []string) (string, error) {
|
||||||
cmd := exec.Command("/bin/sh", "-c", c)
|
cmd := exec.Command(command[0], command[1:]...)
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
o, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
return string(o), err
|
return string(out), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
|
// AudioToWav converts audio to wav for transcribe.
|
||||||
// TODO: use https://github.com/mccoyst/ogg?
|
// TODO: use https://github.com/mccoyst/ogg?
|
||||||
func audioToWav(src, dst string) error {
|
func audioToWav(src, dst string) error {
|
||||||
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
|
command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
|
||||||
|
out, err := runCommand(command)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error: %w out: %s", err, out)
|
return fmt.Errorf("error: %w out: %s", err, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ package main
|
|||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
import (
|
import (
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
||||||
model_basename=request.ModelBaseName,
|
model_basename=request.ModelBaseName,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
device=device,
|
device=device,
|
||||||
use_triton=request.UseTriton,
|
use_triton=request.UseTriton,
|
||||||
quantize_config=None)
|
quantize_config=None)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ dependencies:
|
|||||||
- regex==2023.10.3
|
- regex==2023.10.3
|
||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.3.3
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers==0.14.0
|
- tokenizers==0.14.0
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -4,6 +4,17 @@ ifeq ($(BUILD_TYPE), cublas)
|
|||||||
CONDA_ENV_PATH = "transformers-nvidia.yml"
|
CONDA_ENV_PATH = "transformers-nvidia.yml"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE), hipblas)
|
||||||
|
CONDA_ENV_PATH = "transformers-rocm.yml"
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Intel GPU are supposed to have dependencies installed in the main python
|
||||||
|
# environment, so we skip conda installation for SYCL builds.
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: transformers
|
.PHONY: transformers
|
||||||
transformers:
|
transformers:
|
||||||
@echo "Installing $(CONDA_ENV_PATH)..."
|
@echo "Installing $(CONDA_ENV_PATH)..."
|
||||||
|
|||||||
@@ -1,24 +1,38 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
|
|
||||||
# Check if environment exist
|
# Check if environment exist
|
||||||
conda_env_exists(){
|
conda_env_exists(){
|
||||||
! conda list --name "${@}" >/dev/null 2>/dev/null
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
if conda_env_exists "transformers" ; then
|
if [ $SKIP_CONDA -eq 1 ]; then
|
||||||
echo "Creating virtual environment..."
|
echo "Skipping conda environment installation"
|
||||||
conda env create --name transformers --file $1
|
else
|
||||||
echo "Virtual environment created."
|
export PATH=$PATH:/opt/conda/bin
|
||||||
else
|
if conda_env_exists "transformers" ; then
|
||||||
echo "Virtual environment already exists."
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name transformers --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Intel GPU: If the directory exists, we assume we are using the intel image
|
||||||
|
# (no conda env)
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ $SKIP_CONDA -eq 0 ]; then
|
||||||
|
# Activate conda environment
|
||||||
# Activate conda environment
|
source activate transformers
|
||||||
source activate transformers
|
fi
|
||||||
|
|
||||||
pip cache purge
|
pip cache purge
|
||||||
fi
|
fi
|
||||||
@@ -33,9 +33,10 @@ dependencies:
|
|||||||
- boto3==1.28.61
|
- boto3==1.28.61
|
||||||
- botocore==1.31.61
|
- botocore==1.31.61
|
||||||
- certifi==2023.7.22
|
- certifi==2023.7.22
|
||||||
|
- TTS==0.22.0
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
- sentence-transformers==2.2.2
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
- sentencepiece==0.1.99
|
- sentencepiece==0.1.99
|
||||||
- dill==0.3.7
|
- dill==0.3.7
|
||||||
- einops==0.7.0
|
- einops==0.7.0
|
||||||
@@ -80,8 +81,8 @@ dependencies:
|
|||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- s3transfer==0.7.0
|
- s3transfer==0.7.0
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.4.1
|
||||||
- scipy==1.11.3
|
- scipy==1.12.0 # Updated Version
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
@@ -112,7 +113,7 @@ dependencies:
|
|||||||
- sudachipy
|
- sudachipy
|
||||||
- sudachidict_core
|
- sudachidict_core
|
||||||
- vocos
|
- vocos
|
||||||
- vllm==0.2.7
|
- vllm==0.3.2
|
||||||
- transformers>=4.36.0 # Required for Mixtral.
|
- transformers>=4.38.2 # Updated Version
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
|||||||
109
backend/python/common-env/transformers/transformers-rocm.yml
Normal file
109
backend/python/common-env/transformers/transformers-rocm.yml
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
name: transformers
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=5.1=1_gnu
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2023.08.22=h06a4308_0
|
||||||
|
- ld_impl_linux-64=2.38=h1181459_1
|
||||||
|
- libffi=3.4.4=h6a678d5_0
|
||||||
|
- libgcc-ng=11.2.0=h1234567_1
|
||||||
|
- libgomp=11.2.0=h1234567_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- libuuid=1.41.5=h5eee18b_0
|
||||||
|
- ncurses=6.4=h6a678d5_0
|
||||||
|
- openssl=3.0.11=h7f8727e_2
|
||||||
|
- pip=23.2.1=py311h06a4308_0
|
||||||
|
- python=3.11.5=h955ad1f_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=68.0.0=py311h06a4308_0
|
||||||
|
- sqlite=3.41.2=h5eee18b_0
|
||||||
|
- tk=8.6.12=h1ccaba5_0
|
||||||
|
- wheel=0.41.2=py311h06a4308_0
|
||||||
|
- xz=5.4.2=h5eee18b_0
|
||||||
|
- zlib=1.2.13=h5eee18b_0
|
||||||
|
- pip:
|
||||||
|
- --pre
|
||||||
|
- --extra-index-url https://download.pytorch.org/whl/nightly/
|
||||||
|
- accelerate==0.23.0
|
||||||
|
- aiohttp==3.8.5
|
||||||
|
- aiosignal==1.3.1
|
||||||
|
- async-timeout==4.0.3
|
||||||
|
- attrs==23.1.0
|
||||||
|
- bark==0.1.5
|
||||||
|
- boto3==1.28.61
|
||||||
|
- botocore==1.31.61
|
||||||
|
- certifi==2023.7.22
|
||||||
|
- TTS==0.22.0
|
||||||
|
- charset-normalizer==3.3.0
|
||||||
|
- datasets==2.14.5
|
||||||
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
|
- sentencepiece==0.1.99
|
||||||
|
- dill==0.3.7
|
||||||
|
- einops==0.7.0
|
||||||
|
- encodec==0.1.1
|
||||||
|
- filelock==3.12.4
|
||||||
|
- frozenlist==1.4.0
|
||||||
|
- fsspec==2023.6.0
|
||||||
|
- funcy==2.0
|
||||||
|
- grpcio==1.59.0
|
||||||
|
- huggingface-hub
|
||||||
|
- idna==3.4
|
||||||
|
- jinja2==3.1.2
|
||||||
|
- jmespath==1.0.1
|
||||||
|
- markupsafe==2.1.3
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- multidict==6.0.4
|
||||||
|
- multiprocess==0.70.15
|
||||||
|
- networkx
|
||||||
|
- numpy==1.26.0
|
||||||
|
- packaging==23.2
|
||||||
|
- pandas
|
||||||
|
- peft==0.5.0
|
||||||
|
- protobuf==4.24.4
|
||||||
|
- psutil==5.9.5
|
||||||
|
- pyarrow==13.0.0
|
||||||
|
- python-dateutil==2.8.2
|
||||||
|
- pytz==2023.3.post1
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- regex==2023.10.3
|
||||||
|
- requests==2.31.0
|
||||||
|
- rouge==1.0.1
|
||||||
|
- s3transfer==0.7.0
|
||||||
|
- safetensors>=0.4.1
|
||||||
|
- scipy==1.12.0 # Updated Version
|
||||||
|
- six==1.16.0
|
||||||
|
- sympy==1.12
|
||||||
|
- tokenizers
|
||||||
|
- torch
|
||||||
|
- torchaudio
|
||||||
|
- tqdm==4.66.1
|
||||||
|
- triton==2.1.0
|
||||||
|
- typing-extensions==4.8.0
|
||||||
|
- tzdata==2023.3
|
||||||
|
- auto-gptq==0.6.0
|
||||||
|
- urllib3==1.26.17
|
||||||
|
- xxhash==3.4.1
|
||||||
|
- yarl==1.9.2
|
||||||
|
- soundfile
|
||||||
|
- langid
|
||||||
|
- wget
|
||||||
|
- unidecode
|
||||||
|
- pyopenjtalk-prebuilt
|
||||||
|
- pypinyin
|
||||||
|
- inflect
|
||||||
|
- cn2an
|
||||||
|
- jieba
|
||||||
|
- eng_to_ipa
|
||||||
|
- openai-whisper
|
||||||
|
- matplotlib
|
||||||
|
- gradio==3.41.2
|
||||||
|
- nltk
|
||||||
|
- sudachipy
|
||||||
|
- sudachidict_core
|
||||||
|
- vocos
|
||||||
|
- vllm==0.3.2
|
||||||
|
- transformers>=4.38.2 # Updated Version
|
||||||
|
- xformers==0.0.23.post1
|
||||||
|
prefix: /opt/conda/envs/transformers
|
||||||
@@ -36,7 +36,7 @@ dependencies:
|
|||||||
- TTS==0.22.0
|
- TTS==0.22.0
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
- sentence-transformers==2.2.2
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
- sentencepiece==0.1.99
|
- sentencepiece==0.1.99
|
||||||
- dill==0.3.7
|
- dill==0.3.7
|
||||||
- einops==0.7.0
|
- einops==0.7.0
|
||||||
@@ -69,8 +69,8 @@ dependencies:
|
|||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- s3transfer==0.7.0
|
- s3transfer==0.7.0
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.4.1
|
||||||
- scipy==1.11.3
|
- scipy==1.12.0 # Updated Version
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
@@ -101,7 +101,7 @@ dependencies:
|
|||||||
- sudachipy
|
- sudachipy
|
||||||
- sudachidict_core
|
- sudachidict_core
|
||||||
- vocos
|
- vocos
|
||||||
- vllm==0.2.7
|
- vllm==0.3.2
|
||||||
- transformers>=4.36.0 # Required for Mixtral.
|
- transformers>=4.38.2 # Updated Version
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,4 +1,15 @@
|
|||||||
CONDA_ENV_PATH = "diffusers.yml"
|
export CONDA_ENV_PATH = "diffusers.yml"
|
||||||
|
|
||||||
|
ifeq ($(BUILD_TYPE), hipblas)
|
||||||
|
export CONDA_ENV_PATH = "diffusers-rocm.yml"
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Intel GPU are supposed to have dependencies installed in the main python
|
||||||
|
# environment, so we skip conda installation for SYCL builds.
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: diffusers
|
.PHONY: diffusers
|
||||||
diffusers:
|
diffusers:
|
||||||
@@ -12,4 +23,4 @@ run:
|
|||||||
@echo "Diffusers run."
|
@echo "Diffusers run."
|
||||||
|
|
||||||
test:
|
test:
|
||||||
bash test.sh
|
bash test.sh
|
||||||
|
|||||||
@@ -21,14 +21,15 @@ from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipelin
|
|||||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||||
from diffusers.utils import load_image,export_to_video
|
from diffusers.utils import load_image,export_to_video
|
||||||
from compel import Compel
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
|
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
COMPEL=os.environ.get("COMPEL", "0") == "1"
|
||||||
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||||
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
||||||
@@ -36,6 +37,10 @@ FPS=os.environ.get("FPS", "7")
|
|||||||
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
||||||
FRAMES=os.environ.get("FRAMES", "64")
|
FRAMES=os.environ.get("FRAMES", "64")
|
||||||
|
|
||||||
|
if XPU:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
print(ipex.xpu.get_device_name(0))
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
@@ -231,8 +236,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.SchedulerType != "":
|
if request.SchedulerType != "":
|
||||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||||
|
|
||||||
if not self.img2vid:
|
if COMPEL:
|
||||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
self.compel = Compel(
|
||||||
|
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2 ],
|
||||||
|
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
|
||||||
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
||||||
|
requires_pooled=[False, True]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if request.ControlNet:
|
if request.ControlNet:
|
||||||
@@ -247,6 +257,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
self.pipe.to('cuda')
|
self.pipe.to('cuda')
|
||||||
if self.controlnet:
|
if self.controlnet:
|
||||||
self.controlnet.to('cuda')
|
self.controlnet.to('cuda')
|
||||||
|
if XPU:
|
||||||
|
self.pipe = self.pipe.to("xpu")
|
||||||
# Assume directory from request.ModelFile.
|
# Assume directory from request.ModelFile.
|
||||||
# Only if request.LoraAdapter it's not an absolute path
|
# Only if request.LoraAdapter it's not an absolute path
|
||||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||||
@@ -386,8 +398,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
image = {}
|
image = {}
|
||||||
if COMPEL:
|
if COMPEL:
|
||||||
conditioning = self.compel.build_conditioning_tensor(prompt)
|
conditioning, pooled = self.compel.build_conditioning_tensor(prompt)
|
||||||
kwargs["prompt_embeds"]= conditioning
|
kwargs["prompt_embeds"] = conditioning
|
||||||
|
kwargs["pooled_prompt_embeds"] = pooled
|
||||||
# pass the kwargs dictionary to the self.pipe method
|
# pass the kwargs dictionary to the self.pipe method
|
||||||
image = self.pipe(
|
image = self.pipe(
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
64
backend/python/diffusers/diffusers-rocm.yml
Normal file
64
backend/python/diffusers/diffusers-rocm.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
name: diffusers
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=5.1=1_gnu
|
||||||
|
- bzip2=1.0.8=h7b6447c_0
|
||||||
|
- ca-certificates=2023.08.22=h06a4308_0
|
||||||
|
- ld_impl_linux-64=2.38=h1181459_1
|
||||||
|
- libffi=3.4.4=h6a678d5_0
|
||||||
|
- libgcc-ng=11.2.0=h1234567_1
|
||||||
|
- libgomp=11.2.0=h1234567_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- libuuid=1.41.5=h5eee18b_0
|
||||||
|
- ncurses=6.4=h6a678d5_0
|
||||||
|
- openssl=3.0.11=h7f8727e_2
|
||||||
|
- pip=23.2.1=py311h06a4308_0
|
||||||
|
- python=3.11.5=h955ad1f_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=68.0.0=py311h06a4308_0
|
||||||
|
- sqlite=3.41.2=h5eee18b_0
|
||||||
|
- tk=8.6.12=h1ccaba5_0
|
||||||
|
- tzdata=2023c=h04d1e81_0
|
||||||
|
- wheel=0.41.2=py311h06a4308_0
|
||||||
|
- xz=5.4.2=h5eee18b_0
|
||||||
|
- zlib=1.2.13=h5eee18b_0
|
||||||
|
- pip:
|
||||||
|
- --pre
|
||||||
|
- --extra-index-url https://download.pytorch.org/whl/nightly/
|
||||||
|
- accelerate>=0.11.0
|
||||||
|
- certifi==2023.7.22
|
||||||
|
- charset-normalizer==3.3.0
|
||||||
|
- compel==2.0.2
|
||||||
|
- diffusers==0.24.0
|
||||||
|
- filelock==3.12.4
|
||||||
|
- fsspec==2023.9.2
|
||||||
|
- grpcio==1.59.0
|
||||||
|
- huggingface-hub>=0.19.4
|
||||||
|
- idna==3.4
|
||||||
|
- importlib-metadata==6.8.0
|
||||||
|
- jinja2==3.1.2
|
||||||
|
- markupsafe==2.1.3
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- networkx==3.1
|
||||||
|
- numpy==1.26.0
|
||||||
|
- omegaconf
|
||||||
|
- packaging==23.2
|
||||||
|
- pillow==10.0.1
|
||||||
|
- protobuf==4.24.4
|
||||||
|
- psutil==5.9.5
|
||||||
|
- pyparsing==3.1.1
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- regex==2023.10.3
|
||||||
|
- requests==2.31.0
|
||||||
|
- safetensors==0.4.0
|
||||||
|
- sympy==1.12
|
||||||
|
- tqdm==4.66.1
|
||||||
|
- transformers>=4.25.1
|
||||||
|
- triton==2.1.0
|
||||||
|
- typing-extensions==4.8.0
|
||||||
|
- urllib3==2.0.6
|
||||||
|
- zipp==3.17.0
|
||||||
|
- torch
|
||||||
|
prefix: /opt/conda/envs/diffusers
|
||||||
@@ -71,4 +71,4 @@ dependencies:
|
|||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- urllib3==2.0.6
|
- urllib3==2.0.6
|
||||||
- zipp==3.17.0
|
- zipp==3.17.0
|
||||||
prefix: /opt/conda/envs/diffusers
|
prefix: /opt/conda/envs/diffusers
|
||||||
|
|||||||
@@ -1,24 +1,50 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
|
|
||||||
# Check if environment exist
|
# Check if environment exist
|
||||||
conda_env_exists(){
|
conda_env_exists(){
|
||||||
! conda list --name "${@}" >/dev/null 2>/dev/null
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
if conda_env_exists "diffusers" ; then
|
if [ $SKIP_CONDA -eq 1 ]; then
|
||||||
echo "Creating virtual environment..."
|
echo "Skipping conda environment installation"
|
||||||
conda env create --name diffusers --file $1
|
else
|
||||||
echo "Virtual environment created."
|
export PATH=$PATH:/opt/conda/bin
|
||||||
else
|
if conda_env_exists "diffusers" ; then
|
||||||
echo "Virtual environment already exists."
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name diffusers --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Intel GPU: If the directory exists, we assume we are using the Intel image
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
pip install torch==2.1.0a0 \
|
||||||
|
torchvision==0.16.0a0 \
|
||||||
|
torchaudio==2.1.0a0 \
|
||||||
|
intel-extension-for-pytorch==2.1.10+xpu \
|
||||||
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
|
||||||
|
pip install google-api-python-client \
|
||||||
|
grpcio \
|
||||||
|
grpcio-tools \
|
||||||
|
diffusers==0.24.0 \
|
||||||
|
transformers>=4.25.1 \
|
||||||
|
accelerate \
|
||||||
|
compel==2.0.2 \
|
||||||
|
Pillow
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ $SKIP_CONDA -ne 1 ]; then
|
||||||
|
# Activate conda environment
|
||||||
# Activate conda environment
|
source activate diffusers
|
||||||
source activate diffusers
|
fi
|
||||||
|
|
||||||
pip cache purge
|
pip cache purge
|
||||||
fi
|
fi
|
||||||
@@ -3,10 +3,15 @@
|
|||||||
##
|
##
|
||||||
## A bash script wrapper that runs the diffusers server with conda
|
## A bash script wrapper that runs the diffusers server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Assumes we are using the Intel oneAPI container image
|
||||||
# Activate conda environment
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
source activate diffusers
|
export XPU=1
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
# Activate conda environment
|
||||||
|
source activate diffusers
|
||||||
|
fi
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
export CONDA_ENV_PATH = "exllama.yml"
|
||||||
|
|
||||||
.PHONY: exllama
|
.PHONY: exllama
|
||||||
exllama:
|
exllama:
|
||||||
$(MAKE) -C ../common-env/transformers
|
bash install.sh ${CONDA_ENV_PATH}
|
||||||
bash install.sh
|
|
||||||
|
|
||||||
.PHONY: run
|
.PHONY: run
|
||||||
run:
|
run:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,14 +1,27 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
source activate transformers
|
echo "[exllama] Attention!!! Nvidia GPU is required - skipping installation"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
# Check if environment exist
|
||||||
|
conda_env_exists(){
|
||||||
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
if conda_env_exists "exllama" ; then
|
||||||
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name exllama --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
|
||||||
|
source activate exllama
|
||||||
|
|
||||||
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd
|
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
##
|
##
|
||||||
## A bash script wrapper that runs the exllama server with conda
|
## A bash script wrapper that runs the exllama server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate transformers
|
source activate exllama
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,10 +2,14 @@
|
|||||||
set -e
|
set -e
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
export SHA=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
|
export SHA=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
|
||||||
|
|
||||||
# Activate conda environment
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
|
echo "[exllamav2] Attention!!! Nvidia GPU is required - skipping installation"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
source activate transformers
|
source activate transformers
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
echo $CONDA_PREFIX
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,13 +2,14 @@
|
|||||||
set -e
|
set -e
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
if [ "$BUILD_TYPE" != "cublas" ]; then
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
echo "[mamba] Attention!!! nvcc is required - skipping installation"
|
echo "[mamba] Attention!!! nvcc is required - skipping installation"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate transformers
|
source activate transformers
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
.PHONY: petals
|
.PHONY: petals
|
||||||
petals:
|
petals:
|
||||||
@echo "Creating virtual environment..."
|
@echo "Creating virtual environment..."
|
||||||
@conda env create --name petals --file petals.yml
|
bash install.sh "petals.yml"
|
||||||
@echo "Virtual environment created."
|
@echo "Virtual environment created."
|
||||||
|
|
||||||
.PHONY: run
|
.PHONY: run
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
5
backend/python/petals/install.sh
Normal file
5
backend/python/petals/install.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
|
conda env create --name petals --file $1
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -3,10 +3,16 @@
|
|||||||
##
|
##
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
## A bash script wrapper that runs the transformers server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
if [ -d "/opt/intel" ]; then
|
||||||
source activate transformers
|
# Assumes we are using the Intel oneAPI container image
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
export XPU=1
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
# Activate conda environment
|
||||||
|
source activate transformers
|
||||||
|
fi
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
@@ -16,7 +16,15 @@ import backend_pb2_grpc
|
|||||||
import grpc
|
import grpc
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
|
|
||||||
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
|
if XPU:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer, AutoModel, set_seed
|
||||||
|
else:
|
||||||
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@@ -69,12 +77,25 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
model_name = request.Model
|
model_name = request.Model
|
||||||
try:
|
try:
|
||||||
if request.Type == "AutoModelForCausalLM":
|
if request.Type == "AutoModelForCausalLM":
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
if XPU:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode,
|
||||||
|
device_map="xpu", load_in_4bit=True)
|
||||||
|
else:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
else:
|
else:
|
||||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
self.CUDA = False
|
self.CUDA = False
|
||||||
|
self.XPU = False
|
||||||
|
|
||||||
|
if XPU:
|
||||||
|
self.XPU = True
|
||||||
|
try:
|
||||||
|
print("Optimizing model", model_name, "to XPU.", file=sys.stderr)
|
||||||
|
self.model = ipex.optimize_transformers(self.model, inplace=True, dtype=torch.float16, device="xpu")
|
||||||
|
except Exception as err:
|
||||||
|
print("Not using XPU:", err, file=sys.stderr)
|
||||||
|
|
||||||
if request.CUDA or torch.cuda.is_available():
|
if request.CUDA or torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
@@ -139,6 +160,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids
|
inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids
|
||||||
if self.CUDA:
|
if self.CUDA:
|
||||||
inputs = inputs.to("cuda")
|
inputs = inputs.to("cuda")
|
||||||
|
if XPU:
|
||||||
|
inputs = inputs.to("xpu")
|
||||||
|
|
||||||
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: ttsvalle
|
.PHONY: ttsvalle
|
||||||
ttsvalle:
|
ttsvalle:
|
||||||
$(MAKE) -C ../common-env/transformers
|
$(MAKE) -C ../common-env/transformers
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,13 +2,16 @@
|
|||||||
|
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
|
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
|
||||||
|
|
||||||
# Activate conda environment
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
if [ $SKIP_CONDA -ne 1 ]; then
|
||||||
|
source activate transformers
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
CONDA_PREFIX=$PWD
|
||||||
|
fi
|
||||||
|
|
||||||
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && popd
|
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && popd
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ dependencies:
|
|||||||
- pypinyin==0.49.0
|
- pypinyin==0.49.0
|
||||||
- python-multipart==0.0.6
|
- python-multipart==0.0.6
|
||||||
- regex==2023.10.3
|
- regex==2023.10.3
|
||||||
- safetensors==0.4.0
|
- safetensors>=0.4.0
|
||||||
- semantic-version==2.10.0
|
- semantic-version==2.10.0
|
||||||
- soundfile==0.12.1
|
- soundfile==0.12.1
|
||||||
- starlette==0.27.0
|
- starlette==0.27.0
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import asyncio
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -10,7 +10,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
from vllm import LLM, SamplingParams
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@@ -79,16 +82,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The load model result.
|
backend_pb2.Result: The load model result.
|
||||||
"""
|
"""
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=request.Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.Quantization != "":
|
||||||
|
engine_args.quantization = request.Quantization
|
||||||
|
if request.GPUMemoryUtilization != 0:
|
||||||
|
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
|
||||||
|
if request.TrustRemoteCode:
|
||||||
|
engine_args.trust_remote_code = request.TrustRemoteCode
|
||||||
|
if request.EnforceEager:
|
||||||
|
engine_args.enforce_eager = request.EnforceEager
|
||||||
|
if request.SwapSpace != 0:
|
||||||
|
engine_args.swap_space = request.SwapSpace
|
||||||
|
if request.MaxModelLen != 0:
|
||||||
|
engine_args.max_model_len = request.MaxModelLen
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if request.Quantization != "":
|
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
self.llm = LLM(model=request.Model, quantization=request.Quantization)
|
|
||||||
else:
|
|
||||||
self.llm = LLM(model=request.Model)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
def Predict(self, request, context):
|
async def Predict(self, request, context):
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters.
|
Generates text based on the given prompt and sampling parameters.
|
||||||
|
|
||||||
@@ -99,24 +116,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Reply: The predict result.
|
backend_pb2.Reply: The predict result.
|
||||||
"""
|
"""
|
||||||
if request.TopP == 0:
|
gen = self._predict(request, context, streaming=False)
|
||||||
request.TopP = 0.9
|
res = await gen.__anext__()
|
||||||
|
return res
|
||||||
|
|
||||||
max_tokens = 200
|
async def PredictStream(self, request, context):
|
||||||
if request.Tokens > 0:
|
|
||||||
max_tokens = request.Tokens
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
|
||||||
outputs = self.llm.generate([request.Prompt], sampling_params)
|
|
||||||
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in generated_text:
|
|
||||||
generated_text = generated_text.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||||
|
|
||||||
@@ -127,30 +131,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The predict stream result.
|
backend_pb2.Result: The predict stream result.
|
||||||
"""
|
"""
|
||||||
yield self.Predict(request, context)
|
iterations = self._predict(request, context, streaming=True)
|
||||||
|
try:
|
||||||
|
async for iteration in iterations:
|
||||||
|
yield iteration
|
||||||
|
finally:
|
||||||
|
await iterations.aclose()
|
||||||
|
|
||||||
def serve(address):
|
async def _predict(self, request, context, streaming=False):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
|
# Build sampling parameters
|
||||||
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||||
|
if request.TopP != 0:
|
||||||
|
sampling_params.top_p = request.TopP
|
||||||
|
if request.Tokens > 0:
|
||||||
|
sampling_params.max_tokens = request.Tokens
|
||||||
|
if request.Temperature != 0:
|
||||||
|
sampling_params.temperature = request.Temperature
|
||||||
|
if request.TopK != 0:
|
||||||
|
sampling_params.top_k = request.TopK
|
||||||
|
if request.PresencePenalty != 0:
|
||||||
|
sampling_params.presence_penalty = request.PresencePenalty
|
||||||
|
if request.FrequencyPenalty != 0:
|
||||||
|
sampling_params.frequency_penalty = request.FrequencyPenalty
|
||||||
|
if request.StopPrompts:
|
||||||
|
sampling_params.stop = request.StopPrompts
|
||||||
|
if request.IgnoreEOS:
|
||||||
|
sampling_params.ignore_eos = request.IgnoreEOS
|
||||||
|
if request.Seed != 0:
|
||||||
|
sampling_params.seed = request.Seed
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
request_id = random_uuid()
|
||||||
|
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
|
||||||
|
|
||||||
|
# Stream the results
|
||||||
|
generated_text = ""
|
||||||
|
try:
|
||||||
|
async for request_output in outputs:
|
||||||
|
iteration_text = request_output.outputs[0].text
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# Remove text already sent as vllm concatenates the text from previous yields
|
||||||
|
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||||
|
# Send the partial result
|
||||||
|
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
# Keep track of text generated
|
||||||
|
generated_text = iteration_text
|
||||||
|
finally:
|
||||||
|
await outputs.aclose()
|
||||||
|
|
||||||
|
# If streaming, we already sent everything
|
||||||
|
if streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sending the final generated text
|
||||||
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
async def serve(address):
|
||||||
|
# Start asyncio gRPC server
|
||||||
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
# Bind the server to the address
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
|
||||||
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(
|
||||||
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
await server.start()
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
# Wait for the server to be terminated
|
||||||
# Define the signal handler function
|
await server.wait_for_termination()
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
@@ -159,4 +217,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
serve(args.addr)
|
asyncio.run(serve(args.addr))
|
||||||
0
configuration/.keep
Normal file
0
configuration/.keep
Normal file
@@ -3,36 +3,36 @@ package backend
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||||
if !c.Embeddings {
|
if !backendConfig.Embeddings {
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile := c.Model
|
modelFile := backendConfig.Model
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
|
||||||
var inferenceModel interface{}
|
var inferenceModel interface{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
model.WithThreads(uint32(c.Threads)),
|
model.WithThreads(uint32(*backendConfig.Threads)),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(appConfig.Context),
|
||||||
})
|
})
|
||||||
|
|
||||||
if c.Backend == "" {
|
if backendConfig.Backend == "" {
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
} else {
|
} else {
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
opts = append(opts, model.WithBackendString(backendConfig.Backend))
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
switch model := inferenceModel.(type) {
|
switch model := inferenceModel.(type) {
|
||||||
case grpc.Backend:
|
case grpc.Backend:
|
||||||
fn = func() ([]float32, error) {
|
fn = func() ([]float32, error) {
|
||||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
embeds := []int32{}
|
embeds := []int32{}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
}
|
}
|
||||||
predictOptions.EmbeddingTokens = embeds
|
predictOptions.EmbeddingTokens = embeds
|
||||||
|
|
||||||
res, err := model.Embeddings(o.Context, predictOptions)
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
}
|
}
|
||||||
predictOptions.Embeddings = s
|
predictOptions.Embeddings = s
|
||||||
|
|
||||||
res, err := model.Embeddings(o.Context, predictOptions)
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
52
core/backend/image.go
Normal file
52
core/backend/image.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||||
|
threads := backendConfig.Threads
|
||||||
|
if *threads == 0 && appConfig.Threads != 0 {
|
||||||
|
threads = &appConfig.Threads
|
||||||
|
}
|
||||||
|
gRPCOpts := gRPCModelOpts(backendConfig)
|
||||||
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(backendConfig.Backend),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
model.WithThreads(uint32(*threads)),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithModel(backendConfig.Model),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
|
||||||
|
})
|
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader(
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := func() error {
|
||||||
|
_, err := inferenceModel.GenerateImage(
|
||||||
|
appConfig.Context,
|
||||||
|
&proto.GenerateImageRequest{
|
||||||
|
Height: int32(height),
|
||||||
|
Width: int32(width),
|
||||||
|
Mode: int32(mode),
|
||||||
|
Step: int32(step),
|
||||||
|
Seed: int32(seed),
|
||||||
|
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
|
||||||
|
PositivePrompt: positive_prompt,
|
||||||
|
NegativePrompt: negative_prompt,
|
||||||
|
Dst: dst,
|
||||||
|
Src: src,
|
||||||
|
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn, nil
|
||||||
|
}
|
||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -26,9 +26,12 @@ type TokenUsage struct {
|
|||||||
Completion int
|
Completion int
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||||
modelFile := c.Model
|
modelFile := c.Model
|
||||||
|
threads := c.Threads
|
||||||
|
if *threads == 0 && o.Threads != 0 {
|
||||||
|
threads = &o.Threads
|
||||||
|
}
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(c)
|
||||||
|
|
||||||
var inferenceModel grpc.Backend
|
var inferenceModel grpc.Backend
|
||||||
@@ -36,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
|
|||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
opts := modelOpts(c, o, []model.Option{
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
@@ -140,7 +143,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
|
|||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||||
var mu sync.Mutex = sync.Mutex{}
|
var mu sync.Mutex = sync.Mutex{}
|
||||||
|
|
||||||
func Finetune(config config.Config, input, prediction string) string {
|
func Finetune(config config.BackendConfig, input, prediction string) string {
|
||||||
if config.Echo {
|
if config.Echo {
|
||||||
prediction = input + prediction
|
prediction = input + prediction
|
||||||
}
|
}
|
||||||
141
core/backend/options.go
Normal file
141
core/backend/options.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
|
||||||
|
if so.SingleBackend {
|
||||||
|
opts = append(opts, model.WithSingleActiveBackend())
|
||||||
|
}
|
||||||
|
|
||||||
|
if so.ParallelBackendRequests {
|
||||||
|
opts = append(opts, model.EnableParallelRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.GRPC.Attempts != 0 {
|
||||||
|
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.GRPC.AttemptsSleepTime != 0 {
|
||||||
|
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range so.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
|
b := 512
|
||||||
|
if c.Batch != 0 {
|
||||||
|
b = c.Batch
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pb.ModelOptions{
|
||||||
|
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||||
|
SchedulerType: c.Diffusers.SchedulerType,
|
||||||
|
PipelineType: c.Diffusers.PipelineType,
|
||||||
|
CFGScale: c.Diffusers.CFGScale,
|
||||||
|
LoraAdapter: c.LoraAdapter,
|
||||||
|
LoraScale: c.LoraScale,
|
||||||
|
F16Memory: *c.F16,
|
||||||
|
LoraBase: c.LoraBase,
|
||||||
|
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||||
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
|
ControlNet: c.Diffusers.ControlNet,
|
||||||
|
ContextSize: int32(*c.ContextSize),
|
||||||
|
Seed: int32(*c.Seed),
|
||||||
|
NBatch: int32(b),
|
||||||
|
NoMulMatQ: c.NoMulMatQ,
|
||||||
|
DraftModel: c.DraftModel,
|
||||||
|
AudioPath: c.VallE.AudioPath,
|
||||||
|
Quantization: c.Quantization,
|
||||||
|
GPUMemoryUtilization: c.GPUMemoryUtilization,
|
||||||
|
TrustRemoteCode: c.TrustRemoteCode,
|
||||||
|
EnforceEager: c.EnforceEager,
|
||||||
|
SwapSpace: int32(c.SwapSpace),
|
||||||
|
MaxModelLen: int32(c.MaxModelLen),
|
||||||
|
MMProj: c.MMProj,
|
||||||
|
YarnExtFactor: c.YarnExtFactor,
|
||||||
|
YarnAttnFactor: c.YarnAttnFactor,
|
||||||
|
YarnBetaFast: c.YarnBetaFast,
|
||||||
|
YarnBetaSlow: c.YarnBetaSlow,
|
||||||
|
NGQA: c.NGQA,
|
||||||
|
RMSNormEps: c.RMSNormEps,
|
||||||
|
MLock: *c.MMlock,
|
||||||
|
RopeFreqBase: c.RopeFreqBase,
|
||||||
|
RopeScaling: c.RopeScaling,
|
||||||
|
Type: c.ModelType,
|
||||||
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
|
NUMA: c.NUMA,
|
||||||
|
Embeddings: c.Embeddings,
|
||||||
|
LowVRAM: *c.LowVRAM,
|
||||||
|
NGPULayers: int32(*c.NGPULayers),
|
||||||
|
MMap: *c.MMap,
|
||||||
|
MainGPU: c.MainGPU,
|
||||||
|
Threads: int32(*c.Threads),
|
||||||
|
TensorSplit: c.TensorSplit,
|
||||||
|
// AutoGPTQ
|
||||||
|
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
||||||
|
Device: c.AutoGPTQ.Device,
|
||||||
|
UseTriton: c.AutoGPTQ.Triton,
|
||||||
|
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
|
||||||
|
// RWKV
|
||||||
|
Tokenizer: c.Tokenizer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
|
||||||
|
promptCachePath := ""
|
||||||
|
if c.PromptCachePath != "" {
|
||||||
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||||
|
os.MkdirAll(filepath.Dir(p), 0755)
|
||||||
|
promptCachePath = p
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pb.PredictOptions{
|
||||||
|
Temperature: float32(*c.Temperature),
|
||||||
|
TopP: float32(*c.TopP),
|
||||||
|
NDraft: c.NDraft,
|
||||||
|
TopK: int32(*c.TopK),
|
||||||
|
Tokens: int32(*c.Maxtokens),
|
||||||
|
Threads: int32(*c.Threads),
|
||||||
|
PromptCacheAll: c.PromptCacheAll,
|
||||||
|
PromptCacheRO: c.PromptCacheRO,
|
||||||
|
PromptCachePath: promptCachePath,
|
||||||
|
F16KV: *c.F16,
|
||||||
|
DebugMode: *c.Debug,
|
||||||
|
Grammar: c.Grammar,
|
||||||
|
NegativePromptScale: c.NegativePromptScale,
|
||||||
|
RopeFreqBase: c.RopeFreqBase,
|
||||||
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
|
NegativePrompt: c.NegativePrompt,
|
||||||
|
Mirostat: int32(*c.LLMConfig.Mirostat),
|
||||||
|
MirostatETA: float32(*c.LLMConfig.MirostatETA),
|
||||||
|
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
|
||||||
|
Debug: *c.Debug,
|
||||||
|
StopPrompts: c.StopWords,
|
||||||
|
Repeat: int32(c.RepeatPenalty),
|
||||||
|
NKeep: int32(c.Keep),
|
||||||
|
Batch: int32(c.Batch),
|
||||||
|
IgnoreEOS: c.IgnoreEOS,
|
||||||
|
Seed: int32(*c.Seed),
|
||||||
|
FrequencyPenalty: float32(c.FrequencyPenalty),
|
||||||
|
MLock: *c.MMlock,
|
||||||
|
MMap: *c.MMap,
|
||||||
|
MainGPU: c.MainGPU,
|
||||||
|
TensorSplit: c.TensorSplit,
|
||||||
|
TailFreeSamplingZ: float32(c.TFZ),
|
||||||
|
TypicalP: float32(c.TypicalP),
|
||||||
|
}
|
||||||
|
}
|
||||||
38
core/backend/transcript.go
Normal file
38
core/backend/transcript.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
|
||||||
|
|
||||||
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(model.WhisperBackend),
|
||||||
|
model.WithModel(backendConfig.Model),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithThreads(uint32(*backendConfig.Threads)),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
})
|
||||||
|
|
||||||
|
whisperModel, err := ml.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if whisperModel == nil {
|
||||||
|
return nil, fmt.Errorf("could not load whisper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
|
Dst: audio,
|
||||||
|
Language: language,
|
||||||
|
Threads: uint32(*backendConfig.Threads),
|
||||||
|
})
|
||||||
|
}
|
||||||
89
core/backend/tts.go
Normal file
89
core/backend/tts.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||||
|
counter := 1
|
||||||
|
fileName := baseName + ext
|
||||||
|
|
||||||
|
for {
|
||||||
|
filePath := filepath.Join(dir, fileName)
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return fileName
|
||||||
|
}
|
||||||
|
|
||||||
|
counter++
|
||||||
|
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
|
||||||
|
bb := backend
|
||||||
|
if bb == "" {
|
||||||
|
bb = model.PiperBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
|
||||||
|
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||||
|
model.WithBackendString(bb),
|
||||||
|
model.WithModel(modelFile),
|
||||||
|
model.WithContext(appConfig.Context),
|
||||||
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
|
})
|
||||||
|
ttsModel, err := loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttsModel == nil {
|
||||||
|
return "", nil, fmt.Errorf("could not load piper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
||||||
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
|
modelPath := ""
|
||||||
|
if modelFile != "" {
|
||||||
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
|
// Checking first that it exists and is not outside ModelPath
|
||||||
|
// TODO: we should actually first check if the modelFile is looking like
|
||||||
|
// a FS path
|
||||||
|
mp := filepath.Join(loader.ModelPath, modelFile)
|
||||||
|
if _, err := os.Stat(mp); err == nil {
|
||||||
|
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
modelPath = mp
|
||||||
|
} else {
|
||||||
|
modelPath = modelFile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||||
|
Text: text,
|
||||||
|
Model: modelPath,
|
||||||
|
Voice: voice,
|
||||||
|
Dst: filePath,
|
||||||
|
})
|
||||||
|
|
||||||
|
return filePath, res, err
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package options
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,27 +6,25 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Option struct {
|
type ApplicationConfig struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
ConfigFile string
|
ConfigFile string
|
||||||
Loader *model.ModelLoader
|
ModelPath string
|
||||||
UploadLimitMB, Threads, ContextSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
F16 bool
|
F16 bool
|
||||||
Debug, DisableMessage bool
|
Debug, DisableMessage bool
|
||||||
ImageDir string
|
ImageDir string
|
||||||
AudioDir string
|
AudioDir string
|
||||||
|
UploadDir string
|
||||||
CORS bool
|
CORS bool
|
||||||
PreloadJSONModels string
|
PreloadJSONModels string
|
||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
CORSAllowOrigins string
|
CORSAllowOrigins string
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
Metrics *metrics.Metrics
|
|
||||||
|
|
||||||
ModelLibraryURL string
|
ModelLibraryURL string
|
||||||
|
|
||||||
@@ -51,10 +49,10 @@ type Option struct {
|
|||||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*ApplicationConfig)
|
||||||
|
|
||||||
func NewOptions(o ...AppOption) *Option {
|
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||||
opt := &Option{
|
opt := &ApplicationConfig{
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
UploadLimitMB: 15,
|
UploadLimitMB: 15,
|
||||||
Threads: 1,
|
Threads: 1,
|
||||||
@@ -69,63 +67,69 @@ func NewOptions(o ...AppOption) *Option {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithModelsURL(urls ...string) AppOption {
|
func WithModelsURL(urls ...string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ModelsURL = urls
|
o.ModelsURL = urls
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithModelPath(path string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.ModelPath = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithCors(b bool) AppOption {
|
func WithCors(b bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.CORS = b
|
o.CORS = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLibraryURL(url string) AppOption {
|
func WithModelLibraryURL(url string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ModelLibraryURL = url
|
o.ModelLibraryURL = url
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDog = func(o *Option) {
|
var EnableWatchDog = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogIdleCheck = func(o *Option) {
|
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogIdle = true
|
o.WatchDogIdle = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogBusyCheck = func(o *Option) {
|
var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogBusy = true
|
o.WatchDogBusy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.WatchDogBusyTimeout = t
|
o.WatchDogBusyTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.WatchDogIdleTimeout = t
|
o.WatchDogIdleTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableSingleBackend = func(o *Option) {
|
var EnableSingleBackend = func(o *ApplicationConfig) {
|
||||||
o.SingleBackend = true
|
o.SingleBackend = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableParallelBackendRequests = func(o *Option) {
|
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
|
||||||
o.ParallelBackendRequests = true
|
o.ParallelBackendRequests = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableGalleriesAutoload = func(o *Option) {
|
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
|
||||||
o.AutoloadGalleries = true
|
o.AutoloadGalleries = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithExternalBackend(name string, uri string) AppOption {
|
func WithExternalBackend(name string, uri string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
if o.ExternalGRPCBackends == nil {
|
if o.ExternalGRPCBackends == nil {
|
||||||
o.ExternalGRPCBackends = make(map[string]string)
|
o.ExternalGRPCBackends = make(map[string]string)
|
||||||
}
|
}
|
||||||
@@ -134,27 +138,26 @@ func WithExternalBackend(name string, uri string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithCorsAllowOrigins(b string) AppOption {
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.CORSAllowOrigins = b
|
o.CORSAllowOrigins = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssetsOutput(out string) AppOption {
|
func WithBackendAssetsOutput(out string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.AssetsDestination = out
|
o.AssetsDestination = out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssets(f embed.FS) AppOption {
|
func WithBackendAssets(f embed.FS) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.BackendAssets = f
|
o.BackendAssets = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithStringGalleries(galls string) AppOption {
|
func WithStringGalleries(galls string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
if galls == "" {
|
if galls == "" {
|
||||||
log.Debug().Msgf("no galleries to load")
|
|
||||||
o.Galleries = []gallery.Gallery{}
|
o.Galleries = []gallery.Gallery{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -167,96 +170,96 @@ func WithStringGalleries(galls string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Galleries = append(o.Galleries, galleries...)
|
o.Galleries = append(o.Galleries, galleries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContext(ctx context.Context) AppOption {
|
func WithContext(ctx context.Context) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Context = ctx
|
o.Context = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.PreloadModelsFromPath = configFile
|
o.PreloadModelsFromPath = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithJSONStringPreload(configFile string) AppOption {
|
func WithJSONStringPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.PreloadJSONModels = configFile
|
o.PreloadJSONModels = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func WithConfigFile(configFile string) AppOption {
|
func WithConfigFile(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ConfigFile = configFile
|
o.ConfigFile = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
|
||||||
return func(o *Option) {
|
|
||||||
o.Loader = loader
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithUploadLimitMB(limit int) AppOption {
|
func WithUploadLimitMB(limit int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.UploadLimitMB = limit
|
o.UploadLimitMB = limit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithThreads(threads int) AppOption {
|
func WithThreads(threads int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Threads = threads
|
o.Threads = threads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContextSize(ctxSize int) AppOption {
|
func WithContextSize(ctxSize int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ContextSize = ctxSize
|
o.ContextSize = ctxSize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithF16(f16 bool) AppOption {
|
func WithF16(f16 bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.F16 = f16
|
o.F16 = f16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDebug(debug bool) AppOption {
|
func WithDebug(debug bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Debug = debug
|
o.Debug = debug
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDisableMessage(disableMessage bool) AppOption {
|
func WithDisableMessage(disableMessage bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.DisableMessage = disableMessage
|
o.DisableMessage = disableMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAudioDir(audioDir string) AppOption {
|
func WithAudioDir(audioDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.AudioDir = audioDir
|
o.AudioDir = audioDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithImageDir(imageDir string) AppOption {
|
func WithImageDir(imageDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ImageDir = imageDir
|
o.ImageDir = imageDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithUploadDir(uploadDir string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.UploadDir = uploadDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithApiKeys(apiKeys []string) AppOption {
|
func WithApiKeys(apiKeys []string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ApiKeys = apiKeys
|
o.ApiKeys = apiKeys
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMetrics(meter *metrics.Metrics) AppOption {
|
// func WithMetrics(meter *metrics.Metrics) AppOption {
|
||||||
return func(o *Option) {
|
// return func(o *StartupOptions) {
|
||||||
o.Metrics = meter
|
// o.Metrics = meter
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
@@ -1,27 +1,31 @@
|
|||||||
package api_config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/glamour"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type BackendConfig struct {
|
||||||
PredictionOptions `yaml:"parameters"`
|
schema.PredictionOptions `yaml:"parameters"`
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
|
|
||||||
F16 bool `yaml:"f16"`
|
F16 *bool `yaml:"f16"`
|
||||||
Threads int `yaml:"threads"`
|
Threads *int `yaml:"threads"`
|
||||||
Debug bool `yaml:"debug"`
|
Debug *bool `yaml:"debug"`
|
||||||
Roles map[string]string `yaml:"roles"`
|
Roles map[string]string `yaml:"roles"`
|
||||||
Embeddings bool `yaml:"embeddings"`
|
Embeddings bool `yaml:"embeddings"`
|
||||||
Backend string `yaml:"backend"`
|
Backend string `yaml:"backend"`
|
||||||
@@ -104,29 +108,34 @@ type LLMConfig struct {
|
|||||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
MirostatETA *float64 `yaml:"mirostat_eta"`
|
||||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
MirostatTAU *float64 `yaml:"mirostat_tau"`
|
||||||
Mirostat int `yaml:"mirostat"`
|
Mirostat *int `yaml:"mirostat"`
|
||||||
NGPULayers int `yaml:"gpu_layers"`
|
NGPULayers *int `yaml:"gpu_layers"`
|
||||||
MMap bool `yaml:"mmap"`
|
MMap *bool `yaml:"mmap"`
|
||||||
MMlock bool `yaml:"mmlock"`
|
MMlock *bool `yaml:"mmlock"`
|
||||||
LowVRAM bool `yaml:"low_vram"`
|
LowVRAM *bool `yaml:"low_vram"`
|
||||||
Grammar string `yaml:"grammar"`
|
Grammar string `yaml:"grammar"`
|
||||||
StopWords []string `yaml:"stopwords"`
|
StopWords []string `yaml:"stopwords"`
|
||||||
Cutstrings []string `yaml:"cutstrings"`
|
Cutstrings []string `yaml:"cutstrings"`
|
||||||
TrimSpace []string `yaml:"trimspace"`
|
TrimSpace []string `yaml:"trimspace"`
|
||||||
TrimSuffix []string `yaml:"trimsuffix"`
|
TrimSuffix []string `yaml:"trimsuffix"`
|
||||||
|
|
||||||
ContextSize int `yaml:"context_size"`
|
ContextSize *int `yaml:"context_size"`
|
||||||
NUMA bool `yaml:"numa"`
|
NUMA bool `yaml:"numa"`
|
||||||
LoraAdapter string `yaml:"lora_adapter"`
|
LoraAdapter string `yaml:"lora_adapter"`
|
||||||
LoraBase string `yaml:"lora_base"`
|
LoraBase string `yaml:"lora_base"`
|
||||||
LoraScale float32 `yaml:"lora_scale"`
|
LoraScale float32 `yaml:"lora_scale"`
|
||||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||||
DraftModel string `yaml:"draft_model"`
|
DraftModel string `yaml:"draft_model"`
|
||||||
NDraft int32 `yaml:"n_draft"`
|
NDraft int32 `yaml:"n_draft"`
|
||||||
Quantization string `yaml:"quantization"`
|
Quantization string `yaml:"quantization"`
|
||||||
MMProj string `yaml:"mmproj"`
|
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||||
|
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||||
|
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||||
|
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||||
|
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||||
|
MMProj string `yaml:"mmproj"`
|
||||||
|
|
||||||
RopeScaling string `yaml:"rope_scaling"`
|
RopeScaling string `yaml:"rope_scaling"`
|
||||||
ModelType string `yaml:"type"`
|
ModelType string `yaml:"type"`
|
||||||
@@ -148,6 +157,7 @@ type Functions struct {
|
|||||||
DisableNoAction bool `yaml:"disable_no_action"`
|
DisableNoAction bool `yaml:"disable_no_action"`
|
||||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||||
|
ParallelCalls bool `yaml:"parallel_calls"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TemplateConfig struct {
|
type TemplateConfig struct {
|
||||||
@@ -158,108 +168,209 @@ type TemplateConfig struct {
|
|||||||
Functions string `yaml:"function"`
|
Functions string `yaml:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigLoader struct {
|
func (c *BackendConfig) SetFunctionCallString(s string) {
|
||||||
configs map[string]Config
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) SetFunctionCallString(s string) {
|
|
||||||
c.functionCallString = s
|
c.functionCallString = s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) SetFunctionCallNameString(s string) {
|
func (c *BackendConfig) SetFunctionCallNameString(s string) {
|
||||||
c.functionCallNameString = s
|
c.functionCallNameString = s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) ShouldUseFunctions() bool {
|
func (c *BackendConfig) ShouldUseFunctions() bool {
|
||||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) ShouldCallSpecificFunction() bool {
|
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
|
||||||
return len(c.functionCallNameString) > 0
|
return len(c.functionCallNameString) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) FunctionToCall() string {
|
func (c *BackendConfig) FunctionToCall() string {
|
||||||
return c.functionCallNameString
|
return c.functionCallNameString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load a config file for a model
|
func (cfg *BackendConfig) SetDefaults(debug bool, threads, ctx int, f16 bool) {
|
||||||
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) {
|
defaultTopP := 0.7
|
||||||
// Load a config file if present after the model name
|
defaultTopK := 80
|
||||||
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
defaultTemp := 0.9
|
||||||
|
defaultMaxTokens := 2048
|
||||||
|
defaultMirostat := 2
|
||||||
|
defaultMirostatTAU := 5.0
|
||||||
|
defaultMirostatETA := 0.1
|
||||||
|
|
||||||
var cfg *Config
|
// Try to offload all GPU layers (if GPU is found)
|
||||||
|
defaultNGPULayers := 99999999
|
||||||
|
|
||||||
defaults := func() {
|
trueV := true
|
||||||
cfg = DefaultConfig(modelName)
|
falseV := false
|
||||||
cfg.ContextSize = ctx
|
|
||||||
cfg.Threads = threads
|
if cfg.Seed == nil {
|
||||||
cfg.F16 = f16
|
// random number generator seed
|
||||||
cfg.Debug = debug
|
defaultSeed := int(rand.Int31())
|
||||||
|
cfg.Seed = &defaultSeed
|
||||||
}
|
}
|
||||||
|
|
||||||
cfgExisting, exists := cm.GetConfig(modelName)
|
if cfg.TopK == nil {
|
||||||
if !exists {
|
cfg.TopK = &defaultTopK
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MMap == nil {
|
||||||
|
// MMap is enabled by default
|
||||||
|
cfg.MMap = &trueV
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MMlock == nil {
|
||||||
|
// MMlock is disabled by default
|
||||||
|
cfg.MMlock = &falseV
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.TopP == nil {
|
||||||
|
cfg.TopP = &defaultTopP
|
||||||
|
}
|
||||||
|
if cfg.Temperature == nil {
|
||||||
|
cfg.Temperature = &defaultTemp
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Maxtokens == nil {
|
||||||
|
cfg.Maxtokens = &defaultMaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Mirostat == nil {
|
||||||
|
cfg.Mirostat = &defaultMirostat
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MirostatETA == nil {
|
||||||
|
cfg.MirostatETA = &defaultMirostatETA
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MirostatTAU == nil {
|
||||||
|
cfg.MirostatTAU = &defaultMirostatTAU
|
||||||
|
}
|
||||||
|
if cfg.NGPULayers == nil {
|
||||||
|
cfg.NGPULayers = &defaultNGPULayers
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.LowVRAM == nil {
|
||||||
|
cfg.LowVRAM = &falseV
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value passed by the top level are treated as default (no implicit defaults)
|
||||||
|
// defaults are set by the user
|
||||||
|
if ctx == 0 {
|
||||||
|
ctx = 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ContextSize == nil {
|
||||||
|
cfg.ContextSize = &ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
if threads == 0 {
|
||||||
|
// Threads can't be 0
|
||||||
|
threads = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Threads == nil {
|
||||||
|
cfg.Threads = &threads
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.F16 == nil {
|
||||||
|
cfg.F16 = &f16
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
cfg.Debug = &debug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////// Config Loader ////////
|
||||||
|
|
||||||
|
type BackendConfigLoader struct {
|
||||||
|
configs map[string]BackendConfig
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoadOptions struct {
|
||||||
|
debug bool
|
||||||
|
threads, ctxSize int
|
||||||
|
f16 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionDebug(debug bool) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.debug = debug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionThreads(threads int) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.threads = threads
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.ctxSize = ctxSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionF16(f16 bool) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.f16 = f16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigLoaderOption func(*LoadOptions)
|
||||||
|
|
||||||
|
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
|
||||||
|
for _, l := range options {
|
||||||
|
l(lo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load a config file for a model
|
||||||
|
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||||
|
|
||||||
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
|
// Load a config file if present after the model name
|
||||||
|
cfg := &BackendConfig{
|
||||||
|
PredictionOptions: schema.PredictionOptions{
|
||||||
|
Model: modelName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgExisting, exists := cl.GetBackendConfig(modelName)
|
||||||
|
if exists {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
} else {
|
||||||
|
// Try loading a model config file
|
||||||
|
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
||||||
if _, err := os.Stat(modelConfig); err == nil {
|
if _, err := os.Stat(modelConfig); err == nil {
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
if err := cl.LoadBackendConfig(modelConfig); err != nil {
|
||||||
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||||
}
|
}
|
||||||
cfgExisting, exists = cm.GetConfig(modelName)
|
cfgExisting, exists = cl.GetBackendConfig(modelName)
|
||||||
if exists {
|
if exists {
|
||||||
cfg = &cfgExisting
|
cfg = &cfgExisting
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cfg = &cfgExisting
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
//updateConfig(cfg, input)
|
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if cfg.Threads == 0 {
|
|
||||||
if threads != 0 {
|
|
||||||
cfg.Threads = threads
|
|
||||||
} else {
|
|
||||||
cfg.Threads = 4
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
cfg.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
if debug {
|
|
||||||
cfg.Debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultPredictOptions(modelFile string) PredictionOptions {
|
func NewBackendConfigLoader() *BackendConfigLoader {
|
||||||
return PredictionOptions{
|
return &BackendConfigLoader{
|
||||||
TopP: 0.7,
|
configs: make(map[string]BackendConfig),
|
||||||
TopK: 80,
|
|
||||||
Maxtokens: 512,
|
|
||||||
Temperature: 0.9,
|
|
||||||
Model: modelFile,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
|
||||||
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
func DefaultConfig(modelFile string) *Config {
|
c := &[]*BackendConfig{}
|
||||||
return &Config{
|
|
||||||
PredictionOptions: defaultPredictOptions(modelFile),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfigLoader() *ConfigLoader {
|
|
||||||
return &ConfigLoader{
|
|
||||||
configs: make(map[string]Config),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func ReadConfigFile(file string) ([]*Config, error) {
|
|
||||||
c := &[]*Config{}
|
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
@@ -268,11 +379,18 @@ func ReadConfigFile(file string) ([]*Config, error) {
|
|||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, cc := range *c {
|
||||||
|
cc.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
|
}
|
||||||
|
|
||||||
return *c, nil
|
return *c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) {
|
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||||
c := &Config{}
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
|
c := &BackendConfig{}
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
@@ -281,13 +399,14 @@ func ReadConfig(file string) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cm.Lock()
|
||||||
defer cm.Unlock()
|
defer cm.Unlock()
|
||||||
c, err := ReadConfigFile(file)
|
c, err := ReadBackendConfigFile(file, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot load config file: %w", err)
|
return fmt.Errorf("cannot load config file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -298,49 +417,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
c, err := ReadConfig(file)
|
c, err := ReadBackendConfig(file, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot read config file: %w", err)
|
return fmt.Errorf("cannot read config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.configs[c.Name] = *c
|
cl.configs[c.Name] = *c
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
v, exists := cm.configs[m]
|
v, exists := cl.configs[m]
|
||||||
return v, exists
|
return v, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetAllConfigs() []Config {
|
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
var res []Config
|
var res []BackendConfig
|
||||||
for _, v := range cm.configs {
|
for _, v := range cl.configs {
|
||||||
res = append(res, v)
|
res = append(res, v)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) ListConfigs() []string {
|
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
var res []string
|
var res []string
|
||||||
for k := range cm.configs {
|
for k := range cl.configs {
|
||||||
res = append(res, k)
|
res = append(res, k)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preload prepare models if they are not local but url or huggingface repositories
|
// Preload prepare models if they are not local but url or huggingface repositories
|
||||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
status := func(fileName, current, total string, percent float64) {
|
status := func(fileName, current, total string, percent float64) {
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||||
@@ -348,7 +467,21 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
|
|||||||
|
|
||||||
log.Info().Msgf("Preloading models from %s", modelPath)
|
log.Info().Msgf("Preloading models from %s", modelPath)
|
||||||
|
|
||||||
for i, config := range cm.configs {
|
renderMode := "dark"
|
||||||
|
if os.Getenv("COLOR") != "" {
|
||||||
|
renderMode = os.Getenv("COLOR")
|
||||||
|
}
|
||||||
|
|
||||||
|
glamText := func(t string) {
|
||||||
|
out, err := glamour.Render(t, renderMode)
|
||||||
|
if err == nil && os.Getenv("NO_COLOR") == "" {
|
||||||
|
fmt.Println(out)
|
||||||
|
} else {
|
||||||
|
fmt.Println(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, config := range cl.configs {
|
||||||
|
|
||||||
// Download files and verify their SHA
|
// Download files and verify their SHA
|
||||||
for _, file := range config.DownloadFiles {
|
for _, file := range config.DownloadFiles {
|
||||||
@@ -380,25 +513,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := cm.configs[i]
|
cc := cl.configs[i]
|
||||||
c := &cc
|
c := &cc
|
||||||
c.PredictionOptions.Model = md5Name
|
c.PredictionOptions.Model = md5Name
|
||||||
cm.configs[i] = *c
|
cl.configs[i] = *c
|
||||||
}
|
}
|
||||||
if cm.configs[i].Name != "" {
|
if cl.configs[i].Name != "" {
|
||||||
log.Info().Msgf("Model name: %s", cm.configs[i].Name)
|
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
|
||||||
}
|
}
|
||||||
if cm.configs[i].Description != "" {
|
if cl.configs[i].Description != "" {
|
||||||
log.Info().Msgf("Model description: %s", cm.configs[i].Description)
|
//glamText("**Description**")
|
||||||
|
glamText(cl.configs[i].Description)
|
||||||
}
|
}
|
||||||
if cm.configs[i].Usage != "" {
|
if cl.configs[i].Usage != "" {
|
||||||
log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage)
|
//glamText("**Usage**")
|
||||||
|
glamText(cl.configs[i].Usage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
|
||||||
|
// (non-recursive)
|
||||||
|
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cm.Lock()
|
||||||
defer cm.Unlock()
|
defer cm.Unlock()
|
||||||
entries, err := os.ReadDir(path)
|
entries, err := os.ReadDir(path)
|
||||||
@@ -418,7 +555,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
|
|||||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cm.configs[c.Name] = *c
|
cm.configs[c.Name] = *c
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package api_config_test
|
package config_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api/config"
|
. "github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
Context("Test Read configuration functions", func() {
|
Context("Test Read configuration functions", func() {
|
||||||
configFile = os.Getenv("CONFIG_FILE")
|
configFile = os.Getenv("CONFIG_FILE")
|
||||||
It("Test ReadConfigFile", func() {
|
It("Test ReadConfigFile", func() {
|
||||||
config, err := ReadConfigFile(configFile)
|
config, err := ReadBackendConfigFile(configFile)
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(config).ToNot(BeNil())
|
Expect(config).ToNot(BeNil())
|
||||||
// two configs in config.yaml
|
// two configs in config.yaml
|
||||||
@@ -28,29 +27,26 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("Test LoadConfigs", func() {
|
It("Test LoadConfigs", func() {
|
||||||
cm := NewConfigLoader()
|
cm := NewBackendConfigLoader()
|
||||||
opts := options.NewOptions()
|
opts := NewApplicationConfig()
|
||||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
err := cm.LoadBackendConfigsFromPath(opts.ModelPath)
|
||||||
options.WithModelLoader(modelLoader)(opts)
|
|
||||||
|
|
||||||
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
|
||||||
|
|
||||||
// config should includes gpt4all models's api.config
|
// config should includes gpt4all models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all"))
|
||||||
|
|
||||||
// config should includes gpt2 models's api.config
|
// config should includes gpt2 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2"))
|
||||||
|
|
||||||
// config should includes text-embedding-ada-002 models's api.config
|
// config should includes text-embedding-ada-002 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||||
|
|
||||||
// config should includes rwkv_test models's api.config
|
// config should includes rwkv_test models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
|
||||||
|
|
||||||
// config should includes whisper-1 models's api.config
|
// config should includes whisper-1 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
242
core/http/api.go
Normal file
242
core/http/api.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||||
|
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||||
|
)
|
||||||
|
|
||||||
|
func readAuthHeader(c *fiber.Ctx) string {
|
||||||
|
authHeader := c.Get("Authorization")
|
||||||
|
|
||||||
|
// elevenlabs
|
||||||
|
xApiKey := c.Get("xi-api-key")
|
||||||
|
if xApiKey != "" {
|
||||||
|
authHeader = "Bearer " + xApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// anthropic
|
||||||
|
xApiKey = c.Get("x-api-key")
|
||||||
|
if xApiKey != "" {
|
||||||
|
authHeader = "Bearer " + xApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
||||||
|
// Return errors as JSON responses
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
|
DisableStartupMessage: appConfig.DisableMessage,
|
||||||
|
// Override default error handler
|
||||||
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||||
|
// Status code defaults to 500
|
||||||
|
code := fiber.StatusInternalServerError
|
||||||
|
|
||||||
|
// Retrieve the custom status code if it's a *fiber.Error
|
||||||
|
var e *fiber.Error
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
code = e.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send custom error page
|
||||||
|
return ctx.Status(code).JSON(
|
||||||
|
schema.ErrorResponse{
|
||||||
|
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if appConfig.Debug {
|
||||||
|
app.Use(logger.New(logger.Config{
|
||||||
|
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default middleware config
|
||||||
|
|
||||||
|
if !appConfig.Debug {
|
||||||
|
app.Use(recover.New())
|
||||||
|
}
|
||||||
|
|
||||||
|
metricsService, err := services.NewLocalAIMetricsService()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if metricsService != nil {
|
||||||
|
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||||
|
app.Hooks().OnShutdown(func() error {
|
||||||
|
return metricsService.Shutdown()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||||
|
auth := func(c *fiber.Ctx) error {
|
||||||
|
if len(appConfig.ApiKeys) == 0 {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for api_keys.json file
|
||||||
|
fileContent, err := os.ReadFile("api_keys.json")
|
||||||
|
if err == nil {
|
||||||
|
// Parse JSON content from the file
|
||||||
|
var fileKeys []string
|
||||||
|
err := json.Unmarshal(fileContent, &fileKeys)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add file keys to options.ApiKeys
|
||||||
|
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(appConfig.ApiKeys) == 0 {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader := readAuthHeader(c)
|
||||||
|
if authHeader == "" {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a bearer token
|
||||||
|
authHeaderParts := strings.Split(authHeader, " ")
|
||||||
|
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := authHeaderParts[1]
|
||||||
|
for _, key := range appConfig.ApiKeys {
|
||||||
|
if apiKey == key {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if appConfig.CORS {
|
||||||
|
var c func(ctx *fiber.Ctx) error
|
||||||
|
if appConfig.CORSAllowOrigins == "" {
|
||||||
|
c = cors.New()
|
||||||
|
} else {
|
||||||
|
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Use(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAI API endpoints
|
||||||
|
galleryService := services.NewGalleryService(appConfig.ModelPath)
|
||||||
|
galleryService.Start(appConfig.Context, cl)
|
||||||
|
|
||||||
|
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||||
|
return c.JSON(struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
}{Version: internal.PrintableVersion()})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Load upload json
|
||||||
|
openai.LoadUploadConfig(appConfig.UploadDir)
|
||||||
|
|
||||||
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
|
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
|
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||||
|
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||||
|
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||||
|
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||||
|
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
|
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||||
|
|
||||||
|
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// Elevenlabs
|
||||||
|
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
|
// chat
|
||||||
|
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// edit
|
||||||
|
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// files
|
||||||
|
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
|
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||||
|
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
|
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
|
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
|
|
||||||
|
// completion
|
||||||
|
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// embeddings
|
||||||
|
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// audio
|
||||||
|
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// images
|
||||||
|
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
if appConfig.ImageDir != "" {
|
||||||
|
app.Static("/generated-images", appConfig.ImageDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
if appConfig.AudioDir != "" {
|
||||||
|
app.Static("/generated-audio", appConfig.AudioDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := func(c *fiber.Ctx) error {
|
||||||
|
return c.SendStatus(200)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kubernetes health checks
|
||||||
|
app.Get("/healthz", ok)
|
||||||
|
app.Get("/readyz", ok)
|
||||||
|
|
||||||
|
// Experimental Backend Statistics Module
|
||||||
|
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
|
||||||
|
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||||
|
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||||
|
|
||||||
|
// models
|
||||||
|
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
|
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
|
|
||||||
|
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
|
|
||||||
|
return app, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package api_test
|
package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -13,9 +13,10 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/api"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
. "github.com/go-skynet/LocalAI/core/http"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
"github.com/go-skynet/LocalAI/core/startup"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -127,25 +128,33 @@ var backendAssets embed.FS
|
|||||||
var _ = Describe("API test", func() {
|
var _ = Describe("API test", func() {
|
||||||
|
|
||||||
var app *fiber.App
|
var app *fiber.App
|
||||||
var modelLoader *model.ModelLoader
|
|
||||||
var client *openai.Client
|
var client *openai.Client
|
||||||
var client2 *openaigo.Client
|
var client2 *openaigo.Client
|
||||||
var c context.Context
|
var c context.Context
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
var tmpdir string
|
var tmpdir string
|
||||||
|
var modelDir string
|
||||||
|
var bcl *config.BackendConfigLoader
|
||||||
|
var ml *model.ModelLoader
|
||||||
|
var applicationConfig *config.ApplicationConfig
|
||||||
|
|
||||||
commonOpts := []options.AppOption{
|
commonOpts := []config.AppOption{
|
||||||
options.WithDebug(true),
|
config.WithDebug(true),
|
||||||
options.WithDisableMessage(true),
|
config.WithDisableMessage(true),
|
||||||
}
|
}
|
||||||
|
|
||||||
Context("API with ephemeral models", func() {
|
Context("API with ephemeral models", func() {
|
||||||
BeforeEach(func() {
|
|
||||||
|
BeforeEach(func(sc SpecContext) {
|
||||||
var err error
|
var err error
|
||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
modelDir = filepath.Join(tmpdir, "models")
|
||||||
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||||
|
err = os.Mkdir(backendAssetsDir, 0755)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
g := []gallery.GalleryModel{
|
g := []gallery.GalleryModel{
|
||||||
@@ -172,16 +181,18 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
|
append(commonOpts,
|
||||||
|
config.WithContext(c),
|
||||||
|
config.WithGalleries(galleries),
|
||||||
|
config.WithModelPath(modelDir),
|
||||||
|
config.WithBackendAssets(backendAssets),
|
||||||
|
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
append(commonOpts,
|
|
||||||
options.WithMetrics(metricsService),
|
|
||||||
options.WithContext(c),
|
|
||||||
options.WithGalleries(galleries),
|
|
||||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -198,15 +209,21 @@ var _ = Describe("API test", func() {
|
|||||||
}, "2m").ShouldNot(HaveOccurred())
|
}, "2m").ShouldNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func(sc SpecContext) {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
os.RemoveAll(tmpdir)
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
err := os.RemoveAll(tmpdir)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = os.ReadDir(tmpdir)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Applying models", func() {
|
Context("Applying models", func() {
|
||||||
It("applies models from a gallery", func() {
|
|
||||||
|
|
||||||
|
It("applies models from a gallery", func() {
|
||||||
models := getModels("http://127.0.0.1:9090/models/available")
|
models := getModels("http://127.0.0.1:9090/models/available")
|
||||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||||
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||||
@@ -228,10 +245,10 @@ var _ = Describe("API test", func() {
|
|||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
|
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -253,6 +270,7 @@ var _ = Describe("API test", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("overrides models", func() {
|
It("overrides models", func() {
|
||||||
|
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||||
Name: "bert",
|
Name: "bert",
|
||||||
@@ -270,7 +288,7 @@ var _ = Describe("API test", func() {
|
|||||||
return response["processed"].(bool)
|
return response["processed"].(bool)
|
||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -294,7 +312,7 @@ var _ = Describe("API test", func() {
|
|||||||
return response["processed"].(bool)
|
return response["processed"].(bool)
|
||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -368,7 +386,7 @@ var _ = Describe("API test", func() {
|
|||||||
var res map[string]string
|
var res map[string]string
|
||||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
|
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
||||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||||
|
|
||||||
@@ -483,8 +501,11 @@ var _ = Describe("API test", func() {
|
|||||||
var err error
|
var err error
|
||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
modelDir = filepath.Join(tmpdir, "models")
|
||||||
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||||
|
err = os.Mkdir(backendAssetsDir, 0755)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
galleries := []gallery.Gallery{
|
galleries := []gallery.Gallery{
|
||||||
@@ -494,21 +515,20 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
app, err = App(
|
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
config.WithAudioDir(tmpdir),
|
||||||
options.WithAudioDir(tmpdir),
|
config.WithImageDir(tmpdir),
|
||||||
options.WithImageDir(tmpdir),
|
config.WithGalleries(galleries),
|
||||||
options.WithGalleries(galleries),
|
config.WithModelPath(modelDir),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithBackendAssets(backendAssets),
|
||||||
options.WithBackendAssets(backendAssets),
|
config.WithBackendAssetsOutput(tmpdir))...,
|
||||||
options.WithBackendAssetsOutput(tmpdir))...,
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -527,8 +547,14 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
os.RemoveAll(tmpdir)
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
err := os.RemoveAll(tmpdir)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = os.ReadDir(tmpdir)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
It("installs and is capable to run tts", Label("tts"), func() {
|
It("installs and is capable to run tts", Label("tts"), func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
@@ -599,20 +625,20 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("API query", func() {
|
Context("API query", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelPath := os.Getenv("MODELS_PATH")
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
var err error
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
app, err = App(
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithModelPath(modelPath),
|
||||||
options.WithMetrics(metricsService),
|
|
||||||
)...)
|
)...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -630,7 +656,10 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
It("returns the models list", func() {
|
It("returns the models list", func() {
|
||||||
models, err := client.ListModels(context.TODO())
|
models, err := client.ListModels(context.TODO())
|
||||||
@@ -811,20 +840,20 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("Config file", func() {
|
Context("Config file", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelPath := os.Getenv("MODELS_PATH")
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
var err error
|
||||||
Expect(err).ToNot(HaveOccurred())
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
|
|
||||||
app, err = App(
|
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
config.WithModelPath(modelPath),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -840,7 +869,10 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
It("can generate chat completions from config file (list1)", func() {
|
It("can generate chat completions from config file (list1)", func() {
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package api_test
|
package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
55
core/http/endpoints/elevenlabs/tts.go
Normal file
55
core/http/endpoints/elevenlabs/tts.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package elevenlabs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.ElevenLabsTTSRequest)
|
||||||
|
voiceID := c.Params("voice-id")
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||||
|
} else {
|
||||||
|
if input.ModelID != "" {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||||
|
|
||||||
|
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
36
core/http/endpoints/localai/backend_monitor.go
Normal file
36
core/http/endpoints/localai/backend_monitor.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := bm.CheckAndSample(input.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return bm.ShutdownModel(input.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
146
core/http/endpoints/localai/gallery.go
Normal file
146
core/http/endpoints/localai/gallery.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelGalleryEndpointService struct {
|
||||||
|
galleries []gallery.Gallery
|
||||||
|
modelPath string
|
||||||
|
galleryApplier *services.GalleryService
|
||||||
|
}
|
||||||
|
|
||||||
|
type GalleryModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
gallery.GalleryModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||||
|
return ModelGalleryEndpointService{
|
||||||
|
galleries: galleries,
|
||||||
|
modelPath: modelPath,
|
||||||
|
galleryApplier: galleryApplier,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||||
|
if status == nil {
|
||||||
|
return fmt.Errorf("could not find any status for ID")
|
||||||
|
}
|
||||||
|
return c.JSON(status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(GalleryModel)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid, err := uuid.NewUUID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mgs.galleryApplier.C <- gallery.GalleryOp{
|
||||||
|
Req: input.GalleryModel,
|
||||||
|
Id: uuid.String(),
|
||||||
|
GalleryName: input.ID,
|
||||||
|
Galleries: mgs.galleries,
|
||||||
|
}
|
||||||
|
return c.JSON(struct {
|
||||||
|
ID string `json:"uuid"`
|
||||||
|
StatusURL string `json:"status"`
|
||||||
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||||
|
|
||||||
|
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||||
|
for _, m := range models {
|
||||||
|
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s already exists", input.Name)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||||
|
mgs.galleries = append(mgs.galleries, *input)
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||||
|
}
|
||||||
|
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
})
|
||||||
|
return c.Send(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
43
core/http/endpoints/localai/metrics.go
Normal file
43
core/http/endpoints/localai/metrics.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||||
|
|
||||||
|
return adaptor.HTTPHandler(promhttp.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiMiddlewareConfig struct {
|
||||||
|
Filter func(c *fiber.Ctx) bool
|
||||||
|
metricsService *services.LocalAIMetricsService
|
||||||
|
}
|
||||||
|
|
||||||
|
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||||
|
cfg := apiMiddlewareConfig{
|
||||||
|
metricsService: metrics,
|
||||||
|
Filter: func(c *fiber.Ctx) bool {
|
||||||
|
return c.Path() == "/metrics"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
if cfg.Filter != nil && cfg.Filter(c) {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
path := c.Path()
|
||||||
|
method := c.Method()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := c.Next()
|
||||||
|
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||||
|
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
55
core/http/endpoints/localai/tts.go
Normal file
55
core/http/endpoints/localai/tts.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.TTSRequest)
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.Model
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||||
|
|
||||||
|
if input.Backend != "" {
|
||||||
|
cfg.Backend = input.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
609
core/http/endpoints/openai/chat.go
Normal file
609
core/http/endpoints/openai/chat.go
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
emptyMessage := ""
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
|
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
|
||||||
|
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: usage.Prompt,
|
||||||
|
CompletionTokens: usage.Completion,
|
||||||
|
TotalTokens: usage.Prompt + usage.Completion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
|
result := ""
|
||||||
|
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
|
result += s
|
||||||
|
// TODO: Change generated BNF grammar to be compliant with the schema so we can
|
||||||
|
// stream the result token by token here.
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
|
||||||
|
noActionToRun := len(results) > 0 && results[0].name == noAction
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case noActionToRun:
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
|
||||||
|
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("error handling question: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: tokenUsage.Prompt,
|
||||||
|
CompletionTokens: tokenUsage.Completion,
|
||||||
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
responses <- resp
|
||||||
|
|
||||||
|
default:
|
||||||
|
for i, ss := range results {
|
||||||
|
name, args := ss.name, ss.arguments
|
||||||
|
|
||||||
|
initialMessage := schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []schema.ToolCall{
|
||||||
|
{
|
||||||
|
Index: i,
|
||||||
|
ID: id,
|
||||||
|
Type: "function",
|
||||||
|
FunctionCall: schema.FunctionCall{
|
||||||
|
Name: name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
responses <- initialMessage
|
||||||
|
|
||||||
|
responses <- schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{{
|
||||||
|
Delta: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []schema.ToolCall{
|
||||||
|
{
|
||||||
|
Index: i,
|
||||||
|
ID: id,
|
||||||
|
Type: "function",
|
||||||
|
FunctionCall: schema.FunctionCall{
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
processFunctions := false
|
||||||
|
funcs := grammar.Functions{}
|
||||||
|
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Configuration read: %+v", config)
|
||||||
|
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model
|
||||||
|
noActionName := "answer"
|
||||||
|
noActionDescription := "use this action to answer without performing any action"
|
||||||
|
|
||||||
|
if config.FunctionsConfig.NoActionFunctionName != "" {
|
||||||
|
noActionName = config.FunctionsConfig.NoActionFunctionName
|
||||||
|
}
|
||||||
|
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
||||||
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ResponseFormat.Type == "json_object" {
|
||||||
|
input.Grammar = grammar.JSONBNF
|
||||||
|
}
|
||||||
|
|
||||||
|
// process functions if we have any defined or if we have a function call string
|
||||||
|
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||||
|
log.Debug().Msgf("Response needs to process functions")
|
||||||
|
|
||||||
|
processFunctions = true
|
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{
|
||||||
|
Name: noActionName,
|
||||||
|
Description: noActionDescription,
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to reply the user with",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, input.Functions...)
|
||||||
|
if !config.FunctionsConfig.DisableNoAction {
|
||||||
|
funcs = append(funcs, noActionGrammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.FunctionToCall() != "" {
|
||||||
|
funcs = funcs.Select(config.FunctionToCall())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure()
|
||||||
|
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||||
|
} else if input.JSONFunctionGrammarObject != nil {
|
||||||
|
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
toStream := input.Stream
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config)
|
||||||
|
|
||||||
|
var predInput string
|
||||||
|
|
||||||
|
suppressConfigSystemPrompt := false
|
||||||
|
mess := []string{}
|
||||||
|
for messageIndex, i := range input.Messages {
|
||||||
|
var content string
|
||||||
|
role := i.Role
|
||||||
|
|
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||||
|
roleFn := "assistant_function_call"
|
||||||
|
r := config.Roles[roleFn]
|
||||||
|
if r != "" {
|
||||||
|
role = roleFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := config.Roles[role]
|
||||||
|
contentExists := i.Content != nil && i.StringContent != ""
|
||||||
|
|
||||||
|
// First attempt to populate content via a chat message specific template
|
||||||
|
if config.TemplateConfig.ChatMessage != "" {
|
||||||
|
chatMessageData := model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
Role: r,
|
||||||
|
RoleName: role,
|
||||||
|
Content: i.StringContent,
|
||||||
|
FunctionName: i.Name,
|
||||||
|
MessageIndex: messageIndex,
|
||||||
|
}
|
||||||
|
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||||
|
} else {
|
||||||
|
if templatedChatMessage == "" {
|
||||||
|
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
|
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||||
|
content = templatedChatMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||||
|
if content == "" {
|
||||||
|
if r != "" {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(r, i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||||
|
} else {
|
||||||
|
content = fmt.Sprint(r, " ", string(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if contentExists {
|
||||||
|
content = fmt.Sprint(i.StringContent)
|
||||||
|
}
|
||||||
|
if i.FunctionCall != nil {
|
||||||
|
j, err := json.Marshal(i.FunctionCall)
|
||||||
|
if err == nil {
|
||||||
|
if contentExists {
|
||||||
|
content += "\n" + string(j)
|
||||||
|
} else {
|
||||||
|
content = string(j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||||
|
if contentExists && role == "system" {
|
||||||
|
suppressConfigSystemPrompt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mess = append(mess, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
predInput = strings.Join(mess, "\n")
|
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||||
|
|
||||||
|
if toStream {
|
||||||
|
log.Debug().Msgf("Stream request received")
|
||||||
|
c.Context().SetContentType("text/event-stream")
|
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
// c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache")
|
||||||
|
c.Set("Connection", "keep-alive")
|
||||||
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
|
}
|
||||||
|
|
||||||
|
templateFile := ""
|
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
|
templateFile = config.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Chat
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
|
||||||
|
if templateFile != "" {
|
||||||
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
|
SystemPrompt: config.SystemPrompt,
|
||||||
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
predInput = templatedInput
|
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
|
if processFunctions {
|
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case toStream:
|
||||||
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
|
if !processFunctions {
|
||||||
|
go process(predInput, input, config, ml, responses)
|
||||||
|
} else {
|
||||||
|
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
usage := &schema.OpenAIUsage{}
|
||||||
|
toolsCalled := false
|
||||||
|
for ev := range responses {
|
||||||
|
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||||
|
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
|
toolsCalled = true
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.Encode(ev)
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
|
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||||
|
input.Cancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if toolsCalled {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
} else if toolsCalled && len(input.Tools) == 0 {
|
||||||
|
finishReason = "function_call"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []schema.Choice{
|
||||||
|
{
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Index: 0,
|
||||||
|
Delta: &schema.Message{Content: &emptyMessage},
|
||||||
|
}},
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Usage: *usage,
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||||
|
w.WriteString("data: [DONE]\n\n")
|
||||||
|
w.Flush()
|
||||||
|
}))
|
||||||
|
return nil
|
||||||
|
|
||||||
|
// no streaming mode
|
||||||
|
default:
|
||||||
|
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||||
|
if !processFunctions {
|
||||||
|
// no function is called, just reply and use stop as finish reason
|
||||||
|
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
|
||||||
|
noActionsToRun := len(results) > 0 && results[0].name == noActionName
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case noActionsToRun:
|
||||||
|
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("error handling question: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*c = append(*c, schema.Choice{
|
||||||
|
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||||
|
default:
|
||||||
|
toolChoice := schema.Choice{
|
||||||
|
Message: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input.Tools) > 0 {
|
||||||
|
toolChoice.FinishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ss := range results {
|
||||||
|
name, args := ss.name, ss.arguments
|
||||||
|
if len(input.Tools) > 0 {
|
||||||
|
// If we are using tools, we condense the function calls into
|
||||||
|
// a single response choice with all the tools
|
||||||
|
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||||
|
schema.ToolCall{
|
||||||
|
ID: id,
|
||||||
|
Type: "function",
|
||||||
|
FunctionCall: schema.FunctionCall{
|
||||||
|
Name: name,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// otherwise we return more choices directly
|
||||||
|
*c = append(*c, schema.Choice{
|
||||||
|
FinishReason: "function_call",
|
||||||
|
Message: &schema.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
FunctionCall: map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"arguments": args,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input.Tools) > 0 {
|
||||||
|
// we need to append our result if we are using tools
|
||||||
|
*c = append(*c, toolChoice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Usage: schema.OpenAIUsage{
|
||||||
|
PromptTokens: tokenUsage.Prompt,
|
||||||
|
CompletionTokens: tokenUsage.Completion,
|
||||||
|
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respData, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", respData)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
|
||||||
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(args), &arguments)
|
||||||
|
m, exists := arguments["message"]
|
||||||
|
if exists {
|
||||||
|
switch message := m.(type) {
|
||||||
|
case string:
|
||||||
|
if message != "" {
|
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||||
|
message = backend.Finetune(*config, prompt, message)
|
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||||
|
|
||||||
|
return message, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU/GPU) another computation
|
||||||
|
config.Grammar = ""
|
||||||
|
images := []string{}
|
||||||
|
for _, m := range input.Messages {
|
||||||
|
images = append(images, m.StringImages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return backend.Finetune(*config, prompt, prediction.Response), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type funcCallResults struct {
|
||||||
|
name string
|
||||||
|
arguments string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
|
||||||
|
results := []funcCallResults{}
|
||||||
|
|
||||||
|
// TODO: use generics to avoid this code duplication
|
||||||
|
if multipleResults {
|
||||||
|
ss := []map[string]interface{}{}
|
||||||
|
s := utils.EscapeNewLines(llmresult)
|
||||||
|
json.Unmarshal([]byte(s), &ss)
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
for _, s := range ss {
|
||||||
|
func_name, ok := s["function"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
args, ok := s["arguments"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
funcName, ok := func_name.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||||
|
ss := map[string]interface{}{}
|
||||||
|
// This prevent newlines to break JSON parsing for clients
|
||||||
|
s := utils.EscapeNewLines(llmresult)
|
||||||
|
json.Unmarshal([]byte(s), &ss)
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name, ok := ss["function"]
|
||||||
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
funcName, ok := func_name.(string)
|
||||||
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
@@ -8,10 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@@ -21,12 +21,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
created := int(time.Now().Unix())
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -53,14 +53,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input)
|
log.Debug().Msgf("`input`: %+v", input)
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -84,7 +84,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
templateFile := ""
|
templateFile := ""
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
templateFile = config.Model
|
templateFile = config.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -111,7 +111,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
go process(predInput, input, config, ml, responses)
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
Input: i,
|
Input: i,
|
||||||
})
|
})
|
||||||
@@ -164,7 +164,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
r, tokenUsage, err := ComputeChoices(
|
||||||
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -5,10 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -16,14 +16,14 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +33,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
templateFile := ""
|
templateFile := ""
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
templateFile = config.Model
|
templateFile = config.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: i,
|
Input: i,
|
||||||
Instruction: input.Instruction,
|
Instruction: input.Instruction,
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
@@ -57,7 +57,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
*c = append(*c, schema.Choice{Text: s})
|
*c = append(*c, schema.Choice{Text: s})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -5,25 +5,26 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readRequest(c, o, true)
|
model, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +34,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
for i, s := range config.InputToken {
|
for i, s := range config.InputToken {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -47,7 +48,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
for i, s := range config.InputStrings {
|
for i, s := range config.InputStrings {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
218
core/http/endpoints/openai/files.go
Normal file
218
core/http/endpoints/openai/files.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var uploadedFiles []File
|
||||||
|
|
||||||
|
const uploadedFilesFile = "uploadedFiles.json"
|
||||||
|
|
||||||
|
// File represents the structure of a file object from the OpenAI API.
|
||||||
|
type File struct {
|
||||||
|
ID string `json:"id"` // Unique identifier for the file
|
||||||
|
Object string `json:"object"` // Type of the object (e.g., "file")
|
||||||
|
Bytes int `json:"bytes"` // Size of the file in bytes
|
||||||
|
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
|
||||||
|
Filename string `json:"filename"` // The name of the file
|
||||||
|
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveUploadConfig(uploadDir string) {
|
||||||
|
file, err := json.MarshalIndent(uploadedFiles, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadUploadConfig(uploadPath string) {
|
||||||
|
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
|
||||||
|
|
||||||
|
_, err := os.Stat(uploadFilePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.ReadFile(uploadFilePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to read file: %s", err)
|
||||||
|
} else {
|
||||||
|
err = json.Unmarshal(file, &uploadedFiles)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||||
|
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := c.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the file size
|
||||||
|
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
|
||||||
|
}
|
||||||
|
|
||||||
|
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||||
|
if purpose == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the filename to prevent directory traversal
|
||||||
|
filename := utils.SanitizeFileName(file.Filename)
|
||||||
|
|
||||||
|
savePath := filepath.Join(appConfig.UploadDir, filename)
|
||||||
|
|
||||||
|
// Check if file already exists
|
||||||
|
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.SaveFile(file, savePath)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
f := File{
|
||||||
|
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
|
||||||
|
Object: "file",
|
||||||
|
Bytes: int(file.Size),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Filename: file.Filename,
|
||||||
|
Purpose: purpose,
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedFiles = append(uploadedFiles, f)
|
||||||
|
saveUploadConfig(appConfig.UploadDir)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||||
|
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
type ListFiles struct {
|
||||||
|
Data []File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
var listFiles ListFiles
|
||||||
|
|
||||||
|
purpose := c.Query("purpose")
|
||||||
|
if purpose == "" {
|
||||||
|
listFiles.Data = uploadedFiles
|
||||||
|
} else {
|
||||||
|
for _, f := range uploadedFiles {
|
||||||
|
if purpose == f.Purpose {
|
||||||
|
listFiles.Data = append(listFiles.Data, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
listFiles.Object = "list"
|
||||||
|
return c.Status(fiber.StatusOK).JSON(listFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
|
||||||
|
id := c.Params("file_id")
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("file_id parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range uploadedFiles {
|
||||||
|
if id == f.ID {
|
||||||
|
return &f, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to find file id %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
|
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
|
||||||
|
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
type DeleteStatus struct {
|
||||||
|
Id string
|
||||||
|
Object string
|
||||||
|
Deleted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
// If the file doesn't exist then we should just continue to remove it
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove upload from list
|
||||||
|
for i, f := range uploadedFiles {
|
||||||
|
if f.ID == file.ID {
|
||||||
|
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
saveUploadConfig(appConfig.UploadDir)
|
||||||
|
return c.JSON(DeleteStatus{
|
||||||
|
Id: file.ID,
|
||||||
|
Object: "file",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
|
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
file, err := getFileFromRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Send(fileContents)
|
||||||
|
}
|
||||||
|
}
|
||||||
287
core/http/endpoints/openai/files_test.go
Normal file
287
core/http/endpoints/openai/files_test.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
|
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ListFiles struct {
|
||||||
|
Data []File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader = &config.BackendConfigLoader{}
|
||||||
|
|
||||||
|
option = &config.ApplicationConfig{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app = fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
loader := &config.BackendConfigLoader{}
|
||||||
|
|
||||||
|
option := &config.ApplicationConfig{
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(option.UploadDir)
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Post("/files", UploadFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files", ListFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
|
||||||
|
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
|
||||||
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
|
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||||
|
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||||
|
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
fmt.Println(f1)
|
||||||
|
fmt.Printf("ERror: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||||
|
})
|
||||||
|
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||||
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
// Check if file exists in the disk
|
||||||
|
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
|
||||||
|
assert.False(t, os.IsNotExist(err))
|
||||||
|
assert.Equal(t, file.Bytes, 5242880)
|
||||||
|
assert.NotEmpty(t, file.CreatedAt)
|
||||||
|
assert.Equal(t, file.Filename, "test.txt")
|
||||||
|
assert.Equal(t, file.Purpose, "fine-tune")
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != len(uploadedFiles) {
|
||||||
|
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||||
|
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
if len(listFiles.Data) != 1 {
|
||||||
|
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||||
|
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
listFiles := responseToListFile(t, resp)
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/files", nil)
|
||||||
|
resp, _ := app.Test(req)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var listFiles ListFiles
|
||||||
|
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listFiles.Data) != 0 {
|
||||||
|
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
|
||||||
|
var target string
|
||||||
|
if purpose != "" {
|
||||||
|
target = fmt.Sprintf("/files?purpose=%s", purpose)
|
||||||
|
} else {
|
||||||
|
target = "/files"
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
|
||||||
|
return app.Test(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
|
||||||
|
// Create a file that exceeds the limit
|
||||||
|
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||||
|
|
||||||
|
// Creating a new HTTP Request
|
||||||
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/files", body)
|
||||||
|
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
f := responseToFile(t, resp)
|
||||||
|
|
||||||
|
id := f.ID
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_, err := CallFilesDeleteEndpoint(t, app, id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
|
||||||
|
target := fmt.Sprintf("/files/%s", fileId)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
return app.Test(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create multi-part file
|
||||||
|
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
|
||||||
|
body := new(strings.Builder)
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
file, _ := os.Open(filePath)
|
||||||
|
defer file.Close()
|
||||||
|
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
|
||||||
|
io.Copy(part, file)
|
||||||
|
|
||||||
|
if purpose != "" {
|
||||||
|
_ = writer.WriteField("purpose", purpose)
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Close()
|
||||||
|
return strings.NewReader(body.String()), writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create test files
|
||||||
|
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
|
||||||
|
err := os.MkdirAll(option.UploadDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
t.Fatalf("Error MKDIR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, _ := os.Create(name)
|
||||||
|
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(name)
|
||||||
|
os.RemoveAll(option.UploadDir)
|
||||||
|
})
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToString(resp *http.Response, t *testing.T) string {
|
||||||
|
return string(bodyToByteArray(resp, t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToFile(t *testing.T, resp *http.Response) File {
|
||||||
|
var file File
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to decode response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
|
||||||
|
var listFiles ListFiles
|
||||||
|
responseToString := bodyToString(resp, t)
|
||||||
|
|
||||||
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to decode response: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return listFiles
|
||||||
|
}
|
||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) {
|
|||||||
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readRequest(c, o, false)
|
m, input, err := readRequest(c, ml, appConfig, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -71,7 +71,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
log.Debug().Msgf("Loading model: %+v", m)
|
log.Debug().Msgf("Loading model: %+v", m)
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -104,7 +104,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -133,15 +133,15 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x")
|
sizeParts := strings.Split(input.Size, "x")
|
||||||
if len(sizeParts) != 2 {
|
if len(sizeParts) != 2 {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
width, err := strconv.Atoi(sizeParts[0])
|
width, err := strconv.Atoi(sizeParts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
height, err := strconv.Atoi(sizeParts[1])
|
height, err := strconv.Atoi(sizeParts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
|
|
||||||
b64JSON := false
|
b64JSON := false
|
||||||
@@ -179,7 +179,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
tempDir := ""
|
tempDir := ""
|
||||||
if !b64JSON {
|
if !b64JSON {
|
||||||
tempDir = o.ImageDir
|
tempDir = appConfig.ImageDir
|
||||||
}
|
}
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||||
@@ -196,7 +196,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
baseURL := c.BaseURL()
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ComputeChoices(
|
func ComputeChoices(
|
||||||
req *schema.OpenAIRequest,
|
req *schema.OpenAIRequest,
|
||||||
predInput string,
|
predInput string,
|
||||||
config *config.Config,
|
config *config.BackendConfig,
|
||||||
o *options.Option,
|
o *config.ApplicationConfig,
|
||||||
loader *model.ModelLoader,
|
loader *model.ModelLoader,
|
||||||
cb func(string, *[]schema.Choice),
|
cb func(string, *[]schema.Choice),
|
||||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
||||||
@@ -3,15 +3,15 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
models, err := loader.ListModels()
|
models, err := ml.ListModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -40,7 +40,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
|||||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||||
|
|
||||||
// Start with the known configurations
|
// Start with the known configurations
|
||||||
for _, c := range cm.GetAllConfigs() {
|
for _, c := range cl.GetAllBackendConfigs() {
|
||||||
if excludeConfigured {
|
if excludeConfigured {
|
||||||
mm[c.Model] = nil
|
mm[c.Model] = nil
|
||||||
}
|
}
|
||||||
@@ -5,24 +5,22 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
options "github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/api/schema"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||||
input := new(schema.OpenAIRequest)
|
input := new(schema.OpenAIRequest)
|
||||||
ctx, cancel := context.WithCancel(o.Context)
|
|
||||||
input.Context = ctx
|
|
||||||
input.Cancel = cancel
|
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
if err := c.BodyParser(input); err != nil {
|
if err := c.BodyParser(input); err != nil {
|
||||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||||
@@ -30,9 +28,13 @@ func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *sch
|
|||||||
|
|
||||||
received, _ := json.Marshal(input)
|
received, _ := json.Marshal(input)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(o.Context)
|
||||||
|
input.Context = ctx
|
||||||
|
input.Cancel = cancel
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received))
|
log.Debug().Msgf("Request received: %s", string(received))
|
||||||
|
|
||||||
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel)
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
||||||
|
|
||||||
return modelFile, input, err
|
return modelFile, input, err
|
||||||
}
|
}
|
||||||
@@ -49,7 +51,7 @@ func getBase64Image(s string) (string, error) {
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// read the image data into memory
|
// read the image data into memory
|
||||||
data, err := ioutil.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -68,14 +70,14 @@ func getBase64Image(s string) (string, error) {
|
|||||||
return "", fmt.Errorf("not valid string")
|
return "", fmt.Errorf("not valid string")
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||||
if input.Echo {
|
if input.Echo {
|
||||||
config.Echo = input.Echo
|
config.Echo = input.Echo
|
||||||
}
|
}
|
||||||
if input.TopK != 0 {
|
if input.TopK != nil {
|
||||||
config.TopK = input.TopK
|
config.TopK = input.TopK
|
||||||
}
|
}
|
||||||
if input.TopP != 0 {
|
if input.TopP != nil {
|
||||||
config.TopP = input.TopP
|
config.TopP = input.TopP
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,11 +117,11 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
config.Grammar = input.Grammar
|
config.Grammar = input.Grammar
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Temperature != 0 {
|
if input.Temperature != nil {
|
||||||
config.Temperature = input.Temperature
|
config.Temperature = input.Temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Maxtokens != 0 {
|
if input.Maxtokens != nil {
|
||||||
config.Maxtokens = input.Maxtokens
|
config.Maxtokens = input.Maxtokens
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,20 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(input.Tools) > 0 {
|
||||||
|
for _, tool := range input.Tools {
|
||||||
|
input.Functions = append(input.Functions, tool.Function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.ToolsChoice != nil {
|
||||||
|
var toolChoice grammar.Tool
|
||||||
|
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
|
||||||
|
input.FunctionCall = map[string]interface{}{
|
||||||
|
"name": toolChoice.Function.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Decode each request's message content
|
// Decode each request's message content
|
||||||
index := 0
|
index := 0
|
||||||
for i, m := range input.Messages {
|
for i, m := range input.Messages {
|
||||||
@@ -177,30 +193,14 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
config.Batch = input.Batch
|
config.Batch = input.Batch
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.F16 {
|
|
||||||
config.F16 = input.F16
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.IgnoreEOS {
|
if input.IgnoreEOS {
|
||||||
config.IgnoreEOS = input.IgnoreEOS
|
config.IgnoreEOS = input.IgnoreEOS
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Seed != 0 {
|
if input.Seed != nil {
|
||||||
config.Seed = input.Seed
|
config.Seed = input.Seed
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Mirostat != 0 {
|
|
||||||
config.LLMConfig.Mirostat = input.Mirostat
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatETA != 0 {
|
|
||||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatTAU != 0 {
|
|
||||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.TypicalP != 0 {
|
if input.TypicalP != 0 {
|
||||||
config.TypicalP = input.TypicalP
|
config.TypicalP = input.TypicalP
|
||||||
}
|
}
|
||||||
@@ -255,8 +255,13 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
|
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||||
cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
|
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
|
||||||
|
config.LoadOptionDebug(debug),
|
||||||
|
config.LoadOptionThreads(threads),
|
||||||
|
config.LoadOptionContextSize(ctx),
|
||||||
|
config.LoadOptionF16(f16),
|
||||||
|
)
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
// Set the parameters for the language model prediction
|
||||||
updateRequestConfig(cfg, input)
|
updateRequestConfig(cfg, input)
|
||||||
@@ -8,23 +8,23 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readRequest(c, o, false)
|
m, input, err := readRequest(c, ml, appConfig, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
6
core/schema/elevenlabs.go
Normal file
6
core/schema/elevenlabs.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
type ElevenLabsTTSRequest struct {
|
||||||
|
Text string `json:"text" yaml:"text"`
|
||||||
|
ModelID string `json:"model_id" yaml:"model_id"`
|
||||||
|
}
|
||||||
22
core/schema/localai.go
Normal file
22
core/schema/localai.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitorRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackendMonitorResponse struct {
|
||||||
|
MemoryInfo *gopsutil.MemoryInfoStat
|
||||||
|
MemoryPercent float32
|
||||||
|
CPUPercent float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type TTSRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
Input string `json:"input" yaml:"input"`
|
||||||
|
Voice string `json:"voice" yaml:"voice"`
|
||||||
|
Backend string `json:"backend" yaml:"backend"`
|
||||||
|
}
|
||||||
@@ -3,8 +3,6 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,7 +47,7 @@ type OpenAIResponse struct {
|
|||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
FinishReason string `json:"finish_reason,omitempty"`
|
FinishReason string `json:"finish_reason"`
|
||||||
Message *Message `json:"message,omitempty"`
|
Message *Message `json:"message,omitempty"`
|
||||||
Delta *Message `json:"delta,omitempty"`
|
Delta *Message `json:"delta,omitempty"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
@@ -68,6 +66,10 @@ type ContentURL struct {
|
|||||||
type Message struct {
|
type Message struct {
|
||||||
// The message role
|
// The message role
|
||||||
Role string `json:"role,omitempty" yaml:"role"`
|
Role string `json:"role,omitempty" yaml:"role"`
|
||||||
|
|
||||||
|
// The message name (used for tools calls)
|
||||||
|
Name string `json:"name,omitempty" yaml:"name"`
|
||||||
|
|
||||||
// The message content
|
// The message content
|
||||||
Content interface{} `json:"content" yaml:"content"`
|
Content interface{} `json:"content" yaml:"content"`
|
||||||
|
|
||||||
@@ -76,6 +78,20 @@ type Message struct {
|
|||||||
|
|
||||||
// A result of a function call
|
// A result of a function call
|
||||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||||
|
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
FunctionCall FunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIModel struct {
|
type OpenAIModel struct {
|
||||||
@@ -90,10 +106,10 @@ type ChatCompletionResponseFormat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIRequest struct {
|
type OpenAIRequest struct {
|
||||||
config.PredictionOptions
|
PredictionOptions
|
||||||
|
|
||||||
Context context.Context
|
Context context.Context `json:"-"`
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc `json:"-"`
|
||||||
|
|
||||||
// whisper
|
// whisper
|
||||||
File string `json:"file" validate:"required"`
|
File string `json:"file" validate:"required"`
|
||||||
@@ -117,6 +133,9 @@ type OpenAIRequest struct {
|
|||||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||||
|
|
||||||
|
Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"`
|
||||||
|
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`
|
||||||
|
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
|
||||||
// Image (not supported by OpenAI)
|
// Image (not supported by OpenAI)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package api_config
|
package schema
|
||||||
|
|
||||||
type PredictionOptions struct {
|
type PredictionOptions struct {
|
||||||
|
|
||||||
@@ -12,28 +12,23 @@ type PredictionOptions struct {
|
|||||||
N int `json:"n"`
|
N int `json:"n"`
|
||||||
|
|
||||||
// Common options between all the API calls, part of the OpenAI spec
|
// Common options between all the API calls, part of the OpenAI spec
|
||||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
TopP *float64 `json:"top_p" yaml:"top_p"`
|
||||||
TopK int `json:"top_k" yaml:"top_k"`
|
TopK *int `json:"top_k" yaml:"top_k"`
|
||||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
Temperature *float64 `json:"temperature" yaml:"temperature"`
|
||||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
Maxtokens *int `json:"max_tokens" yaml:"max_tokens"`
|
||||||
Echo bool `json:"echo"`
|
Echo bool `json:"echo"`
|
||||||
|
|
||||||
// Custom parameters - not present in the OpenAI API
|
// Custom parameters - not present in the OpenAI API
|
||||||
Batch int `json:"batch" yaml:"batch"`
|
Batch int `json:"batch" yaml:"batch"`
|
||||||
F16 bool `json:"f16" yaml:"f16"`
|
|
||||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||||
|
|
||||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
|
||||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
|
||||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
|
||||||
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||||
Seed int `json:"seed" yaml:"seed"`
|
Seed *int `json:"seed" yaml:"seed"`
|
||||||
|
|
||||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
||||||
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
||||||
140
core/services/backend_monitor.go
Normal file
140
core/services/backend_monitor.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitor struct {
|
||||||
|
configLoader *config.BackendConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
|
||||||
|
return BackendMonitor{
|
||||||
|
configLoader: configLoader,
|
||||||
|
modelLoader: modelLoader,
|
||||||
|
options: appConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||||
|
config, exists := bm.configLoader.GetBackendConfig(modelName)
|
||||||
|
var backendId string
|
||||||
|
if exists {
|
||||||
|
backendId = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backendId = modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backendId, ".bin") {
|
||||||
|
backendId = fmt.Sprintf("%s.bin", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||||
|
config, exists := bm.configLoader.GetBackendConfig(model)
|
||||||
|
var backend string
|
||||||
|
if exists {
|
||||||
|
backend = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backend = model
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backend, ".bin") {
|
||||||
|
backend = fmt.Sprintf("%s.bin", backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := bm.modelLoader.GetGRPCPID(backend)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
||||||
|
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memInfo, err := backendProcess.MemoryInfo()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memPercent, err := backendProcess.MemoryPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cpuPercent, err := backendProcess.CPUPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.BackendMonitorResponse{
|
||||||
|
MemoryInfo: memInfo,
|
||||||
|
MemoryPercent: memPercent,
|
||||||
|
CPUPercent: cpuPercent,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
|
||||||
|
if modelAddr == "" {
|
||||||
|
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
|
||||||
|
if rpcErr != nil {
|
||||||
|
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||||
|
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||||
|
if slbErr != nil {
|
||||||
|
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||||
|
}
|
||||||
|
return &proto.StatusResponse{
|
||||||
|
State: proto.StatusResponse_ERROR,
|
||||||
|
Memory: &proto.MemoryUsageData{
|
||||||
|
Total: val.MemoryInfo.VMS,
|
||||||
|
Breakdown: map[string]uint64{
|
||||||
|
"gopsutil-RSS": val.MemoryInfo.RSS,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) ShutdownModel(modelName string) error {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bm.modelLoader.ShutdownModel(backendId)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user