Compare commits

..

1 Commits

Author SHA1 Message Date
Ettore Di Giacinto
47743b74ab Revert "Revert #1963 (#2056)"
This reverts commit af9e5a2d05.
2024-04-17 23:36:17 +02:00
210 changed files with 4004 additions and 10341 deletions

View File

@@ -5,7 +5,4 @@ models
examples/chatbot-ui/models
examples/rwkv/models
examples/**/models
Dockerfile*
# SonarQube
.scannerwork
Dockerfile*

4
.env
View File

@@ -10,7 +10,7 @@
#
## Define galleries.
## models will to install will be visible in `/models/available`
# LOCALAI_GALLERIES=[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml@master"}]
# LOCALAI_GALLERIES=[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}]
## CORS settings
# LOCALAI_CORS=true
@@ -86,4 +86,4 @@
# LOCALAI_WATCHDOG_BUSY=true
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m

View File

@@ -2,6 +2,6 @@
set -xe
REPO=$1
LATEST_TAG=$(curl -s "https://api.github.com/repos/$REPO/releases/latest" | jq -r '.tag_name')
LATEST_TAG=$(curl -s "https://api.github.com/repos/$REPO/releases/latest" | jq -r '.name')
cat <<< $(jq ".version = \"$LATEST_TAG\"" docs/data/version.json) > docs/data/version.json

7
.github/labeler.yml vendored
View File

@@ -8,11 +8,6 @@ kind/documentation:
- changed-files:
- any-glob-to-any-file: '*.md'
area/ai-model:
- any:
- changed-files:
- any-glob-to-any-file: 'gallery/*'
examples:
- any:
- changed-files:
@@ -21,4 +16,4 @@ examples:
ci:
- any:
- changed-files:
- any-glob-to-any-file: '.github/*'
- any-glob-to-any-file: '.github/*'

View File

@@ -14,7 +14,7 @@ jobs:
steps:
- name: Dependabot metadata
id: metadata
uses: dependabot/fetch-metadata@v2.1.0
uses: dependabot/fetch-metadata@v2.0.0
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
skip-commit-verification: true

View File

@@ -1,94 +0,0 @@
name: 'generate and publish GRPC docker caches'
on:
workflow_dispatch:
push:
branches:
- master
concurrency:
group: grpc-cache-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true
jobs:
generate_caches:
strategy:
matrix:
include:
- grpc-base-image: ubuntu:22.04
runs-on: 'ubuntu-latest'
platforms: 'linux/amd64'
runs-on: ${{matrix.runs-on}}
steps:
- name: Release space from worker
if: matrix.runs-on == 'ubuntu-latest'
run: |
echo "Listing top largest packages"
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
head -n 30 <<< "${pkgs}"
echo
df -h
echo
sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
sudo apt-get remove --auto-remove android-sdk-platform-tools || true
sudo apt-get purge --auto-remove android-sdk-platform-tools || true
sudo rm -rf /usr/local/lib/android
sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
sudo rm -rf /usr/share/dotnet
sudo apt-get remove -y '^mono-.*' || true
sudo apt-get remove -y '^ghc-.*' || true
sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
sudo apt-get remove -y 'php.*' || true
sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
sudo apt-get remove -y '^google-.*' || true
sudo apt-get remove -y azure-cli || true
sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
sudo apt-get remove -y '^gfortran-.*' || true
sudo apt-get remove -y microsoft-edge-stable || true
sudo apt-get remove -y firefox || true
sudo apt-get remove -y powershell || true
sudo apt-get remove -y r-base-core || true
sudo apt-get autoremove -y
sudo apt-get clean
echo
echo "Listing top largest packages"
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
head -n 30 <<< "${pkgs}"
echo
sudo rm -rfv build || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf "/usr/local/share/boost" || true
sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@master
with:
platforms: all
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@master
- name: Checkout
uses: actions/checkout@v4
- name: Cache GRPC
uses: docker/build-push-action@v5
with:
builder: ${{ steps.buildx.outputs.name }}
# The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache.
# This means that even the MAKEFLAGS have to be an EXACT match.
# If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch.
build-args: |
GRPC_BASE_IMAGE=${{ matrix.grpc-base-image }}
GRPC_MAKEFLAGS=--jobs=4 --output-sync=target
GRPC_VERSION=v1.63.0
context: .
file: ./Dockerfile
cache-to: type=gha,ignore-error=true
cache-from: type=gha
target: grpc
platforms: ${{ matrix.platforms }}
push: false

View File

@@ -22,7 +22,6 @@ jobs:
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
grpc-base-image: ${{ matrix.grpc-base-image }}
makeflags: ${{ matrix.makeflags }}
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -62,14 +61,12 @@ jobs:
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: 'sycl-f16-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
@@ -88,7 +85,6 @@ jobs:
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
grpc-base-image: ${{ matrix.grpc-base-image }}
makeflags: ${{ matrix.makeflags }}
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -106,12 +102,11 @@ jobs:
image-type: 'core'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: 'sycl-f16-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
@@ -127,4 +122,4 @@ jobs:
image-type: 'core'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"

View File

@@ -26,7 +26,6 @@ jobs:
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
grpc-base-image: ${{ matrix.grpc-base-image }}
aio: ${{ matrix.aio }}
makeflags: ${{ matrix.makeflags }}
latest-image: ${{ matrix.latest-image }}
@@ -130,7 +129,6 @@ jobs:
image-type: 'extras'
aio: "-aio-gpu-hipblas"
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
grpc-base-image: "ubuntu:22.04"
latest-image: 'latest-gpu-hipblas'
latest-image-aio: 'latest-aio-gpu-hipblas'
runs-on: 'arc-runner-set'
@@ -142,14 +140,12 @@ jobs:
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'auto'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f16-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
@@ -161,8 +157,7 @@ jobs:
- build-type: 'sycl_f32'
platforms: 'linux/amd64'
tag-latest: 'auto'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f32-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
@@ -175,8 +170,7 @@ jobs:
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f16-core'
ffmpeg: 'false'
image-type: 'core'
@@ -185,8 +179,7 @@ jobs:
- build-type: 'sycl_f32'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f32-core'
ffmpeg: 'false'
image-type: 'core'
@@ -195,8 +188,7 @@ jobs:
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f16-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
@@ -205,8 +197,7 @@ jobs:
- build-type: 'sycl_f32'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04"
grpc-base-image: "ubuntu:22.04"
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f32-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
@@ -219,7 +210,6 @@ jobs:
ffmpeg: 'true'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
@@ -229,7 +219,6 @@ jobs:
ffmpeg: 'false'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
@@ -247,7 +236,6 @@ jobs:
runs-on: ${{ matrix.runs-on }}
aio: ${{ matrix.aio }}
base-image: ${{ matrix.base-image }}
grpc-base-image: ${{ matrix.grpc-base-image }}
makeflags: ${{ matrix.makeflags }}
latest-image: ${{ matrix.latest-image }}
latest-image-aio: ${{ matrix.latest-image-aio }}
@@ -270,7 +258,7 @@ jobs:
aio: "-aio-cpu"
latest-image: 'latest-cpu'
latest-image-aio: 'latest-aio-cpu'
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"
- build-type: 'cublas'
cuda-major-version: "11"
cuda-minor-version: "7"
@@ -281,7 +269,7 @@ jobs:
image-type: 'core'
base-image: "ubuntu:22.04"
runs-on: 'ubuntu-latest'
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "1"
@@ -292,7 +280,7 @@ jobs:
image-type: 'core'
base-image: "ubuntu:22.04"
runs-on: 'ubuntu-latest'
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"
- build-type: 'cublas'
cuda-major-version: "11"
cuda-minor-version: "7"
@@ -303,7 +291,7 @@ jobs:
image-type: 'core'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"
- build-type: 'cublas'
cuda-major-version: "12"
cuda-minor-version: "1"
@@ -314,4 +302,4 @@ jobs:
image-type: 'core'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
makeflags: "--jobs=4 --output-sync=target"
makeflags: "--jobs=5 --output-sync=target"

View File

@@ -6,10 +6,6 @@ on:
inputs:
base-image:
description: 'Base image'
required: true
type: string
grpc-base-image:
description: 'GRPC Base image, must be a compatible image with base-image'
required: false
default: ''
type: string
@@ -61,7 +57,7 @@ on:
makeflags:
description: 'Make Flags'
required: false
default: '--jobs=4 --output-sync=target'
default: '--jobs=3 --output-sync=target'
type: string
aio:
description: 'AIO Image Name'
@@ -201,14 +197,29 @@ jobs:
username: ${{ secrets.quayUsername }}
password: ${{ secrets.quayPassword }}
- name: Cache GRPC
uses: docker/build-push-action@v5
with:
builder: ${{ steps.buildx.outputs.name }}
build-args: |
IMAGE_TYPE=${{ inputs.image-type }}
BASE_IMAGE=${{ inputs.base-image }}
MAKEFLAGS=${{ inputs.makeflags }}
GRPC_VERSION=v1.58.0
context: .
file: ./Dockerfile
cache-from: type=gha
cache-to: type=gha,ignore-error=true
target: grpc
platforms: ${{ inputs.platforms }}
push: false
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Build and push
uses: docker/build-push-action@v5
with:
builder: ${{ steps.buildx.outputs.name }}
# The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache.
# This means that even the MAKEFLAGS have to be an EXACT match.
# If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch.
# This is why some build args like GRPC_VERSION and MAKEFLAGS are hardcoded
build-args: |
BUILD_TYPE=${{ inputs.build-type }}
CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
@@ -216,9 +227,6 @@ jobs:
FFMPEG=${{ inputs.ffmpeg }}
IMAGE_TYPE=${{ inputs.image-type }}
BASE_IMAGE=${{ inputs.base-image }}
GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }}
GRPC_MAKEFLAGS=--jobs=4 --output-sync=target
GRPC_VERSION=v1.63.0
MAKEFLAGS=${{ inputs.makeflags }}
context: .
file: ./Dockerfile
@@ -228,6 +236,14 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Inspect image
if: github.event_name != 'pull_request'
run: |
docker pull localai/localai:${{ steps.meta.outputs.version }}
docker image inspect localai/localai:${{ steps.meta.outputs.version }}
docker pull quay.io/go-skynet/local-ai:${{ steps.meta.outputs.version }}
docker image inspect quay.io/go-skynet/local-ai:${{ steps.meta.outputs.version }}
- name: Build and push AIO image
if: inputs.aio != ''
uses: docker/build-push-action@v5

View File

@@ -5,7 +5,7 @@ on:
- pull_request
env:
GRPC_VERSION: v1.63.0
GRPC_VERSION: v1.58.0
permissions:
contents: write

View File

@@ -34,7 +34,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
@@ -64,7 +64,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
@@ -74,37 +74,6 @@ jobs:
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers
make --jobs=5 --output-sync=target -C backend/python/sentencetransformers test
tests-rerankers:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v4
with:
submodules: true
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential ffmpeg
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
sudo apt-get update && \
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
sudo rm -rfv /usr/bin/conda || true
- name: Test rerankers
run: |
export PATH=$PATH:/opt/conda/bin
make --jobs=5 --output-sync=target -C backend/python/rerankers
make --jobs=5 --output-sync=target -C backend/python/rerankers test
tests-diffusers:
runs-on: ubuntu-latest
steps:
@@ -125,7 +94,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
@@ -155,7 +124,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
@@ -185,7 +154,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
@@ -217,7 +186,7 @@ jobs:
# sudo apt-get install -y conda
# sudo apt-get install -y ca-certificates cmake curl patch python3-pip
# sudo apt-get install -y libopencv-dev
# pip install --user grpcio-tools==1.63.0
# pip install --user grpcio-tools
# sudo rm -rfv /usr/bin/conda || true
@@ -289,7 +258,7 @@ jobs:
# sudo apt-get install -y conda
# sudo apt-get install -y ca-certificates cmake curl patch python3-pip
# sudo apt-get install -y libopencv-dev
# pip install --user grpcio-tools==1.63.0
# pip install --user grpcio-tools
# sudo rm -rfv /usr/bin/conda || true
@@ -322,7 +291,7 @@ jobs:
# sudo apt-get install -y conda
# sudo apt-get install -y ca-certificates cmake curl patch python3-pip
# sudo apt-get install -y libopencv-dev
# pip install --user grpcio-tools==1.63.0
# pip install --user grpcio-tools
# sudo rm -rfv /usr/bin/conda || true
# - name: Test vllm
# run: |
@@ -349,7 +318,7 @@ jobs:
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
sudo apt-get install -y libopencv-dev
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
- name: Test vall-e-x
run: |
@@ -376,7 +345,7 @@ jobs:
sudo apt-get update && \
sudo apt-get install -y conda
sudo apt-get install -y ca-certificates cmake curl patch espeak espeak-ng python3-pip
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
sudo rm -rfv /usr/bin/conda || true
- name: Test coqui

View File

@@ -10,7 +10,7 @@ on:
- '*'
env:
GRPC_VERSION: v1.63.0
GRPC_VERSION: v1.58.0
concurrency:
group: ci-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
@@ -123,9 +123,7 @@ jobs:
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.18
with:
detached: true
connect-timeout-seconds: 180
limit-access-to-actor: true
tests-aio-container:
runs-on: ubuntu-latest
@@ -178,9 +176,7 @@ jobs:
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.18
with:
detached: true
connect-timeout-seconds: 180
limit-access-to-actor: true
tests-apple:
runs-on: macOS-14
@@ -203,7 +199,7 @@ jobs:
- name: Dependencies
run: |
brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc
pip install --user grpcio-tools==1.63.0
pip install --user grpcio-tools
- name: Test
run: |
export C_INCLUDE_PATH=/usr/local/include
@@ -215,6 +211,4 @@ jobs:
if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.18
with:
detached: true
connect-timeout-seconds: 180
limit-access-to-actor: true
connect-timeout-seconds: 180

View File

@@ -1,31 +0,0 @@
name: Update swagger
on:
schedule:
- cron: 0 20 * * *
workflow_dispatch:
jobs:
swagger:
strategy:
fail-fast: false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: 'stable'
- run: |
go install github.com/swaggo/swag/cmd/swag@latest
- name: Bump swagger 🔧
run: |
make swagger
- name: Create Pull Request
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.UPDATE_BOT_TOKEN }}
push-to-fork: ci-forks/LocalAI
commit-message: 'feat(swagger): update swagger'
title: 'feat(swagger): update swagger'
branch: "update/swagger"
body: Update swagger
signoff: true

View File

@@ -1,18 +0,0 @@
name: 'Yamllint GitHub Actions'
on:
- pull_request
jobs:
yamllint:
name: 'Yamllint'
runs-on: ubuntu-latest
steps:
- name: 'Checkout'
uses: actions/checkout@master
- name: 'Yamllint'
uses: karancode/yamllint-github-action@master
with:
yamllint_file_or_dir: 'gallery'
yamllint_strict: false
yamllint_comment: true
env:
GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}

3
.gitignore vendored
View File

@@ -44,6 +44,3 @@ prepare
*.pb.go
*pb2.py
*pb2_grpc.py
# SonarQube
.scannerwork

View File

@@ -1,4 +0,0 @@
extends: default
rules:
line-length: disable

View File

