mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-23 16:20:01 -04:00
Compare commits
57 Commits
v2.9.0
...
docs_updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b8d6a31e2 | ||
|
|
f0752be4aa | ||
|
|
bafc9effad | ||
|
|
d2934dd69f | ||
|
|
a6b540737f | ||
|
|
f82065703d | ||
|
|
b423af001d | ||
|
|
b9e77d394b | ||
|
|
57222497ec | ||
|
|
5c5f07c1e7 | ||
|
|
f895d06605 | ||
|
|
bc8f648a91 | ||
|
|
8e57f4df31 | ||
|
|
a08cc5adbb | ||
|
|
595a73fce4 | ||
|
|
dc919e08e8 | ||
|
|
5d1018495f | ||
|
|
ad6fd7a991 | ||
|
|
e022b5959e | ||
|
|
db7f4955a1 | ||
|
|
5c69dd155f | ||
|
|
504f2e8bf4 | ||
|
|
e586dc2924 | ||
|
|
333f918005 | ||
|
|
c8e29033c2 | ||
|
|
d0bd961bde | ||
|
|
006511ee25 | ||
|
|
4ab72146cd | ||
|
|
b60a3fc879 | ||
|
|
a0eeb74957 | ||
|
|
daa0b8741c | ||
|
|
939411300a | ||
|
|
1c312685aa | ||
|
|
316de82f51 | ||
|
|
9068bc5271 | ||
|
|
31a4c9c9d3 | ||
|
|
c1966af2cf | ||
|
|
c665898652 | ||
|
|
f651a660aa | ||
|
|
ba672b51da | ||
|
|
be498c5dd9 | ||
|
|
6e95beccb9 | ||
|
|
c8be839481 | ||
|
|
c7e08813a5 | ||
|
|
d21a6b33ab | ||
|
|
9112cf153e | ||
|
|
3868ac8402 | ||
|
|
3f09010227 | ||
|
|
d6cf82aba3 | ||
|
|
dfe54639b1 | ||
|
|
bc5f5aa538 | ||
|
|
05818e0425 | ||
|
|
7f72a61104 | ||
|
|
8e45d47740 | ||
|
|
71771d1e9b | ||
|
|
aa098e4d0b | ||
|
|
0135e1e3b9 |
@@ -3,3 +3,4 @@ models
|
|||||||
examples/chatbot-ui/models
|
examples/chatbot-ui/models
|
||||||
examples/rwkv/models
|
examples/rwkv/models
|
||||||
examples/**/models
|
examples/**/models
|
||||||
|
Dockerfile
|
||||||
2
.env
2
.env
@@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
## Default path for models
|
## Default path for models
|
||||||
#
|
#
|
||||||
MODELS_PATH=/models
|
# MODELS_PATH=/models
|
||||||
|
|
||||||
## Enable debug mode
|
## Enable debug mode
|
||||||
# DEBUG=true
|
# DEBUG=true
|
||||||
|
|||||||
10
.github/workflows/image-pr.yml
vendored
10
.github/workflows/image-pr.yml
vendored
@@ -59,6 +59,14 @@ jobs:
|
|||||||
image-type: 'extras'
|
image-type: 'extras'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'sycl_f16'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
|
tag-suffix: 'sycl-f16-ffmpeg'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'extras'
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
core-image-build:
|
core-image-build:
|
||||||
uses: ./.github/workflows/image_build.yml
|
uses: ./.github/workflows/image_build.yml
|
||||||
with:
|
with:
|
||||||
@@ -105,4 +113,4 @@ jobs:
|
|||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'core'
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
base-image: "ubuntu:22.04"
|
base-image: "ubuntu:22.04"
|
||||||
97
.github/workflows/image.yml
vendored
97
.github/workflows/image.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
extras-image-build:
|
self-hosted-jobs:
|
||||||
uses: ./.github/workflows/image_build.yml
|
uses: ./.github/workflows/image_build.yml
|
||||||
with:
|
with:
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
tag-latest: ${{ matrix.tag-latest }}
|
||||||
@@ -37,6 +37,7 @@ jobs:
|
|||||||
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
|
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
# Extra images
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
#platforms: 'linux/amd64,linux/arm64'
|
#platforms: 'linux/amd64,linux/arm64'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
@@ -119,51 +120,23 @@ jobs:
|
|||||||
image-type: 'extras'
|
image-type: 'extras'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
core-image-build:
|
- build-type: 'sycl_f16'
|
||||||
uses: ./.github/workflows/image_build.yml
|
|
||||||
with:
|
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
|
||||||
tag-suffix: ${{ matrix.tag-suffix }}
|
|
||||||
ffmpeg: ${{ matrix.ffmpeg }}
|
|
||||||
image-type: ${{ matrix.image-type }}
|
|
||||||
build-type: ${{ matrix.build-type }}
|
|
||||||
cuda-major-version: ${{ matrix.cuda-major-version }}
|
|
||||||
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
runs-on: ${{ matrix.runs-on }}
|
|
||||||
base-image: ${{ matrix.base-image }}
|
|
||||||
secrets:
|
|
||||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
|
||||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
|
||||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- build-type: 'hipblas'
|
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-hipblas-ffmpeg-core'
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
|
tag-suffix: '-sycl-f16-ffmpeg'
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'extras'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
- build-type: 'hipblas'
|
- build-type: 'sycl_f32'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-hipblas-core'
|
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
|
||||||
ffmpeg: 'false'
|
tag-suffix: '-sycl-f32-ffmpeg'
|
||||||
image-type: 'core'
|
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: ''
|
|
||||||
platforms: 'linux/amd64'
|
|
||||||
tag-latest: 'false'
|
|
||||||
tag-suffix: '-ffmpeg-core'
|
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'extras'
|
||||||
base-image: "ubuntu:22.04"
|
runs-on: 'arc-runner-set'
|
||||||
runs-on: 'ubuntu-latest'
|
# Core images
|
||||||
- build-type: 'sycl_f16'
|
- build-type: 'sycl_f16'
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
@@ -196,6 +169,52 @@ jobs:
|
|||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'core'
|
image-type: 'core'
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas-ffmpeg-core'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-hipblas-core'
|
||||||
|
ffmpeg: 'false'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
|
||||||
|
core-image-build:
|
||||||
|
uses: ./.github/workflows/image_build.yml
|
||||||
|
with:
|
||||||
|
tag-latest: ${{ matrix.tag-latest }}
|
||||||
|
tag-suffix: ${{ matrix.tag-suffix }}
|
||||||
|
ffmpeg: ${{ matrix.ffmpeg }}
|
||||||
|
image-type: ${{ matrix.image-type }}
|
||||||
|
build-type: ${{ matrix.build-type }}
|
||||||
|
cuda-major-version: ${{ matrix.cuda-major-version }}
|
||||||
|
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
|
runs-on: ${{ matrix.runs-on }}
|
||||||
|
base-image: ${{ matrix.base-image }}
|
||||||
|
secrets:
|
||||||
|
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||||
|
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- build-type: ''
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'false'
|
||||||
|
tag-suffix: '-ffmpeg-core'
|
||||||
|
ffmpeg: 'true'
|
||||||
|
image-type: 'core'
|
||||||
|
base-image: "ubuntu:22.04"
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
cuda-minor-version: "7"
|
cuda-minor-version: "7"
|
||||||
|
|||||||
29
.github/workflows/release.yaml
vendored
29
.github/workflows/release.yaml
vendored
@@ -89,6 +89,35 @@ jobs:
|
|||||||
files: |
|
files: |
|
||||||
release/*
|
release/*
|
||||||
|
|
||||||
|
build-stablediffusion:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: '>=1.21.0'
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y --no-install-recommends libopencv-dev
|
||||||
|
sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||||
|
- name: Build stablediffusion
|
||||||
|
run: |
|
||||||
|
make backend-assets/grpc/stablediffusion
|
||||||
|
mkdir -p release && cp backend-assets/grpc/stablediffusion release
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: stablediffusion
|
||||||
|
path: release/
|
||||||
|
- name: Release
|
||||||
|
uses: softprops/action-gh-release@v1
|
||||||
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
release/*
|
||||||
|
|
||||||
build-macOS:
|
build-macOS:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,6 +21,7 @@ local-ai
|
|||||||
!charts/*
|
!charts/*
|
||||||
# prevent above rules from omitting the api/localai folder
|
# prevent above rules from omitting the api/localai folder
|
||||||
!api/localai
|
!api/localai
|
||||||
|
!core/**/localai
|
||||||
|
|
||||||
# Ignore models
|
# Ignore models
|
||||||
models/*
|
models/*
|
||||||
@@ -34,6 +35,7 @@ release/
|
|||||||
.idea
|
.idea
|
||||||
|
|
||||||
# Generated during build
|
# Generated during build
|
||||||
backend-assets/
|
backend-assets/*
|
||||||
|
!backend-assets/.keep
|
||||||
prepare
|
prepare
|
||||||
/ggml-metal.metal
|
/ggml-metal.metal
|
||||||
|
|||||||
48
Dockerfile
48
Dockerfile
@@ -4,6 +4,8 @@ ARG BASE_IMAGE=ubuntu:22.04
|
|||||||
# extras or core
|
# extras or core
|
||||||
FROM ${BASE_IMAGE} as requirements-core
|
FROM ${BASE_IMAGE} as requirements-core
|
||||||
|
|
||||||
|
USER root
|
||||||
|
|
||||||
ARG GO_VERSION=1.21.7
|
ARG GO_VERSION=1.21.7
|
||||||
ARG BUILD_TYPE
|
ARG BUILD_TYPE
|
||||||
ARG CUDA_MAJOR_VERSION=11
|
ARG CUDA_MAJOR_VERSION=11
|
||||||
@@ -21,7 +23,7 @@ RUN apt-get update && \
|
|||||||
apt-get install -y ca-certificates curl patch pip cmake git && apt-get clean
|
apt-get install -y ca-certificates curl patch pip cmake git && apt-get clean
|
||||||
|
|
||||||
# Install Go
|
# Install Go
|
||||||
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -v -C /usr/local -xz
|
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -C /usr/local -xz
|
||||||
ENV PATH $PATH:/usr/local/go/bin
|
ENV PATH $PATH:/usr/local/go/bin
|
||||||
|
|
||||||
COPY --chmod=644 custom-ca-certs/* /usr/local/share/ca-certificates/
|
COPY --chmod=644 custom-ca-certs/* /usr/local/share/ca-certificates/
|
||||||
@@ -79,6 +81,10 @@ RUN pip install --upgrade pip
|
|||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
RUN apt-get install -y espeak-ng espeak && apt-get clean
|
RUN apt-get install -y espeak-ng espeak && apt-get clean
|
||||||
|
|
||||||
|
RUN if [ ! -e /usr/bin/python ]; then \
|
||||||
|
ln -s /usr/bin/python3 /usr/bin/python \
|
||||||
|
; fi
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
###################################
|
###################################
|
||||||
|
|
||||||
@@ -99,6 +105,13 @@ COPY . .
|
|||||||
COPY .git .
|
COPY .git .
|
||||||
RUN make prepare
|
RUN make prepare
|
||||||
|
|
||||||
|
# If we are building with clblas support, we need the libraries for the builds
|
||||||
|
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y libclblast-dev && \
|
||||||
|
apt-get clean \
|
||||||
|
; fi
|
||||||
|
|
||||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
# stablediffusion does not tolerate a newer version of abseil, build it first
|
||||||
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
||||||
|
|
||||||
@@ -142,6 +155,13 @@ RUN if [ "${FFMPEG}" = "true" ]; then \
|
|||||||
apt-get install -y ffmpeg && apt-get clean \
|
apt-get install -y ffmpeg && apt-get clean \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
|
# Add OpenCL
|
||||||
|
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y libclblast1 && \
|
||||||
|
apt-get clean \
|
||||||
|
; fi
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
|
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
|
||||||
@@ -166,43 +186,43 @@ COPY --from=builder /build/backend-assets/grpc/stablediffusion ./backend-assets/
|
|||||||
|
|
||||||
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/autogptq \
|
make -C backend/python/autogptq \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/bark \
|
make -C backend/python/bark \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers \
|
make -C backend/python/diffusers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
|
make -C backend/python/vllm \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
|
make -C backend/python/mamba \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
|
make -C backend/python/sentencetransformers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers \
|
make -C backend/python/transformers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
|
make -C backend/python/vall-e-x \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama \
|
make -C backend/python/exllama \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama2 \
|
make -C backend/python/exllama2 \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/petals \
|
make -C backend/python/petals \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers-musicgen \
|
make -C backend/python/transformers-musicgen \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/coqui \
|
make -C backend/python/coqui \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
# Make sure the models directory exists
|
# Make sure the models directory exists
|
||||||
|
|||||||
35
Makefile
35
Makefile
@@ -8,7 +8,7 @@ GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0
|
|||||||
|
|
||||||
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
||||||
|
|
||||||
CPPLLAMA_VERSION?=fd43d66f46ee3b5345fb8a74a252d86ccd34a409
|
CPPLLAMA_VERSION?=19885d205e768579ab090d1e99281cae58c21b54
|
||||||
|
|
||||||
# gpt4all version
|
# gpt4all version
|
||||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||||
@@ -16,7 +16,7 @@ GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8
|
|||||||
|
|
||||||
# go-rwkv version
|
# go-rwkv version
|
||||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||||
RWKV_VERSION?=633c5a3485c403cb2520693dc0991a25dace9f0f
|
RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
|
||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
|
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
|
||||||
@@ -28,7 +28,7 @@ BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
|||||||
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
|
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
|
||||||
|
|
||||||
# stablediffusion version
|
# stablediffusion version
|
||||||
STABLEDIFFUSION_VERSION?=d5d2be8e7e395c2d73ceef61e6fe8d240f2cd831
|
STABLEDIFFUSION_VERSION?=362df9da29f882dbf09ade61972d16a1f53c3485
|
||||||
|
|
||||||
# tinydream version
|
# tinydream version
|
||||||
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
|
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
|
||||||
@@ -44,6 +44,8 @@ BUILD_ID?=git
|
|||||||
|
|
||||||
TEST_DIR=/tmp/test
|
TEST_DIR=/tmp/test
|
||||||
|
|
||||||
|
TEST_FLAKES?=5
|
||||||
|
|
||||||
RANDOM := $(shell bash -c 'echo $$RANDOM')
|
RANDOM := $(shell bash -c 'echo $$RANDOM')
|
||||||
|
|
||||||
VERSION?=$(shell git describe --always --tags || echo "dev" )
|
VERSION?=$(shell git describe --always --tags || echo "dev" )
|
||||||
@@ -155,6 +157,7 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
|
|||||||
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
|
||||||
|
|
||||||
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
||||||
|
TEST_PATHS?=./api/... ./pkg/... ./core/...
|
||||||
|
|
||||||
# If empty, then we build all
|
# If empty, then we build all
|
||||||
ifeq ($(GRPC_BACKENDS),)
|
ifeq ($(GRPC_BACKENDS),)
|
||||||
@@ -250,7 +253,7 @@ sources/go-piper/libpiper_binding.a: sources/go-piper
|
|||||||
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
|
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
|
||||||
|
|
||||||
backend/cpp/llama/llama.cpp:
|
backend/cpp/llama/llama.cpp:
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
|
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
|
||||||
|
|
||||||
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
|
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
|
||||||
touch $@
|
touch $@
|
||||||
@@ -328,7 +331,7 @@ test-models/testmodel:
|
|||||||
cp tests/models_fixtures/* test-models
|
cp tests/models_fixtures/* test-models
|
||||||
|
|
||||||
prepare-test: grpcs
|
prepare-test: grpcs
|
||||||
cp -rf backend-assets api
|
cp -rf backend-assets core/http
|
||||||
cp tests/models_fixtures/* test-models
|
cp tests/models_fixtures/* test-models
|
||||||
|
|
||||||
test: prepare test-models/testmodel grpcs
|
test: prepare test-models/testmodel grpcs
|
||||||
@@ -336,7 +339,7 @@ test: prepare test-models/testmodel grpcs
|
|||||||
export GO_TAGS="tts stablediffusion"
|
export GO_TAGS="tts stablediffusion"
|
||||||
$(MAKE) prepare-test
|
$(MAKE) prepare-test
|
||||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
|
||||||
$(MAKE) test-gpt4all
|
$(MAKE) test-gpt4all
|
||||||
$(MAKE) test-llama
|
$(MAKE) test-llama
|
||||||
$(MAKE) test-llama-gguf
|
$(MAKE) test-llama-gguf
|
||||||
@@ -365,23 +368,23 @@ teardown-e2e:
|
|||||||
|
|
||||||
test-gpt4all: prepare-test
|
test-gpt4all: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-llama: prepare-test
|
test-llama: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-llama-gguf: prepare-test
|
test-llama-gguf: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-tts: prepare-test
|
test-tts: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-stablediffusion: prepare-test
|
test-stablediffusion: prepare-test
|
||||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
|
||||||
|
|
||||||
test-container:
|
test-container:
|
||||||
docker build --target requirements -t local-ai-test-container .
|
docker build --target requirements -t local-ai-test-container .
|
||||||
@@ -482,7 +485,7 @@ ifdef BUILD_GRPC_FOR_BACKEND_LLAMA
|
|||||||
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
else
|
else
|
||||||
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
endif
|
endif
|
||||||
## BACKEND CPP LLAMA END
|
## BACKEND CPP LLAMA END
|
||||||
|
|
||||||
@@ -516,6 +519,7 @@ backend-assets/grpc/langchain-huggingface: backend-assets/grpc
|
|||||||
|
|
||||||
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
||||||
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
||||||
|
$(MAKE) sources/go-stable-diffusion; \
|
||||||
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
|
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-stable-diffusion/ LIBRARY_PATH=$(CURDIR)/sources/go-stable-diffusion/ \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-stable-diffusion/ LIBRARY_PATH=$(CURDIR)/sources/go-stable-diffusion/ \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
|
||||||
@@ -553,3 +557,10 @@ docker-image-intel:
|
|||||||
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
--build-arg GO_TAGS="none" \
|
--build-arg GO_TAGS="none" \
|
||||||
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
|
|
||||||
|
docker-image-intel-xpu:
|
||||||
|
docker build \
|
||||||
|
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04 \
|
||||||
|
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
|
||||||
|
--build-arg GO_TAGS="none" \
|
||||||
|
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
|
||||||
14
README.md
14
README.md
@@ -48,7 +48,7 @@
|
|||||||
- Tools API support: https://github.com/mudler/LocalAI/pull/1715
|
- Tools API support: https://github.com/mudler/LocalAI/pull/1715
|
||||||
- LLaVa 1.6: https://github.com/mudler/LocalAI/pull/1714
|
- LLaVa 1.6: https://github.com/mudler/LocalAI/pull/1714
|
||||||
- ROCm container images: https://github.com/mudler/LocalAI/pull/1595
|
- ROCm container images: https://github.com/mudler/LocalAI/pull/1595
|
||||||
- Intel GPU support (sycl): https://github.com/mudler/LocalAI/issues/1653
|
- Intel GPU support (sycl, transformers, diffusers): https://github.com/mudler/LocalAI/issues/1653
|
||||||
- Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651
|
- Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651
|
||||||
- Mamba support: https://github.com/mudler/LocalAI/pull/1589
|
- Mamba support: https://github.com/mudler/LocalAI/pull/1589
|
||||||
- Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522
|
- Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522
|
||||||
@@ -59,7 +59,9 @@ Hot topics (looking for contributors):
|
|||||||
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
|
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
|
||||||
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
|
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
|
||||||
- Assistant API: https://github.com/mudler/LocalAI/issues/1273
|
- Assistant API: https://github.com/mudler/LocalAI/issues/1273
|
||||||
|
- Moderation endpoint: https://github.com/mudler/LocalAI/issues/999
|
||||||
|
- Vulkan: https://github.com/mudler/LocalAI/issues/1647
|
||||||
|
|
||||||
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
|
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
|
||||||
|
|
||||||
## 💻 [Getting started](https://localai.io/basics/getting_started/index.html)
|
## 💻 [Getting started](https://localai.io/basics/getting_started/index.html)
|
||||||
@@ -67,7 +69,7 @@ If you want to help and contribute, issues up for grabs: https://github.com/mudl
|
|||||||
For a detailed step-by-step introduction, refer to the [Getting Started](https://localai.io/basics/getting_started/index.html) guide. For those in a hurry, here's a straightforward one-liner to launch a LocalAI instance with [phi-2](https://huggingface.co/microsoft/phi-2) using `docker`:
|
For a detailed step-by-step introduction, refer to the [Getting Started](https://localai.io/basics/getting_started/index.html) guide. For those in a hurry, here's a straightforward one-liner to launch a LocalAI instance with [phi-2](https://huggingface.co/microsoft/phi-2) using `docker`:
|
||||||
|
|
||||||
```
|
```
|
||||||
docker run -ti -p 8080:8080 localai/localai:v2.7.0-ffmpeg-core phi-2
|
docker run -ti -p 8080:8080 localai/localai:v2.9.0-ffmpeg-core phi-2
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🚀 [Features](https://localai.io/features/)
|
## 🚀 [Features](https://localai.io/features/)
|
||||||
@@ -97,9 +99,6 @@ WebUIs:
|
|||||||
|
|
||||||
Model galleries
|
Model galleries
|
||||||
- https://github.com/go-skynet/model-gallery
|
- https://github.com/go-skynet/model-gallery
|
||||||
|
|
||||||
UI / Management Programs
|
|
||||||
- [LocalAI Manager](https://io.midori-ai.xyz/howtos/easy-model-installer/)
|
|
||||||
|
|
||||||
Other:
|
Other:
|
||||||
- Helm chart https://github.com/go-skynet/helm-charts
|
- Helm chart https://github.com/go-skynet/helm-charts
|
||||||
@@ -110,6 +109,7 @@ Other:
|
|||||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
||||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
||||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
||||||
|
|
||||||
|
|
||||||
### 🔗 Resources
|
### 🔗 Resources
|
||||||
|
|
||||||
@@ -121,6 +121,8 @@ Other:
|
|||||||
|
|
||||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||||
|
|
||||||
|
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/ai/answers/tiZMDoZzZV6TLxgDXNBnFE/deploying-helm-charts-on-aws-eks)
|
||||||
|
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
|
||||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
||||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
||||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||||
|
|||||||
42
SECURITY.md
Normal file
42
SECURITY.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Security Policy
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
At LocalAI, we take the security of our software seriously. We understand the importance of protecting our community from vulnerabilities and are committed to ensuring the safety and security of our users.
|
||||||
|
|
||||||
|
## Supported Versions
|
||||||
|
|
||||||
|
We provide support and updates for certain versions of our software. The following table outlines which versions are currently supported with security updates:
|
||||||
|
|
||||||
|
| Version | Supported |
|
||||||
|
| ------- | ------------------ |
|
||||||
|
| > 2.0 | :white_check_mark: |
|
||||||
|
| < 2.0 | :x: |
|
||||||
|
|
||||||
|
Please ensure that you are using a supported version to receive the latest security updates.
|
||||||
|
|
||||||
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
We encourage the responsible disclosure of any security vulnerabilities. If you believe you've found a security issue in our software, we kindly ask you to follow the steps below to report it to us:
|
||||||
|
|
||||||
|
1. **Email Us:** Send an email to [security@localai.io](mailto:security@localai.io) with a detailed report. Please do not disclose the vulnerability publicly or to any third parties before it has been addressed by us.
|
||||||
|
|
||||||
|
2. **Expect a Response:** We aim to acknowledge receipt of vulnerability reports within 48 hours. Our security team will review your report and work closely with you to understand the impact and ensure a thorough investigation.
|
||||||
|
|
||||||
|
3. **Collaboration:** If the vulnerability is accepted, we will work with you and our community to address the issue promptly. We'll keep you informed throughout the resolution process and may request additional information or collaboration.
|
||||||
|
|
||||||
|
4. **Disclosure:** Once the vulnerability has been resolved, we encourage a coordinated disclosure. We believe in transparency and will work with you to ensure that our community is informed in a responsible manner.
|
||||||
|
|
||||||
|
## Use of Third-Party Platforms
|
||||||
|
|
||||||
|
As a Free and Open Source Software (FOSS) organization, we do not offer monetary bounties. However, researchers who wish to report vulnerabilities can also do so via [Huntr](https://huntr.dev/bounties), a platform that recognizes contributions to open source security.
|
||||||
|
|
||||||
|
## Contact
|
||||||
|
|
||||||
|
For any security-related inquiries beyond vulnerability reporting, please contact us at [security@localai.io](mailto:security@localai.io).
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
We appreciate the efforts of those who contribute to the security of our project. Your responsible disclosure is invaluable to the safety and integrity of LocalAI.
|
||||||
|
|
||||||
|
Thank you for helping us keep LocalAI secure.
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
|
|
||||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BackendMonitorRequest struct {
|
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitorResponse struct {
|
|
||||||
MemoryInfo *gopsutil.MemoryInfoStat
|
|
||||||
MemoryPercent float32
|
|
||||||
CPUPercent float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type BackendMonitor struct {
|
|
||||||
configLoader *config.ConfigLoader
|
|
||||||
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
|
|
||||||
return BackendMonitor{
|
|
||||||
configLoader: configLoader,
|
|
||||||
options: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
|
|
||||||
config, exists := bm.configLoader.GetConfig(model)
|
|
||||||
var backend string
|
|
||||||
if exists {
|
|
||||||
backend = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backend = model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backend, ".bin") {
|
|
||||||
backend = fmt.Sprintf("%s.bin", backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
pid, err := bm.options.Loader.GetGRPCPID(backend)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
|
||||||
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memInfo, err := backendProcess.MemoryInfo()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
memPercent, err := backendProcess.MemoryPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cpuPercent, err := backendProcess.CPUPercent()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &BackendMonitorResponse{
|
|
||||||
MemoryInfo: memInfo,
|
|
||||||
MemoryPercent: memPercent,
|
|
||||||
CPUPercent: cpuPercent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
|
|
||||||
input := new(BackendMonitorRequest)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
config, exists := bm.configLoader.GetConfig(input.Model)
|
|
||||||
var backendId string
|
|
||||||
if exists {
|
|
||||||
backendId = config.Model
|
|
||||||
} else {
|
|
||||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
|
||||||
backendId = input.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasSuffix(backendId, ".bin") {
|
|
||||||
backendId = fmt.Sprintf("%s.bin", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return backendId, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
model := bm.options.Loader.CheckIsLoaded(backendId)
|
|
||||||
if model == "" {
|
|
||||||
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
|
||||||
}
|
|
||||||
|
|
||||||
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
|
|
||||||
if rpcErr != nil {
|
|
||||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
|
||||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
|
||||||
if slbErr != nil {
|
|
||||||
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
|
||||||
}
|
|
||||||
return c.JSON(proto.StatusResponse{
|
|
||||||
State: proto.StatusResponse_ERROR,
|
|
||||||
Memory: &proto.MemoryUsageData{
|
|
||||||
Total: val.MemoryInfo.VMS,
|
|
||||||
Breakdown: map[string]uint64{
|
|
||||||
"gopsutil-RSS": val.MemoryInfo.RSS,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bm.options.Loader.ShutdownModel(backendId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
package localai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
json "github.com/json-iterator/go"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type galleryOp struct {
|
|
||||||
req gallery.GalleryModel
|
|
||||||
id string
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
galleryName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryOpStatus struct {
|
|
||||||
FileName string `json:"file_name"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
Processed bool `json:"processed"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
TotalFileSize string `json:"file_size"`
|
|
||||||
DownloadedFileSize string `json:"downloaded_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryApplier struct {
|
|
||||||
modelPath string
|
|
||||||
sync.Mutex
|
|
||||||
C chan galleryOp
|
|
||||||
statuses map[string]*galleryOpStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGalleryService(modelPath string) *galleryApplier {
|
|
||||||
return &galleryApplier{
|
|
||||||
modelPath: modelPath,
|
|
||||||
C: make(chan galleryOp),
|
|
||||||
statuses: make(map[string]*galleryOpStatus),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
|
||||||
|
|
||||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
|
||||||
|
|
||||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
g.statuses[s] = op
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses[s]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
|
|
||||||
g.Lock()
|
|
||||||
defer g.Unlock()
|
|
||||||
|
|
||||||
return g.statuses
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Done():
|
|
||||||
return
|
|
||||||
case op := <-g.C:
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
|
||||||
|
|
||||||
// updates the status with an error
|
|
||||||
updateError := func(e error) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
|
||||||
}
|
|
||||||
|
|
||||||
// displayDownload displays the download progress
|
|
||||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
|
||||||
if op.galleryName != "" {
|
|
||||||
if strings.Contains(op.galleryName, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload models
|
|
||||||
err = cm.LoadConfigs(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cm.Preload(g.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
updateError(err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryModel struct {
|
|
||||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
|
||||||
var err error
|
|
||||||
for _, r := range requests {
|
|
||||||
utils.ResetDownloadTimers()
|
|
||||||
if r.ID == "" {
|
|
||||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
if strings.Contains(r.ID, "@") {
|
|
||||||
err = gallery.InstallModelFromGallery(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
} else {
|
|
||||||
err = gallery.InstallModelFromGalleryByName(
|
|
||||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
dat, err := os.ReadFile(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var requests []galleryModel
|
|
||||||
|
|
||||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
|
||||||
var requests []galleryModel
|
|
||||||
err := json.Unmarshal([]byte(s), &requests)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return processRequests(modelPath, s, cm, galleries, requests)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Endpoint Service
|
|
||||||
|
|
||||||
type ModelGalleryService struct {
|
|
||||||
galleries []gallery.Gallery
|
|
||||||
modelPath string
|
|
||||||
galleryApplier *galleryApplier
|
|
||||||
}
|
|
||||||
|
|
||||||
type GalleryModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
gallery.GalleryModel
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
|
|
||||||
return ModelGalleryService{
|
|
||||||
galleries: galleries,
|
|
||||||
modelPath: modelPath,
|
|
||||||
galleryApplier: galleryApplier,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
|
|
||||||
if status == nil {
|
|
||||||
return fmt.Errorf("could not find any status for ID")
|
|
||||||
}
|
|
||||||
return c.JSON(status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
return c.JSON(mgs.galleryApplier.getAllStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(GalleryModel)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
uuid, err := uuid.NewUUID()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
mgs.galleryApplier.C <- galleryOp{
|
|
||||||
req: input.GalleryModel,
|
|
||||||
id: uuid.String(),
|
|
||||||
galleryName: input.ID,
|
|
||||||
galleries: mgs.galleries,
|
|
||||||
}
|
|
||||||
return c.JSON(struct {
|
|
||||||
ID string `json:"uuid"`
|
|
||||||
StatusURL string `json:"status"`
|
|
||||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
|
||||||
|
|
||||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
|
||||||
for _, m := range models {
|
|
||||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(models)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
|
||||||
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s already exists", input.Name)
|
|
||||||
}
|
|
||||||
dat, err := json.Marshal(mgs.galleries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
|
||||||
mgs.galleries = append(mgs.galleries, *input)
|
|
||||||
return c.Send(dat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
|
||||||
input := new(gallery.Gallery)
|
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
}) {
|
|
||||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
|
||||||
}
|
|
||||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
|
||||||
return gallery.Name == input.Name
|
|
||||||
})
|
|
||||||
return c.Send(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -126,6 +126,11 @@ message ModelOptions {
|
|||||||
|
|
||||||
// vllm
|
// vllm
|
||||||
string Quantization = 40;
|
string Quantization = 40;
|
||||||
|
float GPUMemoryUtilization = 50;
|
||||||
|
bool TrustRemoteCode = 51;
|
||||||
|
bool EnforceEager = 52;
|
||||||
|
int32 SwapSpace = 53;
|
||||||
|
int32 MaxModelLen = 54;
|
||||||
|
|
||||||
string MMProj = 41;
|
string MMProj = 41;
|
||||||
|
|
||||||
@@ -186,6 +191,7 @@ message TTSRequest {
|
|||||||
string text = 1;
|
string text = 1;
|
||||||
string model = 2;
|
string model = 2;
|
||||||
string dst = 3;
|
string dst = 3;
|
||||||
|
string voice = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizationResponse {
|
message TokenizationResponse {
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ ifeq ($(BUILD_TYPE),cublas)
|
|||||||
# to CMAKE_ARGS automatically
|
# to CMAKE_ARGS automatically
|
||||||
else ifeq ($(BUILD_TYPE),openblas)
|
else ifeq ($(BUILD_TYPE),openblas)
|
||||||
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
|
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
|
||||||
# If build type is clblast (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
# If build type is clblas (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
else ifeq ($(BUILD_TYPE),clblast)
|
else ifeq ($(BUILD_TYPE),clblas)
|
||||||
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
|
||||||
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
|
||||||
else ifeq ($(BUILD_TYPE),hipblas)
|
else ifeq ($(BUILD_TYPE),hipblas)
|
||||||
|
|||||||
@@ -58,9 +58,11 @@ struct server_params
|
|||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
bool slots_endpoint = true;
|
bool slots_endpoint = true;
|
||||||
|
bool metrics_endpoint = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool server_verbose = false;
|
bool server_verbose = false;
|
||||||
|
bool server_log_json = true;
|
||||||
|
|
||||||
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
|
||||||
{
|
{
|
||||||
@@ -316,12 +318,76 @@ struct llama_client_slot
|
|||||||
}
|
}
|
||||||
|
|
||||||
void print_timings() const {
|
void print_timings() const {
|
||||||
LOG_TEE("\n");
|
char buffer[512];
|
||||||
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
double t_token = t_prompt_processing / num_prompt_tokens_processed;
|
||||||
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
|
double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed;
|
||||||
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
||||||
__func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded);
|
t_prompt_processing, num_prompt_tokens_processed,
|
||||||
LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation);
|
t_token, n_tokens_second);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_prompt_processing", t_prompt_processing},
|
||||||
|
{"num_prompt_tokens_processed", num_prompt_tokens_processed},
|
||||||
|
{"t_token", t_token},
|
||||||
|
{"n_tokens_second", n_tokens_second},
|
||||||
|
});
|
||||||
|
|
||||||
|
t_token = t_token_generation / n_decoded;
|
||||||
|
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
||||||
|
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
||||||
|
t_token_generation, n_decoded,
|
||||||
|
t_token, n_tokens_second);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_token_generation", t_token_generation},
|
||||||
|
{"n_decoded", n_decoded},
|
||||||
|
{"t_token", t_token},
|
||||||
|
{"n_tokens_second", n_tokens_second},
|
||||||
|
});
|
||||||
|
|
||||||
|
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
||||||
|
LOG_INFO(buffer, {
|
||||||
|
{"slot_id", id},
|
||||||
|
{"task_id", task_id},
|
||||||
|
{"t_prompt_processing", t_prompt_processing},
|
||||||
|
{"t_token_generation", t_token_generation},
|
||||||
|
{"t_total", t_prompt_processing + t_token_generation},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_metrics {
|
||||||
|
uint64_t n_prompt_tokens_processed_total = 0;
|
||||||
|
uint64_t n_tokens_predicted_total = 0;
|
||||||
|
|
||||||
|
uint64_t n_prompt_tokens_processed = 0;
|
||||||
|
uint64_t t_prompt_processing = 0;
|
||||||
|
|
||||||
|
uint64_t n_tokens_predicted = 0;
|
||||||
|
uint64_t t_tokens_generation = 0;
|
||||||
|
|
||||||
|
|
||||||
|
void on_prompt_eval(const llama_client_slot &slot) {
|
||||||
|
n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed;
|
||||||
|
|
||||||
|
n_prompt_tokens_processed += slot.num_prompt_tokens_processed;
|
||||||
|
t_prompt_processing += slot.t_prompt_processing;
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_prediction(const llama_client_slot &slot) {
|
||||||
|
n_tokens_predicted_total += slot.n_decoded;
|
||||||
|
|
||||||
|
n_tokens_predicted += slot.n_decoded;
|
||||||
|
t_tokens_generation += slot.t_token_generation;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_bucket() {
|
||||||
|
n_prompt_tokens_processed = 0;
|
||||||
|
t_prompt_processing = 0;
|
||||||
|
n_tokens_predicted = 0;
|
||||||
|
t_tokens_generation = 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -359,6 +425,8 @@ struct llama_server_context
|
|||||||
llama_server_queue queue_tasks;
|
llama_server_queue queue_tasks;
|
||||||
llama_server_response queue_results;
|
llama_server_response queue_results;
|
||||||
|
|
||||||
|
llama_metrics metrics;
|
||||||
|
|
||||||
~llama_server_context()
|
~llama_server_context()
|
||||||
{
|
{
|
||||||
if (ctx)
|
if (ctx)
|
||||||
@@ -378,7 +446,7 @@ struct llama_server_context
|
|||||||
params = params_;
|
params = params_;
|
||||||
if (!params.mmproj.empty()) {
|
if (!params.mmproj.empty()) {
|
||||||
multimodal = true;
|
multimodal = true;
|
||||||
LOG_TEE("Multi Modal Mode Enabled");
|
LOG_INFO("Multi Modal Mode Enabled", {});
|
||||||
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
|
||||||
if(clp_ctx == nullptr) {
|
if(clp_ctx == nullptr) {
|
||||||
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
|
||||||
@@ -415,13 +483,23 @@ struct llama_server_context
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void validate_model_chat_template(server_params & sparams) {
|
||||||
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
|
std::vector<char> buf(1);
|
||||||
|
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
||||||
|
if (res < 0) {
|
||||||
|
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||||
|
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void initialize() {
|
void initialize() {
|
||||||
// create slots
|
// create slots
|
||||||
all_slots_are_idle = true;
|
all_slots_are_idle = true;
|
||||||
|
|
||||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||||
|
|
||||||
LOG_TEE("Available slots:\n");
|
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
|
||||||
for (int i = 0; i < params.n_parallel; i++)
|
for (int i = 0; i < params.n_parallel; i++)
|
||||||
{
|
{
|
||||||
llama_client_slot slot;
|
llama_client_slot slot;
|
||||||
@@ -430,7 +508,10 @@ struct llama_server_context
|
|||||||
slot.n_ctx = n_ctx_slot;
|
slot.n_ctx = n_ctx_slot;
|
||||||
slot.n_predict = params.n_predict;
|
slot.n_predict = params.n_predict;
|
||||||
|
|
||||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
LOG_INFO("new slot", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"n_ctx_slot", slot.n_ctx}
|
||||||
|
});
|
||||||
|
|
||||||
const int ga_n = params.grp_attn_n;
|
const int ga_n = params.grp_attn_n;
|
||||||
const int ga_w = params.grp_attn_w;
|
const int ga_w = params.grp_attn_w;
|
||||||
@@ -440,7 +521,12 @@ struct llama_server_context
|
|||||||
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
||||||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
||||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
||||||
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
|
|
||||||
|
LOG_INFO("slot self-extend", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"ga_n", ga_n},
|
||||||
|
{"ga_w", ga_w}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.ga_i = 0;
|
slot.ga_i = 0;
|
||||||
@@ -726,10 +812,16 @@ struct llama_server_context
|
|||||||
img_sl.img_data = clip_image_u8_init();
|
img_sl.img_data = clip_image_u8_init();
|
||||||
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
|
||||||
{
|
{
|
||||||
LOG_TEE("slot %i - failed to load image [id: %i]\n", slot->id, img_sl.id);
|
LOG_ERROR("failed to load image", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"img_sl_id", img_sl.id}
|
||||||
|
});
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
LOG_TEE("slot %i - loaded image\n", slot->id);
|
LOG_VERBOSE("image loaded", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"img_sl_id", img_sl.id}
|
||||||
|
});
|
||||||
img_sl.request_encode_image = true;
|
img_sl.request_encode_image = true;
|
||||||
slot->images.push_back(img_sl);
|
slot->images.push_back(img_sl);
|
||||||
}
|
}
|
||||||
@@ -789,7 +881,10 @@ struct llama_server_context
|
|||||||
|
|
||||||
all_slots_are_idle = false;
|
all_slots_are_idle = false;
|
||||||
|
|
||||||
LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id);
|
LOG_INFO("slot is processing task", {
|
||||||
|
{"slot_id", slot->id},
|
||||||
|
{"task_id", slot->task_id},
|
||||||
|
});
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -814,10 +909,24 @@ struct llama_server_context
|
|||||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0)
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
|
||||||
{
|
{
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
return;
|
llama_batch batch_view = {
|
||||||
|
n_tokens,
|
||||||
|
batch.token + i,
|
||||||
|
nullptr,
|
||||||
|
batch.pos + i,
|
||||||
|
batch.n_seq_id + i,
|
||||||
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
if (llama_decode(ctx, batch_view) != 0)
|
||||||
|
{
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
// assign the system KV cache to all parallel sequences
|
||||||
@@ -1351,7 +1460,7 @@ struct llama_server_context
|
|||||||
if (slot == nullptr)
|
if (slot == nullptr)
|
||||||
{
|
{
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
LOG_VERBOSE("no slot is available", {});
|
LOG_VERBOSE("no slot is available", {{"task_id", task.id}});
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(task);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1425,7 +1534,7 @@ struct llama_server_context
|
|||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
if (system_need_update)
|
if (system_need_update)
|
||||||
{
|
{
|
||||||
LOG_TEE("updating system prompt\n");
|
LOG_INFO("updating system prompt", {});
|
||||||
update_system_prompt();
|
update_system_prompt();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1435,12 +1544,13 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
if (system_prompt.empty() && clean_kv_cache)
|
if (system_prompt.empty() && clean_kv_cache)
|
||||||
{
|
{
|
||||||
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
|
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
|
||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
||||||
task_server task;
|
task_server task;
|
||||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||||
task.target_id = -1;
|
task.target_id = -1;
|
||||||
@@ -1471,6 +1581,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
|
LOG_VERBOSE("decoding ongoing sequences", {});
|
||||||
for (auto & slot : slots)
|
for (auto & slot : slots)
|
||||||
{
|
{
|
||||||
// release the slot
|
// release the slot
|
||||||
@@ -1480,7 +1591,15 @@ struct llama_server_context
|
|||||||
slot.command = NONE;
|
slot.command = NONE;
|
||||||
slot.t_last_used = ggml_time_us();
|
slot.t_last_used = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
|
LOG_INFO("slot released", {
|
||||||
|
{"slot_id", slot.id},
|
||||||
|
{"task_id", slot.task_id},
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_past", slot.n_past},
|
||||||
|
{"n_system_tokens", system_tokens.size()},
|
||||||
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||||
|
{"truncated", slot.truncated}
|
||||||
|
});
|
||||||
queue_tasks.notify_slot_changed();
|
queue_tasks.notify_slot_changed();
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@@ -1607,6 +1726,14 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
|
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
||||||
|
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
||||||
|
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
|
||||||
|
{
|
||||||
|
slot.n_past -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||||
|
|
||||||
if (slot.ga_n != 1)
|
if (slot.ga_n != 1)
|
||||||
@@ -1628,7 +1755,12 @@ struct llama_server_context
|
|||||||
slot.ga_i = ga_i;
|
slot.ga_i = ga_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
LOG_INFO("slot progression", {
|
||||||
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id },
|
||||||
|
{ "n_past", slot.n_past },
|
||||||
|
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.cache_tokens = prompt_tokens;
|
slot.cache_tokens = prompt_tokens;
|
||||||
@@ -1636,7 +1768,10 @@ struct llama_server_context
|
|||||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||||
{
|
{
|
||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
||||||
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id }
|
||||||
|
});
|
||||||
slot.n_past--;
|
slot.n_past--;
|
||||||
if (slot.ga_i > 0)
|
if (slot.ga_i > 0)
|
||||||
{
|
{
|
||||||
@@ -1644,9 +1779,13 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
|
LOG_INFO("kv cache rm [p0, end)", {
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
|
{ "slot_id", slot.id },
|
||||||
|
{ "task_id", slot.task_id },
|
||||||
|
{ "p0", p0 }
|
||||||
|
});
|
||||||
|
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||||
|
|
||||||
LOG_VERBOSE("prompt ingested", {
|
LOG_VERBOSE("prompt ingested", {
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
@@ -1681,7 +1820,13 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
{
|
{
|
||||||
LOG_TEE("failed processing images\n");
|
LOG_ERROR("failed processing images", {
|
||||||
|
"slot_id", slot.id,
|
||||||
|
"task_id", slot.task_id,
|
||||||
|
});
|
||||||
|
// FIXME @phymbert: to be properly tested
|
||||||
|
// early returning without changing the slot state will block the slot for ever
|
||||||
|
// no one at the moment is checking the return value
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1723,9 +1868,9 @@ struct llama_server_context
|
|||||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||||
|
|
||||||
slot.n_past_se -= bd;
|
slot.n_past_se -= bd;
|
||||||
|
|
||||||
@@ -1781,7 +1926,7 @@ struct llama_server_context
|
|||||||
send_embedding(slot);
|
send_embedding(slot);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
return true;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
@@ -1794,6 +1939,7 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
slot.t_start_genereration = ggml_time_us();
|
slot.t_start_genereration = ggml_time_us();
|
||||||
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
||||||
|
metrics.on_prompt_eval(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
||||||
@@ -1816,11 +1962,14 @@ struct llama_server_context
|
|||||||
slot.release();
|
slot.release();
|
||||||
slot.print_timings();
|
slot.print_timings();
|
||||||
send_final_response(slot);
|
send_final_response(slot);
|
||||||
|
metrics.on_prediction(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("slots updated", {});
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1849,18 +1998,6 @@ static json format_partial_response(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_tokenizer_response(const std::vector<llama_token> &tokens)
|
|
||||||
{
|
|
||||||
return json{
|
|
||||||
{"tokens", tokens}};
|
|
||||||
}
|
|
||||||
|
|
||||||
static json format_detokenized_response(std::string content)
|
|
||||||
{
|
|
||||||
return json{
|
|
||||||
{"content", content}};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct token_translator
|
struct token_translator
|
||||||
{
|
{
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
@@ -2119,9 +2256,9 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
params.use_mmap = request->mmap();
|
params.use_mmap = request->mmap();
|
||||||
params.embedding = request->embeddings();
|
params.embedding = request->embeddings();
|
||||||
|
|
||||||
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
|
||||||
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||||
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
|
||||||
if ( request->yarnextfactor() != 0.0f ) {
|
if ( request->yarnextfactor() != 0.0f ) {
|
||||||
params.yarn_ext_factor = request->yarnextfactor();
|
params.yarn_ext_factor = request->yarnextfactor();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,21 +11,21 @@ import (
|
|||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func sh(c string) (string, error) {
|
func runCommand(command []string) (string, error) {
|
||||||
cmd := exec.Command("/bin/sh", "-c", c)
|
cmd := exec.Command(command[0], command[1:]...)
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
o, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
return string(o), err
|
return string(out), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
|
// AudioToWav converts audio to wav for transcribe.
|
||||||
// TODO: use https://github.com/mccoyst/ogg?
|
// TODO: use https://github.com/mccoyst/ogg?
|
||||||
func audioToWav(src, dst string) error {
|
func audioToWav(src, dst string) error {
|
||||||
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
|
command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
|
||||||
|
out, err := runCommand(command)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error: %w out: %s", err, out)
|
return fmt.Errorf("error: %w out: %s", err, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
||||||
model_basename=request.ModelBaseName,
|
model_basename=request.ModelBaseName,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=request.TrustRemoteCode,
|
||||||
device=device,
|
device=device,
|
||||||
use_triton=request.UseTriton,
|
use_triton=request.UseTriton,
|
||||||
quantize_config=None)
|
quantize_config=None)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ dependencies:
|
|||||||
- regex==2023.10.3
|
- regex==2023.10.3
|
||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.3.3
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers==0.14.0
|
- tokenizers==0.14.0
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -8,6 +8,13 @@ ifeq ($(BUILD_TYPE), hipblas)
|
|||||||
CONDA_ENV_PATH = "transformers-rocm.yml"
|
CONDA_ENV_PATH = "transformers-rocm.yml"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
# Intel GPU are supposed to have dependencies installed in the main python
|
||||||
|
# environment, so we skip conda installation for SYCL builds.
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: transformers
|
.PHONY: transformers
|
||||||
transformers:
|
transformers:
|
||||||
@echo "Installing $(CONDA_ENV_PATH)..."
|
@echo "Installing $(CONDA_ENV_PATH)..."
|
||||||
|
|||||||
@@ -1,24 +1,38 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
|
|
||||||
# Check if environment exist
|
# Check if environment exist
|
||||||
conda_env_exists(){
|
conda_env_exists(){
|
||||||
! conda list --name "${@}" >/dev/null 2>/dev/null
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
if conda_env_exists "transformers" ; then
|
if [ $SKIP_CONDA -eq 1 ]; then
|
||||||
echo "Creating virtual environment..."
|
echo "Skipping conda environment installation"
|
||||||
conda env create --name transformers --file $1
|
else
|
||||||
echo "Virtual environment created."
|
export PATH=$PATH:/opt/conda/bin
|
||||||
else
|
if conda_env_exists "transformers" ; then
|
||||||
echo "Virtual environment already exists."
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name transformers --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Intel GPU: If the directory exists, we assume we are using the intel image
|
||||||
|
# (no conda env)
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ $SKIP_CONDA -eq 0 ]; then
|
||||||
|
# Activate conda environment
|
||||||
# Activate conda environment
|
source activate transformers
|
||||||
source activate transformers
|
fi
|
||||||
|
|
||||||
pip cache purge
|
pip cache purge
|
||||||
fi
|
fi
|
||||||
@@ -36,7 +36,7 @@ dependencies:
|
|||||||
- TTS==0.22.0
|
- TTS==0.22.0
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
- sentence-transformers==2.2.2
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
- sentencepiece==0.1.99
|
- sentencepiece==0.1.99
|
||||||
- dill==0.3.7
|
- dill==0.3.7
|
||||||
- einops==0.7.0
|
- einops==0.7.0
|
||||||
@@ -81,8 +81,8 @@ dependencies:
|
|||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- s3transfer==0.7.0
|
- s3transfer==0.7.0
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.4.1
|
||||||
- scipy==1.11.3
|
- scipy==1.12.0 # Updated Version
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
@@ -113,7 +113,7 @@ dependencies:
|
|||||||
- sudachipy
|
- sudachipy
|
||||||
- sudachidict_core
|
- sudachidict_core
|
||||||
- vocos
|
- vocos
|
||||||
- vllm==0.2.7
|
- vllm==0.3.2
|
||||||
- transformers>=4.36.0 # Required for Mixtral.
|
- transformers>=4.38.2 # Updated Version
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ dependencies:
|
|||||||
- TTS==0.22.0
|
- TTS==0.22.0
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
- sentence-transformers==2.2.2
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
- sentencepiece==0.1.99
|
- sentencepiece==0.1.99
|
||||||
- dill==0.3.7
|
- dill==0.3.7
|
||||||
- einops==0.7.0
|
- einops==0.7.0
|
||||||
@@ -71,8 +71,8 @@ dependencies:
|
|||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- s3transfer==0.7.0
|
- s3transfer==0.7.0
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.4.1
|
||||||
- scipy==1.11.3
|
- scipy==1.12.0 # Updated Version
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
@@ -103,7 +103,7 @@ dependencies:
|
|||||||
- sudachipy
|
- sudachipy
|
||||||
- sudachidict_core
|
- sudachidict_core
|
||||||
- vocos
|
- vocos
|
||||||
- vllm==0.2.7
|
- vllm==0.3.2
|
||||||
- transformers>=4.36.0 # Required for Mixtral.
|
- transformers>=4.38.2 # Updated Version
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ dependencies:
|
|||||||
- TTS==0.22.0
|
- TTS==0.22.0
|
||||||
- charset-normalizer==3.3.0
|
- charset-normalizer==3.3.0
|
||||||
- datasets==2.14.5
|
- datasets==2.14.5
|
||||||
- sentence-transformers==2.2.2
|
- sentence-transformers==2.5.1 # Updated Version
|
||||||
- sentencepiece==0.1.99
|
- sentencepiece==0.1.99
|
||||||
- dill==0.3.7
|
- dill==0.3.7
|
||||||
- einops==0.7.0
|
- einops==0.7.0
|
||||||
@@ -69,8 +69,8 @@ dependencies:
|
|||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
- rouge==1.0.1
|
- rouge==1.0.1
|
||||||
- s3transfer==0.7.0
|
- s3transfer==0.7.0
|
||||||
- safetensors==0.3.3
|
- safetensors>=0.4.1
|
||||||
- scipy==1.11.3
|
- scipy==1.12.0 # Updated Version
|
||||||
- six==1.16.0
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
- tokenizers
|
- tokenizers
|
||||||
@@ -101,7 +101,7 @@ dependencies:
|
|||||||
- sudachipy
|
- sudachipy
|
||||||
- sudachidict_core
|
- sudachidict_core
|
||||||
- vocos
|
- vocos
|
||||||
- vllm==0.2.7
|
- vllm==0.3.2
|
||||||
- transformers>=4.36.0 # Required for Mixtral.
|
- transformers>=4.38.2 # Updated Version
|
||||||
- xformers==0.0.23.post1
|
- xformers==0.0.23.post1
|
||||||
prefix: /opt/conda/envs/transformers
|
prefix: /opt/conda/envs/transformers
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -4,6 +4,13 @@ ifeq ($(BUILD_TYPE), hipblas)
|
|||||||
export CONDA_ENV_PATH = "diffusers-rocm.yml"
|
export CONDA_ENV_PATH = "diffusers-rocm.yml"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
# Intel GPU are supposed to have dependencies installed in the main python
|
||||||
|
# environment, so we skip conda installation for SYCL builds.
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: diffusers
|
.PHONY: diffusers
|
||||||
diffusers:
|
diffusers:
|
||||||
@echo "Installing $(CONDA_ENV_PATH)..."
|
@echo "Installing $(CONDA_ENV_PATH)..."
|
||||||
|
|||||||
@@ -21,14 +21,15 @@ from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipelin
|
|||||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||||
from diffusers.utils import load_image,export_to_video
|
from diffusers.utils import load_image,export_to_video
|
||||||
from compel import Compel
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
|
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
COMPEL=os.environ.get("COMPEL", "0") == "1"
|
||||||
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||||
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
||||||
@@ -36,6 +37,10 @@ FPS=os.environ.get("FPS", "7")
|
|||||||
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
||||||
FRAMES=os.environ.get("FRAMES", "64")
|
FRAMES=os.environ.get("FRAMES", "64")
|
||||||
|
|
||||||
|
if XPU:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
print(ipex.xpu.get_device_name(0))
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
@@ -231,8 +236,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if request.SchedulerType != "":
|
if request.SchedulerType != "":
|
||||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||||
|
|
||||||
if not self.img2vid:
|
if COMPEL:
|
||||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
self.compel = Compel(
|
||||||
|
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2 ],
|
||||||
|
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
|
||||||
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
||||||
|
requires_pooled=[False, True]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if request.ControlNet:
|
if request.ControlNet:
|
||||||
@@ -247,6 +257,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
self.pipe.to('cuda')
|
self.pipe.to('cuda')
|
||||||
if self.controlnet:
|
if self.controlnet:
|
||||||
self.controlnet.to('cuda')
|
self.controlnet.to('cuda')
|
||||||
|
if XPU:
|
||||||
|
self.pipe = self.pipe.to("xpu")
|
||||||
# Assume directory from request.ModelFile.
|
# Assume directory from request.ModelFile.
|
||||||
# Only if request.LoraAdapter it's not an absolute path
|
# Only if request.LoraAdapter it's not an absolute path
|
||||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||||
@@ -386,8 +398,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
image = {}
|
image = {}
|
||||||
if COMPEL:
|
if COMPEL:
|
||||||
conditioning = self.compel.build_conditioning_tensor(prompt)
|
conditioning, pooled = self.compel.build_conditioning_tensor(prompt)
|
||||||
kwargs["prompt_embeds"]= conditioning
|
kwargs["prompt_embeds"] = conditioning
|
||||||
|
kwargs["pooled_prompt_embeds"] = pooled
|
||||||
# pass the kwargs dictionary to the self.pipe method
|
# pass the kwargs dictionary to the self.pipe method
|
||||||
image = self.pipe(
|
image = self.pipe(
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,24 +1,50 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
|
|
||||||
# Check if environment exist
|
# Check if environment exist
|
||||||
conda_env_exists(){
|
conda_env_exists(){
|
||||||
! conda list --name "${@}" >/dev/null 2>/dev/null
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
if conda_env_exists "diffusers" ; then
|
if [ $SKIP_CONDA -eq 1 ]; then
|
||||||
echo "Creating virtual environment..."
|
echo "Skipping conda environment installation"
|
||||||
conda env create --name diffusers --file $1
|
else
|
||||||
echo "Virtual environment created."
|
export PATH=$PATH:/opt/conda/bin
|
||||||
else
|
if conda_env_exists "diffusers" ; then
|
||||||
echo "Virtual environment already exists."
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name diffusers --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Intel GPU: If the directory exists, we assume we are using the Intel image
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
pip install torch==2.1.0a0 \
|
||||||
|
torchvision==0.16.0a0 \
|
||||||
|
torchaudio==2.1.0a0 \
|
||||||
|
intel-extension-for-pytorch==2.1.10+xpu \
|
||||||
|
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
|
||||||
|
pip install google-api-python-client \
|
||||||
|
grpcio \
|
||||||
|
grpcio-tools \
|
||||||
|
diffusers==0.24.0 \
|
||||||
|
transformers>=4.25.1 \
|
||||||
|
accelerate \
|
||||||
|
compel==2.0.2 \
|
||||||
|
Pillow
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
if [ "$PIP_CACHE_PURGE" = true ] ; then
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ $SKIP_CONDA -ne 1 ]; then
|
||||||
|
# Activate conda environment
|
||||||
# Activate conda environment
|
source activate diffusers
|
||||||
source activate diffusers
|
fi
|
||||||
|
|
||||||
pip cache purge
|
pip cache purge
|
||||||
fi
|
fi
|
||||||
@@ -3,10 +3,15 @@
|
|||||||
##
|
##
|
||||||
## A bash script wrapper that runs the diffusers server with conda
|
## A bash script wrapper that runs the diffusers server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
if [ -d "/opt/intel" ]; then
|
||||||
|
# Assumes we are using the Intel oneAPI container image
|
||||||
# Activate conda environment
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
source activate diffusers
|
export XPU=1
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
# Activate conda environment
|
||||||
|
source activate diffusers
|
||||||
|
fi
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
export CONDA_ENV_PATH = "exllama.yml"
|
||||||
|
|
||||||
.PHONY: exllama
|
.PHONY: exllama
|
||||||
exllama:
|
exllama:
|
||||||
$(MAKE) -C ../common-env/transformers
|
bash install.sh ${CONDA_ENV_PATH}
|
||||||
bash install.sh
|
|
||||||
|
|
||||||
.PHONY: run
|
.PHONY: run
|
||||||
run:
|
run:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,14 +1,27 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
source activate transformers
|
echo "[exllama] Attention!!! Nvidia GPU is required - skipping installation"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
# Check if environment exist
|
||||||
|
conda_env_exists(){
|
||||||
|
! conda list --name "${@}" >/dev/null 2>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
if conda_env_exists "exllama" ; then
|
||||||
|
echo "Creating virtual environment..."
|
||||||
|
conda env create --name exllama --file $1
|
||||||
|
echo "Virtual environment created."
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists."
|
||||||
|
fi
|
||||||
|
|
||||||
|
source activate exllama
|
||||||
|
|
||||||
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd
|
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
##
|
##
|
||||||
## A bash script wrapper that runs the exllama server with conda
|
## A bash script wrapper that runs the exllama server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate transformers
|
source activate exllama
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,10 +2,14 @@
|
|||||||
set -e
|
set -e
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
export SHA=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
|
export SHA=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
|
||||||
|
|
||||||
# Activate conda environment
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
|
echo "[exllamav2] Attention!!! Nvidia GPU is required - skipping installation"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
source activate transformers
|
source activate transformers
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
echo $CONDA_PREFIX
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,13 +2,14 @@
|
|||||||
set -e
|
set -e
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
if [ "$BUILD_TYPE" != "cublas" ]; then
|
if [ "$BUILD_TYPE" != "cublas" ]; then
|
||||||
echo "[mamba] Attention!!! nvcc is required - skipping installation"
|
echo "[mamba] Attention!!! nvcc is required - skipping installation"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
# Activate conda environment
|
# Activate conda environment
|
||||||
source activate transformers
|
source activate transformers
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
.PHONY: petals
|
.PHONY: petals
|
||||||
petals:
|
petals:
|
||||||
@echo "Creating virtual environment..."
|
@echo "Creating virtual environment..."
|
||||||
@conda env create --name petals --file petals.yml
|
bash install.sh "petals.yml"
|
||||||
@echo "Virtual environment created."
|
@echo "Virtual environment created."
|
||||||
|
|
||||||
.PHONY: run
|
.PHONY: run
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
5
backend/python/petals/install.sh
Normal file
5
backend/python/petals/install.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
|
||||||
|
conda env create --name petals --file $1
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -3,10 +3,16 @@
|
|||||||
##
|
##
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
## A bash script wrapper that runs the transformers server with conda
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
if [ -d "/opt/intel" ]; then
|
||||||
source activate transformers
|
# Assumes we are using the Intel oneAPI container image
|
||||||
|
# https://github.com/intel/intel-extension-for-pytorch/issues/538
|
||||||
|
export XPU=1
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
# Activate conda environment
|
||||||
|
source activate transformers
|
||||||
|
fi
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
# get the directory where the bash script is located
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
|||||||
@@ -16,7 +16,15 @@ import backend_pb2_grpc
|
|||||||
import grpc
|
import grpc
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
|
|
||||||
|
XPU=os.environ.get("XPU", "0") == "1"
|
||||||
|
if XPU:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer, AutoModel, set_seed
|
||||||
|
else:
|
||||||
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@@ -69,12 +77,25 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
model_name = request.Model
|
model_name = request.Model
|
||||||
try:
|
try:
|
||||||
if request.Type == "AutoModelForCausalLM":
|
if request.Type == "AutoModelForCausalLM":
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
if XPU:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode,
|
||||||
|
device_map="xpu", load_in_4bit=True)
|
||||||
|
else:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
else:
|
else:
|
||||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
self.CUDA = False
|
self.CUDA = False
|
||||||
|
self.XPU = False
|
||||||
|
|
||||||
|
if XPU:
|
||||||
|
self.XPU = True
|
||||||
|
try:
|
||||||
|
print("Optimizing model", model_name, "to XPU.", file=sys.stderr)
|
||||||
|
self.model = ipex.optimize_transformers(self.model, inplace=True, dtype=torch.float16, device="xpu")
|
||||||
|
except Exception as err:
|
||||||
|
print("Not using XPU:", err, file=sys.stderr)
|
||||||
|
|
||||||
if request.CUDA or torch.cuda.is_available():
|
if request.CUDA or torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
@@ -139,6 +160,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids
|
inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids
|
||||||
if self.CUDA:
|
if self.CUDA:
|
||||||
inputs = inputs.to("cuda")
|
inputs = inputs.to("cuda")
|
||||||
|
if XPU:
|
||||||
|
inputs = inputs.to("xpu")
|
||||||
|
|
||||||
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
|
||||||
|
export SKIP_CONDA=1
|
||||||
|
endif
|
||||||
|
|
||||||
.PHONY: ttsvalle
|
.PHONY: ttsvalle
|
||||||
ttsvalle:
|
ttsvalle:
|
||||||
$(MAKE) -C ../common-env/transformers
|
$(MAKE) -C ../common-env/transformers
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2,13 +2,16 @@
|
|||||||
|
|
||||||
##
|
##
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
|
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
|
||||||
|
|
||||||
# Activate conda environment
|
SKIP_CONDA=${SKIP_CONDA:-0}
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
if [ $SKIP_CONDA -ne 1 ]; then
|
||||||
|
source activate transformers
|
||||||
|
else
|
||||||
|
export PATH=$PATH:/opt/conda/bin
|
||||||
|
CONDA_PREFIX=$PWD
|
||||||
|
fi
|
||||||
|
|
||||||
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && popd
|
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && popd
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ dependencies:
|
|||||||
- pypinyin==0.49.0
|
- pypinyin==0.49.0
|
||||||
- python-multipart==0.0.6
|
- python-multipart==0.0.6
|
||||||
- regex==2023.10.3
|
- regex==2023.10.3
|
||||||
- safetensors==0.4.0
|
- safetensors>=0.4.0
|
||||||
- semantic-version==2.10.0
|
- semantic-version==2.10.0
|
||||||
- soundfile==0.12.1
|
- soundfile==0.12.1
|
||||||
- starlette==0.27.0
|
- starlette==0.27.0
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import asyncio
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -10,7 +10,10 @@ import backend_pb2
|
|||||||
import backend_pb2_grpc
|
import backend_pb2_grpc
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
from vllm import LLM, SamplingParams
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
@@ -79,16 +82,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The load model result.
|
backend_pb2.Result: The load model result.
|
||||||
"""
|
"""
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=request.Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.Quantization != "":
|
||||||
|
engine_args.quantization = request.Quantization
|
||||||
|
if request.GPUMemoryUtilization != 0:
|
||||||
|
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
|
||||||
|
if request.TrustRemoteCode:
|
||||||
|
engine_args.trust_remote_code = request.TrustRemoteCode
|
||||||
|
if request.EnforceEager:
|
||||||
|
engine_args.enforce_eager = request.EnforceEager
|
||||||
|
if request.SwapSpace != 0:
|
||||||
|
engine_args.swap_space = request.SwapSpace
|
||||||
|
if request.MaxModelLen != 0:
|
||||||
|
engine_args.max_model_len = request.MaxModelLen
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if request.Quantization != "":
|
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
self.llm = LLM(model=request.Model, quantization=request.Quantization)
|
|
||||||
else:
|
|
||||||
self.llm = LLM(model=request.Model)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
def Predict(self, request, context):
|
async def Predict(self, request, context):
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters.
|
Generates text based on the given prompt and sampling parameters.
|
||||||
|
|
||||||
@@ -99,24 +116,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Reply: The predict result.
|
backend_pb2.Reply: The predict result.
|
||||||
"""
|
"""
|
||||||
if request.TopP == 0:
|
gen = self._predict(request, context, streaming=False)
|
||||||
request.TopP = 0.9
|
res = await gen.__anext__()
|
||||||
|
return res
|
||||||
|
|
||||||
max_tokens = 200
|
async def PredictStream(self, request, context):
|
||||||
if request.Tokens > 0:
|
|
||||||
max_tokens = request.Tokens
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
|
||||||
outputs = self.llm.generate([request.Prompt], sampling_params)
|
|
||||||
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in generated_text:
|
|
||||||
generated_text = generated_text.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""
|
"""
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||||
|
|
||||||
@@ -127,30 +131,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
Returns:
|
Returns:
|
||||||
backend_pb2.Result: The predict stream result.
|
backend_pb2.Result: The predict stream result.
|
||||||
"""
|
"""
|
||||||
yield self.Predict(request, context)
|
iterations = self._predict(request, context, streaming=True)
|
||||||
|
try:
|
||||||
|
async for iteration in iterations:
|
||||||
|
yield iteration
|
||||||
|
finally:
|
||||||
|
await iterations.aclose()
|
||||||
|
|
||||||
def serve(address):
|
async def _predict(self, request, context, streaming=False):
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
|
# Build sampling parameters
|
||||||
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||||
|
if request.TopP != 0:
|
||||||
|
sampling_params.top_p = request.TopP
|
||||||
|
if request.Tokens > 0:
|
||||||
|
sampling_params.max_tokens = request.Tokens
|
||||||
|
if request.Temperature != 0:
|
||||||
|
sampling_params.temperature = request.Temperature
|
||||||
|
if request.TopK != 0:
|
||||||
|
sampling_params.top_k = request.TopK
|
||||||
|
if request.PresencePenalty != 0:
|
||||||
|
sampling_params.presence_penalty = request.PresencePenalty
|
||||||
|
if request.FrequencyPenalty != 0:
|
||||||
|
sampling_params.frequency_penalty = request.FrequencyPenalty
|
||||||
|
if request.StopPrompts:
|
||||||
|
sampling_params.stop = request.StopPrompts
|
||||||
|
if request.IgnoreEOS:
|
||||||
|
sampling_params.ignore_eos = request.IgnoreEOS
|
||||||
|
if request.Seed != 0:
|
||||||
|
sampling_params.seed = request.Seed
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
request_id = random_uuid()
|
||||||
|
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
|
||||||
|
|
||||||
|
# Stream the results
|
||||||
|
generated_text = ""
|
||||||
|
try:
|
||||||
|
async for request_output in outputs:
|
||||||
|
iteration_text = request_output.outputs[0].text
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# Remove text already sent as vllm concatenates the text from previous yields
|
||||||
|
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||||
|
# Send the partial result
|
||||||
|
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
# Keep track of text generated
|
||||||
|
generated_text = iteration_text
|
||||||
|
finally:
|
||||||
|
await outputs.aclose()
|
||||||
|
|
||||||
|
# If streaming, we already sent everything
|
||||||
|
if streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sending the final generated text
|
||||||
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||||
|
|
||||||
|
async def serve(address):
|
||||||
|
# Start asyncio gRPC server
|
||||||
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||||
|
# Add the servicer to the server
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
# Bind the server to the address
|
||||||
server.add_insecure_port(address)
|
server.add_insecure_port(address)
|
||||||
server.start()
|
|
||||||
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(
|
||||||
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
await server.start()
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
# Wait for the server to be terminated
|
||||||
# Define the signal handler function
|
await server.wait_for_termination()
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
@@ -159,4 +217,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
serve(args.addr)
|
asyncio.run(serve(args.addr))
|
||||||
0
configuration/.keep
Normal file
0
configuration/.keep
Normal file
@@ -3,36 +3,36 @@ package backend
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||||
if !c.Embeddings {
|
if !backendConfig.Embeddings {
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile := c.Model
|
modelFile := backendConfig.Model
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
|
||||||
var inferenceModel interface{}
|
var inferenceModel interface{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
model.WithThreads(uint32(c.Threads)),
|
model.WithThreads(uint32(*backendConfig.Threads)),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(appConfig.Context),
|
||||||
})
|
})
|
||||||
|
|
||||||
if c.Backend == "" {
|
if backendConfig.Backend == "" {
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
} else {
|
} else {
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
opts = append(opts, model.WithBackendString(backendConfig.Backend))
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
switch model := inferenceModel.(type) {
|
switch model := inferenceModel.(type) {
|
||||||
case grpc.Backend:
|
case grpc.Backend:
|
||||||
fn = func() ([]float32, error) {
|
fn = func() ([]float32, error) {
|
||||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
embeds := []int32{}
|
embeds := []int32{}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
}
|
}
|
||||||
predictOptions.EmbeddingTokens = embeds
|
predictOptions.EmbeddingTokens = embeds
|
||||||
|
|
||||||
res, err := model.Embeddings(o.Context, predictOptions)
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
}
|
}
|
||||||
predictOptions.Embeddings = s
|
predictOptions.Embeddings = s
|
||||||
|
|
||||||
res, err := model.Embeddings(o.Context, predictOptions)
|
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,34 +1,25 @@
|
|||||||
package backend
|
package backend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||||
|
threads := backendConfig.Threads
|
||||||
opts := modelOpts(c, o, []model.Option{
|
if *threads == 0 && appConfig.Threads != 0 {
|
||||||
model.WithBackendString(c.Backend),
|
threads = &appConfig.Threads
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
}
|
||||||
model.WithThreads(uint32(c.Threads)),
|
gRPCOpts := gRPCModelOpts(backendConfig)
|
||||||
model.WithContext(o.Context),
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
model.WithModel(c.Model),
|
model.WithBackendString(backendConfig.Backend),
|
||||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
model.WithThreads(uint32(*threads)),
|
||||||
SchedulerType: c.Diffusers.SchedulerType,
|
model.WithContext(appConfig.Context),
|
||||||
PipelineType: c.Diffusers.PipelineType,
|
model.WithModel(backendConfig.Model),
|
||||||
CFGScale: c.Diffusers.CFGScale,
|
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
|
||||||
LoraAdapter: c.LoraAdapter,
|
|
||||||
LoraScale: c.LoraScale,
|
|
||||||
LoraBase: c.LoraBase,
|
|
||||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
|
||||||
CLIPModel: c.Diffusers.ClipModel,
|
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
|
||||||
ControlNet: c.Diffusers.ControlNet,
|
|
||||||
}),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader(
|
inferenceModel, err := loader.BackendLoader(
|
||||||
@@ -40,19 +31,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||||||
|
|
||||||
fn := func() error {
|
fn := func() error {
|
||||||
_, err := inferenceModel.GenerateImage(
|
_, err := inferenceModel.GenerateImage(
|
||||||
o.Context,
|
appConfig.Context,
|
||||||
&proto.GenerateImageRequest{
|
&proto.GenerateImageRequest{
|
||||||
Height: int32(height),
|
Height: int32(height),
|
||||||
Width: int32(width),
|
Width: int32(width),
|
||||||
Mode: int32(mode),
|
Mode: int32(mode),
|
||||||
Step: int32(step),
|
Step: int32(step),
|
||||||
Seed: int32(seed),
|
Seed: int32(seed),
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
|
||||||
PositivePrompt: positive_prompt,
|
PositivePrompt: positive_prompt,
|
||||||
NegativePrompt: negative_prompt,
|
NegativePrompt: negative_prompt,
|
||||||
Dst: dst,
|
Dst: dst,
|
||||||
Src: src,
|
Src: src,
|
||||||
EnableParameters: c.Diffusers.EnableParameters,
|
EnableParameters: backendConfig.Diffusers.EnableParameters,
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -26,9 +26,12 @@ type TokenUsage struct {
|
|||||||
Completion int
|
Completion int
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||||
modelFile := c.Model
|
modelFile := c.Model
|
||||||
|
threads := c.Threads
|
||||||
|
if *threads == 0 && o.Threads != 0 {
|
||||||
|
threads = &o.Threads
|
||||||
|
}
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(c)
|
||||||
|
|
||||||
var inferenceModel grpc.Backend
|
var inferenceModel grpc.Backend
|
||||||
@@ -36,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
|
|||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
opts := modelOpts(c, o, []model.Option{
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
@@ -140,7 +143,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
|
|||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||||
var mu sync.Mutex = sync.Mutex{}
|
var mu sync.Mutex = sync.Mutex{}
|
||||||
|
|
||||||
func Finetune(config config.Config, input, prediction string) string {
|
func Finetune(config config.BackendConfig, input, prediction string) string {
|
||||||
if config.Echo {
|
if config.Echo {
|
||||||
prediction = input + prediction
|
prediction = input + prediction
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,19 +4,17 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
|
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
|
||||||
if o.SingleBackend {
|
if so.SingleBackend {
|
||||||
opts = append(opts, model.WithSingleActiveBackend())
|
opts = append(opts, model.WithSingleActiveBackend())
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.ParallelBackendRequests {
|
if so.ParallelBackendRequests {
|
||||||
opts = append(opts, model.EnableParallelRequests)
|
opts = append(opts, model.EnableParallelRequests)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,52 +26,65 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
|
|||||||
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range o.ExternalGRPCBackends {
|
for k, v := range so.ExternalGRPCBackends {
|
||||||
opts = append(opts, model.WithExternalBackend(k, v))
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||||
b := 512
|
b := 512
|
||||||
if c.Batch != 0 {
|
if c.Batch != 0 {
|
||||||
b = c.Batch
|
b = c.Batch
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pb.ModelOptions{
|
return &pb.ModelOptions{
|
||||||
ContextSize: int32(c.ContextSize),
|
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||||
Seed: int32(c.Seed),
|
SchedulerType: c.Diffusers.SchedulerType,
|
||||||
NBatch: int32(b),
|
PipelineType: c.Diffusers.PipelineType,
|
||||||
NoMulMatQ: c.NoMulMatQ,
|
CFGScale: c.Diffusers.CFGScale,
|
||||||
CUDA: c.CUDA, // diffusers, transformers
|
LoraAdapter: c.LoraAdapter,
|
||||||
DraftModel: c.DraftModel,
|
LoraScale: c.LoraScale,
|
||||||
AudioPath: c.VallE.AudioPath,
|
F16Memory: *c.F16,
|
||||||
Quantization: c.Quantization,
|
LoraBase: c.LoraBase,
|
||||||
MMProj: c.MMProj,
|
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||||
YarnExtFactor: c.YarnExtFactor,
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
YarnAttnFactor: c.YarnAttnFactor,
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
YarnBetaFast: c.YarnBetaFast,
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
YarnBetaSlow: c.YarnBetaSlow,
|
ControlNet: c.Diffusers.ControlNet,
|
||||||
LoraAdapter: c.LoraAdapter,
|
ContextSize: int32(*c.ContextSize),
|
||||||
LoraBase: c.LoraBase,
|
Seed: int32(*c.Seed),
|
||||||
LoraScale: c.LoraScale,
|
NBatch: int32(b),
|
||||||
NGQA: c.NGQA,
|
NoMulMatQ: c.NoMulMatQ,
|
||||||
RMSNormEps: c.RMSNormEps,
|
DraftModel: c.DraftModel,
|
||||||
F16Memory: c.F16,
|
AudioPath: c.VallE.AudioPath,
|
||||||
MLock: c.MMlock,
|
Quantization: c.Quantization,
|
||||||
RopeFreqBase: c.RopeFreqBase,
|
GPUMemoryUtilization: c.GPUMemoryUtilization,
|
||||||
RopeScaling: c.RopeScaling,
|
TrustRemoteCode: c.TrustRemoteCode,
|
||||||
Type: c.ModelType,
|
EnforceEager: c.EnforceEager,
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
SwapSpace: int32(c.SwapSpace),
|
||||||
NUMA: c.NUMA,
|
MaxModelLen: int32(c.MaxModelLen),
|
||||||
Embeddings: c.Embeddings,
|
MMProj: c.MMProj,
|
||||||
LowVRAM: c.LowVRAM,
|
YarnExtFactor: c.YarnExtFactor,
|
||||||
NGPULayers: int32(c.NGPULayers),
|
YarnAttnFactor: c.YarnAttnFactor,
|
||||||
MMap: c.MMap,
|
YarnBetaFast: c.YarnBetaFast,
|
||||||
MainGPU: c.MainGPU,
|
YarnBetaSlow: c.YarnBetaSlow,
|
||||||
Threads: int32(c.Threads),
|
NGQA: c.NGQA,
|
||||||
TensorSplit: c.TensorSplit,
|
RMSNormEps: c.RMSNormEps,
|
||||||
|
MLock: *c.MMlock,
|
||||||
|
RopeFreqBase: c.RopeFreqBase,
|
||||||
|
RopeScaling: c.RopeScaling,
|
||||||
|
Type: c.ModelType,
|
||||||
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
|
NUMA: c.NUMA,
|
||||||
|
Embeddings: c.Embeddings,
|
||||||
|
LowVRAM: *c.LowVRAM,
|
||||||
|
NGPULayers: int32(*c.NGPULayers),
|
||||||
|
MMap: *c.MMap,
|
||||||
|
MainGPU: c.MainGPU,
|
||||||
|
Threads: int32(*c.Threads),
|
||||||
|
TensorSplit: c.TensorSplit,
|
||||||
// AutoGPTQ
|
// AutoGPTQ
|
||||||
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
||||||
Device: c.AutoGPTQ.Device,
|
Device: c.AutoGPTQ.Device,
|
||||||
@@ -84,43 +95,44 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
|
||||||
promptCachePath := ""
|
promptCachePath := ""
|
||||||
if c.PromptCachePath != "" {
|
if c.PromptCachePath != "" {
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||||
os.MkdirAll(filepath.Dir(p), 0755)
|
os.MkdirAll(filepath.Dir(p), 0755)
|
||||||
promptCachePath = p
|
promptCachePath = p
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pb.PredictOptions{
|
return &pb.PredictOptions{
|
||||||
Temperature: float32(c.Temperature),
|
Temperature: float32(*c.Temperature),
|
||||||
TopP: float32(c.TopP),
|
TopP: float32(*c.TopP),
|
||||||
NDraft: c.NDraft,
|
NDraft: c.NDraft,
|
||||||
TopK: int32(c.TopK),
|
TopK: int32(*c.TopK),
|
||||||
Tokens: int32(c.Maxtokens),
|
Tokens: int32(*c.Maxtokens),
|
||||||
Threads: int32(c.Threads),
|
Threads: int32(*c.Threads),
|
||||||
PromptCacheAll: c.PromptCacheAll,
|
PromptCacheAll: c.PromptCacheAll,
|
||||||
PromptCacheRO: c.PromptCacheRO,
|
PromptCacheRO: c.PromptCacheRO,
|
||||||
PromptCachePath: promptCachePath,
|
PromptCachePath: promptCachePath,
|
||||||
F16KV: c.F16,
|
F16KV: *c.F16,
|
||||||
DebugMode: c.Debug,
|
DebugMode: *c.Debug,
|
||||||
Grammar: c.Grammar,
|
Grammar: c.Grammar,
|
||||||
NegativePromptScale: c.NegativePromptScale,
|
NegativePromptScale: c.NegativePromptScale,
|
||||||
RopeFreqBase: c.RopeFreqBase,
|
RopeFreqBase: c.RopeFreqBase,
|
||||||
RopeFreqScale: c.RopeFreqScale,
|
RopeFreqScale: c.RopeFreqScale,
|
||||||
NegativePrompt: c.NegativePrompt,
|
NegativePrompt: c.NegativePrompt,
|
||||||
Mirostat: int32(c.LLMConfig.Mirostat),
|
Mirostat: int32(*c.LLMConfig.Mirostat),
|
||||||
MirostatETA: float32(c.LLMConfig.MirostatETA),
|
MirostatETA: float32(*c.LLMConfig.MirostatETA),
|
||||||
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
|
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
|
||||||
Debug: c.Debug,
|
Debug: *c.Debug,
|
||||||
StopPrompts: c.StopWords,
|
StopPrompts: c.StopWords,
|
||||||
Repeat: int32(c.RepeatPenalty),
|
Repeat: int32(c.RepeatPenalty),
|
||||||
NKeep: int32(c.Keep),
|
NKeep: int32(c.Keep),
|
||||||
Batch: int32(c.Batch),
|
Batch: int32(c.Batch),
|
||||||
IgnoreEOS: c.IgnoreEOS,
|
IgnoreEOS: c.IgnoreEOS,
|
||||||
Seed: int32(c.Seed),
|
Seed: int32(*c.Seed),
|
||||||
FrequencyPenalty: float32(c.FrequencyPenalty),
|
FrequencyPenalty: float32(c.FrequencyPenalty),
|
||||||
MLock: c.MMlock,
|
MLock: *c.MMlock,
|
||||||
MMap: c.MMap,
|
MMap: *c.MMap,
|
||||||
MainGPU: c.MainGPU,
|
MainGPU: c.MainGPU,
|
||||||
TensorSplit: c.TensorSplit,
|
TensorSplit: c.TensorSplit,
|
||||||
TailFreeSamplingZ: float32(c.TFZ),
|
TailFreeSamplingZ: float32(c.TFZ),
|
||||||
|
|||||||
@@ -4,25 +4,24 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
|
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
|
||||||
|
|
||||||
opts := modelOpts(c, o, []model.Option{
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
model.WithBackendString(model.WhisperBackend),
|
model.WithBackendString(model.WhisperBackend),
|
||||||
model.WithModel(c.Model),
|
model.WithModel(backendConfig.Model),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(appConfig.Context),
|
||||||
model.WithThreads(uint32(c.Threads)),
|
model.WithThreads(uint32(*backendConfig.Threads)),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
})
|
})
|
||||||
|
|
||||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
whisperModel, err := ml.BackendLoader(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -34,6 +33,6 @@ func ModelTranscription(audio, language string, loader *model.ModelLoader, c con
|
|||||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
Dst: audio,
|
Dst: audio,
|
||||||
Language: language,
|
Language: language,
|
||||||
Threads: uint32(c.Threads),
|
Threads: uint32(*backendConfig.Threads),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
@@ -29,53 +29,59 @@ func generateUniqueFileName(dir, baseName, ext string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) {
|
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
|
||||||
bb := backend
|
bb := backend
|
||||||
if bb == "" {
|
if bb == "" {
|
||||||
bb = model.PiperBackend
|
bb = model.PiperBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(backendConfig)
|
||||||
|
|
||||||
opts := modelOpts(config.Config{}, o, []model.Option{
|
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||||
model.WithBackendString(bb),
|
model.WithBackendString(bb),
|
||||||
model.WithModel(modelFile),
|
model.WithModel(modelFile),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(appConfig.Context),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(appConfig.AssetsDestination),
|
||||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||||
})
|
})
|
||||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
ttsModel, err := loader.BackendLoader(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if piperModel == nil {
|
if ttsModel == nil {
|
||||||
return "", nil, fmt.Errorf("could not load piper model")
|
return "", nil, fmt.Errorf("could not load piper model")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
|
||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
||||||
filePath := filepath.Join(o.AudioDir, fileName)
|
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||||
|
|
||||||
// If the model file is not empty, we pass it joined with the model path
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
modelPath := ""
|
modelPath := ""
|
||||||
if modelFile != "" {
|
if modelFile != "" {
|
||||||
if bb != model.TransformersMusicGen {
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
// Checking first that it exists and is not outside ModelPath
|
||||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
// TODO: we should actually first check if the modelFile is looking like
|
||||||
|
// a FS path
|
||||||
|
mp := filepath.Join(loader.ModelPath, modelFile)
|
||||||
|
if _, err := os.Stat(mp); err == nil {
|
||||||
|
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
modelPath = mp
|
||||||
} else {
|
} else {
|
||||||
modelPath = modelFile
|
modelPath = modelFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||||
Text: text,
|
Text: text,
|
||||||
Model: modelPath,
|
Model: modelPath,
|
||||||
|
Voice: voice,
|
||||||
Dst: filePath,
|
Dst: filePath,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package options
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,16 +6,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Option struct {
|
type ApplicationConfig struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
ConfigFile string
|
ConfigFile string
|
||||||
Loader *model.ModelLoader
|
ModelPath string
|
||||||
UploadLimitMB, Threads, ContextSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
F16 bool
|
F16 bool
|
||||||
Debug, DisableMessage bool
|
Debug, DisableMessage bool
|
||||||
@@ -27,7 +25,6 @@ type Option struct {
|
|||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
CORSAllowOrigins string
|
CORSAllowOrigins string
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
Metrics *metrics.Metrics
|
|
||||||
|
|
||||||
ModelLibraryURL string
|
ModelLibraryURL string
|
||||||
|
|
||||||
@@ -52,10 +49,10 @@ type Option struct {
|
|||||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*ApplicationConfig)
|
||||||
|
|
||||||
func NewOptions(o ...AppOption) *Option {
|
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||||
opt := &Option{
|
opt := &ApplicationConfig{
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
UploadLimitMB: 15,
|
UploadLimitMB: 15,
|
||||||
Threads: 1,
|
Threads: 1,
|
||||||
@@ -70,63 +67,69 @@ func NewOptions(o ...AppOption) *Option {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithModelsURL(urls ...string) AppOption {
|
func WithModelsURL(urls ...string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ModelsURL = urls
|
o.ModelsURL = urls
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithModelPath(path string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.ModelPath = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithCors(b bool) AppOption {
|
func WithCors(b bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.CORS = b
|
o.CORS = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLibraryURL(url string) AppOption {
|
func WithModelLibraryURL(url string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ModelLibraryURL = url
|
o.ModelLibraryURL = url
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDog = func(o *Option) {
|
var EnableWatchDog = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogIdleCheck = func(o *Option) {
|
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogIdle = true
|
o.WatchDogIdle = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDogBusyCheck = func(o *Option) {
|
var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
|
||||||
o.WatchDog = true
|
o.WatchDog = true
|
||||||
o.WatchDogBusy = true
|
o.WatchDogBusy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.WatchDogBusyTimeout = t
|
o.WatchDogBusyTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.WatchDogIdleTimeout = t
|
o.WatchDogIdleTimeout = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableSingleBackend = func(o *Option) {
|
var EnableSingleBackend = func(o *ApplicationConfig) {
|
||||||
o.SingleBackend = true
|
o.SingleBackend = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableParallelBackendRequests = func(o *Option) {
|
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
|
||||||
o.ParallelBackendRequests = true
|
o.ParallelBackendRequests = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableGalleriesAutoload = func(o *Option) {
|
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
|
||||||
o.AutoloadGalleries = true
|
o.AutoloadGalleries = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithExternalBackend(name string, uri string) AppOption {
|
func WithExternalBackend(name string, uri string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
if o.ExternalGRPCBackends == nil {
|
if o.ExternalGRPCBackends == nil {
|
||||||
o.ExternalGRPCBackends = make(map[string]string)
|
o.ExternalGRPCBackends = make(map[string]string)
|
||||||
}
|
}
|
||||||
@@ -135,27 +138,26 @@ func WithExternalBackend(name string, uri string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithCorsAllowOrigins(b string) AppOption {
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.CORSAllowOrigins = b
|
o.CORSAllowOrigins = b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssetsOutput(out string) AppOption {
|
func WithBackendAssetsOutput(out string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.AssetsDestination = out
|
o.AssetsDestination = out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBackendAssets(f embed.FS) AppOption {
|
func WithBackendAssets(f embed.FS) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.BackendAssets = f
|
o.BackendAssets = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithStringGalleries(galls string) AppOption {
|
func WithStringGalleries(galls string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
if galls == "" {
|
if galls == "" {
|
||||||
log.Debug().Msgf("no galleries to load")
|
|
||||||
o.Galleries = []gallery.Gallery{}
|
o.Galleries = []gallery.Gallery{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -168,102 +170,96 @@ func WithStringGalleries(galls string) AppOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Galleries = append(o.Galleries, galleries...)
|
o.Galleries = append(o.Galleries, galleries...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContext(ctx context.Context) AppOption {
|
func WithContext(ctx context.Context) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Context = ctx
|
o.Context = ctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.PreloadModelsFromPath = configFile
|
o.PreloadModelsFromPath = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithJSONStringPreload(configFile string) AppOption {
|
func WithJSONStringPreload(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.PreloadJSONModels = configFile
|
o.PreloadJSONModels = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func WithConfigFile(configFile string) AppOption {
|
func WithConfigFile(configFile string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ConfigFile = configFile
|
o.ConfigFile = configFile
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
|
||||||
return func(o *Option) {
|
|
||||||
o.Loader = loader
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithUploadLimitMB(limit int) AppOption {
|
func WithUploadLimitMB(limit int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.UploadLimitMB = limit
|
o.UploadLimitMB = limit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithThreads(threads int) AppOption {
|
func WithThreads(threads int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Threads = threads
|
o.Threads = threads
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithContextSize(ctxSize int) AppOption {
|
func WithContextSize(ctxSize int) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ContextSize = ctxSize
|
o.ContextSize = ctxSize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithF16(f16 bool) AppOption {
|
func WithF16(f16 bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.F16 = f16
|
o.F16 = f16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDebug(debug bool) AppOption {
|
func WithDebug(debug bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.Debug = debug
|
o.Debug = debug
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithDisableMessage(disableMessage bool) AppOption {
|
func WithDisableMessage(disableMessage bool) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.DisableMessage = disableMessage
|
o.DisableMessage = disableMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAudioDir(audioDir string) AppOption {
|
func WithAudioDir(audioDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.AudioDir = audioDir
|
o.AudioDir = audioDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithImageDir(imageDir string) AppOption {
|
func WithImageDir(imageDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ImageDir = imageDir
|
o.ImageDir = imageDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithUploadDir(uploadDir string) AppOption {
|
func WithUploadDir(uploadDir string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.UploadDir = uploadDir
|
o.UploadDir = uploadDir
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithApiKeys(apiKeys []string) AppOption {
|
func WithApiKeys(apiKeys []string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ApiKeys = apiKeys
|
o.ApiKeys = apiKeys
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithMetrics(meter *metrics.Metrics) AppOption {
|
// func WithMetrics(meter *metrics.Metrics) AppOption {
|
||||||
return func(o *Option) {
|
// return func(o *StartupOptions) {
|
||||||
o.Metrics = meter
|
// o.Metrics = meter
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
@@ -4,24 +4,28 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/glamour"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type BackendConfig struct {
|
||||||
PredictionOptions `yaml:"parameters"`
|
schema.PredictionOptions `yaml:"parameters"`
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
|
|
||||||
F16 bool `yaml:"f16"`
|
F16 *bool `yaml:"f16"`
|
||||||
Threads int `yaml:"threads"`
|
Threads *int `yaml:"threads"`
|
||||||
Debug bool `yaml:"debug"`
|
Debug *bool `yaml:"debug"`
|
||||||
Roles map[string]string `yaml:"roles"`
|
Roles map[string]string `yaml:"roles"`
|
||||||
Embeddings bool `yaml:"embeddings"`
|
Embeddings bool `yaml:"embeddings"`
|
||||||
Backend string `yaml:"backend"`
|
Backend string `yaml:"backend"`
|
||||||
@@ -104,29 +108,34 @@ type LLMConfig struct {
|
|||||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
MirostatETA *float64 `yaml:"mirostat_eta"`
|
||||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
MirostatTAU *float64 `yaml:"mirostat_tau"`
|
||||||
Mirostat int `yaml:"mirostat"`
|
Mirostat *int `yaml:"mirostat"`
|
||||||
NGPULayers int `yaml:"gpu_layers"`
|
NGPULayers *int `yaml:"gpu_layers"`
|
||||||
MMap bool `yaml:"mmap"`
|
MMap *bool `yaml:"mmap"`
|
||||||
MMlock bool `yaml:"mmlock"`
|
MMlock *bool `yaml:"mmlock"`
|
||||||
LowVRAM bool `yaml:"low_vram"`
|
LowVRAM *bool `yaml:"low_vram"`
|
||||||
Grammar string `yaml:"grammar"`
|
Grammar string `yaml:"grammar"`
|
||||||
StopWords []string `yaml:"stopwords"`
|
StopWords []string `yaml:"stopwords"`
|
||||||
Cutstrings []string `yaml:"cutstrings"`
|
Cutstrings []string `yaml:"cutstrings"`
|
||||||
TrimSpace []string `yaml:"trimspace"`
|
TrimSpace []string `yaml:"trimspace"`
|
||||||
TrimSuffix []string `yaml:"trimsuffix"`
|
TrimSuffix []string `yaml:"trimsuffix"`
|
||||||
|
|
||||||
ContextSize int `yaml:"context_size"`
|
ContextSize *int `yaml:"context_size"`
|
||||||
NUMA bool `yaml:"numa"`
|
NUMA bool `yaml:"numa"`
|
||||||
LoraAdapter string `yaml:"lora_adapter"`
|
LoraAdapter string `yaml:"lora_adapter"`
|
||||||
LoraBase string `yaml:"lora_base"`
|
LoraBase string `yaml:"lora_base"`
|
||||||
LoraScale float32 `yaml:"lora_scale"`
|
LoraScale float32 `yaml:"lora_scale"`
|
||||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||||
DraftModel string `yaml:"draft_model"`
|
DraftModel string `yaml:"draft_model"`
|
||||||
NDraft int32 `yaml:"n_draft"`
|
NDraft int32 `yaml:"n_draft"`
|
||||||
Quantization string `yaml:"quantization"`
|
Quantization string `yaml:"quantization"`
|
||||||
MMProj string `yaml:"mmproj"`
|
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||||
|
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||||
|
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||||
|
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||||
|
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||||
|
MMProj string `yaml:"mmproj"`
|
||||||
|
|
||||||
RopeScaling string `yaml:"rope_scaling"`
|
RopeScaling string `yaml:"rope_scaling"`
|
||||||
ModelType string `yaml:"type"`
|
ModelType string `yaml:"type"`
|
||||||
@@ -159,108 +168,209 @@ type TemplateConfig struct {
|
|||||||
Functions string `yaml:"function"`
|
Functions string `yaml:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigLoader struct {
|
func (c *BackendConfig) SetFunctionCallString(s string) {
|
||||||
configs map[string]Config
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) SetFunctionCallString(s string) {
|
|
||||||
c.functionCallString = s
|
c.functionCallString = s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) SetFunctionCallNameString(s string) {
|
func (c *BackendConfig) SetFunctionCallNameString(s string) {
|
||||||
c.functionCallNameString = s
|
c.functionCallNameString = s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) ShouldUseFunctions() bool {
|
func (c *BackendConfig) ShouldUseFunctions() bool {
|
||||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) ShouldCallSpecificFunction() bool {
|
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
|
||||||
return len(c.functionCallNameString) > 0
|
return len(c.functionCallNameString) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) FunctionToCall() string {
|
func (c *BackendConfig) FunctionToCall() string {
|
||||||
return c.functionCallNameString
|
return c.functionCallNameString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load a config file for a model
|
func (cfg *BackendConfig) SetDefaults(debug bool, threads, ctx int, f16 bool) {
|
||||||
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) {
|
defaultTopP := 0.7
|
||||||
// Load a config file if present after the model name
|
defaultTopK := 80
|
||||||
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
defaultTemp := 0.9
|
||||||
|
defaultMaxTokens := 2048
|
||||||
|
defaultMirostat := 2
|
||||||
|
defaultMirostatTAU := 5.0
|
||||||
|
defaultMirostatETA := 0.1
|
||||||
|
|
||||||
var cfg *Config
|
// Try to offload all GPU layers (if GPU is found)
|
||||||
|
defaultNGPULayers := 99999999
|
||||||
|
|
||||||
defaults := func() {
|
trueV := true
|
||||||
cfg = DefaultConfig(modelName)
|
falseV := false
|
||||||
cfg.ContextSize = ctx
|
|
||||||
cfg.Threads = threads
|
if cfg.Seed == nil {
|
||||||
cfg.F16 = f16
|
// random number generator seed
|
||||||
cfg.Debug = debug
|
defaultSeed := int(rand.Int31())
|
||||||
|
cfg.Seed = &defaultSeed
|
||||||
}
|
}
|
||||||
|
|
||||||
cfgExisting, exists := cm.GetConfig(modelName)
|
if cfg.TopK == nil {
|
||||||
if !exists {
|
cfg.TopK = &defaultTopK
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MMap == nil {
|
||||||
|
// MMap is enabled by default
|
||||||
|
cfg.MMap = &trueV
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MMlock == nil {
|
||||||
|
// MMlock is disabled by default
|
||||||
|
cfg.MMlock = &falseV
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.TopP == nil {
|
||||||
|
cfg.TopP = &defaultTopP
|
||||||
|
}
|
||||||
|
if cfg.Temperature == nil {
|
||||||
|
cfg.Temperature = &defaultTemp
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Maxtokens == nil {
|
||||||
|
cfg.Maxtokens = &defaultMaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Mirostat == nil {
|
||||||
|
cfg.Mirostat = &defaultMirostat
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MirostatETA == nil {
|
||||||
|
cfg.MirostatETA = &defaultMirostatETA
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MirostatTAU == nil {
|
||||||
|
cfg.MirostatTAU = &defaultMirostatTAU
|
||||||
|
}
|
||||||
|
if cfg.NGPULayers == nil {
|
||||||
|
cfg.NGPULayers = &defaultNGPULayers
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.LowVRAM == nil {
|
||||||
|
cfg.LowVRAM = &falseV
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value passed by the top level are treated as default (no implicit defaults)
|
||||||
|
// defaults are set by the user
|
||||||
|
if ctx == 0 {
|
||||||
|
ctx = 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ContextSize == nil {
|
||||||
|
cfg.ContextSize = &ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
if threads == 0 {
|
||||||
|
// Threads can't be 0
|
||||||
|
threads = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Threads == nil {
|
||||||
|
cfg.Threads = &threads
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.F16 == nil {
|
||||||
|
cfg.F16 = &f16
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
cfg.Debug = &debug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////// Config Loader ////////
|
||||||
|
|
||||||
|
type BackendConfigLoader struct {
|
||||||
|
configs map[string]BackendConfig
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoadOptions struct {
|
||||||
|
debug bool
|
||||||
|
threads, ctxSize int
|
||||||
|
f16 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionDebug(debug bool) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.debug = debug
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionThreads(threads int) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.threads = threads
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.ctxSize = ctxSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOptionF16(f16 bool) ConfigLoaderOption {
|
||||||
|
return func(o *LoadOptions) {
|
||||||
|
o.f16 = f16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigLoaderOption func(*LoadOptions)
|
||||||
|
|
||||||
|
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
|
||||||
|
for _, l := range options {
|
||||||
|
l(lo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load a config file for a model
|
||||||
|
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||||
|
|
||||||
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
|
// Load a config file if present after the model name
|
||||||
|
cfg := &BackendConfig{
|
||||||
|
PredictionOptions: schema.PredictionOptions{
|
||||||
|
Model: modelName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgExisting, exists := cl.GetBackendConfig(modelName)
|
||||||
|
if exists {
|
||||||
|
cfg = &cfgExisting
|
||||||
|
} else {
|
||||||
|
// Try loading a model config file
|
||||||
|
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
||||||
if _, err := os.Stat(modelConfig); err == nil {
|
if _, err := os.Stat(modelConfig); err == nil {
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
if err := cl.LoadBackendConfig(modelConfig); err != nil {
|
||||||
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||||
}
|
}
|
||||||
cfgExisting, exists = cm.GetConfig(modelName)
|
cfgExisting, exists = cl.GetBackendConfig(modelName)
|
||||||
if exists {
|
if exists {
|
||||||
cfg = &cfgExisting
|
cfg = &cfgExisting
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
defaults()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cfg = &cfgExisting
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
//updateConfig(cfg, input)
|
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if cfg.Threads == 0 {
|
|
||||||
if threads != 0 {
|
|
||||||
cfg.Threads = threads
|
|
||||||
} else {
|
|
||||||
cfg.Threads = 4
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
cfg.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
if debug {
|
|
||||||
cfg.Debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultPredictOptions(modelFile string) PredictionOptions {
|
func NewBackendConfigLoader() *BackendConfigLoader {
|
||||||
return PredictionOptions{
|
return &BackendConfigLoader{
|
||||||
TopP: 0.7,
|
configs: make(map[string]BackendConfig),
|
||||||
TopK: 80,
|
|
||||||
Maxtokens: 512,
|
|
||||||
Temperature: 0.9,
|
|
||||||
Model: modelFile,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
|
||||||
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
func DefaultConfig(modelFile string) *Config {
|
c := &[]*BackendConfig{}
|
||||||
return &Config{
|
|
||||||
PredictionOptions: defaultPredictOptions(modelFile),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfigLoader() *ConfigLoader {
|
|
||||||
return &ConfigLoader{
|
|
||||||
configs: make(map[string]Config),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func ReadConfigFile(file string) ([]*Config, error) {
|
|
||||||
c := &[]*Config{}
|
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
@@ -269,11 +379,18 @@ func ReadConfigFile(file string) ([]*Config, error) {
|
|||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, cc := range *c {
|
||||||
|
cc.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
|
}
|
||||||
|
|
||||||
return *c, nil
|
return *c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) {
|
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||||
c := &Config{}
|
lo := &LoadOptions{}
|
||||||
|
lo.Apply(opts...)
|
||||||
|
|
||||||
|
c := &BackendConfig{}
|
||||||
f, err := os.ReadFile(file)
|
f, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||||
@@ -282,13 +399,14 @@ func ReadConfig(file string) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cm.Lock()
|
||||||
defer cm.Unlock()
|
defer cm.Unlock()
|
||||||
c, err := ReadConfigFile(file)
|
c, err := ReadBackendConfigFile(file, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot load config file: %w", err)
|
return fmt.Errorf("cannot load config file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -299,49 +417,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
c, err := ReadConfig(file)
|
c, err := ReadBackendConfig(file, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot read config file: %w", err)
|
return fmt.Errorf("cannot read config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.configs[c.Name] = *c
|
cl.configs[c.Name] = *c
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
v, exists := cm.configs[m]
|
v, exists := cl.configs[m]
|
||||||
return v, exists
|
return v, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) GetAllConfigs() []Config {
|
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
var res []Config
|
var res []BackendConfig
|
||||||
for _, v := range cm.configs {
|
for _, v := range cl.configs {
|
||||||
res = append(res, v)
|
res = append(res, v)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) ListConfigs() []string {
|
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
var res []string
|
var res []string
|
||||||
for k := range cm.configs {
|
for k := range cl.configs {
|
||||||
res = append(res, k)
|
res = append(res, k)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preload prepare models if they are not local but url or huggingface repositories
|
// Preload prepare models if they are not local but url or huggingface repositories
|
||||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||||
cm.Lock()
|
cl.Lock()
|
||||||
defer cm.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
status := func(fileName, current, total string, percent float64) {
|
status := func(fileName, current, total string, percent float64) {
|
||||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||||
@@ -349,7 +467,21 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
|
|||||||
|
|
||||||
log.Info().Msgf("Preloading models from %s", modelPath)
|
log.Info().Msgf("Preloading models from %s", modelPath)
|
||||||
|
|
||||||
for i, config := range cm.configs {
|
renderMode := "dark"
|
||||||
|
if os.Getenv("COLOR") != "" {
|
||||||
|
renderMode = os.Getenv("COLOR")
|
||||||
|
}
|
||||||
|
|
||||||
|
glamText := func(t string) {
|
||||||
|
out, err := glamour.Render(t, renderMode)
|
||||||
|
if err == nil && os.Getenv("NO_COLOR") == "" {
|
||||||
|
fmt.Println(out)
|
||||||
|
} else {
|
||||||
|
fmt.Println(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, config := range cl.configs {
|
||||||
|
|
||||||
// Download files and verify their SHA
|
// Download files and verify their SHA
|
||||||
for _, file := range config.DownloadFiles {
|
for _, file := range config.DownloadFiles {
|
||||||
@@ -381,25 +513,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cc := cm.configs[i]
|
cc := cl.configs[i]
|
||||||
c := &cc
|
c := &cc
|
||||||
c.PredictionOptions.Model = md5Name
|
c.PredictionOptions.Model = md5Name
|
||||||
cm.configs[i] = *c
|
cl.configs[i] = *c
|
||||||
}
|
}
|
||||||
if cm.configs[i].Name != "" {
|
if cl.configs[i].Name != "" {
|
||||||
log.Info().Msgf("Model name: %s", cm.configs[i].Name)
|
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
|
||||||
}
|
}
|
||||||
if cm.configs[i].Description != "" {
|
if cl.configs[i].Description != "" {
|
||||||
log.Info().Msgf("Model description: %s", cm.configs[i].Description)
|
//glamText("**Description**")
|
||||||
|
glamText(cl.configs[i].Description)
|
||||||
}
|
}
|
||||||
if cm.configs[i].Usage != "" {
|
if cl.configs[i].Usage != "" {
|
||||||
log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage)
|
//glamText("**Usage**")
|
||||||
|
glamText(cl.configs[i].Usage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
|
||||||
|
// (non-recursive)
|
||||||
|
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||||
cm.Lock()
|
cm.Lock()
|
||||||
defer cm.Unlock()
|
defer cm.Unlock()
|
||||||
entries, err := os.ReadDir(path)
|
entries, err := os.ReadDir(path)
|
||||||
@@ -419,7 +555,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
|
|||||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cm.configs[c.Name] = *c
|
cm.configs[c.Name] = *c
|
||||||
}
|
}
|
||||||
@@ -4,8 +4,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
. "github.com/go-skynet/LocalAI/core/config"
|
. "github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
@@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
Context("Test Read configuration functions", func() {
|
Context("Test Read configuration functions", func() {
|
||||||
configFile = os.Getenv("CONFIG_FILE")
|
configFile = os.Getenv("CONFIG_FILE")
|
||||||
It("Test ReadConfigFile", func() {
|
It("Test ReadConfigFile", func() {
|
||||||
config, err := ReadConfigFile(configFile)
|
config, err := ReadBackendConfigFile(configFile)
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(config).ToNot(BeNil())
|
Expect(config).ToNot(BeNil())
|
||||||
// two configs in config.yaml
|
// two configs in config.yaml
|
||||||
@@ -28,29 +27,26 @@ var _ = Describe("Test cases for config related functions", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("Test LoadConfigs", func() {
|
It("Test LoadConfigs", func() {
|
||||||
cm := NewConfigLoader()
|
cm := NewBackendConfigLoader()
|
||||||
opts := options.NewOptions()
|
opts := NewApplicationConfig()
|
||||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
err := cm.LoadBackendConfigsFromPath(opts.ModelPath)
|
||||||
options.WithModelLoader(modelLoader)(opts)
|
|
||||||
|
|
||||||
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
|
||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
|
||||||
|
|
||||||
// config should includes gpt4all models's api.config
|
// config should includes gpt4all models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all"))
|
||||||
|
|
||||||
// config should includes gpt2 models's api.config
|
// config should includes gpt2 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2"))
|
||||||
|
|
||||||
// config should includes text-embedding-ada-002 models's api.config
|
// config should includes text-embedding-ada-002 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||||
|
|
||||||
// config should includes rwkv_test models's api.config
|
// config should includes rwkv_test models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
|
||||||
|
|
||||||
// config should includes whisper-1 models's api.config
|
// config should includes whisper-1 models's api.config
|
||||||
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
|
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
244
core/http/api.go
244
core/http/api.go
@@ -3,122 +3,48 @@ package http
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/localai"
|
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
|
||||||
"github.com/go-skynet/LocalAI/api/openai"
|
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
"github.com/go-skynet/LocalAI/internal"
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/go-skynet/LocalAI/pkg/startup"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
|
func readAuthHeader(c *fiber.Ctx) string {
|
||||||
options := options.NewOptions(opts...)
|
authHeader := c.Get("Authorization")
|
||||||
|
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
// elevenlabs
|
||||||
if options.Debug {
|
xApiKey := c.Get("xi-api-key")
|
||||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
if xApiKey != "" {
|
||||||
|
authHeader = "Bearer " + xApiKey
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
// anthropic
|
||||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
xApiKey = c.Get("x-api-key")
|
||||||
|
if xApiKey != "" {
|
||||||
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
|
authHeader = "Bearer " + xApiKey
|
||||||
|
|
||||||
cl := config.NewConfigLoader()
|
|
||||||
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.ConfigFile != "" {
|
return authHeader
|
||||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
|
||||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadJSONModels != "" {
|
|
||||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.PreloadModelsFromPath != "" {
|
|
||||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.Debug {
|
|
||||||
for _, v := range cl.ListConfigs() {
|
|
||||||
cfg, _ := cl.GetConfig(v)
|
|
||||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.AssetsDestination != "" {
|
|
||||||
// Extract files from the embedded FS
|
|
||||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
|
||||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// turn off any process that was started by GRPC if the context is canceled
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
options.Loader.StopAllGRPC()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if options.WatchDog {
|
|
||||||
wd := model.NewWatchDog(
|
|
||||||
options.Loader,
|
|
||||||
options.WatchDogBusyTimeout,
|
|
||||||
options.WatchDogIdleTimeout,
|
|
||||||
options.WatchDogBusy,
|
|
||||||
options.WatchDogIdle)
|
|
||||||
options.Loader.SetWatchDog(wd)
|
|
||||||
go wd.Run()
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
wd.Shutdown()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return options, cl, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
||||||
|
|
||||||
options, cl, err := Startup(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return errors as JSON responses
|
// Return errors as JSON responses
|
||||||
app := fiber.New(fiber.Config{
|
app := fiber.New(fiber.Config{
|
||||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||||
DisableStartupMessage: options.DisableMessage,
|
DisableStartupMessage: appConfig.DisableMessage,
|
||||||
// Override default error handler
|
// Override default error handler
|
||||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||||
// Status code defaults to 500
|
// Status code defaults to 500
|
||||||
@@ -139,7 +65,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if options.Debug {
|
if appConfig.Debug {
|
||||||
app.Use(logger.New(logger.Config{
|
app.Use(logger.New(logger.Config{
|
||||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||||
}))
|
}))
|
||||||
@@ -147,17 +73,25 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
|
|
||||||
// Default middleware config
|
// Default middleware config
|
||||||
|
|
||||||
if !options.Debug {
|
if !appConfig.Debug {
|
||||||
app.Use(recover.New())
|
app.Use(recover.New())
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Metrics != nil {
|
metricsService, err := services.NewLocalAIMetricsService()
|
||||||
app.Use(metrics.APIMiddleware(options.Metrics))
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if metricsService != nil {
|
||||||
|
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||||
|
app.Hooks().OnShutdown(func() error {
|
||||||
|
return metricsService.Shutdown()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||||
auth := func(c *fiber.Ctx) error {
|
auth := func(c *fiber.Ctx) error {
|
||||||
if len(options.ApiKeys) == 0 {
|
if len(appConfig.ApiKeys) == 0 {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,47 +106,48 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add file keys to options.ApiKeys
|
// Add file keys to options.ApiKeys
|
||||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.ApiKeys) == 0 {
|
if len(appConfig.ApiKeys) == 0 {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
authHeader := c.Get("Authorization")
|
authHeader := readAuthHeader(c)
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If it's a bearer token
|
||||||
authHeaderParts := strings.Split(authHeader, " ")
|
authHeaderParts := strings.Split(authHeader, " ")
|
||||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey := authHeaderParts[1]
|
apiKey := authHeaderParts[1]
|
||||||
for _, key := range options.ApiKeys {
|
for _, key := range appConfig.ApiKeys {
|
||||||
if apiKey == key {
|
if apiKey == key {
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.CORS {
|
if appConfig.CORS {
|
||||||
var c func(ctx *fiber.Ctx) error
|
var c func(ctx *fiber.Ctx) error
|
||||||
if options.CORSAllowOrigins == "" {
|
if appConfig.CORSAllowOrigins == "" {
|
||||||
c = cors.New()
|
c = cors.New()
|
||||||
} else {
|
} else {
|
||||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Use(c)
|
app.Use(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAI API endpoints
|
// LocalAI API endpoints
|
||||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
galleryService := services.NewGalleryService(appConfig.ModelPath)
|
||||||
galleryService.Start(options.Context, cl)
|
galleryService.Start(appConfig.Context, cl)
|
||||||
|
|
||||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||||
return c.JSON(struct {
|
return c.JSON(struct {
|
||||||
@@ -220,69 +155,68 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Make sure directories exists
|
|
||||||
os.MkdirAll(options.ImageDir, 0755)
|
|
||||||
os.MkdirAll(options.AudioDir, 0755)
|
|
||||||
os.MkdirAll(options.UploadDir, 0755)
|
|
||||||
os.MkdirAll(options.Loader.ModelPath, 0755)
|
|
||||||
|
|
||||||
// Load upload json
|
// Load upload json
|
||||||
openai.LoadUploadConfig(options.UploadDir)
|
openai.LoadUploadConfig(appConfig.UploadDir)
|
||||||
|
|
||||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
|
||||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
|
||||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||||
|
|
||||||
|
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// Elevenlabs
|
||||||
|
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// openAI compatible API endpoint
|
// openAI compatible API endpoint
|
||||||
|
|
||||||
// chat
|
// chat
|
||||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// edit
|
// edit
|
||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// files
|
// files
|
||||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options))
|
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options))
|
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options))
|
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options))
|
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options))
|
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options))
|
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options))
|
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||||
|
|
||||||
// completion
|
// completion
|
||||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
|
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// embeddings
|
// embeddings
|
||||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
|
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
|
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// images
|
// images
|
||||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
|
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
if options.ImageDir != "" {
|
if appConfig.ImageDir != "" {
|
||||||
app.Static("/generated-images", options.ImageDir)
|
app.Static("/generated-images", appConfig.ImageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.AudioDir != "" {
|
if appConfig.AudioDir != "" {
|
||||||
app.Static("/generated-audio", options.AudioDir)
|
app.Static("/generated-audio", appConfig.AudioDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := func(c *fiber.Ctx) error {
|
ok := func(c *fiber.Ctx) error {
|
||||||
@@ -294,15 +228,15 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
app.Get("/readyz", ok)
|
app.Get("/readyz", ok)
|
||||||
|
|
||||||
// Experimental Backend Statistics Module
|
// Experimental Backend Statistics Module
|
||||||
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
|
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
|
||||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||||
|
|
||||||
// models
|
// models
|
||||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||||
|
|
||||||
app.Get("/metrics", metrics.MetricsHandler())
|
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,10 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
. "github.com/go-skynet/LocalAI/core/http"
|
. "github.com/go-skynet/LocalAI/core/http"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
"github.com/go-skynet/LocalAI/core/startup"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -127,25 +128,33 @@ var backendAssets embed.FS
|
|||||||
var _ = Describe("API test", func() {
|
var _ = Describe("API test", func() {
|
||||||
|
|
||||||
var app *fiber.App
|
var app *fiber.App
|
||||||
var modelLoader *model.ModelLoader
|
|
||||||
var client *openai.Client
|
var client *openai.Client
|
||||||
var client2 *openaigo.Client
|
var client2 *openaigo.Client
|
||||||
var c context.Context
|
var c context.Context
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
var tmpdir string
|
var tmpdir string
|
||||||
|
var modelDir string
|
||||||
|
var bcl *config.BackendConfigLoader
|
||||||
|
var ml *model.ModelLoader
|
||||||
|
var applicationConfig *config.ApplicationConfig
|
||||||
|
|
||||||
commonOpts := []options.AppOption{
|
commonOpts := []config.AppOption{
|
||||||
options.WithDebug(true),
|
config.WithDebug(true),
|
||||||
options.WithDisableMessage(true),
|
config.WithDisableMessage(true),
|
||||||
}
|
}
|
||||||
|
|
||||||
Context("API with ephemeral models", func() {
|
Context("API with ephemeral models", func() {
|
||||||
BeforeEach(func() {
|
|
||||||
|
BeforeEach(func(sc SpecContext) {
|
||||||
var err error
|
var err error
|
||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
modelDir = filepath.Join(tmpdir, "models")
|
||||||
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||||
|
err = os.Mkdir(backendAssetsDir, 0755)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
g := []gallery.GalleryModel{
|
g := []gallery.GalleryModel{
|
||||||
@@ -172,16 +181,18 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
|
append(commonOpts,
|
||||||
|
config.WithContext(c),
|
||||||
|
config.WithGalleries(galleries),
|
||||||
|
config.WithModelPath(modelDir),
|
||||||
|
config.WithBackendAssets(backendAssets),
|
||||||
|
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
app, err = App(
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
append(commonOpts,
|
|
||||||
options.WithMetrics(metricsService),
|
|
||||||
options.WithContext(c),
|
|
||||||
options.WithGalleries(galleries),
|
|
||||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -198,15 +209,21 @@ var _ = Describe("API test", func() {
|
|||||||
}, "2m").ShouldNot(HaveOccurred())
|
}, "2m").ShouldNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func(sc SpecContext) {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
os.RemoveAll(tmpdir)
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
err := os.RemoveAll(tmpdir)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = os.ReadDir(tmpdir)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Applying models", func() {
|
Context("Applying models", func() {
|
||||||
It("applies models from a gallery", func() {
|
|
||||||
|
|
||||||
|
It("applies models from a gallery", func() {
|
||||||
models := getModels("http://127.0.0.1:9090/models/available")
|
models := getModels("http://127.0.0.1:9090/models/available")
|
||||||
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
|
||||||
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
|
||||||
@@ -228,10 +245,10 @@ var _ = Describe("API test", func() {
|
|||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
|
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -253,6 +270,7 @@ var _ = Describe("API test", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("overrides models", func() {
|
It("overrides models", func() {
|
||||||
|
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||||
Name: "bert",
|
Name: "bert",
|
||||||
@@ -270,7 +288,7 @@ var _ = Describe("API test", func() {
|
|||||||
return response["processed"].(bool)
|
return response["processed"].(bool)
|
||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -294,7 +312,7 @@ var _ = Describe("API test", func() {
|
|||||||
return response["processed"].(bool)
|
return response["processed"].(bool)
|
||||||
}, "360s", "10s").Should(Equal(true))
|
}, "360s", "10s").Should(Equal(true))
|
||||||
|
|
||||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
content := map[string]interface{}{}
|
content := map[string]interface{}{}
|
||||||
@@ -368,7 +386,7 @@ var _ = Describe("API test", func() {
|
|||||||
var res map[string]string
|
var res map[string]string
|
||||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
|
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
|
||||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||||
|
|
||||||
@@ -483,8 +501,11 @@ var _ = Describe("API test", func() {
|
|||||||
var err error
|
var err error
|
||||||
tmpdir, err = os.MkdirTemp("", "")
|
tmpdir, err = os.MkdirTemp("", "")
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
modelDir = filepath.Join(tmpdir, "models")
|
||||||
|
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
|
||||||
|
err = os.Mkdir(backendAssetsDir, 0755)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
modelLoader = model.NewModelLoader(tmpdir)
|
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
galleries := []gallery.Gallery{
|
galleries := []gallery.Gallery{
|
||||||
@@ -494,21 +515,20 @@ var _ = Describe("API test", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
app, err = App(
|
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
config.WithAudioDir(tmpdir),
|
||||||
options.WithAudioDir(tmpdir),
|
config.WithImageDir(tmpdir),
|
||||||
options.WithImageDir(tmpdir),
|
config.WithGalleries(galleries),
|
||||||
options.WithGalleries(galleries),
|
config.WithModelPath(modelDir),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithBackendAssets(backendAssets),
|
||||||
options.WithBackendAssets(backendAssets),
|
config.WithBackendAssetsOutput(tmpdir))...,
|
||||||
options.WithBackendAssetsOutput(tmpdir))...,
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -527,8 +547,14 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
os.RemoveAll(tmpdir)
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
err := os.RemoveAll(tmpdir)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = os.ReadDir(tmpdir)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
It("installs and is capable to run tts", Label("tts"), func() {
|
It("installs and is capable to run tts", Label("tts"), func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
@@ -599,20 +625,20 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("API query", func() {
|
Context("API query", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelPath := os.Getenv("MODELS_PATH")
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
var err error
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
app, err = App(
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithModelPath(modelPath),
|
||||||
options.WithMetrics(metricsService),
|
|
||||||
)...)
|
)...)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -630,7 +656,10 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
It("returns the models list", func() {
|
It("returns the models list", func() {
|
||||||
models, err := client.ListModels(context.TODO())
|
models, err := client.ListModels(context.TODO())
|
||||||
@@ -811,20 +840,20 @@ var _ = Describe("API test", func() {
|
|||||||
|
|
||||||
Context("Config file", func() {
|
Context("Config file", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
modelPath := os.Getenv("MODELS_PATH")
|
||||||
c, cancel = context.WithCancel(context.Background())
|
c, cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
metricsService, err := metrics.SetupMetrics()
|
var err error
|
||||||
Expect(err).ToNot(HaveOccurred())
|
bcl, ml, applicationConfig, err = startup.Startup(
|
||||||
|
|
||||||
app, err = App(
|
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
options.WithContext(c),
|
config.WithContext(c),
|
||||||
options.WithMetrics(metricsService),
|
config.WithModelPath(modelPath),
|
||||||
options.WithModelLoader(modelLoader),
|
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
app, err = App(bcl, ml, applicationConfig)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
go app.Listen("127.0.0.1:9090")
|
go app.Listen("127.0.0.1:9090")
|
||||||
|
|
||||||
defaultConfig := openai.DefaultConfig("")
|
defaultConfig := openai.DefaultConfig("")
|
||||||
@@ -840,7 +869,10 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
cancel()
|
cancel()
|
||||||
app.Shutdown()
|
if app != nil {
|
||||||
|
err := app.Shutdown()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
It("can generate chat completions from config file (list1)", func() {
|
It("can generate chat completions from config file (list1)", func() {
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||||
|
|||||||
55
core/http/endpoints/elevenlabs/tts.go
Normal file
55
core/http/endpoints/elevenlabs/tts.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package elevenlabs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.ElevenLabsTTSRequest)
|
||||||
|
voiceID := c.Params("voice-id")
|
||||||
|
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||||
|
} else {
|
||||||
|
if input.ModelID != "" {
|
||||||
|
modelFile = input.ModelID
|
||||||
|
} else {
|
||||||
|
modelFile = cfg.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||||
|
|
||||||
|
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Download(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
36
core/http/endpoints/localai/backend_monitor.go
Normal file
36
core/http/endpoints/localai/backend_monitor.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := bm.CheckAndSample(input.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(schema.BackendMonitorRequest)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return bm.ShutdownModel(input.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
146
core/http/endpoints/localai/gallery.go
Normal file
146
core/http/endpoints/localai/gallery.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelGalleryEndpointService struct {
|
||||||
|
galleries []gallery.Gallery
|
||||||
|
modelPath string
|
||||||
|
galleryApplier *services.GalleryService
|
||||||
|
}
|
||||||
|
|
||||||
|
type GalleryModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
gallery.GalleryModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||||
|
return ModelGalleryEndpointService{
|
||||||
|
galleries: galleries,
|
||||||
|
modelPath: modelPath,
|
||||||
|
galleryApplier: galleryApplier,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||||
|
if status == nil {
|
||||||
|
return fmt.Errorf("could not find any status for ID")
|
||||||
|
}
|
||||||
|
return c.JSON(status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(GalleryModel)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid, err := uuid.NewUUID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mgs.galleryApplier.C <- gallery.GalleryOp{
|
||||||
|
Req: input.GalleryModel,
|
||||||
|
Id: uuid.String(),
|
||||||
|
GalleryName: input.ID,
|
||||||
|
Galleries: mgs.galleries,
|
||||||
|
}
|
||||||
|
return c.JSON(struct {
|
||||||
|
ID string `json:"uuid"`
|
||||||
|
StatusURL string `json:"status"`
|
||||||
|
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||||
|
|
||||||
|
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||||
|
for _, m := range models {
|
||||||
|
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(models)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||||
|
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s already exists", input.Name)
|
||||||
|
}
|
||||||
|
dat, err := json.Marshal(mgs.galleries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||||
|
mgs.galleries = append(mgs.galleries, *input)
|
||||||
|
return c.Send(dat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input := new(gallery.Gallery)
|
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
}) {
|
||||||
|
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||||
|
}
|
||||||
|
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||||
|
return gallery.Name == input.Name
|
||||||
|
})
|
||||||
|
return c.Send(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
43
core/http/endpoints/localai/metrics.go
Normal file
43
core/http/endpoints/localai/metrics.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||||
|
|
||||||
|
return adaptor.HTTPHandler(promhttp.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiMiddlewareConfig struct {
|
||||||
|
Filter func(c *fiber.Ctx) bool
|
||||||
|
metricsService *services.LocalAIMetricsService
|
||||||
|
}
|
||||||
|
|
||||||
|
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||||
|
cfg := apiMiddlewareConfig{
|
||||||
|
metricsService: metrics,
|
||||||
|
Filter: func(c *fiber.Ctx) bool {
|
||||||
|
return c.Path() == "/metrics"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
if cfg.Filter != nil && cfg.Filter(c) {
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
path := c.Path()
|
||||||
|
method := c.Method()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := c.Next()
|
||||||
|
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||||
|
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,37 +1,39 @@
|
|||||||
package localai
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/rs/zerolog/log"
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TTSRequest struct {
|
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
Model string `json:"model" yaml:"model"`
|
|
||||||
Input string `json:"input" yaml:"input"`
|
|
||||||
Backend string `json:"backend" yaml:"backend"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
input := new(TTSRequest)
|
input := new(schema.TTSRequest)
|
||||||
|
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
if err := c.BodyParser(input); err != nil {
|
if err := c.BodyParser(input); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false)
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
modelFile = input.Model
|
modelFile = input.Model
|
||||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
}
|
}
|
||||||
cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false)
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||||
|
config.LoadOptionDebug(appConfig.Debug),
|
||||||
|
config.LoadOptionThreads(appConfig.Threads),
|
||||||
|
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||||
|
config.LoadOptionF16(appConfig.F16),
|
||||||
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
modelFile = input.Model
|
modelFile = input.Model
|
||||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||||
@@ -44,7 +46,7 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
cfg.Backend = input.Backend
|
cfg.Backend = input.Backend
|
||||||
}
|
}
|
||||||
|
|
||||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg)
|
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -9,8 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -21,12 +20,12 @@ import (
|
|||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
emptyMessage := ""
|
emptyMessage := ""
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
created := int(time.Now().Unix())
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
initialMessage := schema.OpenAIResponse{
|
initialMessage := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -36,7 +35,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
responses <- initialMessage
|
responses <- initialMessage
|
||||||
|
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -55,9 +54,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
})
|
})
|
||||||
close(responses)
|
close(responses)
|
||||||
}
|
}
|
||||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
result := ""
|
result := ""
|
||||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
result += s
|
result += s
|
||||||
// TODO: Change generated BNF grammar to be compliant with the schema so we can
|
// TODO: Change generated BNF grammar to be compliant with the schema so we can
|
||||||
// stream the result token by token here.
|
// stream the result token by token here.
|
||||||
@@ -78,7 +77,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
responses <- initialMessage
|
responses <- initialMessage
|
||||||
|
|
||||||
result, err := handleQuestion(config, req, o, results[0].arguments, prompt)
|
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("error handling question: %s", err.Error())
|
log.Error().Msgf("error handling question: %s", err.Error())
|
||||||
return
|
return
|
||||||
@@ -154,12 +153,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
processFunctions := false
|
processFunctions := false
|
||||||
funcs := grammar.Functions{}
|
funcs := grammar.Functions{}
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
modelFile, input, err := readRequest(c, ml, startupOptions, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -252,7 +251,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
FunctionName: i.Name,
|
FunctionName: i.Name,
|
||||||
MessageIndex: messageIndex,
|
MessageIndex: messageIndex,
|
||||||
}
|
}
|
||||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||||
} else {
|
} else {
|
||||||
@@ -320,7 +319,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
templateFile := ""
|
templateFile := ""
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
templateFile = config.Model
|
templateFile = config.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,7 +332,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
@@ -357,9 +356,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
if !processFunctions {
|
if !processFunctions {
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
go process(predInput, input, config, ml, responses)
|
||||||
} else {
|
} else {
|
||||||
go processTools(noActionName, predInput, input, config, o.Loader, responses)
|
go processTools(noActionName, predInput, input, config, ml, responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
@@ -413,7 +412,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
|
|
||||||
// no streaming mode
|
// no streaming mode
|
||||||
default:
|
default:
|
||||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||||
if !processFunctions {
|
if !processFunctions {
|
||||||
// no function is called, just reply and use stop as finish reason
|
// no function is called, just reply and use stop as finish reason
|
||||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||||
@@ -425,7 +424,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case noActionsToRun:
|
case noActionsToRun:
|
||||||
result, err := handleQuestion(config, input, o, results[0].arguments, predInput)
|
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("error handling question: %s", err.Error())
|
log.Error().Msgf("error handling question: %s", err.Error())
|
||||||
return
|
return
|
||||||
@@ -506,7 +505,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) {
|
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
|
||||||
log.Debug().Msgf("nothing to do, computing a reply")
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
@@ -535,7 +534,7 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio
|
|||||||
images = append(images, m.StringImages...)
|
images = append(images, m.StringImages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil)
|
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
return "", err
|
return "", err
|
||||||
@@ -565,10 +564,20 @@ func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults
|
|||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
for _, s := range ss {
|
for _, s := range ss {
|
||||||
func_name := s["function"]
|
func_name, ok := s["function"]
|
||||||
args := s["arguments"]
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
args, ok := s["arguments"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
d, _ := json.Marshal(args)
|
d, _ := json.Marshal(args)
|
||||||
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
|
funcName, ok := func_name.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
|
||||||
@@ -579,12 +588,21 @@ func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults
|
|||||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
func_name := ss["function"]
|
func_name, ok := ss["function"]
|
||||||
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
d, _ := json.Marshal(args)
|
d, _ := json.Marshal(args)
|
||||||
|
funcName, ok := func_name.(string)
|
||||||
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
|
if !ok {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
|
||||||
}
|
}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -21,12 +21,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
created := int(time.Now().Unix())
|
created := int(time.Now().Unix())
|
||||||
|
|
||||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||||
resp := schema.OpenAIResponse{
|
resp := schema.OpenAIResponse{
|
||||||
ID: id,
|
ID: id,
|
||||||
Created: created,
|
Created: created,
|
||||||
@@ -53,14 +53,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input)
|
log.Debug().Msgf("`input`: %+v", input)
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -84,7 +84,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
templateFile := ""
|
templateFile := ""
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
templateFile = config.Model
|
templateFile = config.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -111,7 +111,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
|
|
||||||
go process(predInput, input, config, o.Loader, responses)
|
go process(predInput, input, config, ml, responses)
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
Input: i,
|
Input: i,
|
||||||
})
|
})
|
||||||
@@ -164,7 +164,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
r, tokenUsage, err := ComputeChoices(
|
||||||
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@@ -16,14 +16,14 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readRequest(c, o, true)
|
modelFile, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +33,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
templateFile := ""
|
templateFile := ""
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||||
templateFile = config.Model
|
templateFile = config.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
if templateFile != "" {
|
if templateFile != "" {
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: i,
|
Input: i,
|
||||||
Instruction: input.Instruction,
|
Instruction: input.Instruction,
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
@@ -57,7 +57,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||||
*c = append(*c, schema.Choice{Text: s})
|
*c = append(*c, schema.Choice{Text: s})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -6,24 +6,25 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readRequest(c, o, true)
|
model, input, err := readRequest(c, ml, appConfig, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +34,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
for i, s := range config.InputToken {
|
for i, s := range config.InputToken {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -47,7 +48,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
for i, s := range config.InputStrings {
|
for i, s := range config.InputStrings {
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -62,7 +62,7 @@ func LoadUploadConfig(uploadPath string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||||
func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -70,8 +70,8 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check the file size
|
// Check the file size
|
||||||
if file.Size > int64(o.UploadLimitMB*1024*1024) {
|
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
|
||||||
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB))
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
|
||||||
}
|
}
|
||||||
|
|
||||||
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
|
||||||
@@ -82,7 +82,7 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
// Sanitize the filename to prevent directory traversal
|
// Sanitize the filename to prevent directory traversal
|
||||||
filename := utils.SanitizeFileName(file.Filename)
|
filename := utils.SanitizeFileName(file.Filename)
|
||||||
|
|
||||||
savePath := filepath.Join(o.UploadDir, filename)
|
savePath := filepath.Join(appConfig.UploadDir, filename)
|
||||||
|
|
||||||
// Check if file already exists
|
// Check if file already exists
|
||||||
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
|
||||||
@@ -104,13 +104,13 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
}
|
}
|
||||||
|
|
||||||
uploadedFiles = append(uploadedFiles, f)
|
uploadedFiles = append(uploadedFiles, f)
|
||||||
saveUploadConfig(o.UploadDir)
|
saveUploadConfig(appConfig.UploadDir)
|
||||||
return c.Status(fiber.StatusOK).JSON(f)
|
return c.Status(fiber.StatusOK).JSON(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||||
func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
type ListFiles struct {
|
type ListFiles struct {
|
||||||
Data []File
|
Data []File
|
||||||
Object string
|
Object string
|
||||||
@@ -150,7 +150,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
|
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
file, err := getFileFromRequest(c)
|
file, err := getFileFromRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -162,7 +162,7 @@ func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
|
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
|
||||||
func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
type DeleteStatus struct {
|
type DeleteStatus struct {
|
||||||
Id string
|
Id string
|
||||||
Object string
|
Object string
|
||||||
@@ -175,7 +175,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.Remove(filepath.Join(o.UploadDir, file.Filename))
|
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the file doesn't exist then we should just continue to remove it
|
// If the file doesn't exist then we should just continue to remove it
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
@@ -191,7 +191,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
saveUploadConfig(o.UploadDir)
|
saveUploadConfig(appConfig.UploadDir)
|
||||||
return c.JSON(DeleteStatus{
|
return c.JSON(DeleteStatus{
|
||||||
Id: file.ID,
|
Id: file.ID,
|
||||||
Object: "file",
|
Object: "file",
|
||||||
@@ -201,14 +201,14 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
file, err := getFileFromRequest(c)
|
file, err := getFileFromRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename))
|
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||||
}
|
}
|
||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
|
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -25,11 +25,11 @@ type ListFiles struct {
|
|||||||
Object string
|
Object string
|
||||||
}
|
}
|
||||||
|
|
||||||
func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) {
|
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
|
||||||
// Preparing the mocked objects
|
// Preparing the mocked objects
|
||||||
loader = &config.ConfigLoader{}
|
loader = &config.BackendConfigLoader{}
|
||||||
|
|
||||||
option = &options.Option{
|
option = &config.ApplicationConfig{
|
||||||
UploadLimitMB: 10,
|
UploadLimitMB: 10,
|
||||||
UploadDir: "test_dir",
|
UploadDir: "test_dir",
|
||||||
}
|
}
|
||||||
@@ -52,9 +52,9 @@ func startUpApp() (app *fiber.App, option *options.Option, loader *config.Config
|
|||||||
|
|
||||||
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
// Preparing the mocked objects
|
// Preparing the mocked objects
|
||||||
loader := &config.ConfigLoader{}
|
loader := &config.BackendConfigLoader{}
|
||||||
|
|
||||||
option := &options.Option{
|
option := &config.ApplicationConfig{
|
||||||
UploadLimitMB: 10,
|
UploadLimitMB: 10,
|
||||||
UploadDir: "test_dir",
|
UploadDir: "test_dir",
|
||||||
}
|
}
|
||||||
@@ -174,9 +174,9 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt
|
|||||||
return app.Test(request)
|
return app.Test(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) {
|
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||||
// Create a file that exceeds the limit
|
// Create a file that exceeds the limit
|
||||||
file := createTestFile(t, fileName, fileSize, o)
|
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||||
|
|
||||||
// Creating a new HTTP Request
|
// Creating a new HTTP Request
|
||||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
@@ -186,9 +186,9 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos
|
|||||||
return app.Test(req)
|
return app.Test(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File {
|
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
|
||||||
// Create a file that exceeds the limit
|
// Create a file that exceeds the limit
|
||||||
file := createTestFile(t, fileName, fileSize, o)
|
file := createTestFile(t, fileName, fileSize, appConfig)
|
||||||
|
|
||||||
// Creating a new HTTP Request
|
// Creating a new HTTP Request
|
||||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
@@ -233,7 +233,7 @@ func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipar
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create test files
|
// Helper to create test files
|
||||||
func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File {
|
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
|
||||||
err := os.MkdirAll(option.UploadDir, 0755)
|
err := os.MkdirAll(option.UploadDir, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) {
|
|||||||
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readRequest(c, o, false)
|
m, input, err := readRequest(c, ml, appConfig, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -71,7 +71,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
log.Debug().Msgf("Loading model: %+v", m)
|
log.Debug().Msgf("Loading model: %+v", m)
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -104,7 +104,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -133,15 +133,15 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x")
|
sizeParts := strings.Split(input.Size, "x")
|
||||||
if len(sizeParts) != 2 {
|
if len(sizeParts) != 2 {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
width, err := strconv.Atoi(sizeParts[0])
|
width, err := strconv.Atoi(sizeParts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
height, err := strconv.Atoi(sizeParts[1])
|
height, err := strconv.Atoi(sizeParts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
|
|
||||||
b64JSON := false
|
b64JSON := false
|
||||||
@@ -179,7 +179,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
tempDir := ""
|
tempDir := ""
|
||||||
if !b64JSON {
|
if !b64JSON {
|
||||||
tempDir = o.ImageDir
|
tempDir = appConfig.ImageDir
|
||||||
}
|
}
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||||
@@ -196,7 +196,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
baseURL := c.BaseURL()
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -2,8 +2,8 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
func ComputeChoices(
|
func ComputeChoices(
|
||||||
req *schema.OpenAIRequest,
|
req *schema.OpenAIRequest,
|
||||||
predInput string,
|
predInput string,
|
||||||
config *config.Config,
|
config *config.BackendConfig,
|
||||||
o *options.Option,
|
o *config.ApplicationConfig,
|
||||||
loader *model.ModelLoader,
|
loader *model.ModelLoader,
|
||||||
cb func(string, *[]schema.Choice),
|
cb func(string, *[]schema.Choice),
|
||||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
||||||
@@ -3,15 +3,15 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
models, err := loader.ListModels()
|
models, err := ml.ListModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -40,7 +40,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
|||||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||||
|
|
||||||
// Start with the known configurations
|
// Start with the known configurations
|
||||||
for _, c := range cm.GetAllConfigs() {
|
for _, c := range cl.GetAllBackendConfigs() {
|
||||||
if excludeConfigured {
|
if excludeConfigured {
|
||||||
mm[c.Model] = nil
|
mm[c.Model] = nil
|
||||||
}
|
}
|
||||||
@@ -5,13 +5,12 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||||
options "github.com/go-skynet/LocalAI/core/options"
|
|
||||||
"github.com/go-skynet/LocalAI/core/schema"
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
@@ -19,11 +18,9 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
||||||
input := new(schema.OpenAIRequest)
|
input := new(schema.OpenAIRequest)
|
||||||
ctx, cancel := context.WithCancel(o.Context)
|
|
||||||
input.Context = ctx
|
|
||||||
input.Cancel = cancel
|
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
if err := c.BodyParser(input); err != nil {
|
if err := c.BodyParser(input); err != nil {
|
||||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||||
@@ -31,9 +28,13 @@ func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *sch
|
|||||||
|
|
||||||
received, _ := json.Marshal(input)
|
received, _ := json.Marshal(input)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(o.Context)
|
||||||
|
input.Context = ctx
|
||||||
|
input.Cancel = cancel
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received))
|
log.Debug().Msgf("Request received: %s", string(received))
|
||||||
|
|
||||||
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel)
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
||||||
|
|
||||||
return modelFile, input, err
|
return modelFile, input, err
|
||||||
}
|
}
|
||||||
@@ -50,7 +51,7 @@ func getBase64Image(s string) (string, error) {
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// read the image data into memory
|
// read the image data into memory
|
||||||
data, err := ioutil.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -69,14 +70,14 @@ func getBase64Image(s string) (string, error) {
|
|||||||
return "", fmt.Errorf("not valid string")
|
return "", fmt.Errorf("not valid string")
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||||
if input.Echo {
|
if input.Echo {
|
||||||
config.Echo = input.Echo
|
config.Echo = input.Echo
|
||||||
}
|
}
|
||||||
if input.TopK != 0 {
|
if input.TopK != nil {
|
||||||
config.TopK = input.TopK
|
config.TopK = input.TopK
|
||||||
}
|
}
|
||||||
if input.TopP != 0 {
|
if input.TopP != nil {
|
||||||
config.TopP = input.TopP
|
config.TopP = input.TopP
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,11 +117,11 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
config.Grammar = input.Grammar
|
config.Grammar = input.Grammar
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Temperature != 0 {
|
if input.Temperature != nil {
|
||||||
config.Temperature = input.Temperature
|
config.Temperature = input.Temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Maxtokens != 0 {
|
if input.Maxtokens != nil {
|
||||||
config.Maxtokens = input.Maxtokens
|
config.Maxtokens = input.Maxtokens
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,30 +193,14 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
config.Batch = input.Batch
|
config.Batch = input.Batch
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.F16 {
|
|
||||||
config.F16 = input.F16
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.IgnoreEOS {
|
if input.IgnoreEOS {
|
||||||
config.IgnoreEOS = input.IgnoreEOS
|
config.IgnoreEOS = input.IgnoreEOS
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Seed != 0 {
|
if input.Seed != nil {
|
||||||
config.Seed = input.Seed
|
config.Seed = input.Seed
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Mirostat != 0 {
|
|
||||||
config.LLMConfig.Mirostat = input.Mirostat
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatETA != 0 {
|
|
||||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.MirostatTAU != 0 {
|
|
||||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.TypicalP != 0 {
|
if input.TypicalP != 0 {
|
||||||
config.TypicalP = input.TypicalP
|
config.TypicalP = input.TypicalP
|
||||||
}
|
}
|
||||||
@@ -270,8 +255,13 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
|
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||||
cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
|
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
|
||||||
|
config.LoadOptionDebug(debug),
|
||||||
|
config.LoadOptionThreads(threads),
|
||||||
|
config.LoadOptionContextSize(ctx),
|
||||||
|
config.LoadOptionF16(f16),
|
||||||
|
)
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
// Set the parameters for the language model prediction
|
||||||
updateRequestConfig(cfg, input)
|
updateRequestConfig(cfg, input)
|
||||||
@@ -9,22 +9,22 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/backend"
|
"github.com/go-skynet/LocalAI/core/backend"
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
"github.com/go-skynet/LocalAI/core/options"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readRequest(c, o, false)
|
m, input, err := readRequest(c, ml, appConfig, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
6
core/schema/elevenlabs.go
Normal file
6
core/schema/elevenlabs.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
type ElevenLabsTTSRequest struct {
|
||||||
|
Text string `json:"text" yaml:"text"`
|
||||||
|
ModelID string `json:"model_id" yaml:"model_id"`
|
||||||
|
}
|
||||||
22
core/schema/localai.go
Normal file
22
core/schema/localai.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitorRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackendMonitorResponse struct {
|
||||||
|
MemoryInfo *gopsutil.MemoryInfoStat
|
||||||
|
MemoryPercent float32
|
||||||
|
CPUPercent float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type TTSRequest struct {
|
||||||
|
Model string `json:"model" yaml:"model"`
|
||||||
|
Input string `json:"input" yaml:"input"`
|
||||||
|
Voice string `json:"voice" yaml:"voice"`
|
||||||
|
Backend string `json:"backend" yaml:"backend"`
|
||||||
|
}
|
||||||
@@ -3,8 +3,6 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/core/config"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,7 +47,7 @@ type OpenAIResponse struct {
|
|||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
FinishReason string `json:"finish_reason,omitempty"`
|
FinishReason string `json:"finish_reason"`
|
||||||
Message *Message `json:"message,omitempty"`
|
Message *Message `json:"message,omitempty"`
|
||||||
Delta *Message `json:"delta,omitempty"`
|
Delta *Message `json:"delta,omitempty"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
@@ -108,10 +106,10 @@ type ChatCompletionResponseFormat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIRequest struct {
|
type OpenAIRequest struct {
|
||||||
config.PredictionOptions
|
PredictionOptions
|
||||||
|
|
||||||
Context context.Context
|
Context context.Context `json:"-"`
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc `json:"-"`
|
||||||
|
|
||||||
// whisper
|
// whisper
|
||||||
File string `json:"file" validate:"required"`
|
File string `json:"file" validate:"required"`
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package schema
|
||||||
|
|
||||||
type PredictionOptions struct {
|
type PredictionOptions struct {
|
||||||
|
|
||||||
@@ -12,28 +12,23 @@ type PredictionOptions struct {
|
|||||||
N int `json:"n"`
|
N int `json:"n"`
|
||||||
|
|
||||||
// Common options between all the API calls, part of the OpenAI spec
|
// Common options between all the API calls, part of the OpenAI spec
|
||||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
TopP *float64 `json:"top_p" yaml:"top_p"`
|
||||||
TopK int `json:"top_k" yaml:"top_k"`
|
TopK *int `json:"top_k" yaml:"top_k"`
|
||||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
Temperature *float64 `json:"temperature" yaml:"temperature"`
|
||||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
Maxtokens *int `json:"max_tokens" yaml:"max_tokens"`
|
||||||
Echo bool `json:"echo"`
|
Echo bool `json:"echo"`
|
||||||
|
|
||||||
// Custom parameters - not present in the OpenAI API
|
// Custom parameters - not present in the OpenAI API
|
||||||
Batch int `json:"batch" yaml:"batch"`
|
Batch int `json:"batch" yaml:"batch"`
|
||||||
F16 bool `json:"f16" yaml:"f16"`
|
|
||||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||||
|
|
||||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
|
||||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
|
||||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
|
||||||
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||||
Seed int `json:"seed" yaml:"seed"`
|
Seed *int `json:"seed" yaml:"seed"`
|
||||||
|
|
||||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
||||||
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
||||||
140
core/services/backend_monitor.go
Normal file
140
core/services/backend_monitor.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/schema"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendMonitor struct {
|
||||||
|
configLoader *config.BackendConfigLoader
|
||||||
|
modelLoader *model.ModelLoader
|
||||||
|
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
|
||||||
|
return BackendMonitor{
|
||||||
|
configLoader: configLoader,
|
||||||
|
modelLoader: modelLoader,
|
||||||
|
options: appConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||||
|
config, exists := bm.configLoader.GetBackendConfig(modelName)
|
||||||
|
var backendId string
|
||||||
|
if exists {
|
||||||
|
backendId = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backendId = modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backendId, ".bin") {
|
||||||
|
backendId = fmt.Sprintf("%s.bin", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||||
|
config, exists := bm.configLoader.GetBackendConfig(model)
|
||||||
|
var backend string
|
||||||
|
if exists {
|
||||||
|
backend = config.Model
|
||||||
|
} else {
|
||||||
|
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||||
|
backend = model
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasSuffix(backend, ".bin") {
|
||||||
|
backend = fmt.Sprintf("%s.bin", backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := bm.modelLoader.GetGRPCPID(backend)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
||||||
|
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memInfo, err := backendProcess.MemoryInfo()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
memPercent, err := backendProcess.MemoryPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cpuPercent, err := backendProcess.CPUPercent()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &schema.BackendMonitorResponse{
|
||||||
|
MemoryInfo: memInfo,
|
||||||
|
MemoryPercent: memPercent,
|
||||||
|
CPUPercent: cpuPercent,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
|
||||||
|
if modelAddr == "" {
|
||||||
|
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||||
|
}
|
||||||
|
|
||||||
|
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
|
||||||
|
if rpcErr != nil {
|
||||||
|
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||||
|
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||||
|
if slbErr != nil {
|
||||||
|
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||||
|
}
|
||||||
|
return &proto.StatusResponse{
|
||||||
|
State: proto.StatusResponse_ERROR,
|
||||||
|
Memory: &proto.MemoryUsageData{
|
||||||
|
Total: val.MemoryInfo.VMS,
|
||||||
|
Breakdown: map[string]uint64{
|
||||||
|
"gopsutil-RSS": val.MemoryInfo.RSS,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bm BackendMonitor) ShutdownModel(modelName string) error {
|
||||||
|
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bm.modelLoader.ShutdownModel(backendId)
|
||||||
|
}
|
||||||
167
core/services/gallery.go
Normal file
167
core/services/gallery.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GalleryService struct {
|
||||||
|
modelPath string
|
||||||
|
sync.Mutex
|
||||||
|
C chan gallery.GalleryOp
|
||||||
|
statuses map[string]*gallery.GalleryOpStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGalleryService(modelPath string) *GalleryService {
|
||||||
|
return &GalleryService{
|
||||||
|
modelPath: modelPath,
|
||||||
|
C: make(chan gallery.GalleryOp),
|
||||||
|
statuses: make(map[string]*gallery.GalleryOpStatus),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
|
||||||
|
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||||
|
|
||||||
|
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
g.statuses[s] = op
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryService) GetStatus(s string) *gallery.GalleryOpStatus {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
return g.statuses[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryService) GetAllStatus() map[string]*gallery.GalleryOpStatus {
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
|
||||||
|
return g.statuses
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
return
|
||||||
|
case op := <-g.C:
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
|
||||||
|
|
||||||
|
// updates the status with an error
|
||||||
|
updateError := func(e error) {
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// displayDownload displays the download progress
|
||||||
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
|
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||||
|
if op.GalleryName != "" {
|
||||||
|
if strings.Contains(op.GalleryName, "@") {
|
||||||
|
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload models
|
||||||
|
err = cl.LoadBackendConfigsFromPath(g.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cl.Preload(g.modelPath)
|
||||||
|
if err != nil {
|
||||||
|
updateError(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
type galleryModel struct {
|
||||||
|
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
||||||
|
ID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||||
|
var err error
|
||||||
|
for _, r := range requests {
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
if r.ID == "" {
|
||||||
|
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||||
|
} else {
|
||||||
|
if strings.Contains(r.ID, "@") {
|
||||||
|
err = gallery.InstallModelFromGallery(
|
||||||
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGalleryByName(
|
||||||
|
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
|
||||||
|
dat, err := os.ReadFile(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var requests []galleryModel
|
||||||
|
|
||||||
|
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return processRequests(modelPath, s, cl, galleries, requests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
|
||||||
|
var requests []galleryModel
|
||||||
|
err := json.Unmarshal([]byte(s), &requests)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return processRequests(modelPath, s, cl, galleries, requests)
|
||||||
|
}
|
||||||
54
core/services/metrics.go
Normal file
54
core/services/metrics.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
metricApi "go.opentelemetry.io/otel/sdk/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LocalAIMetricsService struct {
|
||||||
|
Meter metric.Meter
|
||||||
|
ApiTimeMetric metric.Float64Histogram
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LocalAIMetricsService) ObserveAPICall(method string, path string, duration float64) {
|
||||||
|
opts := metric.WithAttributes(
|
||||||
|
attribute.String("method", method),
|
||||||
|
attribute.String("path", path),
|
||||||
|
)
|
||||||
|
m.ApiTimeMetric.Record(context.Background(), duration, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
||||||
|
// If it does not return an error, make sure to call shutdown for proper cleanup.
|
||||||
|
func NewLocalAIMetricsService() (*LocalAIMetricsService, error) {
|
||||||
|
exporter, err := prometheus.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
provider := metricApi.NewMeterProvider(metricApi.WithReader(exporter))
|
||||||
|
meter := provider.Meter("github.com/go-skynet/LocalAI")
|
||||||
|
|
||||||
|
apiTimeMetric, err := meter.Float64Histogram("api_call", metric.WithDescription("api calls"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LocalAIMetricsService{
|
||||||
|
Meter: meter,
|
||||||
|
ApiTimeMetric: apiTimeMetric,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lams LocalAIMetricsService) Shutdown() error {
|
||||||
|
// TODO: Not sure how to actually do this:
|
||||||
|
//// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
||||||
|
//// If it does not return an error, make sure to call shutdown for proper cleanup.
|
||||||
|
|
||||||
|
log.Warn().Msgf("LocalAIMetricsService Shutdown called, but OTelSDK proper shutdown not yet implemented?")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
100
core/startup/config_file_watcher.go
Normal file
100
core/startup/config_file_watcher.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package startup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/imdario/mergo"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WatchConfigDirectoryCloser func() error
|
||||||
|
|
||||||
|
func ReadApiKeysJson(configDir string, appConfig *config.ApplicationConfig) error {
|
||||||
|
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
|
||||||
|
if err == nil {
|
||||||
|
// Parse JSON content from the file
|
||||||
|
var fileKeys []string
|
||||||
|
err := json.Unmarshal(fileContent, &fileKeys)
|
||||||
|
if err == nil {
|
||||||
|
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadExternalBackendsJson(configDir string, appConfig *config.ApplicationConfig) error {
|
||||||
|
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Parse JSON content from the file
|
||||||
|
var fileBackends map[string]string
|
||||||
|
err = json.Unmarshal(fileContent, &fileBackends)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = mergo.Merge(&appConfig.ExternalGRPCBackends, fileBackends)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var CONFIG_FILE_UPDATES = map[string]func(configDir string, appConfig *config.ApplicationConfig) error{
|
||||||
|
"api_keys.json": ReadApiKeysJson,
|
||||||
|
"external_backends.json": ReadExternalBackendsJson,
|
||||||
|
}
|
||||||
|
|
||||||
|
func WatchConfigDirectory(configDir string, appConfig *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) {
|
||||||
|
if len(configDir) == 0 {
|
||||||
|
return nil, fmt.Errorf("configDir blank")
|
||||||
|
}
|
||||||
|
configWatcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
|
||||||
|
}
|
||||||
|
ret := func() error {
|
||||||
|
configWatcher.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start listening for events.
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-configWatcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Has(fsnotify.Write) {
|
||||||
|
for targetName, watchFn := range CONFIG_FILE_UPDATES {
|
||||||
|
if event.Name == targetName {
|
||||||
|
err := watchFn(configDir, appConfig)
|
||||||
|
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case _, ok := <-configWatcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add a path.
|
||||||
|
err = configWatcher.Add(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
128
core/startup/startup.go
Normal file
128
core/startup/startup.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package startup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/core/services"
|
||||||
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
pkgStartup "github.com/go-skynet/LocalAI/pkg/startup"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
|
||||||
|
options := config.NewApplicationConfig(opts...)
|
||||||
|
|
||||||
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
|
if options.Debug {
|
||||||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||||
|
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||||
|
|
||||||
|
// Make sure directories exists
|
||||||
|
if options.ModelPath == "" {
|
||||||
|
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||||
|
}
|
||||||
|
err := os.MkdirAll(options.ModelPath, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||||
|
}
|
||||||
|
if options.ImageDir != "" {
|
||||||
|
err := os.MkdirAll(options.ImageDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.AudioDir != "" {
|
||||||
|
err := os.MkdirAll(options.AudioDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.UploadDir != "" {
|
||||||
|
err := os.MkdirAll(options.UploadDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...)
|
||||||
|
|
||||||
|
cl := config.NewBackendConfigLoader()
|
||||||
|
ml := model.NewModelLoader(options.ModelPath)
|
||||||
|
|
||||||
|
if err := cl.LoadBackendConfigsFromPath(options.ModelPath); err != nil {
|
||||||
|
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.ConfigFile != "" {
|
||||||
|
if err := cl.LoadBackendConfigFile(options.ConfigFile); err != nil {
|
||||||
|
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cl.Preload(options.ModelPath); err != nil {
|
||||||
|
log.Error().Msgf("error downloading models: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.PreloadJSONModels != "" {
|
||||||
|
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.PreloadModelsFromPath != "" {
|
||||||
|
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.Debug {
|
||||||
|
for _, v := range cl.ListBackendConfigs() {
|
||||||
|
cfg, _ := cl.GetBackendConfig(v)
|
||||||
|
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.AssetsDestination != "" {
|
||||||
|
// Extract files from the embedded FS
|
||||||
|
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||||
|
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// turn off any process that was started by GRPC if the context is canceled
|
||||||
|
go func() {
|
||||||
|
<-options.Context.Done()
|
||||||
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
|
ml.StopAllGRPC()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if options.WatchDog {
|
||||||
|
wd := model.NewWatchDog(
|
||||||
|
ml,
|
||||||
|
options.WatchDogBusyTimeout,
|
||||||
|
options.WatchDogIdleTimeout,
|
||||||
|
options.WatchDogBusy,
|
||||||
|
options.WatchDogIdle)
|
||||||
|
ml.SetWatchDog(wd)
|
||||||
|
go wd.Run()
|
||||||
|
go func() {
|
||||||
|
<-options.Context.Done()
|
||||||
|
log.Debug().Msgf("Context canceled, shutting down")
|
||||||
|
wd.Shutdown()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msg("core/startup process completed!")
|
||||||
|
return cl, ml, options, nil
|
||||||
|
}
|
||||||
@@ -2,15 +2,30 @@ version: '3.6'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
api:
|
api:
|
||||||
image: quay.io/go-skynet/local-ai:latest
|
# See https://localai.io/basics/getting_started/#container-images for
|
||||||
|
# a list of available container images (or build your own with the provided Dockerfile)
|
||||||
|
# Available images with CUDA, ROCm, SYCL
|
||||||
|
# Image list (quay.io): https://quay.io/repository/go-skynet/local-ai?tab=tags
|
||||||
|
# Image list (dockerhub): https://hub.docker.com/r/localai/localai
|
||||||
|
image: quay.io/go-skynet/local-ai:master-ffmpeg-core
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
- IMAGE_TYPE=core
|
||||||
|
- BASE_IMAGE=ubuntu:22.04
|
||||||
ports:
|
ports:
|
||||||
- 8080:8080
|
- 8080:8080
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
environment:
|
||||||
|
- MODELS_PATH=/models
|
||||||
|
# - DEBUG=true
|
||||||
volumes:
|
volumes:
|
||||||
- ./models:/models:cached
|
- ./models:/models:cached
|
||||||
- ./images/:/tmp/generated/images/
|
- ./images/:/tmp/generated/images/
|
||||||
command: ["/usr/bin/local-ai" ]
|
command:
|
||||||
|
# Here we can specify a list of models to run (see quickstart https://localai.io/basics/getting_started/#running-models )
|
||||||
|
# or an URL pointing to a YAML configuration file, for example:
|
||||||
|
# - https://gist.githubusercontent.com/mudler/ad601a0488b497b69ec549150d9edd18/raw/a8a8869ef1bb7e3830bf5c0bae29a0cce991ff8d/phi-2.yaml
|
||||||
|
- phi-2
|
||||||
|
|||||||
@@ -130,13 +130,14 @@ parameters:
|
|||||||
typical_p:
|
typical_p:
|
||||||
tfz:
|
tfz:
|
||||||
frequency_penalty:
|
frequency_penalty:
|
||||||
mirostat_eta:
|
|
||||||
mirostat_tau:
|
|
||||||
mirostat:
|
|
||||||
rope_freq_base:
|
rope_freq_base:
|
||||||
rope_freq_scale:
|
rope_freq_scale:
|
||||||
negative_prompt_scale:
|
negative_prompt_scale:
|
||||||
|
|
||||||
|
mirostat_eta:
|
||||||
|
mirostat_tau:
|
||||||
|
mirostat:
|
||||||
# Default context size
|
# Default context size
|
||||||
context_size: 512
|
context_size: 512
|
||||||
# Default number of threads
|
# Default number of threads
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
|
|
||||||
+++
|
+++
|
||||||
disableToc = false
|
disableToc = false
|
||||||
title = "🔥 OpenAI functions"
|
title = "🔥 OpenAI functions and tools"
|
||||||
weight = 17
|
weight = 17
|
||||||
url = "/features/openai-functions/"
|
url = "/features/openai-functions/"
|
||||||
+++
|
+++
|
||||||
|
|
||||||
LocalAI supports running OpenAI functions with `llama.cpp` compatible models.
|
LocalAI supports running OpenAI [functions and tools API](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools) with `llama.cpp` compatible models.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
To learn more about OpenAI functions, see the [OpenAI API blog post](https://openai.com/blog/function-calling-and-other-api-updates).
|
To learn more about OpenAI functions, see also the [OpenAI API blog post](https://openai.com/blog/function-calling-and-other-api-updates).
|
||||||
|
|
||||||
|
LocalAI is also supporting [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode) out of the box with llama.cpp-compatible models.
|
||||||
|
|
||||||
💡 Check out also [LocalAGI](https://github.com/mudler/LocalAGI) for an example on how to use LocalAI functions.
|
💡 Check out also [LocalAGI](https://github.com/mudler/LocalAGI) for an example on how to use LocalAI functions.
|
||||||
|
|
||||||
@@ -78,6 +80,26 @@ When running the python script, be sure to:
|
|||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
### Parallel tools calls
|
||||||
|
|
||||||
|
This feature is experimental and has to be configured in the YAML of the model by enabling `function.parallel_calls`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: gpt-3.5-turbo
|
||||||
|
parameters:
|
||||||
|
# Model file name
|
||||||
|
model: ggml-openllama.bin
|
||||||
|
top_p: 80
|
||||||
|
top_k: 0.9
|
||||||
|
temperature: 0.1
|
||||||
|
|
||||||
|
function:
|
||||||
|
# set to true to allow the model to call multiple functions in parallel
|
||||||
|
parallel_calls: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use functions with grammar
|
||||||
|
|
||||||
It is possible to also specify the full function signature (for debugging, or to use with other clients).
|
It is possible to also specify the full function signature (for debugging, or to use with other clients).
|
||||||
|
|
||||||
The chat endpoint accepts the `grammar_json_functions` additional parameter which takes a JSON schema object.
|
The chat endpoint accepts the `grammar_json_functions` additional parameter which takes a JSON schema object.
|
||||||
|
|||||||
@@ -245,8 +245,18 @@ backend: vllm
|
|||||||
parameters:
|
parameters:
|
||||||
model: "facebook/opt-125m"
|
model: "facebook/opt-125m"
|
||||||
|
|
||||||
# Decomment to specify a quantization method (optional)
|
# Uncomment to specify a quantization method (optional)
|
||||||
# quantization: "awq"
|
# quantization: "awq"
|
||||||
|
# Uncomment to limit the GPU memory utilization (vLLM default is 0.9 for 90%)
|
||||||
|
# gpu_memory_utilization: 0.5
|
||||||
|
# Uncomment to trust remote code from huggingface
|
||||||
|
# trust_remote_code: true
|
||||||
|
# Uncomment to enable eager execution
|
||||||
|
# enforce_eager: true
|
||||||
|
# Uncomment to specify the size of the CPU swap space per GPU (in GiB)
|
||||||
|
# swap_space: 2
|
||||||
|
# Uncomment to specify the maximum length of a sequence (including prompt and output)
|
||||||
|
# max_model_len: 32768
|
||||||
```
|
```
|
||||||
|
|
||||||
The backend will automatically download the required files in order to run the model.
|
The backend will automatically download the required files in order to run the model.
|
||||||
@@ -262,3 +272,56 @@ curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d
|
|||||||
"temperature": 0.1, "top_p": 0.1
|
"temperature": 0.1, "top_p": 0.1
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Transformers
|
||||||
|
|
||||||
|
[Transformers](https://huggingface.co/docs/transformers/index) is a State-of-the-art Machine Learning library for PyTorch, TensorFlow, and JAX.
|
||||||
|
|
||||||
|
LocalAI has a built-in integration with Transformers, and it can be used to run models.
|
||||||
|
|
||||||
|
This is an extra backend - in the container images (the `extra` images already contains python dependencies for Transformers) is already available and there is nothing to do for the setup.
|
||||||
|
|
||||||
|
#### Setup
|
||||||
|
|
||||||
|
Create a YAML file for the model you want to use with `transformers`.
|
||||||
|
|
||||||
|
To setup a model, you need to just specify the model name in the YAML config file:
|
||||||
|
```yaml
|
||||||
|
name: transformers
|
||||||
|
backend: transformers
|
||||||
|
parameters:
|
||||||
|
model: "facebook/opt-125m"
|
||||||
|
type: AutoModelForCausalLM
|
||||||
|
quantization: bnb_4bit # One of: bnb_8bit, bnb_4bit, xpu_4bit (optional)
|
||||||
|
```
|
||||||
|
|
||||||
|
The backend will automatically download the required files in order to run the model.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
##### Type
|
||||||
|
|
||||||
|
| Type | Description |
|
||||||
|
| --- | --- |
|
||||||
|
| `AutoModelForCausalLM` | `AutoModelForCausalLM` is a model that can be used to generate sequences. |
|
||||||
|
| N/A | Defaults to `AutoModel` |
|
||||||
|
|
||||||
|
|
||||||
|
##### Quantization
|
||||||
|
|
||||||
|
| Quantization | Description |
|
||||||
|
| --- | --- |
|
||||||
|
| `bnb_8bit` | 8-bit quantization |
|
||||||
|
| `bnb_4bit` | 4-bit quantization |
|
||||||
|
| `xpu_4bit` | 4-bit quantization for Intel XPUs |
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
Use the `completions` endpoint by specifying the `transformers` model:
|
||||||
|
```
|
||||||
|
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||||
|
"model": "transformers",
|
||||||
|
"prompt": "Hello, my name is",
|
||||||
|
"temperature": 0.1, "top_p": 0.1
|
||||||
|
}'
|
||||||
|
```
|
||||||
@@ -6,7 +6,13 @@ weight = 11
|
|||||||
url = "/features/text-to-audio/"
|
url = "/features/text-to-audio/"
|
||||||
+++
|
+++
|
||||||
|
|
||||||
The `/tts` endpoint can be used to generate speech from text.
|
## API Compatibility
|
||||||
|
|
||||||
|
The LocalAI TTS API is compatible with the [OpenAI TTS API](https://platform.openai.com/docs/guides/text-to-speech) and the [Elevenlabs](https://api.elevenlabs.io/docs) API.
|
||||||
|
|
||||||
|
## LocalAI API
|
||||||
|
|
||||||
|
The `/tts` endpoint can also be used to generate speech from text.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ The list below is a list of software that integrates with LocalAI.
|
|||||||
|
|
||||||
- [AnythingLLM](https://github.com/Mintplex-Labs/anything-llm)
|
- [AnythingLLM](https://github.com/Mintplex-Labs/anything-llm)
|
||||||
- [Logseq GPT3 OpenAI plugin](https://github.com/briansunter/logseq-plugin-gpt3-openai) allows to set a base URL, and works with LocalAI.
|
- [Logseq GPT3 OpenAI plugin](https://github.com/briansunter/logseq-plugin-gpt3-openai) allows to set a base URL, and works with LocalAI.
|
||||||
|
- https://plugins.jetbrains.com/plugin/21056-codegpt allows for custom OpenAI compatible endpoints since 2.4.0
|
||||||
- https://github.com/longy2k/obsidian-bmo-chatbot
|
- https://github.com/longy2k/obsidian-bmo-chatbot
|
||||||
- https://github.com/FlowiseAI/Flowise
|
- https://github.com/FlowiseAI/Flowise
|
||||||
- https://github.com/k8sgpt-ai/k8sgpt
|
- https://github.com/k8sgpt-ai/k8sgpt
|
||||||
@@ -25,5 +26,6 @@ The list below is a list of software that integrates with LocalAI.
|
|||||||
- https://github.com/charmbracelet/mods
|
- https://github.com/charmbracelet/mods
|
||||||
- https://github.com/cedriking/spark
|
- https://github.com/cedriking/spark
|
||||||
- [Big AGI](https://github.com/enricoros/big-agi) is a powerful web interface entirely running in the browser, supporting LocalAI
|
- [Big AGI](https://github.com/enricoros/big-agi) is a powerful web interface entirely running in the browser, supporting LocalAI
|
||||||
|
- [Midori AI Subsystem Manager](https://io.midori-ai.xyz/subsystem/manager/) is a powerful docker subsystem for running all types of AI programs
|
||||||
|
|
||||||
Feel free to open up a Pull request (by clicking at the "Edit page" below) to get a page for your project made or if you see a error on one of the pages!
|
Feel free to open up a Pull request (by clicking at the "Edit page" below) to get a page for your project made or if you see a error on one of the pages!
|
||||||
|
|||||||
@@ -111,6 +111,3 @@ This is a community project, a special thanks to our contributors! 🤗
|
|||||||
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
||||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/go-skynet/LocalAI-website/graphs/contributors">
|
|
||||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI-website" />
|
|
||||||
</a>
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"version": "v2.8.2"
|
"version": "v2.9.0"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,14 @@ name: codellama-7b-gguf
|
|||||||
backend: transformers
|
backend: transformers
|
||||||
parameters:
|
parameters:
|
||||||
model: huggingface://TheBloke/CodeLlama-7B-GGUF/codellama-7b.Q4_K_M.gguf
|
model: huggingface://TheBloke/CodeLlama-7B-GGUF/codellama-7b.Q4_K_M.gguf
|
||||||
temperature: 0.2
|
temperature: 0.5
|
||||||
top_k: 40
|
top_k: 40
|
||||||
seed: -1
|
seed: -1
|
||||||
top_p: 0.95
|
top_p: 0.95
|
||||||
|
mirostat: 2
|
||||||
|
mirostat_eta: 1.0
|
||||||
|
mirostat_tau: 1.0
|
||||||
|
|
||||||
context_size: 4096
|
context_size: 4096
|
||||||
f16: true
|
f16: true
|
||||||
gpu_layers: 90
|
gpu_layers: 90
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user