@@ -1,43 +1,41 @@
ARG IMAGE_TYPE=extras
ARG BASE_IMAGE=ubuntu:22.04
ARG GRPC_BASE_IMAGE=${BASE_IMAGE}
# The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it.
FROM ${BASE_IMAGE} AS requirements-core
# extras or core
FROM ${BASE_IMAGE} as requirements-core
USER root
ARG GO_VERSION=1.21.7
ARG BUILD_TYPE
ARG CUDA_MAJOR_VERSION=11
ARG CUDA_MINOR_VERSION=7
ARG TARGETARCH
ARG TARGETVARIANT
ENV BUILD_TYPE=${BUILD_TYPE}
ENV DEBIAN_FRONTEND=noninteractive
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,rerankers:/build/backend/python/rerankers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh"
ARG GO_TAGS="stablediffusion tinydream tts"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
cmake \
curl \
git \
python3-pip \
python-is-python3 \
unzip && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* && \
pip install --upgrade pip
apt-get install -y ca-certificates curl python3-pip unzip && apt-get clean
# Install Go
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
ENV PATH $PATH:/root/go/bin:/usr/local/go/bin
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -C /usr/local -xz
ENV PATH $PATH:/usr/local/go/bin
# Install grpc compilers
ENV PATH $PATH:/root/go/bin
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest && \
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
# Install protobuf (the version in 22.04 is too old)
RUN curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
# Install grpcio-tools (the version in 22.04 is too old)
RUN pip install --user grpcio-tools
@@ -48,6 +46,16 @@ RUN update-ca-certificates
RUN echo "Target Architecture: $TARGETARCH"
RUN echo "Target Variant: $TARGETVARIANT"
# CuBLAS requirements
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
apt-get install -y software-properties-common && \
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
rm -f cuda-keyring_1.1-1_all.deb && \
apt-get update && \
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && apt-get clean \
; fi
# Cuda
ENV PATH /usr/local/cuda/bin:${PATH}
@@ -55,12 +63,10 @@ ENV PATH /usr/local/cuda/bin:${PATH}
ENV PATH /opt/rocm/bin:${PATH}
# OpenBLAS requirements and stable diffusion
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libopenblas-dev \
libopencv-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN apt-get install -y \
libopenblas-dev \
libopencv-dev \
&& apt-get clean
# Set up OpenCV
RUN ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
@@ -73,114 +79,57 @@ RUN test -n "$TARGETARCH" \
###################################
###################################
# The requirements-extras target is for any builds with IMAGE_TYPE=extras. It should not be placed in this target unless every IMAGE_TYPE=extras build will use it
FROM requirements-core AS requirements-extras
FROM requirements-core as requirements-extras
RUN apt-get update && \
apt-get install -y --no-install-recommends gpg && \
RUN apt install -y gpg && \
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list && \
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
conda && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
apt-get install -y conda && apt-get clean
ENV PATH="/root/.cargo/bin:${PATH}"
RUN apt-get install -y python3-pip && apt-get clean
RUN pip install --upgrade pip
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
RUN apt-get update && \
apt-get install -y --no-install-recommends \
espeak-ng \
espeak && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN apt-get install -y espeak-ng espeak && apt-get clean
###################################
###################################
# The requirements-drivers target is for BUILD_TYPE specific items. If you need to install something specific to CUDA, or specific to ROCM, it goes here.
# This target will be built on top of requirements-core or requirements-extras as retermined by the IMAGE_TYPE build-arg
FROM requirements-${IMAGE_TYPE} AS requirements-drivers
ARG BUILD_TYPE
ARG CUDA_MAJOR_VERSION=11
ARG CUDA_MINOR_VERSION=7
ENV BUILD_TYPE=${BUILD_TYPE}
# CuBLAS requirements
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
apt-get update && \
apt-get install -y --no-install-recommends \
software-properties-common && \
curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
rm -f cuda-keyring_1.1-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* \
; fi
# If we are building with clblas support, we need the libraries for the builds
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
apt-get update && \
apt-get install -y --no-install-recommends \
libclblast-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* \
RUN if [ ! -e /usr/bin/python ]; then \
ln -s /usr/bin/python3 /usr/bin/python \
; fi
###################################
###################################
# The grpc target does one thing, it builds and installs GRPC. This is in it's own layer so that it can be effectively cached by CI.
# You probably don't need to change anything here, and if you do, make sure that CI is adjusted so that the cache continues to work.
FROM ${GRPC_BASE_IMAGE} AS grpc
FROM ${BASE_IMAGE} as grpc
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
ARG MAKEFLAGS
ARG GRPC_VERSION=v1.58.0
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
ENV MAKEFLAGS=${MAKEFLAGS}
WORKDIR /build
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates \
build-essential \
cmake \
git && \
apt-get install -y build-essential cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
# and running make install in the target container
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
mkdir -p /build/grpc/cmake/build && \
cd /build/grpc/cmake/build && \
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX:PATH=/opt/grpc ../.. && \
make && \
make install && \
rm -rf /build
RUN git clone --recurse-submodules --jobs 4 -b ${GRPC_VERSION} --depth 1 --shallow-submodules https://github.com/grpc/grpc
RUN cd grpc && \
mkdir -p cmake/build && \
cd cmake/build && \
cmake -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF ../.. && \
make
###################################
###################################
# The builder target compiles LocalAI. This target is not the target that will be uploaded to the registry.
# Adjustments to the build process should likely be made here.
FROM requirements-drivers AS builder
FROM requirements-${IMAGE_TYPE} as builder
ARG GO_TAGS="stablediffusion tts"
ARG GRPC_BACKENDS
@@ -199,36 +148,39 @@ COPY . .
COPY .git .
RUN echo "GO_TAGS: $GO_TAGS"
RUN apt-get update && \
apt-get install -y build-essential cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN make prepare
# We need protoc installed, and the version in 22.04 is too old. We will create one as part installing the GRPC build below
# but that will also being in a newer version of absl which stablediffusion cannot compile with. This version of protoc is only
# here so that we can generate the grpc code for the stablediffusion build
RUN curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
# If we are building with clblas support, we need the libraries for the builds
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
apt-get update && \
apt-get install -y libclblast-dev && \
apt-get clean \
; fi
# stablediffusion does not tolerate a newer version of abseil, build it first
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
# Install the pre-built GRPC
COPY --from=grpc /opt/grpc /usr/local
COPY --from=grpc /build/grpc ./grpc/
RUN cd /build/grpc/cmake/build && make install
# Rebuild with defaults backends
WORKDIR /build
RUN make build
RUN if [ ! -d "/build/sources/go-piper/piper-phonemize/pi/lib/" ]; then \
mkdir -p /build/sources/go-piper/piper-phonemize/pi/lib/ \
touch /build/sources/go-piper/piper-phonemize/pi/lib/keep \
mkdir -p /build/sources/go-piper/piper-phonemize/pi/lib/ \
touch /build/sources/go-piper/piper-phonemize/pi/lib/keep \
; fi
###################################
###################################
# This is the final target. The result of this target will be the image uploaded to the registry.
# If you cannot find a more suitable place for an addition, this layer is a suitable place for it.
FROM requirements-drivers
FROM requirements-${IMAGE_TYPE}
ARG FFMPEG
ARG BUILD_TYPE
@@ -249,13 +201,21 @@ ENV PIP_CACHE_PURGE=true
# Add FFmpeg
RUN if [ "${FFMPEG}" = "true" ]; then \
apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* \
apt-get install -y ffmpeg && apt-get clean \
; fi
# Add OpenCL
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
apt-get update && \
apt-get install -y libclblast1 && \
apt-get clean \
; fi
RUN apt-get update && \
apt-get install -y cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
WORKDIR /build
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
@@ -265,9 +225,9 @@ WORKDIR /build
COPY . .
COPY --from=builder /build/sources ./sources/
COPY --from=grpc /opt/grpc /usr/local
COPY --from=grpc /build/grpc ./grpc/
RUN make prepare-sources
RUN make prepare-sources && cd /build/grpc/cmake/build && make install && rm -rf /build/grpc
# Copy the binary
COPY --from=builder /build/local-ai ./
@@ -297,9 +257,6 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/sentencetransformers \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/rerankers \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
make -C backend/python/transformers \
; fi
@@ -330,7 +287,7 @@ RUN mkdir -p /build/models
# Define the health check command
HEALTHCHECK --interval=1m --timeout=10m --retries=10 \
CMD curl -f ${HEALTHCHECK_ENDPOINT} || exit 1
CMD curl -f $HEALTHCHECK_ENDPOINT || exit 1
VOLUME /build/models
EXPOSE 8080

109
Makefile
View File

@@ -5,7 +5,7 @@ BINARY_NAME=local-ai
# llama.cpp versions
GOLLAMA_STABLE_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
CPPLLAMA_VERSION?=6ecf3189e00a1e8e737a78b6d10e1d7006e050a2
CPPLLAMA_VERSION?=7593639ce335e8d7f89aa9a54d616951f273af60
# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
@@ -16,7 +16,7 @@ RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
# whisper.cpp version
WHISPER_CPP_VERSION?=8fac6455ffeb0a0950a84e790ddb74f7290d33c4
WHISPER_CPP_VERSION?=b0c3cbf2e851cf232e432b590dcc514a689ec028
# bert.cpp version
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
@@ -25,10 +25,10 @@ BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
PIPER_VERSION?=9d0100873a7dbb0824dfea40e8cec70a1b110759
# stablediffusion version
STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
STABLEDIFFUSION_VERSION?=362df9da29f882dbf09ade61972d16a1f53c3485
# tinydream version
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
TINYDREAM_VERSION?=22a12a4bc0ac5455856f28f3b771331a551a4293
export BUILD_TYPE?=
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
@@ -99,7 +99,7 @@ endif
ifeq ($(BUILD_TYPE),cublas)
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH)
export LLAMA_CUBLAS=1
export WHISPER_CUDA=1
export WHISPER_CUBLAS=1
CGO_LDFLAGS_WHISPER+=-L$(CUDA_LIBPATH)/stubs/ -lcuda
endif
@@ -179,20 +179,20 @@ endif
all: help
## BERT embeddings
sources/go-bert.cpp:
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp sources/go-bert.cpp
cd sources/go-bert.cpp && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
sources/go-bert:
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp sources/go-bert
cd sources/go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
sources/go-bert.cpp/libgobert.a: sources/go-bert.cpp
$(MAKE) -C sources/go-bert.cpp libgobert.a
sources/go-bert/libgobert.a: sources/go-bert
$(MAKE) -C sources/go-bert libgobert.a
## go-llama.cpp
sources/go-llama.cpp:
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama.cpp
cd sources/go-llama.cpp && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1
## go-llama-ggml
sources/go-llama-ggml:
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama-ggml
cd sources/go-llama-ggml && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1
sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
sources/go-llama-ggml/libbinding.a: sources/go-llama-ggml
$(MAKE) -C sources/go-llama-ggml BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
## go-piper
sources/go-piper:
@@ -211,12 +211,12 @@ sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a: sources/gpt4all
$(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ libgpt4all.a
## RWKV
sources/go-rwkv.cpp:
git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv.cpp
cd sources/go-rwkv.cpp && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
sources/go-rwkv:
git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv
cd sources/go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
sources/go-rwkv.cpp/librwkv.a: sources/go-rwkv.cpp
cd sources/go-rwkv.cpp && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
sources/go-rwkv/librwkv.a: sources/go-rwkv
cd sources/go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
## stable diffusion
sources/go-stable-diffusion:
@@ -236,24 +236,23 @@ sources/go-tiny-dream/libtinydream.a: sources/go-tiny-dream
## whisper
sources/whisper.cpp:
git clone https://github.com/ggerganov/whisper.cpp sources/whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git sources/whisper.cpp
cd sources/whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && $(MAKE) libwhisper.a
cd sources/whisper.cpp && make libwhisper.a
get-sources: sources/go-llama.cpp sources/gpt4all sources/go-piper sources/go-rwkv.cpp sources/whisper.cpp sources/go-bert.cpp sources/go-stable-diffusion sources/go-tiny-dream
get-sources: sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
replace:
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(CURDIR)/sources/go-rwkv.cpp
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(CURDIR)/sources/go-rwkv
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(CURDIR)/sources/whisper.cpp/bindings/go
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(CURDIR)/sources/go-bert.cpp
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(CURDIR)/sources/go-bert
$(GOCMD) mod edit -replace github.com/M0Rf30/go-tiny-dream=$(CURDIR)/sources/go-tiny-dream
$(GOCMD) mod edit -replace github.com/mudler/go-piper=$(CURDIR)/sources/go-piper
$(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(CURDIR)/sources/go-stable-diffusion
$(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(CURDIR)/sources/gpt4all/gpt4all-bindings/golang
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama.cpp
dropreplace:
$(GOCMD) mod edit -dropreplace github.com/donomii/go-rwkv.cpp
@@ -272,12 +271,12 @@ prepare-sources: get-sources replace
## GENERIC
rebuild: ## Rebuilds the project
$(GOCMD) clean -cache
$(MAKE) -C sources/go-llama.cpp clean
$(MAKE) -C sources/go-llama-ggml clean
$(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ clean
$(MAKE) -C sources/go-rwkv.cpp clean
$(MAKE) -C sources/go-rwkv clean
$(MAKE) -C sources/whisper.cpp clean
$(MAKE) -C sources/go-stable-diffusion clean
$(MAKE) -C sources/go-bert.cpp clean
$(MAKE) -C sources/go-bert clean
$(MAKE) -C sources/go-piper clean
$(MAKE) -C sources/go-tiny-dream clean
$(MAKE) build
@@ -302,6 +301,9 @@ clean-tests:
rm -rf test-dir
rm -rf core/http/backend-assets
halt-backends: ## Used to clean up stray backends sometimes left running when debugging manually
ps | grep 'backend-assets/grpc/' | awk '{print $$1}' | xargs -I {} kill -9 {}
## Build:
build: prepare backend-assets grpcs ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
@@ -366,13 +368,13 @@ run-e2e-image:
run-e2e-aio:
@echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
test-e2e:
@echo 'Running e2e tests'
BUILD_TYPE=$(BUILD_TYPE) \
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
teardown-e2e:
rm -rf $(TEST_DIR) || true
@@ -380,15 +382,15 @@ teardown-e2e:
test-gpt4all: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
test-llama: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
test-llama-gguf: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
test-tts: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
@@ -437,10 +439,10 @@ protogen-go-clean:
$(RM) bin/*
.PHONY: protogen-python
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen rerankers-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen
protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen
.PHONY: protogen-python-clean
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean rerankers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean
protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean
.PHONY: autogptq-protogen
autogptq-protogen:
@@ -506,14 +508,6 @@ petals-protogen:
petals-protogen-clean:
$(MAKE) -C backend/python/petals protogen-clean
.PHONY: rerankers-protogen
rerankers-protogen:
$(MAKE) -C backend/python/rerankers protogen
.PHONY: rerankers-protogen-clean
rerankers-protogen-clean:
$(MAKE) -C backend/python/rerankers protogen-clean
.PHONY: sentencetransformers-protogen
sentencetransformers-protogen:
$(MAKE) -C backend/python/sentencetransformers protogen
@@ -572,7 +566,6 @@ prepare-extra-conda-environments: protogen-python
$(MAKE) -C backend/python/vllm
$(MAKE) -C backend/python/mamba
$(MAKE) -C backend/python/sentencetransformers
$(MAKE) -C backend/python/rerankers
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/transformers-musicgen
$(MAKE) -C backend/python/parler-tts
@@ -608,8 +601,8 @@ backend-assets/gpt4all: sources/gpt4all sources/gpt4all/gpt4all-bindings/golang/
backend-assets/grpc: protogen-go replace
mkdir -p backend-assets/grpc
backend-assets/grpc/bert-embeddings: sources/go-bert.cpp sources/go-bert.cpp/libgobert.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-bert.cpp LIBRARY_PATH=$(CURDIR)/sources/go-bert.cpp \
backend-assets/grpc/bert-embeddings: sources/go-bert sources/go-bert/libgobert.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-bert LIBRARY_PATH=$(CURDIR)/sources/go-bert \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./backend/go/llm/bert/
backend-assets/grpc/gpt4all: sources/gpt4all sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a backend-assets/gpt4all backend-assets/grpc
@@ -651,16 +644,20 @@ ifeq ($(BUILD_TYPE),metal)
cp backend/cpp/llama/llama.cpp/build/bin/default.metallib backend-assets/grpc/
endif
backend-assets/grpc/llama-ggml: sources/go-llama.cpp sources/go-llama.cpp/libbinding.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama.cpp LIBRARY_PATH=$(CURDIR)/sources/go-llama.cpp \
backend-assets/grpc/llama-ggml: sources/go-llama-ggml sources/go-llama-ggml/libbinding.a backend-assets/grpc
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama-ggml
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama-ggml LIBRARY_PATH=$(CURDIR)/sources/go-llama-ggml \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-ggml ./backend/go/llm/llama-ggml/
# EXPERIMENTAL:
ifeq ($(BUILD_TYPE),metal)
cp $(CURDIR)/sources/go-llama-ggml/llama.cpp/ggml-metal.metal backend-assets/grpc/
endif
backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
backend-assets/grpc/rwkv: sources/go-rwkv.cpp sources/go-rwkv.cpp/librwkv.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-rwkv.cpp LIBRARY_PATH=$(CURDIR)/sources/go-rwkv.cpp \
backend-assets/grpc/rwkv: sources/go-rwkv sources/go-rwkv/librwkv.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-rwkv LIBRARY_PATH=$(CURDIR)/sources/go-rwkv \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./backend/go/llm/rwkv
backend-assets/grpc/stablediffusion: sources/go-stable-diffusion sources/go-stable-diffusion/libstablediffusion.a backend-assets/grpc
@@ -707,7 +704,7 @@ docker-aio-all:
docker-image-intel:
docker build \
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04 \
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04 \
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
--build-arg GO_TAGS="none" \
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
@@ -715,7 +712,7 @@ docker-image-intel:
docker-image-intel-xpu:
docker build \
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04 \
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04 \
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
--build-arg GO_TAGS="none" \
--build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \
@@ -723,4 +720,4 @@ docker-image-intel-xpu:
.PHONY: swagger
swagger:
swag init -g core/http/app.go --output swagger
swag init -g core/http/api.go --output swagger

View File

@@ -44,23 +44,20 @@
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API thats compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler).
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API thats compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU.
## 🔥🔥 Hot topics / Roadmap
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
- Reranker API: https://github.com/mudler/LocalAI/pull/2121
- Gallery WebUI: https://github.com/mudler/LocalAI/pull/2104
- llama3: https://github.com/mudler/LocalAI/discussions/2076
- Parler-TTS: https://github.com/mudler/LocalAI/pull/2027
- Landing page: https://github.com/mudler/LocalAI/pull/1922
- Openvino support: https://github.com/mudler/LocalAI/pull/1892
- Vector store: https://github.com/mudler/LocalAI/pull/1795
- All-in-one container image: https://github.com/mudler/LocalAI/issues/1855
- Parallel function calling: https://github.com/mudler/LocalAI/pull/1726 / Tools API support: https://github.com/mudler/LocalAI/pull/1715
Hot topics (looking for contributors):
- WebUI improvements: https://github.com/mudler/LocalAI/issues/2156
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
- Assistant API: https://github.com/mudler/LocalAI/issues/1273
@@ -91,8 +88,7 @@ docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
- 🥽 [Vision API](https://localai.io/features/gpt-vision/)
- 🆕 [Reranker API](https://localai.io/features/reranker/)
- 🆕 [Vision API](https://localai.io/features/gpt-vision/)
## 💻 Usage

View File

@@ -1,27 +0,0 @@
name: jina-reranker-v1-base-en
backend: rerankers
parameters:
model: cross-encoder
usage: |
You can test this model with curl like this:
curl http://localhost:8080/v1/rerank \
-H "Content-Type: application/json" \
-d '{
"model": "jina-reranker-v1-base-en",
"query": "Organic skincare products for sensitive skin",
"documents": [
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"All-natural pet food for dogs with allergies",
"Yoga mats made from recycled materials"
],
"top_n": 3
}'

View File

@@ -1,27 +1,20 @@
name: gpt-4
mmap: true
parameters:
model: huggingface://NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf
model: huggingface://NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/Hermes-2-Pro-Mistral-7B.Q2_K.gguf
template:
chat_message: |
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .FunctionCall }}<tool_call>{{end}}
{{- if eq .RoleName "tool" }}<tool_result>{{end }}
{{- if .Content}}
{{.Content }}
{{.Content}}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>
{{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }}
{{- if .FunctionCall }}</tool_call>{{end }}
{{- if eq .RoleName "tool" }}</tool_result>{{end }}
<|im_end|>
# https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling
function: |
<|im_start|>system
@@ -36,7 +29,8 @@ template:
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{'arguments': <args-dict>, 'name': <function-name>}
</tool_call><|im_end|>
</tool_call>
<|im_end|>
{{.Input -}}
<|im_start|>assistant
<tool_call>

View File

@@ -129,7 +129,7 @@ detect_gpu
detect_gpu_size
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
check_vars

View File

@@ -1,27 +0,0 @@
name: jina-reranker-v1-base-en
backend: rerankers
parameters:
model: cross-encoder
usage: |
You can test this model with curl like this:
curl http://localhost:8080/v1/rerank \
-H "Content-Type: application/json" \
-d '{
"model": "jina-reranker-v1-base-en",
"query": "Organic skincare products for sensitive skin",
"documents": [
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"All-natural pet food for dogs with allergies",
"Yoga mats made from recycled materials"
],
"top_n": 3
}'

View File

@@ -1,27 +1,20 @@
name: gpt-4
mmap: true
parameters:
model: huggingface://NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf
model: huggingface://NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/Hermes-2-Pro-Mistral-7B.Q6_K.gguf
template:
chat_message: |
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .FunctionCall }}<tool_call>{{end}}
{{- if eq .RoleName "tool" }}<tool_result>{{end }}
{{- if .Content}}
{{.Content }}
{{.Content}}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>
{{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }}
{{- if .FunctionCall }}</tool_call>{{end }}
{{- if eq .RoleName "tool" }}</tool_result>{{end }}
<|im_end|>
# https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling
function: |
<|im_start|>system
@@ -36,7 +29,8 @@ template:
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{'arguments': <args-dict>, 'name': <function-name>}
</tool_call><|im_end|>
</tool_call>
<|im_end|>
{{.Input -}}
<|im_start|>assistant
<tool_call>

View File

@@ -1,27 +0,0 @@
name: jina-reranker-v1-base-en
backend: rerankers
parameters:
model: cross-encoder
usage: |
You can test this model with curl like this:
curl http://localhost:8080/v1/rerank \
-H "Content-Type: application/json" \
-d '{
"model": "jina-reranker-v1-base-en",
"query": "Organic skincare products for sensitive skin",
"documents": [
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"All-natural pet food for dogs with allergies",
"Yoga mats made from recycled materials"
],
"top_n": 3
}'

View File

@@ -2,27 +2,20 @@ name: gpt-4
mmap: false
f16: false
parameters:
model: huggingface://NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf
model: huggingface://NousResearch/Hermes-2-Pro-Mistral-7B-GGUF/Hermes-2-Pro-Mistral-7B.Q6_K.gguf
template:
chat_message: |
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
{{- if .FunctionCall }}
<tool_call>
{{- else if eq .RoleName "tool" }}
<tool_response>
{{- end }}
{{- if .FunctionCall }}<tool_call>{{end}}
{{- if eq .RoleName "tool" }}<tool_result>{{end }}
{{- if .Content}}
{{.Content }}
{{.Content}}
{{- end }}
{{- if .FunctionCall}}
{{toJson .FunctionCall}}
{{- end }}
{{- if .FunctionCall }}
</tool_call>
{{- else if eq .RoleName "tool" }}
</tool_response>
{{- end }}<|im_end|>
{{- if .FunctionCall}}{{toJson .FunctionCall}}{{end }}
{{- if .FunctionCall }}</tool_call>{{end }}
{{- if eq .RoleName "tool" }}</tool_result>{{end }}
<|im_end|>
# https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B-GGUF#prompt-format-for-function-calling
function: |
<|im_start|>system
@@ -37,7 +30,8 @@ template:
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{'arguments': <args-dict>, 'name': <function-name>}
</tool_call><|im_end|>
</tool_call>
<|im_end|>
{{.Input -}}
<|im_start|>assistant
<tool_call>

View File

@@ -23,30 +23,6 @@ service Backend {
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
rpc Rerank(RerankRequest) returns (RerankResult) {}
}
message RerankRequest {
string query = 1;
repeated string documents = 2;
int32 top_n = 3;
}
message RerankResult {
Usage usage = 1;
repeated DocumentResult results = 2;
}
message Usage {
int32 total_tokens = 1;
int32 prompt_tokens = 2;
}
message DocumentResult {
int32 index = 1;
string text = 2;
float relevance_score = 3;
}
message StoresKey {
@@ -201,7 +177,6 @@ message ModelOptions {
bool EnforceEager = 52;
int32 SwapSpace = 53;
int32 MaxModelLen = 54;
int32 TensorParallelSize = 55;
string MMProj = 41;

View File

@@ -11,8 +11,8 @@ import (
"github.com/go-skynet/LocalAI/core/schema"
)
func ffmpegCommand(args []string) (string, error) {
cmd := exec.Command("ffmpeg", args...) // Constrain this to ffmpeg to permit security scanner to see that the command is safe.
func runCommand(command []string) (string, error) {
cmd := exec.Command(command[0], command[1:]...)
cmd.Env = os.Environ()
out, err := cmd.CombinedOutput()
return string(out), err
@@ -21,8 +21,8 @@ func ffmpegCommand(args []string) (string, error) {
// AudioToWav converts audio to wav for transcribe.
// TODO: use https://github.com/mccoyst/ogg?
func audioToWav(src, dst string) error {
commandArgs := []string{"-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
out, err := ffmpegCommand(commandArgs)
command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
out, err := runCommand(command)
if err != nil {
return fmt.Errorf("error: %w out: %s", err, out)
}

View File

@@ -41,7 +41,7 @@ dependencies:
- filelock==3.12.4
- frozenlist==1.4.0
- fsspec==2023.6.0
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub==0.16.4
- idna==3.4
- jinja2==3.1.2

View File

@@ -26,7 +26,7 @@ if [ -d "/opt/intel" ]; then
# Intel GPU: If the directory exists, we assume we are using the intel image
# (no conda env)
# https://github.com/intel/intel-extension-for-pytorch/issues/538
pip install torch==2.1.0.post0 torchvision==0.16.0.post0 torchaudio==2.1.0.post0 intel-extension-for-pytorch==2.1.20+xpu oneccl_bind_pt==2.1.200+xpu intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed optimum[openvino]
fi
# If we didn't skip conda, activate the environment

View File

@@ -47,7 +47,7 @@ dependencies:
- frozenlist==1.4.0
- fsspec==2023.6.0
- funcy==2.0
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub
- idna==3.4
- jinja2==3.1.2
@@ -120,6 +120,4 @@ dependencies:
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- rerankers[transformers]
- pydantic
prefix: /opt/conda/envs/transformers

View File

@@ -48,7 +48,7 @@ dependencies:
- frozenlist==1.4.0
- fsspec==2023.6.0
- funcy==2.0
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub
- idna==3.4
- jinja2==3.1.2
@@ -108,6 +108,4 @@ dependencies:
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- rerankers[transformers]
- pydantic
prefix: /opt/conda/envs/transformers

View File

@@ -47,7 +47,7 @@ dependencies:
- frozenlist==1.4.0
- fsspec==2023.6.0
- funcy==2.0
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub
- humanfriendly==10.0
- idna==3.4
@@ -60,10 +60,9 @@ dependencies:
- networkx
- numpy==1.26.0
- onnx==1.15.0
- openvino==2024.1.0
- openvino-telemetry==2024.1.0
- optimum[openvino]==1.19.1
- optimum-intel==1.16.1
- openvino==2024.0.0
- openvino-telemetry==2023.2.1
- optimum[openvino]==1.17.1
- packaging==23.2
- pandas
- peft==0.5.0
@@ -112,7 +111,5 @@ dependencies:
- vllm>=0.4.0
- transformers>=4.38.2 # Updated Version
- transformers_stream_generator==0.0.5
- xformers==0.0.23.post1
- rerankers[transformers]
- pydantic
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers

View File

@@ -34,7 +34,7 @@ dependencies:
- diffusers==0.24.0
- filelock==3.12.4
- fsspec==2023.9.2
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub>=0.19.4
- idna==3.4
- importlib-metadata==6.8.0
@@ -61,5 +61,4 @@ dependencies:
- urllib3==2.0.6
- zipp==3.17.0
- torch
- opencv-python
prefix: /opt/conda/envs/diffusers

View File

@@ -32,7 +32,7 @@ dependencies:
- diffusers==0.24.0
- filelock==3.12.4
- fsspec==2023.9.2
- grpcio==1.63.0
- grpcio==1.59.0
- huggingface-hub>=0.19.4
- idna==3.4
- importlib-metadata==6.8.0
@@ -71,5 +71,4 @@ dependencies:
- typing-extensions==4.8.0
- urllib3==2.0.6
- zipp==3.17.0
- opencv-python
prefix: /opt/conda/envs/diffusers

View File

@@ -31,8 +31,8 @@ if [ -d "/opt/intel" ]; then
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install google-api-python-client \
grpcio==1.63.0 \
grpcio-tools==1.63.0 \
grpcio \
grpcio-tools \
diffusers==0.24.0 \
transformers>=4.25.1 \
accelerate \

View File

@@ -27,7 +27,7 @@ dependencies:
- pip:
- filelock==3.12.4
- fsspec==2023.9.2
- grpcio==1.63.0
- grpcio==1.59.0
- jinja2==3.1.2
- markupsafe==2.1.3
- mpmath==1.3.0

View File

@@ -27,7 +27,7 @@ dependencies:
- pip:
- filelock==3.12.4
- fsspec==2023.9.2
- grpcio==1.63.0
- grpcio==1.59.0
- markupsafe==2.1.3
- mpmath==1.3.0
- networkx==3.1

View File

@@ -26,7 +26,7 @@ dependencies:
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate>=0.11.0
- grpcio==1.63.0
- grpcio==1.59.0
- numpy==1.26.0
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105

View File

@@ -27,7 +27,7 @@ dependencies:
- pip:
- accelerate>=0.11.0
- numpy==1.26.0
- grpcio==1.63.0
- grpcio==1.59.0
- torch==2.1.0
- transformers>=4.34.0
- descript-audio-codec

View File

@@ -1,27 +0,0 @@
.PHONY: rerankers
rerankers: protogen
$(MAKE) -C ../common-env/transformers
.PHONY: run
run: protogen
@echo "Running rerankers..."
bash run.sh
@echo "rerankers run."
# It is not working well by using command line. It only6 works with IDE like VSCode.
.PHONY: test
test: protogen
@echo "Testing rerankers..."
bash test.sh
@echo "rerankers tested."
.PHONY: protogen
protogen: backend_pb2_grpc.py backend_pb2.py
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
backend_pb2_grpc.py backend_pb2.py:
python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto

View File

@@ -1,5 +0,0 @@
# Creating a separate environment for the reranker project
```
make reranker
```

View File

@@ -1,123 +0,0 @@
#!/usr/bin/env python3
"""
Extra gRPC server for Rerankers models.
"""
from concurrent import futures
import argparse
import signal
import sys
import os
import time
import backend_pb2
import backend_pb2_grpc
import grpc
from rerankers import Reranker
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer for the backend service.
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
"""
def Health(self, request, context):
"""
A gRPC method that returns the health status of the backend service.
Args:
request: A HealthRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Reply object that contains the health status of the backend service.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
"""
A gRPC method that loads a model into memory.
Args:
request: A LoadModelRequest object that contains the request parameters.
context: A grpc.ServicerContext object that provides information about the RPC.
Returns:
A Result object that contains the result of the LoadModel operation.
"""
model_name = request.Model
try:
kwargs = {}
if request.Type != "":
kwargs['model_type'] = request.Type
if request.PipelineType != "": # Reuse the PipelineType field for language
kwargs['lang'] = request.PipelineType
self.model_name = model_name
self.model = Reranker(model_name, **kwargs)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
# Replace this with your desired response
return backend_pb2.Result(message="Model loaded successfully", success=True)
def Rerank(self, request, context):
documents = []
for idx, doc in enumerate(request.documents):
documents.append(doc)
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
# Prepare results to return
results = [
backend_pb2.DocumentResult(
index=res.doc_id,
text=res.text,
relevance_score=res.score
) for res in ranked_results.results
]
# Calculate the usage and total tokens
# TODO: Implement the usage calculation with reranker
total_tokens = sum(len(doc.split()) for doc in request.documents) + len(request.query.split())
prompt_tokens = len(request.query.split())
usage = backend_pb2.Usage(total_tokens=total_tokens, prompt_tokens=prompt_tokens)
return backend_pb2.RerankResult(usage=usage, results=results)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
serve(args.addr)

View File

@@ -1,14 +0,0 @@
#!/bin/bash
##
## A bash script wrapper that runs the reranker server with conda
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
python $DIR/reranker.py $@

View File

@@ -1,11 +0,0 @@
#!/bin/bash
##
## A bash script wrapper that runs the reranker server with conda
# Activate conda environment
source activate transformers
# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
python -m unittest $DIR/test_reranker.py

View File

@@ -1,90 +0,0 @@
"""
A test script to test the gRPC service
"""
import unittest
import subprocess
import time
import backend_pb2
import backend_pb2_grpc
import grpc
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service
"""
def setUp(self):
"""
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "reranker.py", "--addr", "localhost:50051"])
time.sleep(10)
def tearDown(self) -> None:
"""
This method tears down the gRPC service by terminating the server
"""
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""
This method tests if the server starts up successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
"""
This method tests if the model is loaded successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_rerank(self):
"""
This method tests if the embeddings are generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you"],
top_n=2
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()

View File

@@ -89,8 +89,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
quantization = None
if self.CUDA:
if request.MainGPU:
device_map=request.MainGPU
if request.Device:
device_map=request.Device
else:
device_map="cuda:0"
if request.Quantization == "bnb_4bit":
@@ -143,37 +143,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
from optimum.intel.openvino import OVModelForCausalLM
from openvino.runtime import Core
if request.MainGPU:
device_map=request.MainGPU
if "GPU" in Core().available_devices:
device_map="GPU"
else:
device_map="AUTO"
devices = Core().available_devices
if "GPU" in " ".join(devices):
device_map="AUTO:GPU"
device_map="CPU"
self.model = OVModelForCausalLM.from_pretrained(model_name,
compile=True,
trust_remote_code=request.TrustRemoteCode,
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"},
device=device_map)
self.OV = True
elif request.Type == "OVModelForFeatureExtraction":
from optimum.intel.openvino import OVModelForFeatureExtraction
from openvino.runtime import Core
if request.MainGPU:
device_map=request.MainGPU
else:
device_map="AUTO"
devices = Core().available_devices
if "GPU" in " ".join(devices):
device_map="AUTO:GPU"
self.model = OVModelForFeatureExtraction.from_pretrained(model_name,
compile=True,
trust_remote_code=request.TrustRemoteCode,
ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT", "GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"},
export=True,
compile=True,
device=device_map)
self.OV = True
else:
@@ -183,11 +158,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
quantization_config=quantization,
device_map=device_map,
torch_dtype=compute)
if request.ContextSize > 0:
self.max_tokens = request.ContextSize
else:
self.max_tokens = self.model.config.max_position_embeddings
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
self.XPU = False
@@ -242,27 +212,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
set_seed(request.Seed)
if request.TopP == 0:
request.TopP = 0.9
if request.TopK == 0:
request.TopK = 40
prompt = request.Prompt
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
eos_token_id = self.tokenizer.eos_token_id
if request.StopPrompts:
eos_token_id = []
for word in request.StopPrompts:
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
inputs = self.tokenizer(prompt, return_tensors="pt")
max_tokens = 200
if request.Tokens > 0:
max_tokens = request.Tokens
else:
max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1]
inputs = self.tokenizer(request.Prompt, return_tensors="pt")
if self.CUDA:
inputs = inputs.to("cuda")
if XPU and self.OV == False:
@@ -280,7 +235,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
top_k=request.TopK,
do_sample=True,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
streamer=streamer)
thread=Thread(target=self.model.generate, kwargs=config)
@@ -309,7 +264,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
top_k=request.TopK,
do_sample=True,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
@@ -379,4 +334,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
asyncio.run(serve(args.addr))
asyncio.run(serve(args.addr))

View File

@@ -42,7 +42,7 @@ dependencies:
- future==0.18.3
- gradio==3.47.1
- gradio-client==0.6.0
- grpcio==1.63.0
- grpcio==1.59.0
- h11==0.14.0
- httpcore==0.18.0
- httpx==0.25.0

View File

@@ -95,8 +95,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
engine_args.trust_remote_code = request.TrustRemoteCode
if request.EnforceEager:
engine_args.enforce_eager = request.EnforceEager
if request.TensorParallelSize:
engine_args.tensor_parallel_size = request.TensorParallelSize
if request.SwapSpace != 0:
engine_args.swap_space = request.SwapSpace
if request.MaxModelLen != 0:

View File

@@ -2,14 +2,100 @@ package backend
import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
type EmbeddingsBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService {
return &EmbeddingsBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
go func(request *schema.OpenAIRequest) {
if request.Model == "" {
request.Model = model.StableDiffusionBackend
}
bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items := []schema.Item{}
for i, s := range bc.InputToken {
// get the model function to call for the result
embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
embeddings, err := embedFn()
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range bc.InputStrings {
// get the model function to call for the result
embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
embeddings, err := embedFn()
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
close(resultChannel)
}(request)
return resultChannel
}
func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(backendConfig)

View File

@@ -1,18 +1,252 @@
package backend
import (
"github.com/go-skynet/LocalAI/core/config"
"bufio"
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
type ImageGenerationBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
BaseUrlForGeneratedImages string
}
func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService {
return &ImageGenerationBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] {
resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse])
go func(request *schema.OpenAIRequest) {
bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
src := ""
if request.File != "" {
var fileData []byte
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") {
out, err := downloadFile(request.File)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)}
close(resultChannel)
return
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)}
close(resultChannel)
return
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(request.File)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64")
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", bc)
switch bc.Backend {
case "stablediffusion":
bc.Backend = model.StableDiffusionBackend
case "tinydream":
bc.Backend = model.TinyDreamBackend
case "":
bc.Backend = model.StableDiffusionBackend
if bc.Model == "" {
bc.Model = "stablediffusion_assets" // TODO: check?
}
}
sizeParts := strings.Split(request.Size, "x")
if len(sizeParts) != 2 {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
close(resultChannel)
return
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
close(resultChannel)
return
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")}
close(resultChannel)
return
}
b64JSON := false
if request.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range bc.PromptStrings {
n := request.N
if request.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := bc.Step
if step == 0 {
step = 15
}
if request.Mode != 0 {
mode = request.Mode
}
if request.Step != 0 {
step = request.Step
}
tempDir := ""
if !b64JSON {
tempDir = igbs.appConfig.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
if request.Seed == nil {
zVal := 0 // Idiomatic way to do this? Actually needed?
request.Seed = &zVal
}
fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
if err := fn(); err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err}
close(resultChannel)
return
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = igbs.BaseUrlForGeneratedImages + base
}
result = append(result, *item)
}
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}
resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp}
close(resultChannel)
}(request)
return resultChannel
}
func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
threads := backendConfig.Threads
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
@@ -50,3 +284,24 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return fn, nil
}
// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change.
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "image")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}

View File

@@ -11,17 +11,22 @@ import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
type LLMRequest struct {
Id int // TODO Remove if not used.
Text string
Images []string
RawMessages []schema.Message
// TODO: Other Modalities?
}
type TokenUsage struct {
@@ -29,57 +34,94 @@ type TokenUsage struct {
Completion int
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
type LLMResponse struct {
Request *LLMRequest
Response string // should this be []byte?
Usage TokenUsage
}
// TODO: Does this belong here or in core/services/openai.go?
type LLMResponseBundle struct {
Request *schema.OpenAIRequest
Response []schema.Choice
Usage TokenUsage
}
type LLMBackendService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
ftMutex sync.Mutex
cutstrings map[string]*regexp.Regexp
}
func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService {
return &LLMBackendService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
ftMutex: sync.Mutex{},
cutstrings: make(map[string]*regexp.Regexp),
}
grpcOpts := gRPCModelOpts(c)
}
// TODO: Should ctx param be removed and replaced with hardcoded req.Context?
func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) (
resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) {
threads := bc.Threads
if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 {
threads = &llmbs.appConfig.Threads
}
grpcOpts := gRPCModelOpts(bc)
var inferenceModel grpc.Backend
var err error
opts := modelOpts(c, o, []model.Option{
opts := modelOpts(bc, llmbs.appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(llmbs.appConfig.AssetsDestination),
model.WithModel(bc.Model),
model.WithContext(llmbs.appConfig.Context),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
if bc.Backend != "" {
opts = append(opts, model.WithBackendString(bc.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
// Check if bc.Model exists, if it doesn't try to load it from the gallery
if llmbs.appConfig.AutoloadGalleries { // experimental
if _, err := os.Stat(bc.Model); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
return nil, nil, err
}
}
}
if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
if bc.Backend == "" {
log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model)
inferenceModel, err = llmbs.ml.GreedyLoader(opts...)
} else {
inferenceModel, err = loader.BackendLoader(opts...)
inferenceModel, err = llmbs.ml.BackendLoader(opts...)
}
if err != nil {
return nil, err
log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend")
return
}
var protoMessages []*proto.Message
// if we are using the tokenizer template, we need to convert the messages to proto messages
// unless the prompt has already been tokenized (non-chat endpoints + functions)
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
protoMessages = make([]*proto.Message, len(messages), len(messages))
for i, message := range messages {
grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath)
grpcPredOpts.Prompt = req.Text
grpcPredOpts.Images = req.Images
if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" {
grpcPredOpts.UseTokenizerTemplate = true
protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages))
for i, message := range req.RawMessages {
protoMessages[i] = &proto.Message{
Role: message.Role,
}
@@ -87,47 +129,32 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
case string:
protoMessages[i].Content = ct
default:
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
return
}
}
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
opts.Images = images
tokenUsage := TokenUsage{}
tokenUsage := TokenUsage{}
promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
// check the per-model feature flag for usage, since tokenCallback may have a cost.
// Defaults to off as for now it is still experimental
if c.FeatureFlag.Enabled("usage") {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}
rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse])
// TODO this next line is the biggest argument for taking named return values _back_ out!!!
var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse]
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}
if enableTokenChannel {
rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse])
tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}
if tokenCallback != nil {
ss := ""
// TODO Needs better name
ss := ""
go func() {
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) {
partialRune = append(partialRune, chars...)
for len(partialRune) > 0 {
@@ -137,54 +164,126 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
break
}
tokenCallback(string(r), tokenUsage)
tokenUsage.Completion++
rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: string(r),
Usage: tokenUsage,
}}
ss += string(r)
partialRune = partialRune[size:]
}
})
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
close(rawTokenChannel)
if err != nil {
return LLMResponse{}, err
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
} else {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: ss,
Usage: tokenUsage,
}}
}
close(rawResultChannel)
}()
} else {
go func() {
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
if tokenUsage.Prompt == 0 {
tokenUsage.Prompt = int(reply.PromptTokens)
}
if tokenUsage.Completion == 0 {
tokenUsage.Completion = int(reply.Tokens)
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
if err != nil {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
close(rawResultChannel)
} else {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}}
close(rawResultChannel)
}
}()
}
return fn, nil
resultChannel = rawResultChannel
tokenChannel = rawTokenChannel
return
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request??
func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig,
mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) (
// Returns:
resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) {
func Finetune(config config.BackendConfig, input, prediction string) string {
rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle])
resultChannel = rawChannel
if request.N == 0 { // number of completions to return
request.N = 1
}
images := []string{}
for _, m := range request.Messages {
images = append(images, m.StringImages...)
}
for i := 0; i < request.N; i++ {
individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{
Text: predInput,
Images: images,
RawMessages: request.Messages,
}, bc, enableTokenChannels)
if infErr != nil {
err = infErr // Avoids complaints about redeclaring err but looks dumb
return
}
completionChannels = append(completionChannels, individualResultChannel)
tokenChannels = append(tokenChannels, tokenChannel)
}
go func() {
initialBundle := LLMResponseBundle{
Request: request,
Response: []schema.Choice{},
Usage: TokenUsage{},
}
wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] {
if iv.Error != nil {
ov.Error = iv.Error
// TODO: Decide if we should wipe partials or not?
return ov
}
ov.Value.Usage.Prompt += iv.Value.Usage.Prompt
ov.Value.Usage.Completion += iv.Value.Usage.Completion
ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value))
return ov
}, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true)
wg.Wait()
}()
return
}
func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
llmbs.ftMutex.Lock()
reg, ok := llmbs.cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
llmbs.cutstrings[c] = regexp.MustCompile(c)
reg = llmbs.cutstrings[c]
}
mu.Unlock()
llmbs.ftMutex.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}

View File

@@ -7,11 +7,10 @@ import (
"github.com/go-skynet/LocalAI/core/config"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
@@ -20,12 +19,12 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
if bc.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(bc.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
if bc.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(bc.GRPC.AttemptsSleepTime))
}
for k, v := range so.ExternalGRPCBackends {
@@ -35,7 +34,7 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
return opts
}
func getSeed(c config.BackendConfig) int32 {
func getSeed(c *config.BackendConfig) int32 {
seed := int32(*c.Seed)
if seed == config.RAND_SEED {
seed = rand.Int31()
@@ -44,7 +43,7 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
@@ -75,7 +74,6 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
EnforceEager: c.EnforceEager,
SwapSpace: int32(c.SwapSpace),
MaxModelLen: int32(c.MaxModelLen),
TensorParallelSize: int32(c.TensorParallelSize),
MMProj: c.MMProj,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
@@ -106,51 +104,47 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
}
}
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
func gRPCPredictOpts(bc *config.BackendConfig, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)
err := os.MkdirAll(filepath.Dir(p), 0750)
if err == nil {
promptCachePath = p
} else {
log.Error().Err(err).Str("promptCachePath", promptCachePath).Msg("error creating prompt cache folder")
}
if bc.PromptCachePath != "" {
p := filepath.Join(modelPath, bc.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}
return &pb.PredictOptions{
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
TopK: int32(*c.TopK),
Tokens: int32(*c.Maxtokens),
Threads: int32(*c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
Temperature: float32(*bc.Temperature),
TopP: float32(*bc.TopP),
NDraft: bc.NDraft,
TopK: int32(*bc.TopK),
Tokens: int32(*bc.Maxtokens),
Threads: int32(*bc.Threads),
PromptCacheAll: bc.PromptCacheAll,
PromptCacheRO: bc.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: *c.F16,
DebugMode: *c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(*c.LLMConfig.Mirostat),
MirostatETA: float32(*c.LLMConfig.MirostatETA),
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
Debug: *c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: getSeed(c),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: *c.MMlock,
MMap: *c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(*c.TFZ),
TypicalP: float32(*c.TypicalP),
F16KV: *bc.F16,
DebugMode: *bc.Debug,
Grammar: bc.Grammar,
NegativePromptScale: bc.NegativePromptScale,
RopeFreqBase: bc.RopeFreqBase,
RopeFreqScale: bc.RopeFreqScale,
NegativePrompt: bc.NegativePrompt,
Mirostat: int32(*bc.LLMConfig.Mirostat),
MirostatETA: float32(*bc.LLMConfig.MirostatETA),
MirostatTAU: float32(*bc.LLMConfig.MirostatTAU),
Debug: *bc.Debug,
StopPrompts: bc.StopWords,
Repeat: int32(bc.RepeatPenalty),
NKeep: int32(bc.Keep),
Batch: int32(bc.Batch),
IgnoreEOS: bc.IgnoreEOS,
Seed: getSeed(bc),
FrequencyPenalty: float32(bc.FrequencyPenalty),
MLock: *bc.MMlock,
MMap: *bc.MMap,
MainGPU: bc.MainGPU,
TensorSplit: bc.TensorSplit,
TailFreeSamplingZ: float32(*bc.TFZ),
TypicalP: float32(*bc.TypicalP),
}
}

View File

@@ -1,39 +0,0 @@
package backend
import (
"context"
"fmt"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
rerankModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if rerankModel == nil {
return nil, fmt.Errorf("could not load rerank model")
}
res, err := rerankModel.Rerank(context.Background(), request)
return res, err
}

View File

@@ -7,11 +7,48 @@ import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
type TranscriptionBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
func NewTranscriptionBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TranscriptionBackendService {
return &TranscriptionBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (tbs *TranscriptionBackendService) Transcribe(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.TranscriptionResult] {
responseChannel := make(chan concurrency.ErrorOr[*schema.TranscriptionResult])
go func(request *schema.OpenAIRequest) {
bc, request, err := tbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, tbs.appConfig)
if err != nil {
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: fmt.Errorf("failed reading parameters from request:%w", err)}
close(responseChannel)
return
}
tr, err := modelTranscription(request.File, request.Language, tbs.ml, bc, tbs.appConfig)
if err != nil {
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: err}
close(responseChannel)
return
}
responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Value: tr}
close(responseChannel)
}(request)
return responseChannel
}
func modelTranscription(audio, language string, ml *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),

View File

@@ -7,29 +7,60 @@ import (
"path/filepath"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
type TextToSpeechBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService {
return &TextToSpeechBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) <-chan concurrency.ErrorOr[*string] {
responseChannel := make(chan concurrency.ErrorOr[*string])
go func(request *schema.TTSRequest) {
cfg, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath,
config.LoadOptionDebug(ttsbs.appConfig.Debug),
config.LoadOptionThreads(ttsbs.appConfig.Threads),
config.LoadOptionContextSize(ttsbs.appConfig.ContextSize),
config.LoadOptionF16(ttsbs.appConfig.F16),
)
if err != nil {
responseChannel <- concurrency.ErrorOr[*string]{Error: err}
close(responseChannel)
return
}
if request.Backend != "" {
cfg.Backend = request.Backend
}
outFile, _, err := modelTTS(cfg.Backend, request.Input, cfg.Model, request.Voice, ttsbs.ml, ttsbs.appConfig, cfg)
if err != nil {
responseChannel <- concurrency.ErrorOr[*string]{Error: err}
close(responseChannel)
return
}
responseChannel <- concurrency.ErrorOr[*string]{Value: &outFile}
close(responseChannel)
}(request)
return responseChannel
}
func modelTTS(backend, text, modelFile string, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig *config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
@@ -37,7 +68,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
opts := modelOpts(&config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
@@ -53,7 +84,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
@@ -87,3 +118,19 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
return filePath, res, err
}
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}

View File

@@ -4,7 +4,7 @@ import "embed"
type Context struct {
Debug bool `env:"LOCALAI_DEBUG,DEBUG" default:"false" hidden:"" help:"DEPRECATED, use --log-level=debug instead. Enable debug logging"`
LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug,trace" help:"Set the level of logs to output [${enum}]"`
LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug" help:"Set the level of logs to output [${enum}]"`
// This field is not a command line argument/flag, the struct tag excludes it from the parsed CLI
BackendAssets embed.FS `kong:"-"`

View File

@@ -25,7 +25,7 @@ type ModelsInstall struct {
}
type ModelsCMD struct {
List ModelsList `cmd:"" help:"List the models available in your galleries" default:"withargs"`
List ModelsList `cmd:"" help:"List the models avaiable in your galleries" default:"withargs"`
Install ModelsInstall `cmd:"" help:"Install a model from the gallery"`
}
@@ -64,11 +64,7 @@ func (mi *ModelsInstall) Run(ctx *Context) error {
progressbar.OptionClearOnFinish(),
)
progressCallback := func(fileName string, current string, total string, percentage float64) {
v := int(percentage * 10)
err := progressBar.Set(v)
if err != nil {
log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar")
}
progressBar.Set(int(percentage * 10))
}
err := gallery.InstallModelFromGallery(galleries, modelName, mi.ModelsPath, gallery.GalleryModel{}, progressCallback)
if err != nil {

View File

@@ -2,31 +2,30 @@ package cli
import (
"fmt"
"os"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
type RunCMD struct {
ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"`
AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"`
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"`
AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"`
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
// The alias on this option is there to preserve functionality with the old `--config-file` parameter
ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"`
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"`
Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models"`
AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models"`
RemoteLibrary string `env:"LOCALAI_REMOTE_LIBRARY,REMOTE_LIBRARY" default:"${remoteLibraryURL}" help:"A LocalAI remote library URL" group:"models"`
PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"`
@@ -42,7 +41,7 @@ type RunCMD struct {
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
DisableWelcome bool `env:"LOCALAI_DISABLE_WELCOME,DISABLE_WELCOME" default:"false" help:"Disable welcome pages" group:"api"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
@@ -61,16 +60,15 @@ func (r *RunCMD) Run(ctx *Context) error {
config.WithYAMLConfigPreload(r.PreloadModelsConfig),
config.WithModelPath(r.ModelsPath),
config.WithContextSize(r.ContextSize),
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
config.WithDebug(*ctx.LogLevel == "debug"),
config.WithImageDir(r.ImagePath),
config.WithAudioDir(r.AudioPath),
config.WithUploadDir(r.UploadPath),
config.WithConfigsDir(r.ConfigPath),
config.WithDynamicConfigDir(r.LocalaiConfigDir),
config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval),
config.WithF16(r.F16),
config.WithStringGalleries(r.Galleries),
config.WithModelLibraryURL(r.RemoteLibrary),
config.WithDisableMessage(false),
config.WithCors(r.CORS),
config.WithCorsAllowOrigins(r.CORSAllowOrigins),
config.WithThreads(r.Threads),
@@ -84,8 +82,8 @@ func (r *RunCMD) Run(ctx *Context) error {
idleWatchDog := r.EnableWatchdogIdle
busyWatchDog := r.EnableWatchdogBusy
if r.DisableWebUI {
opts = append(opts, config.DisableWebUI)
if r.DisableWelcome {
opts = append(opts, config.DisableWelcomePage)
}
if idleWatchDog || busyWatchDog {
@@ -126,16 +124,28 @@ func (r *RunCMD) Run(ctx *Context) error {
}
if r.PreloadBackendOnly {
_, _, _, err := startup.Startup(opts...)
_, err := startup.Startup(opts...)
return err
}
cl, ml, options, err := startup.Startup(opts...)
application, err := startup.Startup(opts...)
if err != nil {
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
appHTTP, err := http.App(cl, ml, options)
// Watch the configuration directory
// If the directory does not exist, we don't watch it
if _, err := os.Stat(r.LocalaiConfigDir); err == nil {
closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, application.ApplicationConfig)
defer closeConfigWatcherFn()
if err != nil {
return fmt.Errorf("failed while watching configuration directory %s", r.LocalaiConfigDir)
}
}
appHTTP, err := http.App(application)
if err != nil {
log.Error().Err(err).Msg("error during HTTP App construction")
return err

View File

@@ -7,8 +7,8 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type TranscriptCMD struct {
@@ -42,18 +42,23 @@ func (t *TranscriptCMD) Run(ctx *Context) error {
c.Threads = &t.Threads
defer func() {
err := ml.StopAllGRPC()
if err != nil {
log.Error().Err(err).Msg("unable to stop all grpc processes")
}
}()
defer ml.StopAllGRPC()
tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts)
if err != nil {
return err
tbs := backend.NewTranscriptionBackendService(ml, cl, opts)
resultChannel := tbs.Transcribe(&schema.OpenAIRequest{
PredictionOptions: schema.PredictionOptions{
Language: t.Language,
},
File: t.Filename,
})
r := <-resultChannel
if r.Error != nil {
return r.Error
}
for _, segment := range tr.Segments {
for _, segment := range r.Value.Segments {
fmt.Println(segment.Start.String(), "-", segment.Text)
}
return nil

View File

@@ -9,8 +9,8 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type TTSCMD struct {
@@ -41,27 +41,31 @@ func (t *TTSCMD) Run(ctx *Context) error {
}
ml := model.NewModelLoader(opts.ModelPath)
defer func() {
err := ml.StopAllGRPC()
if err != nil {
log.Error().Err(err).Msg("unable to stop all grpc processes")
}
}()
defer ml.StopAllGRPC()
options := config.BackendConfig{}
options.SetDefaults()
ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts)
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options)
if err != nil {
return err
request := &schema.TTSRequest{
Model: t.Model,
Input: text,
Backend: t.Backend,
Voice: t.Voice,
}
resultsChannel := ttsbs.TextToAudioFile(request)
rawResult := <-resultsChannel
if rawResult.Error != nil {
return rawResult.Error
}
if outputFile != "" {
if err := os.Rename(filePath, outputFile); err != nil {
if err := os.Rename(*rawResult.Value, outputFile); err != nil {
return err
}
fmt.Printf("Generate file %s\n", outputFile)
fmt.Printf("Generated file %q\n", outputFile)
} else {
fmt.Printf("Generate file %s\n", filePath)
fmt.Printf("Generated file %q\n", *rawResult.Value)
}
return nil
}

View File

@@ -15,15 +15,13 @@ type ApplicationConfig struct {
ConfigFile string
ModelPath string
UploadLimitMB, Threads, ContextSize int
DisableWebUI bool
DisableWelcomePage bool
F16 bool
Debug bool
Debug, DisableMessage bool
ImageDir string
AudioDir string
UploadDir string
ConfigsDir string
DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration
CORS bool
PreloadJSONModels string
PreloadModelsFromPath string
@@ -57,11 +55,12 @@ type AppOption func(*ApplicationConfig)
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
opt := &ApplicationConfig{
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
ContextSize: 512,
Debug: true,
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
ContextSize: 512,
Debug: true,
DisableMessage: true,
}
for _, oo := range o {
oo(opt)
@@ -107,8 +106,8 @@ var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
o.WatchDogBusy = true
}
var DisableWebUI = func(o *ApplicationConfig) {
o.DisableWebUI = true
var DisableWelcomePage = func(o *ApplicationConfig) {
o.DisableWelcomePage = true
}
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
@@ -235,6 +234,12 @@ func WithDebug(debug bool) AppOption {
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *ApplicationConfig) {
o.DisableMessage = disableMessage
}
}
func WithAudioDir(audioDir string) AppOption {
return func(o *ApplicationConfig) {
o.AudioDir = audioDir
@@ -259,18 +264,6 @@ func WithConfigsDir(configsDir string) AppOption {
}
}
func WithDynamicConfigDir(dynamicConfigsDir string) AppOption {
return func(o *ApplicationConfig) {
o.DynamicConfigsDir = dynamicConfigsDir
}
}
func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption {
return func(o *ApplicationConfig) {
o.DynamicConfigsDirPollInterval = interval
}
}
func WithApiKeys(apiKeys []string) AppOption {
return func(o *ApplicationConfig) {
o.ApiKeys = apiKeys

View File

@@ -1,12 +1,7 @@
package config
import (
"os"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/functions"
"github.com/go-skynet/LocalAI/pkg/utils"
)
const (
@@ -29,7 +24,7 @@ type BackendConfig struct {
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
FunctionsConfig Functions `yaml:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
@@ -129,7 +124,6 @@ type LLMConfig struct {
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`
RopeScaling string `yaml:"rope_scaling"`
@@ -148,6 +142,13 @@ type AutoGPTQ struct {
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
}
type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}
type TemplateConfig struct {
Chat string `yaml:"chat"`
ChatMessage string `yaml:"chat_message"`
@@ -173,36 +174,6 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) MMProjFileName() string {
modelURL := downloader.ConvertURL(c.MMProj)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}
return c.MMProj
}
func (c *BackendConfig) IsMMProjURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj))
}
func (c *BackendConfig) IsModelURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.Model))
}
// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) ModelFileName() string {
modelURL := downloader.ConvertURL(c.Model)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}
return c.Model
}
func (c *BackendConfig) FunctionToCall() string {
if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
@@ -213,7 +184,7 @@ func (c *BackendConfig) FunctionToCall() string {
}
func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
lo := &LoadOptions{}
lo := &ConfigLoaderOptions{}
lo.Apply(opts...)
ctx := lo.ctxSize
@@ -224,15 +195,15 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
defaultTopP := 0.95
defaultTopK := 40
defaultTemp := 0.9
defaultMaxTokens := 2048
defaultMirostat := 2
defaultMirostatTAU := 5.0
defaultMirostatETA := 0.1
defaultTypicalP := 1.0
defaultTFZ := 1.0
defaultZero := 0
// Try to offload all GPU layers (if GPU is found)
defaultHigh := 99999999
defaultNGPULayers := 99999999
trueV := true
falseV := false
@@ -257,13 +228,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
if cfg.MMap == nil {
// MMap is enabled by default
// Only exception is for Intel GPUs
if os.Getenv("XPU") != "" {
cfg.MMap = &falseV
} else {
cfg.MMap = &trueV
}
cfg.MMap = &trueV
}
if cfg.MMlock == nil {
@@ -279,7 +244,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
}
if cfg.Maxtokens == nil {
cfg.Maxtokens = &defaultZero
cfg.Maxtokens = &defaultMaxTokens
}
if cfg.Mirostat == nil {
@@ -294,7 +259,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
cfg.MirostatTAU = &defaultMirostatTAU
}
if cfg.NGPULayers == nil {
cfg.NGPULayers = &defaultHigh
cfg.NGPULayers = &defaultNGPULayers
}
if cfg.LowVRAM == nil {

View File

@@ -1,6 +1,7 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
@@ -13,9 +14,10 @@ import (
"github.com/charmbracelet/glamour"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/grammar"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
"gopkg.in/yaml.v2"
)
type BackendConfigLoader struct {
@@ -23,179 +25,96 @@ type BackendConfigLoader struct {
sync.Mutex
}
type LoadOptions struct {
type ConfigLoaderOptions struct {
debug bool
threads, ctxSize int
f16 bool
}
func LoadOptionDebug(debug bool) ConfigLoaderOption {
return func(o *LoadOptions) {
return func(o *ConfigLoaderOptions) {
o.debug = debug
}
}
func LoadOptionThreads(threads int) ConfigLoaderOption {
return func(o *LoadOptions) {
return func(o *ConfigLoaderOptions) {
o.threads = threads
}
}
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
return func(o *LoadOptions) {
return func(o *ConfigLoaderOptions) {
o.ctxSize = ctxSize
}
}
func LoadOptionF16(f16 bool) ConfigLoaderOption {
return func(o *LoadOptions) {
return func(o *ConfigLoaderOptions) {
o.f16 = f16
}
}
type ConfigLoaderOption func(*LoadOptions)
type ConfigLoaderOption func(*ConfigLoaderOptions)
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
func (lo *ConfigLoaderOptions) Apply(options ...ConfigLoaderOption) {
for _, l := range options {
l(lo)
}
}
// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(opts...)
}
return *c, nil
}
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
return c, nil
}
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm.configs[cc.Name] = *cc
}
return nil
}
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfig(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cl.configs[c.Name] = *c
bcl.configs[c.Name] = *c
return nil
}
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
bcl.Lock()
defer bcl.Unlock()
v, exists := bcl.configs[m]
return v, exists
}
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
for _, v := range bcl.configs {
res = append(res, v)
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Name < res[j].Name
})
return res
}
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) ListBackendConfigs() []string {
bcl.Lock()
defer bcl.Unlock()
var res []string
for k := range cl.configs {
for k := range bcl.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
bcl.Lock()
defer bcl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
@@ -217,10 +136,10 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}
for i, config := range cl.configs {
for i, config := range bcl.configs {
// Download files and verify their SHA
for i, file := range config.DownloadFiles {
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
@@ -229,66 +148,49 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}
// If the model is an URL, expand it, and download the file
if config.IsModelURL() {
modelFileName := config.ModelFileName()
modelURL := downloader.ConvertURL(config.Model)
modelURL := config.PredictionOptions.Model
modelURL = downloader.ConvertURL(modelURL)
if downloader.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
cc := cl.configs[i]
cc := bcl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
c.PredictionOptions.Model = md5Name
bcl.configs[i] = *c
}
if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
if bcl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
}
if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
}
if cl.configs[i].Description != "" {
if bcl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
glamText(bcl.configs[i].Description)
}
if cl.configs[i].Usage != "" {
if bcl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
glamText(bcl.configs[i].Usage)
}
}
return nil
}
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
@@ -303,15 +205,305 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
strings.HasPrefix(file.Name(), ".") {
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
c, err := readBackendConfig(filepath.Join(path, file.Name()), opts...)
if err == nil {
cm.configs[c.Name] = *c
bcl.configs[c.Name] = *c
}
}
return nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
bcl.configs[cc.Name] = *cc
}
return nil
}
//////////
// Load a config file for a model
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName string, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := bcl.LoadBackendConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func readBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(opts...)
}
return *c, nil
}
func readBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
return c, nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigForModelAndOpenAIRequest(modelFile string, input *schema.OpenAIRequest, appConfig *ApplicationConfig) (*BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := bcl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionDebug(appConfig.Debug),
LoadOptionF16(appConfig.F16),
LoadOptionThreads(appConfig.Threads),
)
// Set the parameters for the language model prediction
updateBackendConfigFromOpenAIRequest(cfg, input)
return cfg, input, err
}
func updateBackendConfigFromOpenAIRequest(bc *BackendConfig, request *schema.OpenAIRequest) {
if request.Echo {
bc.Echo = request.Echo
}
if request.TopK != nil && *request.TopK != 0 {
bc.TopK = request.TopK
}
if request.TopP != nil && *request.TopP != 0 {
bc.TopP = request.TopP
}
if request.Backend != "" {
bc.Backend = request.Backend
}
if request.ClipSkip != 0 {
bc.Diffusers.ClipSkip = request.ClipSkip
}
if request.ModelBaseName != "" {
bc.AutoGPTQ.ModelBaseName = request.ModelBaseName
}
if request.NegativePromptScale != 0 {
bc.NegativePromptScale = request.NegativePromptScale
}
if request.UseFastTokenizer {
bc.UseFastTokenizer = request.UseFastTokenizer
}
if request.NegativePrompt != "" {
bc.NegativePrompt = request.NegativePrompt
}
if request.RopeFreqBase != 0 {
bc.RopeFreqBase = request.RopeFreqBase
}
if request.RopeFreqScale != 0 {
bc.RopeFreqScale = request.RopeFreqScale
}
if request.Grammar != "" {
bc.Grammar = request.Grammar
}
if request.Temperature != nil && *request.Temperature != 0 {
bc.Temperature = request.Temperature
}
if request.Maxtokens != nil && *request.Maxtokens != 0 {
bc.Maxtokens = request.Maxtokens
}
switch stop := request.Stop.(type) {
case string:
if stop != "" {
bc.StopWords = append(bc.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
bc.StopWords = append(bc.StopWords, s)
}
}
}
if len(request.Tools) > 0 {
for _, tool := range request.Tools {
request.Functions = append(request.Functions, tool.Function)
}
}
if request.ToolsChoice != nil {
var toolChoice grammar.Tool
switch content := request.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
request.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range request.Messages {
switch content := m.Content.(type) {
case string:
request.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
request.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
if err == nil {
request.Messages[i].StringImages = append(request.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
request.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + request.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if request.RepeatPenalty != 0 {
bc.RepeatPenalty = request.RepeatPenalty
}
if request.FrequencyPenalty != 0 {
bc.FrequencyPenalty = request.FrequencyPenalty
}
if request.PresencePenalty != 0 {
bc.PresencePenalty = request.PresencePenalty
}
if request.Keep != 0 {
bc.Keep = request.Keep
}
if request.Batch != 0 {
bc.Batch = request.Batch
}
if request.IgnoreEOS {
bc.IgnoreEOS = request.IgnoreEOS
}
if request.Seed != nil {
bc.Seed = request.Seed
}
if request.TypicalP != nil {
bc.TypicalP = request.TypicalP
}
switch inputs := request.Input.(type) {
case string:
if inputs != "" {
bc.InputStrings = append(bc.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
bc.InputStrings = append(bc.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
bc.InputToken = append(bc.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := request.FunctionCall.(type) {
case string:
if fnc != "" {
bc.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
bc.SetFunctionCallNameString(name)
}
switch p := request.Prompt.(type) {
case string:
bc.PromptStrings = append(bc.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
bc.PromptStrings = append(bc.PromptStrings, s)
}
}
}
}

View File

@@ -0,0 +1,6 @@
package config
// This file re-exports private functions to be used directly in unit tests.
// Since this file's name ends in _test.go, theoretically these should not be exposed past the tests.
var ReadBackendConfigFile = readBackendConfigFile

278
core/http/api.go Normal file
View File

@@ -0,0 +1,278 @@
package http
import (
"errors"
"strings"
"github.com/go-skynet/LocalAI/core"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/gofiber/swagger" // swagger handler
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
// @title LocalAI API
// @version 2.0.0
// @description The LocalAI Rest API.
// @termsOfService
// @contact.name LocalAI
// @contact.url https://localai.io
// @license.name MIT
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
// @BasePath /
// @securityDefinitions.apikey BearerAuth
// @in header
// @name Authorization
func App(application *core.Application) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
Views: renderEngine(),
BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: application.ApplicationConfig.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if application.ApplicationConfig.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
if !application.ApplicationConfig.Debug {
app.Use(recover.New())
}
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(application.ApplicationConfig.ApiKeys) == 0 {
return c.Next()
}
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range application.ApplicationConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if application.ApplicationConfig.CORS {
var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins})
}
app.Use(c)
}
fiberContextExtractor := fiberContext.NewFiberContextExtractor(application.ModelLoader, application.ApplicationConfig)
// LocalAI API endpoints
galleryService := services.NewGalleryService(application.ApplicationConfig.ModelPath)
galleryService.Start(application.ApplicationConfig.Context, application.BackendConfigLoader)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
app.Get("/swagger/*", swagger.HandlerDefault) // default
welcomeRoute(
app,
application.BackendConfigLoader,
application.ModelLoader,
application.ApplicationConfig,
auth,
)
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
// Stores
storeLoader := model.NewModelLoader("") // TODO: Investigate if this should be migrated to application and reused. Should the path be configurable? Merging for now.
app.Post("/stores/set", auth, localai.StoresSetEndpoint(storeLoader, application.ApplicationConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(storeLoader, application.ApplicationConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(storeLoader, application.ApplicationConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(storeLoader, application.ApplicationConfig))
// openAI compatible API endpoints
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService))
app.Post("/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService))
app.Post("/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService))
// assistant
// TODO: Refactor this to the new style eventually
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService))
app.Post("/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(fiberContextExtractor, application.TranscriptionBackendService))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(fiberContextExtractor, application.ImageGenerationBackendService))
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService))
// LocalAI TTS?
app.Post("/tts", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService))
if application.ApplicationConfig.ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig.ImageDir)
}
if application.ApplicationConfig.AudioDir != "" {
app.Static("/generated-audio", application.ApplicationConfig.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// Experimental Backend Statistics Module
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService))
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
// Define a custom 404 handler
// Note: keep this at the bottom!
app.Use(notFoundHandler)
return app, nil
}

View File

@@ -12,7 +12,9 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/schema"
@@ -205,12 +207,11 @@ var _ = Describe("API test", func() {
var cancel context.CancelFunc
var tmpdir string
var modelDir string
var bcl *config.BackendConfigLoader
var ml *model.ModelLoader
var applicationConfig *config.ApplicationConfig
var application *core.Application
commonOpts := []config.AppOption{
config.WithDebug(true),
config.WithDisableMessage(true),
}
Context("API with ephemeral models", func() {
@@ -222,7 +223,7 @@ var _ = Describe("API test", func() {
modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0750)
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background())
@@ -241,7 +242,7 @@ var _ = Describe("API test", func() {
}
out, err := yaml.Marshal(g)
Expect(err).ToNot(HaveOccurred())
err = os.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0600)
err = os.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
Expect(err).ToNot(HaveOccurred())
galleries := []gallery.Gallery{
@@ -251,7 +252,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err = startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithGalleries(galleries),
@@ -260,7 +261,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@@ -473,11 +474,11 @@ var _ = Describe("API test", func() {
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
Expect(resp2.Choices[0].Message.ToolCalls[0].Function).ToNot(BeNil())
Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
@@ -486,13 +487,14 @@ var _ = Describe("API test", func() {
})
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
modelName := "hermes-2-pro-mistral"
// if runtime.GOOS != "linux" {
// Skip("test supported only on linux")
// }
modelName := "codellama"
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ConfigURL: "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/models/hermes-2-pro-mistral.yaml",
URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml",
Name: modelName,
Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true, "context_size": 128},
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@@ -502,7 +504,7 @@ var _ = Describe("API test", func() {
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
}, "480s", "10s").Should(Equal(true))
By("testing chat")
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{
@@ -549,13 +551,15 @@ var _ = Describe("API test", func() {
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
fmt.Printf("\n--- %+v\n\n", resp2.Choices[0].Message)
Expect(resp2.Choices[0].Message.ToolCalls).ToNot(BeNil())
Expect(resp2.Choices[0].Message.ToolCalls[0]).ToNot(BeNil())
Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res))
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
})
@@ -595,7 +599,7 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0750)
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background())
@@ -607,7 +611,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err = startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithAudioDir(tmpdir),
@@ -618,7 +622,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@@ -722,14 +726,14 @@ var _ = Describe("API test", func() {
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err = startup.Startup(
append(commonOpts,
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
config.WithContext(c),
config.WithModelPath(modelPath),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@@ -759,6 +763,11 @@ var _ = Describe("API test", func() {
Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8?
})
It("can generate completions via ggml", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
@@ -766,6 +775,11 @@ var _ = Describe("API test", func() {
})
It("can generate chat completions via ggml", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
@@ -773,6 +787,11 @@ var _ = Describe("API test", func() {
})
It("can generate completions from model configs", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: testPrompt})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
@@ -780,6 +799,11 @@ var _ = Describe("API test", func() {
})
It("can generate chat completions from model configs", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
@@ -866,9 +890,9 @@ var _ = Describe("API test", func() {
Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
// if runtime.GOOS != "linux" {
// Skip("test supported only on linux")
// }
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
@@ -889,17 +913,20 @@ var _ = Describe("API test", func() {
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
if len(response.Choices) > 0 {
text += response.Choices[0].Text
tokens++
}
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
// if runtime.GOOS != "linux" {
// Skip("test supported only on linux")
// }
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
@@ -1008,14 +1035,14 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background())
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err = startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithModelPath(modelPath),
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@@ -1039,18 +1066,33 @@ var _ = Describe("API test", func() {
}
})
It("can generate chat completions from config file (list1)", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate chat completions from config file (list2)", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate edit completions from config file", func() {
bt, ok := os.LookupEnv("BUILD_TYPE")
if ok && strings.ToLower(bt) == "metal" {
Skip("GGML + Metal is known flaky, skip test temporarily")
}
request := openaigo.EditCreateRequestBody{
Model: "list2",
Instruction: "foo",

View File

@@ -1,196 +0,0 @@
package http
import (
"embed"
"errors"
"net/http"
"strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/http/routes"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
// Embed a directory
//
//go:embed static/*
var embedDirStatic embed.FS
// @title LocalAI API
// @version 2.0.0
// @description The LocalAI Rest API.
// @termsOfService
// @contact.name LocalAI
// @contact.url https://localai.io
// @license.name MIT
// @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE
// @BasePath /
// @securityDefinitions.apikey BearerAuth
// @in header
// @name Authorization
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
Views: renderEngine(),
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
app.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
if !appConfig.Debug {
app.Use(recover.New())
}
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if appConfig.CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
}
app.Use(c)
}
// Load config jsons
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(appConfig.Context, cl)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
app.Use("/static", filesystem.New(filesystem.Config{
Root: http.FS(embedDirStatic),
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler
// Note: keep this at the bottom!
app.Use(notFoundHandler)
return app, nil
}

View File

@@ -1,43 +1,88 @@
package fiberContext
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
type FiberContextExtractor struct {
ml *model.ModelLoader
appConfig *config.ApplicationConfig
}
func NewFiberContextExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContextExtractor {
return &FiberContextExtractor{
ml: ml,
appConfig: appConfig,
}
}
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
func (fce *FiberContextExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) {
ctxPM := ctx.Params("model")
if ctxPM != "" {
log.Debug().Msgf("[FCE] Overriding param modelInput %q with ctx.Params value %q", modelInput, ctxPM)
modelInput = ctxPM
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
bearer := strings.TrimPrefix(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := loader.ListModels()
models, _ := fce.ml.ListModels()
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
log.Debug().Msgf("[FCE] No model specified, using first available: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
log.Warn().Msgf("[FCE] No model specified, none available")
return "", fmt.Errorf("[fce] no model specified, none available")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
log.Debug().Msgf("[FCE] Using model from bearer token: %s", bearer)
modelInput = bearer
}
if modelInput == "" {
log.Warn().Msg("[FCE] modelInput is empty")
}
return modelInput, nil
}
// TODO: Do we still need the first return value?
func (fce *FiberContextExtractor) OpenAIRequestFromContext(c *fiber.Ctx, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(fce.appConfig.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
var err error
input.Model, err = fce.ModelFromContext(c, input.Model, firstModel)
return input.Model, input, err
}

View File

@@ -1,285 +0,0 @@
package elements
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/xsync"
)
const (
NoImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg"
)
func DoneProgress(uid, text string) string {
return elem.Div(
attrs.Props{},
elem.H3(
attrs.Props{
"role": "status",
"id": "pblabel",
"tabindex": "-1",
"autofocus": "",
},
elem.Text(text),
),
).Render()
}
func ErrorProgress(err string) string {
return elem.Div(
attrs.Props{},
elem.H3(
attrs.Props{
"role": "status",
"id": "pblabel",
"tabindex": "-1",
"autofocus": "",
},
elem.Text("Error"+err),
),
).Render()
}
func ProgressBar(progress string) string {
return elem.Div(attrs.Props{
"class": "progress",
"role": "progressbar",
"aria-valuemin": "0",
"aria-valuemax": "100",
"aria-valuenow": "0",
"aria-labelledby": "pblabel",
},
elem.Div(attrs.Props{
"id": "pb",
"class": "progress-bar",
"style": "width:" + progress + "%",
}),
).Render()
}
func StartProgressBar(uid, progress, text string) string {
if progress == "" {
progress = "0"
}
return elem.Div(attrs.Props{
"hx-trigger": "done",
"hx-get": "/browse/job/" + uid,
"hx-swap": "outerHTML",
"hx-target": "this",
},
elem.H3(
attrs.Props{
"role": "status",
"id": "pblabel",
"tabindex": "-1",
"autofocus": "",
},
elem.Text(text),
// This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms.
elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid,
"hx-trigger": "every 600ms",
"hx-target": "this",
"hx-swap": "innerHTML",
},
elem.Raw(ProgressBar(progress)),
),
),
).Render()
}
func cardSpan(text, icon string) elem.Node {
return elem.Span(
attrs.Props{
"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
},
elem.I(attrs.Props{
"class": icon + " pr-2",
}),
elem.Text(text),
)
}
func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string {
//StartProgressBar(uid, "0")
modelsElements := []elem.Node{}
// span := func(s string) elem.Node {
// return elem.Span(
// attrs.Props{
// "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs",
// },
// elem.Text(s),
// )
// }
deleteButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button(
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/delete/model/" + m.Name,
},
elem.I(
attrs.Props{
"class": "fa-solid fa-cancel pr-2",
},
),
elem.Text("Delete"),
)
}
installButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button(
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/install/model/" + fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name),
},
elem.I(
attrs.Props{
"class": "fa-solid fa-download pr-2",
},
),
elem.Text("Install"),
)
}
descriptionDiv := func(m *gallery.GalleryModel) elem.Node {
return elem.Div(
attrs.Props{
"class": "p-6 text-surface dark:text-white",
},
elem.H5(
attrs.Props{
"class": "mb-2 text-xl font-medium leading-tight",
},
elem.Text(m.Name),
),
elem.P(
attrs.Props{
"class": "mb-4 text-base",
},
elem.Text(m.Description),
),
)
}
actionDiv := func(m *gallery.GalleryModel) elem.Node {
galleryID := fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name)
currentlyInstalling := installing.Exists(galleryID)
nodes := []elem.Node{
cardSpan("Repository: "+m.Gallery.Name, "fa-brands fa-git-alt"),
}
if m.License != "" {
nodes = append(nodes,
cardSpan("License: "+m.License, "fas fa-book"),
)
}
for _, tag := range m.Tags {
nodes = append(nodes,
cardSpan(tag, "fas fa-tag"),
)
}
for i, url := range m.URLs {
nodes = append(nodes,
elem.A(
attrs.Props{
"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
"href": url,
"target": "_blank",
},
elem.I(attrs.Props{
"class": "fas fa-link pr-2",
}),
elem.Text("Link #"+fmt.Sprintf("%d", i+1)),
))
}
return elem.Div(
attrs.Props{
"class": "px-6 pt-4 pb-2",
},
elem.P(
attrs.Props{
"class": "mb-4 text-base",
},
nodes...,
),
elem.If(
currentlyInstalling,
elem.Node( // If currently installing, show progress bar
elem.Raw(StartProgressBar(installing.Get(galleryID), "0", "Installing")),
), // Otherwise, show install button (if not installed) or display "Installed"
elem.If(m.Installed,
//elem.Node(elem.Div(
// attrs.Props{},
// span("Installed"), deleteButton(m),
// )),
deleteButton(m),
installButton(m),
),
),
)
}
for _, m := range models {
elems := []elem.Node{}
if m.Icon == "" {
m.Icon = NoImage
}
elems = append(elems,
elem.Div(attrs.Props{
"class": "flex justify-center items-center",
},
elem.A(attrs.Props{
"href": "#!",
// "class": "justify-center items-center",
},
elem.Img(attrs.Props{
// "class": "rounded-t-lg object-fit object-center h-96",
"class": "rounded-t-lg max-h-48 max-w-96 object-cover mt-3",
"src": m.Icon,
}),
),
))
elems = append(elems, descriptionDiv(m), actionDiv(m))
modelsElements = append(modelsElements,
elem.Div(
attrs.Props{
"class": " me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface pb-2",
},
elem.Div(
attrs.Props{
// "class": "p-6",
},
elems...,
),
),
)
}
wrapper := elem.Div(attrs.Props{
"class": "dark grid grid-cols-1 grid-rows-1 md:grid-cols-3 block rounded-lg shadow-secondary-1 dark:bg-surface-dark",
//"class": "block rounded-lg bg-white shadow-secondary-1 dark:bg-surface-dark",
}, modelsElements...)
return wrapper.Render()
}

View File

@@ -2,9 +2,7 @@ package elevenlabs
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
@@ -17,7 +15,7 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
@@ -28,34 +26,21 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
var err error
input.ModelID, err = fce.ModelFromContext(c, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
responseChannel := ttsbs.TextToAudioFile(&schema.TTSRequest{
Model: input.ModelID,
Voice: voiceID,
Input: input.Text,
})
rawValue := <-responseChannel
if rawValue.Error != nil {
return rawValue.Error
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
return c.Download(*rawValue.Value)
}
}

View File

@@ -1,84 +0,0 @@
package jina
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Cannot parse JSON",
})
}
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
request := &proto.RerankRequest{
Query: req.Query,
TopN: int32(req.TopN),
Documents: req.Documents,
}
results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg)
if err != nil {
return err
}
response := &schema.JINARerankResponse{
Model: req.Model,
}
for _, r := range results.Results {
response.Results = append(response.Results, schema.JINADocumentResult{
Index: int(r.Index),
Document: schema.JINAText{Text: r.Text},
RelevanceScore: float64(r.RelevanceScore),
})
}
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
return c.Status(fiber.StatusOK).JSON(response)
}
}

View File

@@ -74,27 +74,6 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
}
}
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
mgs.galleryApplier.C <- gallery.GalleryOp{
Delete: true,
GalleryName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)

View File

@@ -2,9 +2,7 @@ package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
@@ -16,45 +14,26 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/audio/speech [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
if err = c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
input.Model, err = fce.ModelFromContext(c, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
responseChannel := ttsbs.TextToAudioFile(input)
rawValue := <-responseChannel
if rawValue.Error != nil {
return rawValue.Error
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
return c.Download(*rawValue.Value)
}
}

View File

@@ -1,32 +0,0 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, _ := ml.ListModels()
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"Models": models,
"ModelsConfig": backendConfigs,
"ApplicationConfig": appConfig,
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Render index
return c.Render("views/index", summary)
}
}
}

View File

@@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
}
}
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find "))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistantID %q", assistantID))
}
}
@@ -455,19 +455,21 @@ func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
for i, assistant := range Assistants {
if assistant.ID == assistantID {
for j, fileId := range assistant.FileIDs {
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...)
if fileId == fileId {
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...)
// Check if the file exists in the assistantFiles slice
for i, assistantFile := range AssistantFiles {
if assistantFile.ID == fileId {
// Remove the file from the assistantFiles slice
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{
ID: fileId,
Object: "assistant.file.deleted",
Deleted: true,
})
// Check if the file exists in the assistantFiles slice
for i, assistantFile := range AssistantFiles {
if assistantFile.ID == fileId {
// Remove the file from the assistantFiles slice
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{
ID: fileId,
Object: "assistant.file.deleted",
Deleted: true,
})
}
}
}
}

View File

@@ -3,6 +3,10 @@ package openai
import (
"encoding/json"
"fmt"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/http"
@@ -12,11 +16,6 @@ import (
"strings"
"testing"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
)
var configsDir string = "/tmp/localai/configs"
@@ -50,8 +49,8 @@ func TestAssistantEndpoints(t *testing.T) {
}
_ = os.RemoveAll(appConfig.ConfigsDir)
_ = os.MkdirAll(appConfig.ConfigsDir, 0750)
_ = os.MkdirAll(modelPath, 0750)
_ = os.MkdirAll(appConfig.ConfigsDir, 0755)
_ = os.MkdirAll(modelPath, 0755)
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j"))
app := fiber.New(fiber.Config{

View File

@@ -5,16 +5,11 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@@ -24,418 +19,82 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
responses <- resp
default:
for i, ss := range results {
name, args := ss.Name, ss.Arguments
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
}
}
close(responses)
}
func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, startupOptions, true)
_, request, err := fce.OpenAIRequestFromContext(c, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
return fmt.Errorf("failed reading parameters from request: %w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
traceID, finalResultChannel, _, tokenChannel, err := oais.Chat(request, false, request.Stream)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
return err
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = functions.JSONBNF
}
if request.Stream {
config.Grammar = input.Grammar
log.Debug().Msgf("Chat Stream request received")
if shouldUseFn {
log.Debug().Msgf("Response needs to process functions")
}
switch {
case !config.FunctionsConfig.NoGrammar && shouldUseFn:
noActionGrammar := functions.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
}
// process functions if we have any defined or if we have a function call string
// functions are not supported in stream mode (yet?)
toStream := input.Stream
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
fcall := i.FunctionCall
if len(i.ToolCalls) > 0 {
fcall = i.ToolCalls
}
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionCall: fcall,
FunctionName: i.Name,
LastMessage: messageIndex == (len(input.Messages) - 1),
Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)),
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping")
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
marshalAnyRole := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
marshalAny := func(f any) {
j, err := json.Marshal(f)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
marshalAnyRole(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAnyRole(i.ToolCalls)
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
marshalAny(i.FunctionCall)
}
if i.ToolCalls != nil {
marshalAny(i.ToolCalls)
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !shouldUseFn {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && shouldUseFn {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if shouldUseFn && config.Grammar != "" {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
}
switch {
case toStream:
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
//
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
responses := make(chan schema.OpenAIResponse)
if !shouldUseFn {
go process(predInput, input, config, ml, responses)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
for ev := range tokenChannel {
if ev.Error != nil {
log.Debug().Err(ev.Error).Msg("chat streaming responseChannel error")
request.Cancel()
break
}
usage = &ev.Value.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Value.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
if ev.Error != nil {
log.Debug().Err(ev.Error).Msg("[ChatEndpoint] error to debug during tokenChannel handler")
enc.Encode(ev.Error)
} else {
enc.Encode(ev.Value)
}
log.Debug().Msgf("chat streaming sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
log.Debug().Err(err).Msgf("Sending chunk failed")
request.Cancel()
break
}
err = w.Flush()
if err != nil {
log.Debug().Msg("error while flushing, closing connection")
request.Cancel()
break
}
w.Flush()
}
finishReason := "stop"
if toolsCalled {
finishReason = "tool_calls"
} else if toolsCalled && len(input.Tools) == 0 {
} else if toolsCalled && len(request.Tools) == 0 {
finishReason = "function_call"
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
Delta: &schema.Message{Content: ""},
}},
Object: "chat.completion.chunk",
Usage: *usage,
@@ -446,146 +105,21 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !shouldUseFn {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.Name, ss.Arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
// TODO is this proper to have exclusive from Stream, or do we need to issue both responses?
rawResponse := <-finalResultChannel
if rawResponse.Error != nil {
return rawResponse.Error
}
jsonResult, _ := json.Marshal(rawResponse.Value)
log.Debug().Str("jsonResult", string(jsonResult)).Msg("Chat Final Response")
// Return the prediction in the response body
return c.JSON(rawResponse.Value)
}
}
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply")
arg := ""
if len(funcResults) > 0 {
arg = funcResults[0].Arguments
}
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
if err := json.Unmarshal([]byte(arg), &arguments); err != nil {
log.Debug().Msg("handleQuestion: function result did not contain a valid JSON object")
}
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, prompt, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
return message, nil
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU/GPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err
}
prediction, err := predFunc()
if err != nil {
log.Error().Err(err).Msg("prediction failed")
return "", err
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}

View File

@@ -4,18 +4,13 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@@ -25,116 +20,50 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
},
},
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
log.Debug().Msgf("Sending goroutine: %s", s)
responses <- resp
return true
})
close(responses)
}
func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
_, request, err := fce.OpenAIRequestFromContext(c, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
log.Debug().Msgf("`OpenAIRequest`: %+v", request)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
traceID, finalResultChannel, _, _, tokenChannel, err := oais.Completion(request, false, request.Stream)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
return err
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = functions.JSONBNF
}
if request.Stream {
log.Debug().Msgf("Completion Stream request received")
config.Grammar = input.Grammar
log.Debug().Msgf("Parameter Config: %+v", config)
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
}
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, ml, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for ev := range responses {
for ev := range tokenChannel {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
if ev.Error != nil {
log.Debug().Msgf("[CompletionEndpoint] error to debug during tokenChannel handler: %q", ev.Error)
enc.Encode(ev.Error)
} else {
enc.Encode(ev.Value)
}
log.Debug().Msgf("Sending chunk: %s", buf.String())
log.Debug().Msgf("completion streaming sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
ID: traceID.ID,
Created: traceID.Created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
@@ -151,55 +80,15 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}))
return nil
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
// TODO is this proper to have exclusive from Stream, or do we need to issue both responses?
rawResponse := <-finalResultChannel
if rawResponse.Error != nil {
return rawResponse.Error
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "text_completion",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
jsonResult, _ := json.Marshal(rawResponse.Value)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(rawResponse.Value)
}
}

View File

@@ -3,92 +3,36 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EditEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, ml, appConfig, true)
_, request, err := fce.OpenAIRequestFromContext(c, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
_, finalResultChannel, _, _, _, err := oais.Edit(request, false, request.Stream)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
return err
}
log.Debug().Msgf("Parameter Config: %+v", config)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
rawResponse := <-finalResultChannel
if rawResponse.Error != nil {
return rawResponse.Error
}
if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
}
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {
return err
}
totalTokenUsage.Prompt += tokenUsage.Prompt
totalTokenUsage.Completion += tokenUsage.Completion
result = append(result, r...)
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "edit",
Usage: schema.OpenAIUsage{
PromptTokens: totalTokenUsage.Prompt,
CompletionTokens: totalTokenUsage.Completion,
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
},
}
jsonResult, _ := json.Marshal(resp)
jsonResult, _ := json.Marshal(rawResponse.Value)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(rawResponse.Value)
}
}

View File

@@ -3,14 +3,9 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@@ -21,63 +16,25 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EmbeddingsEndpoint(fce *fiberContext.FiberContextExtractor, ebs *backend.EmbeddingsBackendService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, ml, appConfig, true)
_, input, err := fce.OpenAIRequestFromContext(c, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
responseChannel := ebs.Embeddings(input)
rawResponse := <-responseChannel
if rawResponse.Error != nil {
return rawResponse.Error
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
jsonResult, _ := json.Marshal(rawResponse.Value)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(rawResponse.Value)
}
}

View File

@@ -251,7 +251,7 @@ func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipar
// Helper to create test files
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
err := os.MkdirAll(option.UploadDir, 0750)
err := os.MkdirAll(option.UploadDir, 0755)
if err != nil {
t.Fatalf("Error MKDIR: %v", err)

View File

@@ -1,50 +1,18 @@
package openai
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/backend"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "image")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}
//
// https://platform.openai.com/docs/api-reference/images/create
/*
*
@@ -59,186 +27,36 @@ func downloadFile(url string) (string, error) {
*
*/
// ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create
// @Summary Creates an image given a prompt.
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func ImageEndpoint(fce *fiberContext.FiberContextExtractor, igbs *backend.ImageGenerationBackendService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
// TODO: Somewhat a hack. Is there a better place to assign this?
if igbs.BaseUrlForGeneratedImages == "" {
igbs.BaseUrlForGeneratedImages = c.BaseURL() + "/generated-images/"
}
_, request, err := fce.OpenAIRequestFromContext(c, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if m == "" {
m = model.StableDiffusionBackend
}
log.Debug().Msgf("Loading model: %+v", m)
responseChannel := igbs.GenerateImage(request)
rawResponse := <-responseChannel
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if rawResponse.Error != nil {
return rawResponse.Error
}
jsonResult, err := json.Marshal(rawResponse.Value)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
return err
}
src := ""
if input.File != "" {
fileData := []byte{}
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
out, err := downloadFile(input.File)
if err != nil {
return fmt.Errorf("failed downloading file:%w", err)
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
return fmt.Errorf("failed reading file:%w", err)
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(input.File)
if err != nil {
return err
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
if err != nil {
return err
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
return err
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionBackend
case "tinydream":
config.Backend = model.TinyDreamBackend
case "":
config.Backend = model.StableDiffusionBackend
}
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return fmt.Errorf("invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return fmt.Errorf("invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return fmt.Errorf("invalid value for 'size'")
}
b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
n := input.N
if input.N == 0 {
n = 1
}
for j := 0; j < n; j++ {
prompts := strings.Split(i, "|")
positive_prompt := prompts[0]
negative_prompt := ""
if len(prompts) > 1 {
negative_prompt = prompts[1]
}
mode := 0
step := config.Step
if step == 0 {
step = 15
}
if input.Mode != 0 {
mode = input.Mode
}
if input.Step != 0 {
step = input.Step
}
tempDir := ""
if !b64JSON {
tempDir = appConfig.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return err
}
outputFile.Close()
output := outputFile.Name() + ".png"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return err
}
baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
if err != nil {
return err
}
if err := fn(); err != nil {
return err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-images/" + base
}
result = append(result, *item)
}
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: result,
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(rawResponse.Value)
}
}

View File

@@ -1,55 +0,0 @@
package openai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.BackendConfig,
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
images := []string{}
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
}

View File

@@ -10,7 +10,6 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
return func(c *fiber.Ctx) error {
// If blank, no filter is applied.
filter := c.Query("filter")
// By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true)
@@ -18,6 +17,7 @@ func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) er
if err != nil {
return err
}
return c.JSON(struct {
Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"`

View File

@@ -1,285 +0,0 @@
package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
return modelFile, input, err
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
return cfg, input, err
}

View File

@@ -9,8 +9,7 @@ import (
"path/filepath"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
model "github.com/go-skynet/LocalAI/pkg/model"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@@ -23,17 +22,15 @@ import (
// @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.TranscriptionBackendService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, ml, appConfig, false)
_, request, err := fce.OpenAIRequestFromContext(c, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// TODO: Investigate this file copy stuff later - potentially belongs in service.
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
@@ -65,13 +62,16 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
if err != nil {
return err
}
request.File = dst
log.Debug().Msgf("Trascribed: %+v", tr)
responseChannel := tbs.Transcribe(request)
rawResponse := <-responseChannel
if rawResponse.Error != nil {
return rawResponse.Error
}
log.Debug().Msgf("Transcribed: %+v", rawResponse.Value)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(tr)
return c.Status(http.StatusOK).JSON(rawResponse.Value)
}
}

View File

@@ -7,10 +7,12 @@ import (
"net/http"
"github.com/Masterminds/sprig/v3"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
fiberhtml "github.com/gofiber/template/html/v2"
"github.com/microcosm-cc/bluemonday"
"github.com/russross/blackfriday"
)
@@ -31,6 +33,40 @@ func notFoundHandler(c *fiber.Ctx) error {
return nil
}
func welcomeRoute(
app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error,
) {
if appConfig.DisableWelcomePage {
return
}
models, _ := ml.ListModels()
backendConfigs := cl.GetAllBackendConfigs()
app.Get("/", auth, func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"Models": models,
"ModelsConfig": backendConfigs,
"ApplicationConfig": appConfig,
}
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Render index
return c.Render("views/index", summary)
}
})
}
func renderEngine() *fiberhtml.Engine {
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
engine.AddFuncMap(sprig.FuncMap())
@@ -40,5 +76,5 @@ func renderEngine() *fiberhtml.Engine {
func markDowner(args ...interface{}) template.HTML {
s := blackfriday.MarkdownCommon([]byte(fmt.Sprintf("%s", args...)))
return template.HTML(bluemonday.UGCPolicy().Sanitize(string(s)))
return template.HTML(s)
}

View File

@@ -1,19 +0,0 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterElevenLabsRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
}

View File

@@ -1,19 +0,0 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/jina"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterJINARoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
}

View File

@@ -1,65 +0,0 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/swagger"
)
func RegisterLocalAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
app.Get("/swagger/*", swagger.HandlerDefault) // default
// LocalAI API endpoints
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
// Kubernetes health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
}

View File

@@ -1,88 +0,0 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterOpenAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
// assistant
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
}
if appConfig.AudioDir != "" {
app.Static("/generated-audio", appConfig.AudioDir)
}
// models
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
}

View File

@@ -1,273 +0,0 @@
package routes
import (
"fmt"
"html/template"
"strings"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/elements"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/xsync"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
)
func RegisterUIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml))
// keeps the state of models that are being installed from the UI
var installingModels = xsync.NewSyncedMap[string, string]()
// Show the Models page (all models)
app.Get("/browse", auth, func(c *fiber.Ctx) error {
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
summary := fiber.Map{
"Title": "LocalAI - Models",
"Version": internal.PrintableVersion(),
"Models": template.HTML(elements.ListModels(models, installingModels)),
"Repositories": appConfig.Galleries,
// "ApplicationConfig": appConfig,
}
// Render index
return c.Render("views/models", summary)
})
// Show the models, filtered from the user input
// https://htmx.org/examples/active-search/
app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error {
form := struct {
Search string `form:"search"`
}{}
if err := c.BodyParser(&form); err != nil {
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
filteredModels := []*gallery.GalleryModel{}
for _, m := range models {
if strings.Contains(m.Name, form.Search) ||
strings.Contains(m.Description, form.Search) ||
strings.Contains(m.Gallery.Name, form.Search) ||
strings.Contains(strings.Join(m.Tags, ","), form.Search) {
filteredModels = append(filteredModels, m)
}
}
return c.SendString(elements.ListModels(filteredModels, installingModels))
})
/*
Install routes
*/
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
id, err := uuid.NewUUID()
if err != nil {
return err
}
uid := id.String()
installingModels.Set(galleryID, uid)
op := gallery.GalleryOp{
Id: uid,
GalleryName: galleryID,
Galleries: appConfig.Galleries,
}
go func() {
galleryService.C <- op
}()
return c.SendString(elements.StartProgressBar(uid, "0", "Installation"))
})
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
id, err := uuid.NewUUID()
if err != nil {
return err
}
uid := id.String()
installingModels.Set(galleryID, uid)
op := gallery.GalleryOp{
Id: uid,
Delete: true,
GalleryName: galleryID,
}
go func() {
galleryService.C <- op
}()
return c.SendString(elements.StartProgressBar(uid, "0", "Deletion"))
})
// Display the job current progress status
// If the job is done, we trigger the /browse/job/:uid route
// https://htmx.org/examples/progress-bar/
app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error {
jobUID := c.Params("uid")
status := galleryService.GetStatus(jobUID)
if status == nil {
//fmt.Errorf("could not find any status for ID")
return c.SendString(elements.ProgressBar("0"))
}
if status.Progress == 100 {
c.Set("HX-Trigger", "done") // this triggers /browse/job/:uid (which is when the job is done)
return c.SendString(elements.ProgressBar("100"))
}
if status.Error != nil {
return c.SendString(elements.ErrorProgress(status.Error.Error()))
}
return c.SendString(elements.ProgressBar(fmt.Sprint(status.Progress)))
})
// this route is hit when the job is done, and we display the
// final state (for now just displays "Installation completed")
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {
status := galleryService.GetStatus(c.Params("uid"))
for _, k := range installingModels.Keys() {
if installingModels.Get(k) == c.Params("uid") {
installingModels.Delete(k)
}
}
displayText := "Installation completed"
if status.Deletion {
displayText = "Deletion completed"
}
return c.SendString(elements.DoneProgress(c.Params("uid"), displayText))
})
// Show the Chat page
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
"Title": "LocalAI - Chat with " + c.Params("model"),
"ModelsConfig": backendConfigs,
"Model": c.Params("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/chat", summary)
})
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
if len(backendConfigs) == 0 {
return c.SendString("No models available")
}
summary := fiber.Map{
"Title": "LocalAI - Chat with " + backendConfigs[0].Name,
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0].Name,
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/chat", summary)
})
app.Get("/text2image/:model", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"ModelsConfig": backendConfigs,
"Model": c.Params("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/text2image", summary)
})
app.Get("/text2image/", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
if len(backendConfigs) == 0 {
return c.SendString("No models available")
}
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + backendConfigs[0].Name,
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0].Name,
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/text2image", summary)
})
app.Get("/tts/:model", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
"Title": "LocalAI - Generate images with " + c.Params("model"),
"ModelsConfig": backendConfigs,
"Model": c.Params("model"),
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/tts", summary)
})
app.Get("/tts/", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
if len(backendConfigs) == 0 {
return c.SendString("No models available")
}
summary := fiber.Map{
"Title": "LocalAI - Generate audio with " + backendConfigs[0].Name,
"ModelsConfig": backendConfigs,
"Model": backendConfigs[0].Name,
"Version": internal.PrintableVersion(),
}
// Render index
return c.Render("views/tts", summary)
})
}

View File

@@ -1,137 +0,0 @@
/*
https://github.com/david-haerer/chatapi
MIT License
Copyright (c) 2023 David Härer
Copyright (c) 2024 Ettore Di Giacinto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
function submitKey(event) {
event.preventDefault();
localStorage.setItem("key", document.getElementById("apiKey").value);
document.getElementById("apiKey").blur();
}
function submitPrompt(event) {
event.preventDefault();
const input = document.getElementById("input").value;
Alpine.store("chat").add("user", input);
document.getElementById("input").value = "";
const key = localStorage.getItem("key");
promptGPT(key, input);
}
async function promptGPT(key, input) {
const model = document.getElementById("chat-model").value;
// Set class "loader" to the element with "loader" id
//document.getElementById("loader").classList.add("loader");
// Make the "loader" visible
document.getElementById("loader").style.display = "block";
document.getElementById("input").disabled = true;
document.getElementById('messages').scrollIntoView(false)
// Source: https://stackoverflow.com/a/75751803/11386095
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
Authorization: `Bearer ${key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model: model,
messages: Alpine.store("chat").messages(),
stream: true,
}),
});
if (!response.ok) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: POST /v1/chat/completions ${response.status}</span>`,
);
return;
}
const reader = response.body
?.pipeThrough(new TextDecoderStream())
.getReader();
if (!reader) {
Alpine.store("chat").add(
"assistant",
`<span class='error'>Error: Failed to decode API response</span>`,
);
return;
}
while (true) {
const { value, done } = await reader.read();
if (done) break;
let dataDone = false;
const arr = value.split("\n");
arr.forEach((data) => {
if (data.length === 0) return;
if (data.startsWith(":")) return;
if (data === "data: [DONE]") {
dataDone = true;
return;
}
const token = JSON.parse(data.substring(6)).choices[0].delta.content;
if (!token) {
return;
}
hljs.highlightAll();
Alpine.store("chat").add("assistant", token);
document.getElementById('messages').scrollIntoView(false)
});
hljs.highlightAll();
if (dataDone) break;
}
// Remove class "loader" from the element with "loader" id
//document.getElementById("loader").classList.remove("loader");
document.getElementById("loader").style.display = "none";
// enable input
document.getElementById("input").disabled = false;
// scroll to the bottom of the chat
document.getElementById('messages').scrollIntoView(false)
// set focus to the input
document.getElementById("input").focus();
}
document.getElementById("key").addEventListener("submit", submitKey);
document.getElementById("prompt").addEventListener("submit", submitPrompt);
document.getElementById("input").focus();
const storeKey = localStorage.getItem("key");
if (storeKey) {
document.getElementById("apiKey").value = storeKey;
}
marked.setOptions({
highlight: function (code) {
return hljs.highlightAuto(code).value;
},
});

View File

@@ -1,93 +0,0 @@
body {
font-family: 'Inter', sans-serif;
}
.chat-container { height: 90vh; display: flex; flex-direction: column; }
.chat-messages { overflow-y: auto; flex-grow: 1; }
.htmx-indicator{
opacity:0;
transition: opacity 10ms ease-in;
}
.htmx-request .htmx-indicator{
opacity:1
}
/* Loader (https://cssloaders.github.io/) */
.loader {
width: 12px;
height: 12px;
border-radius: 50%;
display: block;
margin:15px auto;
position: relative;
color: #FFF;
box-sizing: border-box;
animation: animloader 2s linear infinite;
}
@keyframes animloader {
0% { box-shadow: 14px 0 0 -2px, 38px 0 0 -2px, -14px 0 0 -2px, -38px 0 0 -2px; }
25% { box-shadow: 14px 0 0 -2px, 38px 0 0 -2px, -14px 0 0 -2px, -38px 0 0 2px; }
50% { box-shadow: 14px 0 0 -2px, 38px 0 0 -2px, -14px 0 0 2px, -38px 0 0 -2px; }
75% { box-shadow: 14px 0 0 2px, 38px 0 0 -2px, -14px 0 0 -2px, -38px 0 0 -2px; }
100% { box-shadow: 14px 0 0 -2px, 38px 0 0 2px, -14px 0 0 -2px, -38px 0 0 -2px; }
}
.progress {
height: 20px;
margin-bottom: 20px;
overflow: hidden;
background-color: #f5f5f5;
border-radius: 4px;
box-shadow: inset 0 1px 2px rgba(0,0,0,.1);
}
.progress-bar {
float: left;
width: 0%;
height: 100%;
font-size: 12px;
line-height: 20px;
color: #fff;
text-align: center;
background-color: #337ab7;
-webkit-box-shadow: inset 0 -1px 0 rgba(0,0,0,.15);
box-shadow: inset 0 -1px 0 rgba(0,0,0,.15);
-webkit-transition: width .6s ease;
-o-transition: width .6s ease;
transition: width .6s ease;
}
.user {
background-color: #007bff;
}
.assistant {
background-color: #28a745;
}
.message {
display: flex;
align-items: center;
}
.user, .assistant {
flex-grow: 1;
margin: 0.5rem;
}
ul {
list-style-type: disc; /* Adds bullet points */
padding-left: 1.25rem; /* Indents the list from the left margin */
margin-top: 1rem; /* Space above the list */
}
li {
font-size: 0.875rem; /* Small text size */
color: #4a5568; /* Dark gray text */
background-color: #f7fafc; /* Very light gray background */
border-radius: 0.375rem; /* Rounded corners */
padding: 0.5rem; /* Padding inside each list item */
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06); /* Subtle shadow */
margin-bottom: 0.5rem; /* Vertical space between list items */
}
li:last-child {
margin-bottom: 0; /* Removes bottom margin from the last item */
}

View File

@@ -1,96 +0,0 @@
/*
https://github.com/david-haerer/chatapi
MIT License
Copyright (c) 2023 David Härer
Copyright (c) 2024 Ettore Di Giacinto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
function submitKey(event) {
event.preventDefault();
localStorage.setItem("key", document.getElementById("apiKey").value);
document.getElementById("apiKey").blur();
}
function genImage(event) {
event.preventDefault();
const input = document.getElementById("input").value;
const key = localStorage.getItem("key");
promptDallE(key, input);
}
async function promptDallE(key, input) {
document.getElementById("loader").style.display = "block";
document.getElementById("input").value = "";
document.getElementById("input").disabled = true;
const model = document.getElementById("image-model").value;
const response = await fetch("/v1/images/generations", {
method: "POST",
headers: {
Authorization: `Bearer ${key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model: model,
steps: 10,
prompt: input,
n: 1,
size: "512x512",
}),
});
const json = await response.json();
if (json.error) {
// Display error if there is one
var div = document.getElementById('result'); // Get the div by its ID
div.innerHTML = '<p style="color:red;">' + json.error.message + '</p>';
return;
}
const url = json.data[0].url;
var div = document.getElementById('result'); // Get the div by its ID
var img = document.createElement('img'); // Create a new img element
img.src = url; // Set the source of the image
img.alt = 'Generated image'; // Set the alt text of the image
div.innerHTML = ''; // Clear the existing content of the div
div.appendChild(img); // Add the new img element to the div
document.getElementById("loader").style.display = "none";
document.getElementById("input").disabled = false;
document.getElementById("input").focus();
}
document.getElementById("key").addEventListener("submit", submitKey);
document.getElementById("input").focus();
document.getElementById("genimage").addEventListener("submit", genImage);
document.getElementById("loader").style.display = "none";
const storeKey = localStorage.getItem("key");
if (storeKey) {
document.getElementById("apiKey").value = storeKey;
}

View File

@@ -1,64 +0,0 @@
function submitKey(event) {
event.preventDefault();
localStorage.setItem("key", document.getElementById("apiKey").value);
document.getElementById("apiKey").blur();
}
function genAudio(event) {
event.preventDefault();
const input = document.getElementById("input").value;
const key = localStorage.getItem("key");
tts(key, input);
}
async function tts(key, input) {
document.getElementById("loader").style.display = "block";
document.getElementById("input").value = "";
document.getElementById("input").disabled = true;
const model = document.getElementById("tts-model").value;
const response = await fetch("/tts", {
method: "POST",
headers: {
Authorization: `Bearer ${key}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model: model,
input: input,
}),
});
if (!response.ok) {
const jsonData = await response.json(); // Now safely parse JSON
var div = document.getElementById('result');
div.innerHTML = '<p style="color:red;">Error: ' +jsonData.error.message + '</p>';
return;
}
var div = document.getElementById('result'); // Get the div by its ID
var link=document.createElement('a');
link.className = "m-2 float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong";
link.innerHTML = "<i class='fa-solid fa-download'></i> Download result";
const blob = await response.blob();
link.href=window.URL.createObjectURL(blob);
div.innerHTML = ''; // Clear the existing content of the div
div.appendChild(link); // Add the new img element to the div
console.log(link)
document.getElementById("loader").style.display = "none";
document.getElementById("input").disabled = false;
document.getElementById("input").focus();
}
document.getElementById("key").addEventListener("submit", submitKey);
document.getElementById("input").focus();
document.getElementById("tts").addEventListener("submit", genAudio);
document.getElementById("loader").style.display = "none";
const storeKey = localStorage.getItem("key");
if (storeKey) {
document.getElementById("apiKey").value = storeKey;
}

View File

@@ -1,202 +0,0 @@
<!--
Part of this page is based on the OpenAI Chatbot example by David Härer:
https://github.com/david-haerer/chatapi
MIT License Copyright (c) 2023 David Härer
Copyright (c) 2024 Ettore Di Giacinto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-->
<!doctype html>
<html lang="en">
{{template "views/partials/head" .}}
<script defer src="/static/chat.js"></script>
<style>
body {
overflow: hidden;
}
</style>
<body class="bg-gray-900 text-gray-200" x-data="{ key: $store.chat.key }">
<div class="flex flex-col min-h-screen">
{{template "views/partials/navbar"}}
<div class="chat-container mt-2 mr-2 ml-2 mb-2 bg-gray-800 shadow-lg rounded-lg" >
<!-- Chat Header -->
<div class="border-b border-gray-700 p-4" x-data="{ component: 'menu' }">
<div class="flex items-center justify-between">
<h1 class="text-lg font-semibold"> <i class="fa-solid fa-comments"></i> Chat with {{.Model}} <a href="https://localai.io/features/text-generation/" target="_blank" >
<i class="fas fa-circle-info pr-2"></i>
</a></h1>
<div x-show="component === 'menu'" id="menu">
<button
@click="$store.chat.clear()"
id="clear"
title="Clear chat history"
data-twe-ripple-init
data-twe-ripple-color="light"
class="m-2 float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong"
>
Clear chat 🔥
</button>
<button @click="component = 'key'" title="Update API key"
class="m-2 float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong"
>Set API Key🔑</button>
</div>
<form x-show="component === 'key'" id="key">
<input
type="password"
id="apiKey"
name="apiKey"
placeholder="OpenAI API Key"
x-model.lazy="key"
/>
<button @click="component = 'menu'" type="submit" title="Save API key">
🔒
</button>
</form>
<select x-data="{ link : '' }" x-model="link" x-init="$watch('link', value => window.location = link)"
class="bg-gray-800 text-white border border-gray-600 focus:border-blue-500 focus:ring focus:ring-blue-500 focus:ring-opacity-50 rounded-md shadow-sm p-2 appearance-none"
>
<!-- Options -->
<option value="" disabled class="text-gray-400" >Select a model</option>
{{ $model:=.Model}}
{{ range .ModelsConfig }}
{{ if eq .Name $model }}
<option value="/chat/{{.Name}}" selected class="bg-gray-700 text-white">{{.Name}}</option>
{{ else }}
<option value="/chat/{{.Name}}" class="bg-gray-700 text-white">{{.Name}}</option>
{{ end }}
{{ end }}
</select>
</div>
</div>
<div class="chat-messages p-4" id="chat" x-data="{history: $store.chat.history}">
<p id="usage" x-show="history.length === 0">
Start chatting with the AI by typing a prompt in the input field below.
</p>
<div id="messages">
<template x-for="message in history">
<div class="message flex items-start space-x-2 my-2" >
<!--<img :src="message.role === 'user' ? '/path/to/user-icon.png' : '/path/to/bot-icon.png'" alt="" class="h-6 w-6">-->
<i class="fa-solid h-8 w-8" :class="message.role === 'user' ? 'fa-user' : 'fa-robot'" ></i>
<div class="flex flex-col flex-1">
<span class="text-xs font-semibold text-gray-600" x-text="message.role === 'user' ? 'User' : 'Assistant ({{.Model}})'"></span>
<template x-if="message.role === 'user'">
<div class="p-2 flex-1 rounded" :class="message.role" x-html="message.html"></div>
</template>
<template x-if="message.role === 'assistant'">
<div class="p-2 flex-1 rounded" :class="message.role" x-html="message.html"></div>
</template>
</div>
</div>
</template>
</div>
</div>
<div class="p-4 border-t border-gray-700" x-data="{ inputValue: '', shiftPressed: false }">
<div id="loader" class="my-2 loader" style="display: none;"></div>
<input id="chat-model" type="hidden" value="{{.Model}}">
<form id="prompt" action="/chat/{{.Model}}" method="get" @submit.prevent="submitPrompt">
<div class="relative w-full">
<textarea
id="input"
name="input"
x-model="inputValue"
placeholder="Send a message..."
class="p-2 pl-2 border rounded w-full bg-gray-600 text-white placeholder-gray-300"
required
@keydown.shift="shiftPressed = true"
@keyup.shift="shiftPressed = false"
@keydown.enter="if (!shiftPressed) { submitPrompt($event); }"
style="padding-right: 4rem;"
></textarea>
<button type=submit><i class="fa-solid fa-circle-up text-gray-300 absolute right-2 top-3 text-lg p-2 ml-2"></i></button>
</div>
</form>
</div>
<script>
document.addEventListener("alpine:init", () => {
Alpine.store("chat", {
history: [],
languages: [undefined],
clear() {
this.history.length = 0;
},
add(role, content) {
const N = this.history.length - 1;
if (this.history.length && this.history[N].role === role) {
this.history[N].content += content;
str = this.history[N].content;
this.history[N].html = DOMPurify.sanitize(
marked.parse(this.history[N].content),
);
} else {
c = ""
// split content newlines in content
const lines = content.split("\n");
// for each line, do DOMPurify.sanitize(marked.parse(line)) and add it to c
lines.forEach((line) => {
c += DOMPurify.sanitize(marked.parse(line));
});
this.history.push({
role: role,
content: content,
html: c,
});
}
const parser = new DOMParser();
const html = parser.parseFromString(
this.history[this.history.length - 1].html,
"text/html",
);
const code = html.querySelectorAll("pre code");
if (!code.length) return;
code.forEach((el) => {
const language = el.className.split("language-")[1];
if (this.languages.includes(language)) return;
const script = document.createElement("script");
script.src = `https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/languages/${language}.min.js`;
document.head.appendChild(script);
this.languages.push(language);
});
},
messages() {
return this.history.map((message) => {
return {
role: message.role,
content: message.content,
};
});
},
});
});
</script>
</div>
</body>
</html>

View File

@@ -1,34 +0,0 @@
<!DOCTYPE html>
<html lang="en">
{{template "views/partials/head" .}}
<body class="bg-gray-900 text-gray-200">
<div class="flex flex-col min-h-screen">
{{template "views/partials/navbar" .}}
<div class="container mx-auto px-4 flex-grow">
<div class="models mt-12">
<h2 class="text-center text-3xl font-semibold text-gray-100">
🖼️ Available models from <i>{{ len .Repositories }}</i> repositories <a href="https://localai.io/models/" target="_blank" >
<i class="fas fa-circle-info pr-2"></i>
</a></h2>
<span class="htmx-indicator loader"></span>
<input class="form-control appearance-none block w-full px-3 py-2 text-base font-normal text-gray-300 pb-2 mb-5 bg-gray-800 bg-clip-padding border border-solid border-gray-600 rounded transition ease-in-out m-0 focus:text-gray-300 focus:bg-gray-900 focus:border-blue-500 focus:outline-none" type="search"
name="search" placeholder="Begin Typing To Search models..."
hx-post="/browse/search/models"
hx-trigger="input changed delay:500ms, search"
hx-target="#search-results"
hx-indicator=".htmx-indicator">
<div id="search-results">{{.Models}}</div>
</div>
</div>
{{template "views/partials/footer" .}}
</div>
</body>
</html>

Some files were not shown because too many files have changed in this diff Show More