mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-07 21:22:58 -05:00
Compare commits
3 Commits
v2.3.0
...
enable_gpu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8e91345e2 | ||
|
|
ea4ade6b60 | ||
|
|
796d0c99aa |
17
.env
17
.env
@@ -70,20 +70,3 @@ MODELS_PATH=/models
|
|||||||
|
|
||||||
### Define the number of parallel LLAMA.cpp workers (Defaults to 1)
|
### Define the number of parallel LLAMA.cpp workers (Defaults to 1)
|
||||||
# LLAMACPP_PARALLEL=1
|
# LLAMACPP_PARALLEL=1
|
||||||
|
|
||||||
### Enable to run parallel requests
|
|
||||||
# PARALLEL_REQUESTS=true
|
|
||||||
|
|
||||||
### Watchdog settings
|
|
||||||
###
|
|
||||||
# Enables watchdog to kill backends that are inactive for too much time
|
|
||||||
# WATCHDOG_IDLE=true
|
|
||||||
#
|
|
||||||
# Enables watchdog to kill backends that are busy for too much time
|
|
||||||
# WATCHDOG_BUSY=true
|
|
||||||
#
|
|
||||||
# Time in duration format (e.g. 1h30m) after which a backend is considered idle
|
|
||||||
# WATCHDOG_IDLE_TIMEOUT=5m
|
|
||||||
#
|
|
||||||
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
|
|
||||||
# WATCHDOG_BUSY_TIMEOUT=5m
|
|
||||||
7
.github/bump_docs.sh
vendored
7
.github/bump_docs.sh
vendored
@@ -1,7 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -xe
|
|
||||||
REPO=$1
|
|
||||||
|
|
||||||
LATEST_TAG=$(curl -s "https://api.github.com/repos/$REPO/releases/latest" | jq -r '.name')
|
|
||||||
|
|
||||||
cat <<< $(jq ".version = \"$LATEST_TAG\"" docs/data/version.json) > docs/data/version.json
|
|
||||||
31
.github/workflows/bump_docs.yaml
vendored
31
.github/workflows/bump_docs.yaml
vendored
@@ -1,31 +0,0 @@
|
|||||||
name: Bump dependencies
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: 0 20 * * *
|
|
||||||
workflow_dispatch:
|
|
||||||
jobs:
|
|
||||||
bump:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- repository: "mudler/LocalAI"
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Bump dependencies 🔧
|
|
||||||
run: |
|
|
||||||
bash .github/bump_docs.sh ${{ matrix.repository }}
|
|
||||||
- name: Create Pull Request
|
|
||||||
uses: peter-evans/create-pull-request@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
|
||||||
push-to-fork: ci-forks/LocalAI
|
|
||||||
commit-message: ':arrow_up: Update docs version ${{ matrix.repository }}'
|
|
||||||
title: ':arrow_up: Update docs version ${{ matrix.repository }}'
|
|
||||||
branch: "update/docs"
|
|
||||||
body: Bump of ${{ matrix.repository }} version inside docs
|
|
||||||
signoff: true
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
210
.github/workflows/image.yml
vendored
210
.github/workflows/image.yml
vendored
@@ -14,25 +14,8 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
extras-image-build:
|
image-build:
|
||||||
uses: ./.github/workflows/image_build.yml
|
|
||||||
with:
|
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
|
||||||
tag-suffix: ${{ matrix.tag-suffix }}
|
|
||||||
ffmpeg: ${{ matrix.ffmpeg }}
|
|
||||||
image-type: ${{ matrix.image-type }}
|
|
||||||
build-type: ${{ matrix.build-type }}
|
|
||||||
cuda-major-version: ${{ matrix.cuda-major-version }}
|
|
||||||
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
runs-on: ${{ matrix.runs-on }}
|
|
||||||
secrets:
|
|
||||||
dockerUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
|
||||||
dockerPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
|
||||||
strategy:
|
strategy:
|
||||||
# Pushing with all jobs in parallel
|
|
||||||
# eats the bandwidth of all the nodes
|
|
||||||
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
@@ -41,117 +24,130 @@ jobs:
|
|||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: ''
|
tag-suffix: ''
|
||||||
ffmpeg: ''
|
ffmpeg: ''
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: ''
|
- build-type: ''
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-ffmpeg'
|
tag-suffix: '-ffmpeg'
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: 11
|
||||||
cuda-minor-version: "7"
|
cuda-minor-version: 7
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-cublas-cuda11'
|
tag-suffix: '-cublas-cuda11'
|
||||||
ffmpeg: ''
|
ffmpeg: ''
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: 12
|
||||||
cuda-minor-version: "1"
|
cuda-minor-version: 1
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-cublas-cuda12'
|
tag-suffix: '-cublas-cuda12'
|
||||||
ffmpeg: ''
|
ffmpeg: ''
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: 11
|
||||||
cuda-minor-version: "7"
|
cuda-minor-version: 7
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-cublas-cuda11-ffmpeg'
|
tag-suffix: '-cublas-cuda11-ffmpeg'
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: 12
|
||||||
cuda-minor-version: "1"
|
cuda-minor-version: 1
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'false'
|
tag-latest: 'false'
|
||||||
tag-suffix: '-cublas-cuda12-ffmpeg'
|
tag-suffix: '-cublas-cuda12-ffmpeg'
|
||||||
ffmpeg: 'true'
|
ffmpeg: 'true'
|
||||||
image-type: 'extras'
|
|
||||||
runs-on: 'arc-runner-set'
|
runs-on: arc-runner-set
|
||||||
- build-type: ''
|
steps:
|
||||||
#platforms: 'linux/amd64,linux/arm64'
|
- name: Force Install GIT latest
|
||||||
platforms: 'linux/amd64'
|
run: |
|
||||||
tag-latest: 'auto'
|
sudo apt-get update \
|
||||||
tag-suffix: ''
|
&& sudo apt-get install -y software-properties-common \
|
||||||
ffmpeg: ''
|
&& sudo apt-get update \
|
||||||
image-type: 'extras'
|
&& sudo add-apt-repository -y ppa:git-core/ppa \
|
||||||
runs-on: 'arc-runner-set'
|
&& sudo apt-get update \
|
||||||
core-image-build:
|
&& sudo apt-get install -y git
|
||||||
uses: ./.github/workflows/image_build.yml
|
- name: Checkout
|
||||||
with:
|
uses: actions/checkout@v4
|
||||||
tag-latest: ${{ matrix.tag-latest }}
|
# - name: Release space from worker
|
||||||
tag-suffix: ${{ matrix.tag-suffix }}
|
# run: |
|
||||||
ffmpeg: ${{ matrix.ffmpeg }}
|
# echo "Listing top largest packages"
|
||||||
image-type: ${{ matrix.image-type }}
|
# pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||||
build-type: ${{ matrix.build-type }}
|
# head -n 30 <<< "${pkgs}"
|
||||||
cuda-major-version: ${{ matrix.cuda-major-version }}
|
# echo
|
||||||
cuda-minor-version: ${{ matrix.cuda-minor-version }}
|
# df -h
|
||||||
platforms: ${{ matrix.platforms }}
|
# echo
|
||||||
runs-on: ${{ matrix.runs-on }}
|
# sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
|
||||||
secrets:
|
# sudo apt-get remove --auto-remove android-sdk-platform-tools || true
|
||||||
dockerUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
# sudo apt-get purge --auto-remove android-sdk-platform-tools || true
|
||||||
dockerPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
# sudo rm -rf /usr/local/lib/android
|
||||||
strategy:
|
# sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
|
||||||
matrix:
|
# sudo rm -rf /usr/share/dotnet
|
||||||
include:
|
# sudo apt-get remove -y '^mono-.*' || true
|
||||||
- build-type: ''
|
# sudo apt-get remove -y '^ghc-.*' || true
|
||||||
platforms: 'linux/amd64'
|
# sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
|
||||||
tag-latest: 'false'
|
# sudo apt-get remove -y 'php.*' || true
|
||||||
tag-suffix: '-ffmpeg-core'
|
# sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
|
||||||
ffmpeg: 'true'
|
# sudo apt-get remove -y '^google-.*' || true
|
||||||
image-type: 'core'
|
# sudo apt-get remove -y azure-cli || true
|
||||||
runs-on: 'ubuntu-latest'
|
# sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
|
||||||
- build-type: 'cublas'
|
# sudo apt-get remove -y '^gfortran-.*' || true
|
||||||
cuda-major-version: "11"
|
# sudo apt-get remove -y microsoft-edge-stable || true
|
||||||
cuda-minor-version: "7"
|
# sudo apt-get remove -y firefox || true
|
||||||
platforms: 'linux/amd64'
|
# sudo apt-get remove -y powershell || true
|
||||||
tag-latest: 'false'
|
# sudo apt-get remove -y r-base-core || true
|
||||||
tag-suffix: '-cublas-cuda11-core'
|
# sudo apt-get autoremove -y
|
||||||
ffmpeg: ''
|
# sudo apt-get clean
|
||||||
image-type: 'core'
|
# echo
|
||||||
runs-on: 'ubuntu-latest'
|
# echo "Listing top largest packages"
|
||||||
- build-type: 'cublas'
|
# pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||||
cuda-major-version: "12"
|
# head -n 30 <<< "${pkgs}"
|
||||||
cuda-minor-version: "1"
|
# echo
|
||||||
platforms: 'linux/amd64'
|
# sudo rm -rfv build || true
|
||||||
tag-latest: 'false'
|
# df -h
|
||||||
tag-suffix: '-cublas-cuda12-core'
|
- name: Docker meta
|
||||||
ffmpeg: ''
|
id: meta
|
||||||
image-type: 'core'
|
uses: docker/metadata-action@v5
|
||||||
runs-on: 'ubuntu-latest'
|
with:
|
||||||
- build-type: 'cublas'
|
images: quay.io/go-skynet/local-ai
|
||||||
cuda-major-version: "11"
|
tags: |
|
||||||
cuda-minor-version: "7"
|
type=ref,event=branch
|
||||||
platforms: 'linux/amd64'
|
type=semver,pattern={{raw}}
|
||||||
tag-latest: 'false'
|
type=sha
|
||||||
tag-suffix: '-cublas-cuda11-ffmpeg-core'
|
flavor: |
|
||||||
ffmpeg: 'true'
|
latest=${{ matrix.tag-latest }}
|
||||||
image-type: 'core'
|
suffix=${{ matrix.tag-suffix }}
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
- build-type: 'cublas'
|
- name: Set up QEMU
|
||||||
cuda-major-version: "12"
|
uses: docker/setup-qemu-action@master
|
||||||
cuda-minor-version: "1"
|
with:
|
||||||
platforms: 'linux/amd64'
|
platforms: all
|
||||||
tag-latest: 'false'
|
|
||||||
tag-suffix: '-cublas-cuda12-ffmpeg-core'
|
- name: Set up Docker Buildx
|
||||||
ffmpeg: 'true'
|
id: buildx
|
||||||
image-type: 'core'
|
uses: docker/setup-buildx-action@master
|
||||||
runs-on: 'ubuntu-latest'
|
|
||||||
|
- name: Login to DockerHub
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: quay.io
|
||||||
|
username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
builder: ${{ steps.buildx.outputs.name }}
|
||||||
|
build-args: |
|
||||||
|
BUILD_TYPE=${{ matrix.build-type }}
|
||||||
|
CUDA_MAJOR_VERSION=${{ matrix.cuda-major-version }}
|
||||||
|
CUDA_MINOR_VERSION=${{ matrix.cuda-minor-version }}
|
||||||
|
FFMPEG=${{ matrix.ffmpeg }}
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
|||||||
147
.github/workflows/image_build.yml
vendored
147
.github/workflows/image_build.yml
vendored
@@ -1,147 +0,0 @@
|
|||||||
---
|
|
||||||
name: 'build container images (reusable)'
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
build-type:
|
|
||||||
description: 'Build type'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
cuda-major-version:
|
|
||||||
description: 'CUDA major version'
|
|
||||||
default: "11"
|
|
||||||
type: string
|
|
||||||
cuda-minor-version:
|
|
||||||
description: 'CUDA minor version'
|
|
||||||
default: "7"
|
|
||||||
type: string
|
|
||||||
platforms:
|
|
||||||
description: 'Platforms'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
tag-latest:
|
|
||||||
description: 'Tag latest'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
tag-suffix:
|
|
||||||
description: 'Tag suffix'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
ffmpeg:
|
|
||||||
description: 'FFMPEG'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
image-type:
|
|
||||||
description: 'Image type'
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
runs-on:
|
|
||||||
description: 'Runs on'
|
|
||||||
required: true
|
|
||||||
default: ''
|
|
||||||
type: string
|
|
||||||
secrets:
|
|
||||||
dockerUsername:
|
|
||||||
required: true
|
|
||||||
dockerPassword:
|
|
||||||
required: true
|
|
||||||
jobs:
|
|
||||||
reusable_image-build:
|
|
||||||
runs-on: ${{ inputs.runs-on }}
|
|
||||||
steps:
|
|
||||||
- name: Force Install GIT latest
|
|
||||||
run: |
|
|
||||||
sudo apt-get update \
|
|
||||||
&& sudo apt-get install -y software-properties-common \
|
|
||||||
&& sudo apt-get update \
|
|
||||||
&& sudo add-apt-repository -y ppa:git-core/ppa \
|
|
||||||
&& sudo apt-get update \
|
|
||||||
&& sudo apt-get install -y git
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
# - name: Release space from worker
|
|
||||||
# run: |
|
|
||||||
# echo "Listing top largest packages"
|
|
||||||
# pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
|
||||||
# head -n 30 <<< "${pkgs}"
|
|
||||||
# echo
|
|
||||||
# df -h
|
|
||||||
# echo
|
|
||||||
# sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
|
|
||||||
# sudo apt-get remove --auto-remove android-sdk-platform-tools || true
|
|
||||||
# sudo apt-get purge --auto-remove android-sdk-platform-tools || true
|
|
||||||
# sudo rm -rf /usr/local/lib/android
|
|
||||||
# sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
|
|
||||||
# sudo rm -rf /usr/share/dotnet
|
|
||||||
# sudo apt-get remove -y '^mono-.*' || true
|
|
||||||
# sudo apt-get remove -y '^ghc-.*' || true
|
|
||||||
# sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
|
|
||||||
# sudo apt-get remove -y 'php.*' || true
|
|
||||||
# sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
|
|
||||||
# sudo apt-get remove -y '^google-.*' || true
|
|
||||||
# sudo apt-get remove -y azure-cli || true
|
|
||||||
# sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
|
|
||||||
# sudo apt-get remove -y '^gfortran-.*' || true
|
|
||||||
# sudo apt-get remove -y microsoft-edge-stable || true
|
|
||||||
# sudo apt-get remove -y firefox || true
|
|
||||||
# sudo apt-get remove -y powershell || true
|
|
||||||
# sudo apt-get remove -y r-base-core || true
|
|
||||||
# sudo apt-get autoremove -y
|
|
||||||
# sudo apt-get clean
|
|
||||||
# echo
|
|
||||||
# echo "Listing top largest packages"
|
|
||||||
# pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
|
||||||
# head -n 30 <<< "${pkgs}"
|
|
||||||
# echo
|
|
||||||
# sudo rm -rfv build || true
|
|
||||||
# df -h
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: quay.io/go-skynet/local-ai
|
|
||||||
tags: |
|
|
||||||
type=ref,event=branch
|
|
||||||
type=semver,pattern={{raw}}
|
|
||||||
type=sha
|
|
||||||
flavor: |
|
|
||||||
latest=${{ inputs.tag-latest }}
|
|
||||||
suffix=${{ inputs.tag-suffix }}
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@master
|
|
||||||
with:
|
|
||||||
platforms: all
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
id: buildx
|
|
||||||
uses: docker/setup-buildx-action@master
|
|
||||||
|
|
||||||
- name: Login to DockerHub
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: quay.io
|
|
||||||
username: ${{ secrets.dockerUsername }}
|
|
||||||
password: ${{ secrets.dockerPassword }}
|
|
||||||
|
|
||||||
- name: Build and push
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
builder: ${{ steps.buildx.outputs.name }}
|
|
||||||
build-args: |
|
|
||||||
BUILD_TYPE=${{ inputs.build-type }}
|
|
||||||
CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }}
|
|
||||||
CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }}
|
|
||||||
FFMPEG=${{ inputs.ffmpeg }}
|
|
||||||
IMAGE_TYPE=${{ inputs.image-type }}
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile
|
|
||||||
platforms: ${{ inputs.platforms }}
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
- name: job summary
|
|
||||||
run: |
|
|
||||||
echo "Built image: ${{ steps.meta.outputs.labels }}" >> $GITHUB_STEP_SUMMARY
|
|
||||||
9
.github/workflows/release.yaml
vendored
9
.github/workflows/release.yaml
vendored
@@ -5,10 +5,6 @@ on: push
|
|||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ci-releases-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-linux:
|
build-linux:
|
||||||
strategy:
|
strategy:
|
||||||
@@ -78,7 +74,10 @@ jobs:
|
|||||||
go-version: '>=1.21.0'
|
go-version: '>=1.21.0'
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
run: |
|
run: |
|
||||||
brew install protobuf grpc
|
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||||
|
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
|
||||||
|
-DgRPC_BUILD_TESTS=OFF \
|
||||||
|
../.. && make -j12 install && rm -rf grpc
|
||||||
- name: Build
|
- name: Build
|
||||||
id: build
|
id: build
|
||||||
env:
|
env:
|
||||||
|
|||||||
277
.github/workflows/test-extra.yml
vendored
277
.github/workflows/test-extra.yml
vendored
@@ -1,277 +0,0 @@
|
|||||||
---
|
|
||||||
name: 'Tests extras backends'
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
tags:
|
|
||||||
- '*'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ci-tests-extra-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
tests-transformers:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test transformers
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/transformers
|
|
||||||
make -C backend/python/transformers test
|
|
||||||
|
|
||||||
tests-sentencetransformers:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test sentencetransformers
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/sentencetransformers
|
|
||||||
make -C backend/python/sentencetransformers test
|
|
||||||
|
|
||||||
tests-diffusers:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test diffusers
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/diffusers
|
|
||||||
make -C backend/python/diffusers test
|
|
||||||
|
|
||||||
|
|
||||||
tests-transformers-musicgen:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test transformers-musicgen
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/transformers-musicgen
|
|
||||||
make -C backend/python/transformers-musicgen test
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tests-petals:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test petals
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/petals
|
|
||||||
make -C backend/python/petals test
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tests-bark:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test bark
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/bark
|
|
||||||
make -C backend/python/bark test
|
|
||||||
|
|
||||||
|
|
||||||
# Below tests needs GPU. Commented out for now
|
|
||||||
# TODO: Re-enable as soon as we have GPU nodes
|
|
||||||
# tests-vllm:
|
|
||||||
# runs-on: ubuntu-latest
|
|
||||||
# steps:
|
|
||||||
# - name: Clone
|
|
||||||
# uses: actions/checkout@v4
|
|
||||||
# with:
|
|
||||||
# submodules: true
|
|
||||||
# - name: Dependencies
|
|
||||||
# run: |
|
|
||||||
# sudo apt-get update
|
|
||||||
# sudo apt-get install build-essential ffmpeg
|
|
||||||
# curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
# sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
# gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
# sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
# sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
# sudo apt-get update && \
|
|
||||||
# sudo apt-get install -y conda
|
|
||||||
# sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
# sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
# sudo rm -rfv /usr/bin/conda || true
|
|
||||||
# - name: Test vllm
|
|
||||||
# run: |
|
|
||||||
# export PATH=$PATH:/opt/conda/bin
|
|
||||||
# make -C backend/python/vllm
|
|
||||||
# make -C backend/python/vllm test
|
|
||||||
tests-vallex:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch
|
|
||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
- name: Test vall-e-x
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/vall-e-x
|
|
||||||
make -C backend/python/vall-e-x test
|
|
||||||
|
|
||||||
tests-coqui:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Clone
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: true
|
|
||||||
- name: Dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install build-essential ffmpeg
|
|
||||||
curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \
|
|
||||||
sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \
|
|
||||||
gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \
|
|
||||||
sudo apt-get update && \
|
|
||||||
sudo apt-get install -y conda
|
|
||||||
sudo apt-get install -y ca-certificates cmake curl patch espeak espeak-ng
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
|
||||||
|
|
||||||
- name: Test coqui
|
|
||||||
run: |
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
make -C backend/python/coqui
|
|
||||||
make -C backend/python/coqui test
|
|
||||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -78,12 +78,13 @@ jobs:
|
|||||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||||
|
|
||||||
sudo rm -rfv /usr/bin/conda || true
|
sudo rm -rfv /usr/bin/conda || true
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/huggingface
|
||||||
|
|
||||||
# Pre-build piper before we start tests in order to have shared libraries in place
|
# Pre-build piper before we start tests in order to have shared libraries in place
|
||||||
make sources/go-piper && \
|
make go-piper && \
|
||||||
GO_TAGS="tts" make -C sources/go-piper piper.o && \
|
GO_TAGS="tts" make -C go-piper piper.o && \
|
||||||
sudo cp -rfv sources/go-piper/piper-phonemize/pi/lib/. /usr/lib/ && \
|
sudo cp -rfv go-piper/piper/build/pi/lib/. /usr/lib/ && \
|
||||||
|
|
||||||
# Pre-build stable diffusion before we install a newer version of abseil (not compatible with stablediffusion-ncn)
|
# Pre-build stable diffusion before we install a newer version of abseil (not compatible with stablediffusion-ncn)
|
||||||
GO_TAGS="stablediffusion tts" GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
GO_TAGS="stablediffusion tts" GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
||||||
|
|
||||||
@@ -114,7 +115,10 @@ jobs:
|
|||||||
run: go version
|
run: go version
|
||||||
- name: Dependencies
|
- name: Dependencies
|
||||||
run: |
|
run: |
|
||||||
brew install protobuf grpc
|
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||||
|
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
|
||||||
|
-DgRPC_BUILD_TESTS=OFF \
|
||||||
|
../.. && make -j12 install && rm -rf grpc
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
export C_INCLUDE_PATH=/usr/local/include
|
export C_INCLUDE_PATH=/usr/local/include
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,9 +1,15 @@
|
|||||||
# go-llama build artifacts
|
# go-llama build artifacts
|
||||||
/sources/
|
go-llama
|
||||||
|
go-llama-stable
|
||||||
|
/gpt4all
|
||||||
|
go-stable-diffusion
|
||||||
|
go-piper
|
||||||
|
/go-bert
|
||||||
|
go-ggllm
|
||||||
|
/piper
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.a
|
*.a
|
||||||
get-sources
|
get-sources
|
||||||
prepare-sources
|
|
||||||
/backend/cpp/llama/grpc-server
|
/backend/cpp/llama/grpc-server
|
||||||
/backend/cpp/llama/llama.cpp
|
/backend/cpp/llama/llama.cpp
|
||||||
|
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
|||||||
[submodule "docs/themes/hugo-theme-relearn"]
|
|
||||||
path = docs/themes/hugo-theme-relearn
|
|
||||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
|
||||||
81
Dockerfile
81
Dockerfile
@@ -12,11 +12,9 @@ ARG TARGETARCH
|
|||||||
ARG TARGETVARIANT
|
ARG TARGETVARIANT
|
||||||
|
|
||||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||||
|
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/run.sh,autogptq:/build/extra/grpc/autogptq/run.sh,bark:/build/extra/grpc/bark/run.sh,diffusers:/build/extra/grpc/diffusers/run.sh,exllama:/build/extra/grpc/exllama/run.sh,vall-e-x:/build/extra/grpc/vall-e-x/run.sh,vllm:/build/extra/grpc/vllm/run.sh"
|
||||||
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"
|
|
||||||
|
|
||||||
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
||||||
ARG GO_TAGS="stablediffusion tinydream tts"
|
ARG GO_TAGS="stablediffusion tts"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y ca-certificates curl patch pip cmake && apt-get clean
|
apt-get install -y ca-certificates curl patch pip cmake && apt-get clean
|
||||||
@@ -66,10 +64,23 @@ RUN curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmo
|
|||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y conda
|
apt-get install -y conda
|
||||||
|
|
||||||
|
COPY extra/requirements.txt /build/extra/requirements.txt
|
||||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
RUN pip install --upgrade pip
|
RUN pip install --upgrade pip
|
||||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
RUN apt-get install -y espeak-ng espeak
|
#RUN if [ "${TARGETARCH}" = "amd64" ]; then \
|
||||||
|
# pip install git+https://github.com/suno-ai/bark.git diffusers invisible_watermark transformers accelerate safetensors;\
|
||||||
|
# fi
|
||||||
|
#RUN if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "amd64" ]; then \
|
||||||
|
# pip install torch vllm && pip install auto-gptq https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}-cp39-cp39-linux_x86_64.whl;\
|
||||||
|
# fi
|
||||||
|
#RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt
|
||||||
|
|
||||||
|
# Vall-e-X
|
||||||
|
RUN git clone https://github.com/Plachtaa/VALL-E-X.git /usr/lib/vall-e-x && cd /usr/lib/vall-e-x && pip install -r requirements.txt
|
||||||
|
|
||||||
|
# \
|
||||||
|
# ; fi
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
###################################
|
###################################
|
||||||
@@ -87,9 +98,12 @@ ENV NVIDIA_VISIBLE_DEVICES=all
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY Makefile .
|
||||||
|
RUN make get-sources
|
||||||
|
COPY go.mod .
|
||||||
|
RUN make prepare
|
||||||
COPY . .
|
COPY . .
|
||||||
COPY .git .
|
COPY .git .
|
||||||
RUN make prepare
|
|
||||||
|
|
||||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
# stablediffusion does not tolerate a newer version of abseil, build it first
|
||||||
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
|
||||||
@@ -98,15 +112,15 @@ RUN if [ "${BUILD_GRPC}" = "true" ]; then \
|
|||||||
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
|
||||||
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
|
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
|
||||||
-DgRPC_BUILD_TESTS=OFF \
|
-DgRPC_BUILD_TESTS=OFF \
|
||||||
../.. && make -j12 install \
|
../.. && make -j12 install && rm -rf grpc \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
# Rebuild with defaults backends
|
# Rebuild with defaults backends
|
||||||
RUN make build
|
RUN make build
|
||||||
|
|
||||||
RUN if [ ! -d "/build/sources/go-piper/piper-phonemize/pi/lib/" ]; then \
|
RUN if [ ! -d "/build/go-piper/piper/build/pi/lib/" ]; then \
|
||||||
mkdir -p /build/sources/go-piper/piper-phonemize/pi/lib/ \
|
mkdir -p /build/go-piper/piper/build/pi/lib/ \
|
||||||
touch /build/sources/go-piper/piper-phonemize/pi/lib/keep \
|
touch /build/go-piper/piper/build/pi/lib/keep \
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
@@ -140,59 +154,50 @@ WORKDIR /build
|
|||||||
# see https://github.com/go-skynet/LocalAI/pull/658#discussion_r1241971626 and
|
# see https://github.com/go-skynet/LocalAI/pull/658#discussion_r1241971626 and
|
||||||
# https://github.com/go-skynet/LocalAI/pull/434
|
# https://github.com/go-skynet/LocalAI/pull/434
|
||||||
COPY . .
|
COPY . .
|
||||||
|
RUN make prepare-sources
|
||||||
COPY --from=builder /build/sources ./sources/
|
|
||||||
COPY --from=builder /build/grpc ./grpc/
|
|
||||||
|
|
||||||
RUN make prepare-sources && cd /build/grpc/cmake/build && make install && rm -rf grpc
|
|
||||||
|
|
||||||
# Copy the binary
|
# Copy the binary
|
||||||
COPY --from=builder /build/local-ai ./
|
COPY --from=builder /build/local-ai ./
|
||||||
|
|
||||||
# Copy shared libraries for piper
|
# Copy shared libraries for piper
|
||||||
COPY --from=builder /build/sources/go-piper/piper-phonemize/pi/lib/* /usr/lib/
|
COPY --from=builder /build/go-piper/piper/build/pi/lib/* /usr/lib/
|
||||||
|
|
||||||
# do not let stablediffusion rebuild (requires an older version of absl)
|
# do not let stablediffusion rebuild (requires an older version of absl)
|
||||||
COPY --from=builder /build/backend-assets/grpc/stablediffusion ./backend-assets/grpc/stablediffusion
|
COPY --from=builder /build/backend-assets/grpc/stablediffusion ./backend-assets/grpc/stablediffusion
|
||||||
|
|
||||||
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
## Duplicated from Makefile to avoid having a big layer that's hard to push
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/autogptq \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/autogptq \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/bark \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/bark \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/diffusers \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/vllm \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/huggingface \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/vall-e-x \
|
||||||
; fi
|
; fi
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
|
PATH=$PATH:/opt/conda/bin make -C extra/grpc/exllama \
|
||||||
; fi
|
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama \
|
|
||||||
; fi
|
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama2 \
|
|
||||||
; fi
|
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/petals \
|
|
||||||
; fi
|
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers-musicgen \
|
|
||||||
; fi
|
|
||||||
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
|
|
||||||
PATH=$PATH:/opt/conda/bin make -C backend/python/coqui \
|
|
||||||
; fi
|
; fi
|
||||||
|
|
||||||
|
# Copy VALLE-X as it's not a real "lib"
|
||||||
|
RUN if [ -d /usr/lib/vall-e-x ]; then \
|
||||||
|
cp -rfv /usr/lib/vall-e-x/* ./ ; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# we also copy exllama libs over to resolve exllama import error
|
||||||
|
RUN if [ -d /usr/local/lib/python3.9/dist-packages/exllama ]; then \
|
||||||
|
cp -rfv /usr/local/lib/python3.9/dist-packages/exllama extra/grpc/exllama/;\
|
||||||
|
fi
|
||||||
|
|
||||||
# Define the health check command
|
# Define the health check command
|
||||||
HEALTHCHECK --interval=1m --timeout=10m --retries=10 \
|
HEALTHCHECK --interval=1m --timeout=10m --retries=10 \
|
||||||
CMD curl -f $HEALTHCHECK_ENDPOINT || exit 1
|
CMD curl -f $HEALTHCHECK_ENDPOINT || exit 1
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
||||||
<plist version="1.0">
|
|
||||||
<dict>
|
|
||||||
<key>com.apple.security.network.client</key>
|
|
||||||
<true/>
|
|
||||||
<key>com.apple.security.network.server</key>
|
|
||||||
<true/>
|
|
||||||
</dict>
|
|
||||||
</plist>
|
|
||||||
380
Makefile
380
Makefile
@@ -8,7 +8,7 @@ GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0
|
|||||||
|
|
||||||
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
||||||
|
|
||||||
CPPLLAMA_VERSION?=65e5f6dadbba4b496bba27f573e473c66b446496
|
CPPLLAMA_VERSION?=a75fa576abba9d37f463580c379e4bbf1e1ad03c
|
||||||
|
|
||||||
# gpt4all version
|
# gpt4all version
|
||||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||||
@@ -19,27 +19,23 @@ GOGGMLTRANSFORMERS_VERSION?=ffb09d7dd71e2cbc6c5d7d05357d230eea6f369a
|
|||||||
|
|
||||||
# go-rwkv version
|
# go-rwkv version
|
||||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||||
RWKV_VERSION?=633c5a3485c403cb2520693dc0991a25dace9f0f
|
RWKV_VERSION?=c898cd0f62df8f2a7830e53d1d513bef4f6f792b
|
||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
|
WHISPER_CPP_VERSION?=85ed71aaec8e0612a84c0b67804bde75aa75a273
|
||||||
|
|
||||||
# bert.cpp version
|
# bert.cpp version
|
||||||
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
||||||
|
|
||||||
# go-piper version
|
# go-piper version
|
||||||
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
|
PIPER_VERSION?=736f6fb639ab8e3397356e48eeb6bdcb9da88a78
|
||||||
|
|
||||||
# stablediffusion version
|
# stablediffusion version
|
||||||
STABLEDIFFUSION_VERSION?=902db5f066fd137697e3b69d0fa10d4782bd2c2f
|
STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632
|
||||||
|
|
||||||
# tinydream version
|
|
||||||
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
|
|
||||||
|
|
||||||
export BUILD_TYPE?=
|
export BUILD_TYPE?=
|
||||||
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
|
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
|
||||||
export CMAKE_ARGS?=
|
export CMAKE_ARGS?=
|
||||||
|
|
||||||
CGO_LDFLAGS?=
|
CGO_LDFLAGS?=
|
||||||
CUDA_LIBPATH?=/usr/local/cuda/lib64/
|
CUDA_LIBPATH?=/usr/local/cuda/lib64/
|
||||||
GO_TAGS?=
|
GO_TAGS?=
|
||||||
@@ -72,39 +68,29 @@ ifndef UNAME_S
|
|||||||
UNAME_S := $(shell uname -s)
|
UNAME_S := $(shell uname -s)
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(OS),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
CGO_LDFLAGS += -lcblas -framework Accelerate
|
CGO_LDFLAGS += -lcblas -framework Accelerate
|
||||||
ifeq ($(OSX_SIGNING_IDENTITY),)
|
ifneq ($(BUILD_TYPE),metal)
|
||||||
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
|
# explicit disable metal if on Darwin and metal is disabled
|
||||||
endif
|
CMAKE_ARGS+=-DLLAMA_METAL=OFF
|
||||||
|
endif
|
||||||
# on OSX, if BUILD_TYPE is blank, we should default to use Metal
|
|
||||||
ifeq ($(BUILD_TYPE),)
|
|
||||||
BUILD_TYPE=metal
|
|
||||||
# disable metal if on Darwin and any other value is explicitly passed.
|
|
||||||
else ifneq ($(BUILD_TYPE),metal)
|
|
||||||
CMAKE_ARGS+=-DLLAMA_METAL=OFF
|
|
||||||
endif
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE),openblas)
|
ifeq ($(BUILD_TYPE),openblas)
|
||||||
CGO_LDFLAGS+=-lopenblas
|
CGO_LDFLAGS+=-lopenblas
|
||||||
export WHISPER_OPENBLAS=1
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE),cublas)
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH)
|
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH)
|
||||||
export LLAMA_CUBLAS=1
|
export LLAMA_CUBLAS=1
|
||||||
export WHISPER_CUBLAS=1
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE),hipblas)
|
ifeq ($(BUILD_TYPE),hipblas)
|
||||||
ROCM_HOME ?= /opt/rocm
|
ROCM_HOME ?= /opt/rocm
|
||||||
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
export CXX=$(ROCM_HOME)/llvm/bin/clang++
|
||||||
export CC=$(ROCM_HOME)/llvm/bin/clang
|
export CC=$(ROCM_HOME)/llvm/bin/clang
|
||||||
# llama-ggml has no hipblas support, so override it here.
|
# Llama-stable has no hipblas support, so override it here.
|
||||||
export STABLE_BUILD_TYPE=
|
export STABLE_BUILD_TYPE=
|
||||||
export WHISPER_HIPBLAS=1
|
|
||||||
GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100
|
GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100
|
||||||
AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
|
AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
|
||||||
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
|
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
|
||||||
@@ -114,12 +100,10 @@ endif
|
|||||||
ifeq ($(BUILD_TYPE),metal)
|
ifeq ($(BUILD_TYPE),metal)
|
||||||
CGO_LDFLAGS+=-framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
CGO_LDFLAGS+=-framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||||
export LLAMA_METAL=1
|
export LLAMA_METAL=1
|
||||||
export WHISPER_METAL=1
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE),clblas)
|
ifeq ($(BUILD_TYPE),clblas)
|
||||||
CGO_LDFLAGS+=-lOpenCL -lclblast
|
CGO_LDFLAGS+=-lOpenCL -lclblast
|
||||||
export WHISPER_CLBLAST=1
|
|
||||||
endif
|
endif
|
||||||
|
|
||||||
# glibc-static or glibc-devel-static required
|
# glibc-static or glibc-devel-static required
|
||||||
@@ -132,20 +116,15 @@ ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion)
|
|||||||
OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion
|
OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(findstring tinydream,$(GO_TAGS)),tinydream)
|
|
||||||
# OPTIONAL_TARGETS+=go-tiny-dream/libtinydream.a
|
|
||||||
OPTIONAL_GRPC+=backend-assets/grpc/tinydream
|
|
||||||
endif
|
|
||||||
|
|
||||||
ifeq ($(findstring tts,$(GO_TAGS)),tts)
|
ifeq ($(findstring tts,$(GO_TAGS)),tts)
|
||||||
# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
|
# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
|
||||||
# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
|
# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
|
||||||
PIPER_CGO_CXXFLAGS+=-I$(shell pwd)/sources/go-piper/piper/src/cpp -I$(shell pwd)/sources/go-piper/piper/build/fi/include -I$(shell pwd)/sources/go-piper/piper/build/pi/include -I$(shell pwd)/sources/go-piper/piper/build/si/include
|
PIPER_CGO_CXXFLAGS+=-I$(shell pwd)/go-piper/piper/src/cpp -I$(shell pwd)/go-piper/piper/build/fi/include -I$(shell pwd)/go-piper/piper/build/pi/include -I$(shell pwd)/go-piper/piper/build/si/include
|
||||||
PIPER_CGO_LDFLAGS+=-L$(shell pwd)/sources/go-piper/piper/build/fi/lib -L$(shell pwd)/sources/go-piper/piper/build/pi/lib -L$(shell pwd)/sources/go-piper/piper/build/si/lib -lfmt -lspdlog -lucd
|
PIPER_CGO_LDFLAGS+=-L$(shell pwd)/go-piper/piper/build/fi/lib -L$(shell pwd)/go-piper/piper/build/pi/lib -L$(shell pwd)/go-piper/piper/build/si/lib -lfmt -lspdlog
|
||||||
OPTIONAL_GRPC+=backend-assets/grpc/piper
|
OPTIONAL_GRPC+=backend-assets/grpc/piper
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ALL_GRPC_BACKENDS=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/llama backend-assets/grpc/llama-cpp backend-assets/grpc/llama-ggml backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
|
ALL_GRPC_BACKENDS=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/llama backend-assets/grpc/llama-cpp backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
|
||||||
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
|
||||||
|
|
||||||
# If empty, then we build all
|
# If empty, then we build all
|
||||||
@@ -158,127 +137,112 @@ endif
|
|||||||
all: help
|
all: help
|
||||||
|
|
||||||
## GPT4ALL
|
## GPT4ALL
|
||||||
sources/gpt4all:
|
gpt4all:
|
||||||
git clone --recurse-submodules $(GPT4ALL_REPO) sources/gpt4all
|
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
|
||||||
cd sources/gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
|
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
## go-piper
|
## go-piper
|
||||||
sources/go-piper:
|
go-piper:
|
||||||
git clone --recurse-submodules https://github.com/mudler/go-piper sources/go-piper
|
git clone --recurse-submodules https://github.com/mudler/go-piper go-piper
|
||||||
cd sources/go-piper && git checkout -b build $(PIPER_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-piper && git checkout -b build $(PIPER_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
## BERT embeddings
|
## BERT embeddings
|
||||||
sources/go-bert:
|
go-bert:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp sources/go-bert
|
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert
|
||||||
cd sources/go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
## stable diffusion
|
## stable diffusion
|
||||||
sources/go-stable-diffusion:
|
go-stable-diffusion:
|
||||||
git clone --recurse-submodules https://github.com/mudler/go-stable-diffusion sources/go-stable-diffusion
|
git clone --recurse-submodules https://github.com/mudler/go-stable-diffusion go-stable-diffusion
|
||||||
cd sources/go-stable-diffusion && git checkout -b build $(STABLEDIFFUSION_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-stable-diffusion && git checkout -b build $(STABLEDIFFUSION_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/go-stable-diffusion/libstablediffusion.a:
|
go-stable-diffusion/libstablediffusion.a:
|
||||||
$(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
|
$(MAKE) -C go-stable-diffusion libstablediffusion.a
|
||||||
|
|
||||||
## tiny-dream
|
|
||||||
sources/go-tiny-dream:
|
|
||||||
git clone --recurse-submodules https://github.com/M0Rf30/go-tiny-dream sources/go-tiny-dream
|
|
||||||
cd sources/go-tiny-dream && git checkout -b build $(TINYDREAM_VERSION) && git submodule update --init --recursive --depth 1
|
|
||||||
|
|
||||||
sources/go-tiny-dream/libtinydream.a:
|
|
||||||
$(MAKE) -C sources/go-tiny-dream libtinydream.a
|
|
||||||
|
|
||||||
## RWKV
|
## RWKV
|
||||||
sources/go-rwkv:
|
go-rwkv:
|
||||||
git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv
|
git clone --recurse-submodules $(RWKV_REPO) go-rwkv
|
||||||
cd sources/go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/go-rwkv/librwkv.a: sources/go-rwkv
|
go-rwkv/librwkv.a: go-rwkv
|
||||||
cd sources/go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
|
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
|
||||||
|
|
||||||
sources/go-bert/libgobert.a: sources/go-bert
|
go-bert/libgobert.a: go-bert
|
||||||
$(MAKE) -C sources/go-bert libgobert.a
|
$(MAKE) -C go-bert libgobert.a
|
||||||
|
|
||||||
backend-assets/gpt4all: sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
backend-assets/gpt4all: gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
||||||
mkdir -p backend-assets/gpt4all
|
mkdir -p backend-assets/gpt4all
|
||||||
@cp sources/gpt4all/gpt4all-bindings/golang/buildllm/*.so backend-assets/gpt4all/ || true
|
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.so backend-assets/gpt4all/ || true
|
||||||
@cp sources/gpt4all/gpt4all-bindings/golang/buildllm/*.dylib backend-assets/gpt4all/ || true
|
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dylib backend-assets/gpt4all/ || true
|
||||||
@cp sources/gpt4all/gpt4all-bindings/golang/buildllm/*.dll backend-assets/gpt4all/ || true
|
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dll backend-assets/gpt4all/ || true
|
||||||
|
|
||||||
backend-assets/espeak-ng-data: sources/go-piper
|
backend-assets/espeak-ng-data: go-piper
|
||||||
mkdir -p backend-assets/espeak-ng-data
|
mkdir -p backend-assets/espeak-ng-data
|
||||||
$(MAKE) -C sources/go-piper piper.o
|
$(MAKE) -C go-piper piper.o
|
||||||
@cp -rf sources/go-piper/piper-phonemize/pi/share/espeak-ng-data/. backend-assets/espeak-ng-data
|
@cp -rf go-piper/piper/build/pi/share/espeak-ng-data/. backend-assets/espeak-ng-data
|
||||||
|
|
||||||
sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a: sources/gpt4all
|
gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
|
||||||
$(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ libgpt4all.a
|
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ libgpt4all.a
|
||||||
|
|
||||||
## CEREBRAS GPT
|
## CEREBRAS GPT
|
||||||
sources/go-ggml-transformers:
|
go-ggml-transformers:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp sources/go-ggml-transformers
|
git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp go-ggml-transformers
|
||||||
cd sources/go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/go-ggml-transformers/libtransformers.a: sources/go-ggml-transformers
|
go-ggml-transformers/libtransformers.a: go-ggml-transformers
|
||||||
$(MAKE) -C sources/go-ggml-transformers BUILD_TYPE=$(BUILD_TYPE) libtransformers.a
|
$(MAKE) -C go-ggml-transformers BUILD_TYPE=$(BUILD_TYPE) libtransformers.a
|
||||||
|
|
||||||
sources/whisper.cpp:
|
whisper.cpp:
|
||||||
git clone https://github.com/ggerganov/whisper.cpp.git sources/whisper.cpp
|
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||||
cd sources/whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
|
cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
|
whisper.cpp/libwhisper.a: whisper.cpp
|
||||||
cd sources/whisper.cpp && make libwhisper.a
|
cd whisper.cpp && make libwhisper.a
|
||||||
|
|
||||||
sources/go-llama:
|
go-llama:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama
|
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
||||||
cd sources/go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/go-llama-ggml:
|
go-llama-stable:
|
||||||
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp sources/go-llama-ggml
|
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama-stable
|
||||||
cd sources/go-llama-ggml && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1
|
cd go-llama-stable && git checkout -b build $(GOLLAMA_STABLE_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
sources/go-llama/libbinding.a: sources/go-llama
|
go-llama/libbinding.a: go-llama
|
||||||
$(MAKE) -C sources/go-llama BUILD_TYPE=$(BUILD_TYPE) libbinding.a
|
$(MAKE) -C go-llama BUILD_TYPE=$(BUILD_TYPE) libbinding.a
|
||||||
|
|
||||||
sources/go-llama-ggml/libbinding.a: sources/go-llama-ggml
|
go-llama-stable/libbinding.a: go-llama-stable
|
||||||
$(MAKE) -C sources/go-llama-ggml BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
$(MAKE) -C go-llama-stable BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
|
||||||
|
|
||||||
sources/go-piper/libpiper_binding.a: sources/go-piper
|
go-piper/libpiper_binding.a: go-piper
|
||||||
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
|
$(MAKE) -C go-piper libpiper_binding.a example/main
|
||||||
|
|
||||||
backend/cpp/llama/llama.cpp:
|
get-sources: go-llama go-llama-stable go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert go-stable-diffusion
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
|
|
||||||
|
|
||||||
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/go-ggml-transformers sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
|
|
||||||
touch $@
|
touch $@
|
||||||
|
|
||||||
replace:
|
replace:
|
||||||
$(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(shell pwd)/sources/gpt4all/gpt4all-bindings/golang
|
$(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(shell pwd)/gpt4all/gpt4all-bindings/golang
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-ggml-transformers.cpp=$(shell pwd)/sources/go-ggml-transformers
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-ggml-transformers.cpp=$(shell pwd)/go-ggml-transformers
|
||||||
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/sources/go-rwkv
|
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv
|
||||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/sources/whisper.cpp
|
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/whisper.cpp
|
||||||
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(shell pwd)/sources/whisper.cpp/bindings/go
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/go-bert
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/sources/go-bert
|
$(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/go-stable-diffusion
|
||||||
$(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/sources/go-stable-diffusion
|
$(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/go-piper
|
||||||
$(GOCMD) mod edit -replace github.com/M0Rf30/go-tiny-dream=$(shell pwd)/sources/go-tiny-dream
|
|
||||||
$(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/sources/go-piper
|
|
||||||
|
|
||||||
prepare-sources: get-sources replace
|
prepare-sources: get-sources replace
|
||||||
$(GOCMD) mod download
|
$(GOCMD) mod download
|
||||||
touch $@
|
|
||||||
|
|
||||||
## GENERIC
|
## GENERIC
|
||||||
rebuild: ## Rebuilds the project
|
rebuild: ## Rebuilds the project
|
||||||
$(GOCMD) clean -cache
|
$(GOCMD) clean -cache
|
||||||
$(MAKE) -C sources/go-llama clean
|
$(MAKE) -C go-llama clean
|
||||||
$(MAKE) -C sources/go-llama-ggml clean
|
$(MAKE) -C go-llama-stable clean
|
||||||
$(MAKE) -C sources/gpt4all/gpt4all-bindings/golang/ clean
|
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean
|
||||||
$(MAKE) -C sources/go-ggml-transformers clean
|
$(MAKE) -C go-ggml-transformers clean
|
||||||
$(MAKE) -C sources/go-rwkv clean
|
$(MAKE) -C go-rwkv clean
|
||||||
$(MAKE) -C sources/whisper.cpp clean
|
$(MAKE) -C whisper.cpp clean
|
||||||
$(MAKE) -C sources/go-stable-diffusion clean
|
$(MAKE) -C go-stable-diffusion clean
|
||||||
$(MAKE) -C sources/go-bert clean
|
$(MAKE) -C go-bert clean
|
||||||
$(MAKE) -C sources/go-piper clean
|
$(MAKE) -C go-piper clean
|
||||||
$(MAKE) -C sources/go-tiny-dream clean
|
|
||||||
$(MAKE) build
|
$(MAKE) build
|
||||||
|
|
||||||
prepare: prepare-sources $(OPTIONAL_TARGETS)
|
prepare: prepare-sources $(OPTIONAL_TARGETS)
|
||||||
@@ -287,7 +251,17 @@ prepare: prepare-sources $(OPTIONAL_TARGETS)
|
|||||||
clean: ## Remove build related file
|
clean: ## Remove build related file
|
||||||
$(GOCMD) clean -cache
|
$(GOCMD) clean -cache
|
||||||
rm -f prepare
|
rm -f prepare
|
||||||
rm -rf ./sources
|
rm -rf ./go-llama
|
||||||
|
rm -rf ./gpt4all
|
||||||
|
rm -rf ./go-llama-stable
|
||||||
|
rm -rf ./go-gpt2
|
||||||
|
rm -rf ./go-stable-diffusion
|
||||||
|
rm -rf ./go-ggml-transformers
|
||||||
|
rm -rf ./backend-assets
|
||||||
|
rm -rf ./go-rwkv
|
||||||
|
rm -rf ./go-bert
|
||||||
|
rm -rf ./whisper.cpp
|
||||||
|
rm -rf ./go-piper
|
||||||
rm -rf $(BINARY_NAME)
|
rm -rf $(BINARY_NAME)
|
||||||
rm -rf release/
|
rm -rf release/
|
||||||
rm -rf ./backend/cpp/grpc/grpc_repo
|
rm -rf ./backend/cpp/grpc/grpc_repo
|
||||||
@@ -309,9 +283,6 @@ dist: build
|
|||||||
mkdir -p release
|
mkdir -p release
|
||||||
cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH)
|
cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH)
|
||||||
|
|
||||||
osx-signed: build
|
|
||||||
codesign --deep --force --sign "$(OSX_SIGNING_IDENTITY)" --entitlements "./Entitlements.plist" "./$(BINARY_NAME)"
|
|
||||||
|
|
||||||
## Run
|
## Run
|
||||||
run: prepare ## run local-ai
|
run: prepare ## run local-ai
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||||
@@ -335,7 +306,7 @@ test: prepare test-models/testmodel grpcs
|
|||||||
@echo 'Running tests'
|
@echo 'Running tests'
|
||||||
export GO_TAGS="tts stablediffusion"
|
export GO_TAGS="tts stablediffusion"
|
||||||
$(MAKE) prepare-test
|
$(MAKE) prepare-test
|
||||||
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
|
||||||
$(MAKE) test-gpt4all
|
$(MAKE) test-gpt4all
|
||||||
$(MAKE) test-llama
|
$(MAKE) test-llama
|
||||||
@@ -402,57 +373,40 @@ help: ## Show this help.
|
|||||||
protogen: protogen-go protogen-python
|
protogen: protogen-go protogen-python
|
||||||
|
|
||||||
protogen-go:
|
protogen-go:
|
||||||
protoc -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||||
backend/backend.proto
|
pkg/grpc/proto/backend.proto
|
||||||
|
|
||||||
protogen-python:
|
protogen-python:
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/sentencetransformers/ --grpc_python_out=backend/python/sentencetransformers/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/transformers/ --grpc_python_out=backend/python/transformers/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/autogptq/ --grpc_python_out=extra/grpc/autogptq/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/transformers-musicgen/ --grpc_python_out=backend/python/transformers-musicgen/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/exllama/ --grpc_python_out=extra/grpc/exllama/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/autogptq/ --grpc_python_out=backend/python/autogptq/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/bark/ --grpc_python_out=extra/grpc/bark/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama/ --grpc_python_out=backend/python/exllama/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/diffusers/ --grpc_python_out=extra/grpc/diffusers/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/bark/ --grpc_python_out=backend/python/bark/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/vall-e-x/ --grpc_python_out=extra/grpc/vall-e-x/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/diffusers/ --grpc_python_out=backend/python/diffusers/ backend/backend.proto
|
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/vllm/ --grpc_python_out=extra/grpc/vllm/ pkg/grpc/proto/backend.proto
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/coqui/ --grpc_python_out=backend/python/coqui/ backend/backend.proto
|
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
|
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
|
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto
|
|
||||||
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama2/ --grpc_python_out=backend/python/exllama2/ backend/backend.proto
|
|
||||||
|
|
||||||
## GRPC
|
## GRPC
|
||||||
# Note: it is duplicated in the Dockerfile
|
# Note: it is duplicated in the Dockerfile
|
||||||
prepare-extra-conda-environments:
|
prepare-extra-conda-environments:
|
||||||
$(MAKE) -C backend/python/autogptq
|
$(MAKE) -C extra/grpc/autogptq
|
||||||
$(MAKE) -C backend/python/bark
|
$(MAKE) -C extra/grpc/bark
|
||||||
$(MAKE) -C backend/python/coqui
|
$(MAKE) -C extra/grpc/diffusers
|
||||||
$(MAKE) -C backend/python/diffusers
|
$(MAKE) -C extra/grpc/vllm
|
||||||
$(MAKE) -C backend/python/vllm
|
$(MAKE) -C extra/grpc/huggingface
|
||||||
$(MAKE) -C backend/python/sentencetransformers
|
$(MAKE) -C extra/grpc/vall-e-x
|
||||||
$(MAKE) -C backend/python/transformers
|
$(MAKE) -C extra/grpc/exllama
|
||||||
$(MAKE) -C backend/python/transformers-musicgen
|
|
||||||
$(MAKE) -C backend/python/vall-e-x
|
|
||||||
$(MAKE) -C backend/python/exllama
|
|
||||||
$(MAKE) -C backend/python/petals
|
|
||||||
$(MAKE) -C backend/python/exllama2
|
|
||||||
|
|
||||||
prepare-test-extra:
|
|
||||||
$(MAKE) -C backend/python/transformers
|
|
||||||
$(MAKE) -C backend/python/diffusers
|
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
|
||||||
$(MAKE) -C backend/python/transformers test
|
|
||||||
$(MAKE) -C backend/python/diffusers test
|
|
||||||
|
|
||||||
backend-assets/grpc:
|
backend-assets/grpc:
|
||||||
mkdir -p backend-assets/grpc
|
mkdir -p backend-assets/grpc
|
||||||
|
|
||||||
backend-assets/grpc/llama: backend-assets/grpc sources/go-llama/libbinding.a
|
backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/sources/go-llama
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-llama LIBRARY_PATH=$(shell pwd)/sources/go-llama \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./backend/go/llm/llama/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/
|
||||||
# TODO: every binary should have its own folder instead, so can have different implementations
|
# TODO: every binary should have its own folder instead, so can have different metal implementations
|
||||||
ifeq ($(BUILD_TYPE),metal)
|
ifeq ($(BUILD_TYPE),metal)
|
||||||
cp backend/cpp/llama/llama.cpp/ggml-metal.metal backend-assets/grpc/
|
cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/
|
||||||
endif
|
endif
|
||||||
|
|
||||||
## BACKEND CPP LLAMA START
|
## BACKEND CPP LLAMA START
|
||||||
@@ -471,7 +425,7 @@ ifdef BUILD_GRPC_FOR_BACKEND_LLAMA
|
|||||||
export _PROTOBUF_PROTOC=${INSTALLED_PACKAGES}/bin/proto && \
|
export _PROTOBUF_PROTOC=${INSTALLED_PACKAGES}/bin/proto && \
|
||||||
export _GRPC_CPP_PLUGIN_EXECUTABLE=${INSTALLED_PACKAGES}/bin/grpc_cpp_plugin && \
|
export _GRPC_CPP_PLUGIN_EXECUTABLE=${INSTALLED_PACKAGES}/bin/grpc_cpp_plugin && \
|
||||||
export PATH=${PATH}:${INSTALLED_PACKAGES}/bin && \
|
export PATH=${PATH}:${INSTALLED_PACKAGES}/bin && \
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
CMAKE_ARGS="${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
else
|
else
|
||||||
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
|
||||||
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
|
||||||
@@ -486,75 +440,71 @@ ifeq ($(BUILD_TYPE),metal)
|
|||||||
cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/
|
cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/
|
||||||
endif
|
endif
|
||||||
|
|
||||||
backend-assets/grpc/llama-ggml: backend-assets/grpc sources/go-llama-ggml/libbinding.a
|
backend-assets/grpc/llama-stable: backend-assets/grpc go-llama-stable/libbinding.a
|
||||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/sources/go-llama-ggml
|
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama-stable
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-llama-ggml LIBRARY_PATH=$(shell pwd)/sources/go-llama-ggml \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama-stable LIBRARY_PATH=$(shell pwd)/go-llama \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-ggml ./backend/go/llm/llama-ggml/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-stable ./cmd/grpc/llama-stable/
|
||||||
|
|
||||||
backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all sources/gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/sources/gpt4all/gpt4all-bindings/golang/ \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./backend/go/llm/gpt4all/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/
|
||||||
|
|
||||||
backend-assets/grpc/dolly: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/dolly: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./backend/go/llm/dolly/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/
|
||||||
|
|
||||||
backend-assets/grpc/gpt2: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/gpt2: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./backend/go/llm/gpt2/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/
|
||||||
|
|
||||||
backend-assets/grpc/gptj: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/gptj: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./backend/go/llm/gptj/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/
|
||||||
|
|
||||||
backend-assets/grpc/gptneox: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/gptneox: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./backend/go/llm/gptneox/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/
|
||||||
|
|
||||||
backend-assets/grpc/mpt: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/mpt: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./backend/go/llm/mpt/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/
|
||||||
|
|
||||||
backend-assets/grpc/replit: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/replit: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./backend/go/llm/replit/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/
|
||||||
|
|
||||||
backend-assets/grpc/falcon-ggml: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/falcon-ggml: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./backend/go/llm/falcon-ggml/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./cmd/grpc/falcon-ggml/
|
||||||
|
|
||||||
backend-assets/grpc/starcoder: backend-assets/grpc sources/go-ggml-transformers/libtransformers.a
|
backend-assets/grpc/starcoder: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/sources/go-ggml-transformers \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./backend/go/llm/starcoder/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/
|
||||||
|
|
||||||
backend-assets/grpc/rwkv: backend-assets/grpc sources/go-rwkv/librwkv.a
|
backend-assets/grpc/rwkv: backend-assets/grpc go-rwkv/librwkv.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-rwkv LIBRARY_PATH=$(shell pwd)/sources/go-rwkv \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-rwkv LIBRARY_PATH=$(shell pwd)/go-rwkv \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./backend/go/llm/rwkv
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./cmd/grpc/rwkv/
|
||||||
|
|
||||||
backend-assets/grpc/bert-embeddings: backend-assets/grpc sources/go-bert/libgobert.a
|
backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-bert LIBRARY_PATH=$(shell pwd)/sources/go-bert \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-bert LIBRARY_PATH=$(shell pwd)/go-bert \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./backend/go/llm/bert/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./cmd/grpc/bert-embeddings/
|
||||||
|
|
||||||
backend-assets/grpc/langchain-huggingface: backend-assets/grpc
|
backend-assets/grpc/langchain-huggingface: backend-assets/grpc
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./backend/go/llm/langchain/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/
|
||||||
|
|
||||||
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
backend-assets/grpc/stablediffusion: backend-assets/grpc
|
||||||
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
|
||||||
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
|
$(MAKE) go-stable-diffusion/libstablediffusion.a; \
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/sources/go-stable-diffusion/ \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
backend-assets/grpc/tinydream: backend-assets/grpc sources/go-tiny-dream/libtinydream.a
|
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-tiny-dream \
|
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./cmd/grpc/piper/
|
||||||
|
|
||||||
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data sources/go-piper/libpiper_binding.a
|
backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a
|
||||||
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/sources/go-piper \
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
|
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/
|
||||||
|
|
||||||
backend-assets/grpc/whisper: backend-assets/grpc sources/whisper.cpp/libwhisper.a
|
|
||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/whisper.cpp LIBRARY_PATH=$(shell pwd)/sources/whisper.cpp \
|
|
||||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/
|
|
||||||
|
|
||||||
grpcs: prepare $(GRPC_BACKENDS)
|
grpcs: prepare $(GRPC_BACKENDS)
|
||||||
|
|||||||
104
README.md
104
README.md
@@ -22,10 +22,15 @@
|
|||||||
|
|
||||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||||
>
|
>
|
||||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [📣 News](https://localai.io/basics/news/) [ 🛫 Examples ](https://github.com/go-skynet/LocalAI/tree/master/examples/) [ 🖼️ Models ](https://localai.io/models/) [ 🚀 Roadmap ](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
> [💻 Quickstart](https://localai.io/basics/getting_started/) [📣 News](https://localai.io/basics/news/) [ 🛫 Examples ](https://github.com/go-skynet/LocalAI/tree/master/examples/) [ 🖼️ Models ](https://localai.io/models/)
|
||||||
|
|
||||||
|
|
||||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||||
|
|
||||||
|
**LocalAI** is a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run LLMs (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families that are compatible with the ggml format, pytorch and more. Does not require GPU.
|
||||||
|
|
||||||
|
<p align="center"><b>Follow LocalAI </b></p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://twitter.com/LocalAI_API" target="blank">
|
<a href="https://twitter.com/LocalAI_API" target="blank">
|
||||||
<img src="https://img.shields.io/twitter/follow/LocalAI_API?label=Follow: LocalAI_API&style=social" alt="Follow LocalAI_API"/>
|
<img src="https://img.shields.io/twitter/follow/LocalAI_API?label=Follow: LocalAI_API&style=social" alt="Follow LocalAI_API"/>
|
||||||
@@ -34,25 +39,47 @@
|
|||||||
<img src="https://dcbadge.vercel.app/api/server/uJAeKSAGDy?style=flat-square&theme=default-inverted" alt="Join LocalAI Discord Community"/>
|
<img src="https://dcbadge.vercel.app/api/server/uJAeKSAGDy?style=flat-square&theme=default-inverted" alt="Join LocalAI Discord Community"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that’s compatible with OpenAI API specifications for local inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU.
|
<p align="center"><b>Connect with the Creator </b></p>
|
||||||
|
|
||||||
## 🔥🔥 Hot topics / Roadmap
|
<p align="center">
|
||||||
|
<a href="https://twitter.com/mudler_it" target="blank">
|
||||||
|
<img src="https://img.shields.io/twitter/follow/mudler_it?label=Follow: mudler_it&style=social" alt="Follow mudler_it"/>
|
||||||
|
</a>
|
||||||
|
<a href='https://github.com/mudler'>
|
||||||
|
<img alt="Follow on Github" src="https://img.shields.io/badge/Follow-mudler-black?logo=github&link=https%3A%2F%2Fgithub.com%2Fmudler">
|
||||||
|
</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
<p align="center"><b>Share LocalAI Repository</b></p>
|
||||||
|
|
||||||
- 🐸 Coqui: https://github.com/mudler/LocalAI/pull/1489
|
<p align="center">
|
||||||
- Inline templates: https://github.com/mudler/LocalAI/pull/1452
|
|
||||||
- Mixtral: https://github.com/mudler/LocalAI/pull/1449
|
|
||||||
- Img2vid https://github.com/mudler/LocalAI/pull/1442
|
|
||||||
- Musicgen https://github.com/mudler/LocalAI/pull/1387
|
|
||||||
|
|
||||||
Hot topics (looking for contributors):
|
<a href="https://twitter.com/intent/tweet?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.&url=https://github.com/go-skynet/LocalAI&hashtags=LocalAI,AI" target="blank">
|
||||||
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
|
<img src="https://img.shields.io/twitter/follow/_LocalAI?label=Share Repo on Twitter&style=social" alt="Follow _LocalAI"/></a>
|
||||||
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
|
<a href="https://t.me/share/url?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.&url=https://github.com/go-skynet/LocalAI" target="_blank"><img src="https://img.shields.io/twitter/url?label=Telegram&logo=Telegram&style=social&url=https://github.com/go-skynet/LocalAI" alt="Share on Telegram"/></a>
|
||||||
|
<a href="https://api.whatsapp.com/send?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.%20https://github.com/go-skynet/LocalAI"><img src="https://img.shields.io/twitter/url?label=whatsapp&logo=whatsapp&style=social&url=https://github.com/go-skynet/LocalAI" /></a> <a href="https://www.reddit.com/submit?url=https://github.com/go-skynet/LocalAI&title=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.
|
||||||
|
" target="blank">
|
||||||
|
<img src="https://img.shields.io/twitter/url?label=Reddit&logo=Reddit&style=social&url=https://github.com/go-skynet/LocalAI" alt="Share on Reddit"/>
|
||||||
|
</a> <a href="mailto:?subject=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.%3A%0Ahttps://github.com/go-skynet/LocalAI" target="_blank"><img src="https://img.shields.io/twitter/url?label=Gmail&logo=Gmail&style=social&url=https://github.com/go-skynet/LocalAI"/></a> <a href="https://www.buymeacoffee.com/mudler" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="23" width="100" style="border-radius:1px"></a>
|
||||||
|
|
||||||
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
|
</p>
|
||||||
|
|
||||||
## 💻 [Getting started](https://localai.io/basics/getting_started/index.html)
|
<hr>
|
||||||
|
|
||||||
|
In a nutshell:
|
||||||
|
|
||||||
|
- Local, OpenAI drop-in alternative REST API. You own your data.
|
||||||
|
- NO GPU required. NO Internet access is required either
|
||||||
|
- Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. See also the [build section](https://localai.io/basics/build/index.html).
|
||||||
|
- Supports multiple models
|
||||||
|
- 🏃 Once loaded the first time, it keep models loaded in memory for faster inference
|
||||||
|
- ⚡ Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
|
||||||
|
|
||||||
|
LocalAI was created by [Ettore Di Giacinto](https://github.com/mudler/) and is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome!
|
||||||
|
|
||||||
|
Note that this started just as a [fun weekend project](https://localai.io/#backstory) in order to try to create the necessary pieces for a full AI assistant like `ChatGPT`: the community is growing fast and we are working hard to make it better and more stable. If you want to help, please consider contributing (see below)!
|
||||||
|
|
||||||
|
## 🔥🔥 [Hot topics / Roadmap](https://localai.io/#-hot-topics--roadmap)
|
||||||
|
|
||||||
## 🚀 [Features](https://localai.io/features/)
|
## 🚀 [Features](https://localai.io/features/)
|
||||||
|
|
||||||
@@ -64,41 +91,7 @@ If you want to help and contribute, issues up for grabs: https://github.com/mudl
|
|||||||
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
||||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||||
- 🆕 [Vision API](https://localai.io/features/gpt-vision/)
|
|
||||||
|
|
||||||
## 💻 Usage
|
|
||||||
|
|
||||||
Check out the [Getting started](https://localai.io/basics/getting_started/index.html) section in our documentation.
|
|
||||||
|
|
||||||
### 🔗 Community and integrations
|
|
||||||
|
|
||||||
Build and deploy custom containers:
|
|
||||||
- https://github.com/sozercan/aikit
|
|
||||||
|
|
||||||
WebUIs:
|
|
||||||
- https://github.com/Jirubizu/localai-admin
|
|
||||||
- https://github.com/go-skynet/LocalAI-frontend
|
|
||||||
|
|
||||||
Model galleries
|
|
||||||
- https://github.com/go-skynet/model-gallery
|
|
||||||
|
|
||||||
Other:
|
|
||||||
- Helm chart https://github.com/go-skynet/helm-charts
|
|
||||||
- VSCode extension https://github.com/badgooooor/localai-vscode-plugin
|
|
||||||
- Local Smart assistant https://github.com/mudler/LocalAGI
|
|
||||||
- Home Assistant https://github.com/sammcj/homeassistant-localai / https://github.com/drndos/hass-openai-custom-conversation
|
|
||||||
- Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord
|
|
||||||
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
|
|
||||||
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
|
|
||||||
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
|
|
||||||
|
|
||||||
### 🔗 Resources
|
|
||||||
|
|
||||||
- 🆕 New! [LLM finetuning guide](https://localai.io/advanced/fine-tuning/)
|
|
||||||
- [How to build locally](https://localai.io/basics/build/index.html)
|
|
||||||
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
|
||||||
- [Projects integrating LocalAI](https://localai.io/integrations/)
|
|
||||||
- [How tos section](https://localai.io/howtos/) (curated by our community)
|
|
||||||
|
|
||||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||||
|
|
||||||
@@ -107,6 +100,21 @@ Other:
|
|||||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||||
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
||||||
|
|
||||||
|
## 💻 Usage
|
||||||
|
|
||||||
|
Check out the [Getting started](https://localai.io/basics/getting_started/index.html) section in our documentation.
|
||||||
|
|
||||||
|
### 💡 Example: Use Luna-AI Llama model
|
||||||
|
|
||||||
|
See the [documentation](https://localai.io/basics/getting_started)
|
||||||
|
|
||||||
|
### 🔗 Resources
|
||||||
|
|
||||||
|
- [How to build locally](https://localai.io/basics/build/index.html)
|
||||||
|
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
||||||
|
- [Projects integrating LocalAI](https://localai.io/integrations/)
|
||||||
|
- [How tos section](https://localai.io/howtos/) (curated by our community)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If you utilize this repository, data in a downstream project, please consider citing it with:
|
If you utilize this repository, data in a downstream project, please consider citing it with:
|
||||||
|
|||||||
77
api/api.go
77
api/api.go
@@ -1,10 +1,8 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
@@ -15,7 +13,6 @@ import (
|
|||||||
"github.com/go-skynet/LocalAI/internal"
|
"github.com/go-skynet/LocalAI/internal"
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
"github.com/go-skynet/LocalAI/metrics"
|
||||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||||
"github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||||
@@ -47,10 +44,6 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
|
||||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.Debug {
|
if options.Debug {
|
||||||
for _, v := range cl.ListConfigs() {
|
for _, v := range cl.ListConfigs() {
|
||||||
cfg, _ := cl.GetConfig(v)
|
cfg, _ := cl.GetConfig(v)
|
||||||
@@ -86,22 +79,6 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
|
|||||||
options.Loader.StopAllGRPC()
|
options.Loader.StopAllGRPC()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if options.WatchDog {
|
|
||||||
wd := model.NewWatchDog(
|
|
||||||
options.Loader,
|
|
||||||
options.WatchDogBusyTimeout,
|
|
||||||
options.WatchDogIdleTimeout,
|
|
||||||
options.WatchDogBusy,
|
|
||||||
options.WatchDogIdle)
|
|
||||||
options.Loader.SetWatchDog(wd)
|
|
||||||
go wd.Run()
|
|
||||||
go func() {
|
|
||||||
<-options.Context.Done()
|
|
||||||
log.Debug().Msgf("Context canceled, shutting down")
|
|
||||||
wd.Shutdown()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return options, cl, nil
|
return options, cl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,46 +127,28 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
|
|||||||
|
|
||||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||||
auth := func(c *fiber.Ctx) error {
|
auth := func(c *fiber.Ctx) error {
|
||||||
if len(options.ApiKeys) == 0 {
|
if len(options.ApiKeys) > 0 {
|
||||||
return c.Next()
|
authHeader := c.Get("Authorization")
|
||||||
}
|
if authHeader == "" {
|
||||||
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||||
// Check for api_keys.json file
|
}
|
||||||
fileContent, err := os.ReadFile("api_keys.json")
|
authHeaderParts := strings.Split(authHeader, " ")
|
||||||
if err == nil {
|
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||||
// Parse JSON content from the file
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||||
var fileKeys []string
|
|
||||||
err := json.Unmarshal(fileContent, &fileKeys)
|
|
||||||
if err != nil {
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add file keys to options.ApiKeys
|
apiKey := authHeaderParts[1]
|
||||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
validApiKey := false
|
||||||
}
|
for _, key := range options.ApiKeys {
|
||||||
|
if apiKey == key {
|
||||||
if len(options.ApiKeys) == 0 {
|
validApiKey = true
|
||||||
return c.Next()
|
}
|
||||||
}
|
}
|
||||||
|
if !validApiKey {
|
||||||
authHeader := c.Get("Authorization")
|
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||||
if authHeader == "" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
|
||||||
}
|
|
||||||
authHeaderParts := strings.Split(authHeader, " ")
|
|
||||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := authHeaderParts[1]
|
|
||||||
for _, key := range options.ApiKeys {
|
|
||||||
if apiKey == key {
|
|
||||||
return c.Next()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return c.Next()
|
||||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.CORS {
|
if options.CORS {
|
||||||
|
|||||||
@@ -294,14 +294,14 @@ var _ = Describe("API test", func() {
|
|||||||
Expect(content["backend"]).To(Equal("bert-embeddings"))
|
Expect(content["backend"]).To(Equal("bert-embeddings"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("runs openllama(llama-ggml backend)", Label("llama"), func() {
|
It("runs openllama", Label("llama"), func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test supported only on linux")
|
Skip("test supported only on linux")
|
||||||
}
|
}
|
||||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||||
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
|
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
|
||||||
Name: "openllama_3b",
|
Name: "openllama_3b",
|
||||||
Overrides: map[string]interface{}{"backend": "llama-ggml", "mmap": true, "f16": true, "context_size": 128},
|
Overrides: map[string]interface{}{"backend": "llama-stable", "mmap": true, "f16": true, "context_size": 128},
|
||||||
})
|
})
|
||||||
|
|
||||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||||
@@ -362,10 +362,9 @@ var _ = Describe("API test", func() {
|
|||||||
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
|
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
|
||||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
|
It("runs openllama gguf", Label("llama-gguf"), func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test supported only on linux")
|
Skip("test supported only on linux")
|
||||||
}
|
}
|
||||||
@@ -705,7 +704,7 @@ var _ = Describe("API test", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("External gRPC calls", func() {
|
Context("External gRPC calls", func() {
|
||||||
It("calculate embeddings with sentencetransformers", func() {
|
It("calculate embeddings with huggingface", func() {
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
Skip("test supported only on linux")
|
Skip("test supported only on linux")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
model.WithModel(c.Model),
|
model.WithModel(c.Model),
|
||||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
CUDA: c.Diffusers.CUDA,
|
||||||
SchedulerType: c.Diffusers.SchedulerType,
|
SchedulerType: c.Diffusers.SchedulerType,
|
||||||
PipelineType: c.Diffusers.PipelineType,
|
PipelineType: c.Diffusers.PipelineType,
|
||||||
CFGScale: c.Diffusers.CFGScale,
|
CFGScale: c.Diffusers.CFGScale,
|
||||||
@@ -27,7 +27,6 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||||||
CLIPModel: c.Diffusers.ClipModel,
|
CLIPModel: c.Diffusers.ClipModel,
|
||||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||||
ControlNet: c.Diffusers.ControlNet,
|
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,6 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
|
|||||||
opts = append(opts, model.WithSingleActiveBackend())
|
opts = append(opts, model.WithSingleActiveBackend())
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.ParallelBackendRequests {
|
|
||||||
opts = append(opts, model.EnableParallelRequests)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.GRPC.Attempts != 0 {
|
if c.GRPC.Attempts != 0 {
|
||||||
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||||
}
|
}
|
||||||
@@ -46,7 +42,6 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
|||||||
Seed: int32(c.Seed),
|
Seed: int32(c.Seed),
|
||||||
NBatch: int32(b),
|
NBatch: int32(b),
|
||||||
NoMulMatQ: c.NoMulMatQ,
|
NoMulMatQ: c.NoMulMatQ,
|
||||||
CUDA: c.CUDA, // diffusers, transformers
|
|
||||||
DraftModel: c.DraftModel,
|
DraftModel: c.DraftModel,
|
||||||
AudioPath: c.VallE.AudioPath,
|
AudioPath: c.VallE.AudioPath,
|
||||||
Quantization: c.Quantization,
|
Quantization: c.Quantization,
|
||||||
|
|||||||
@@ -59,13 +59,9 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
|
|||||||
// If the model file is not empty, we pass it joined with the model path
|
// If the model file is not empty, we pass it joined with the model path
|
||||||
modelPath := ""
|
modelPath := ""
|
||||||
if modelFile != "" {
|
if modelFile != "" {
|
||||||
if bb != model.TransformersMusicGen {
|
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
||||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
return "", nil, err
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
modelPath = modelFile
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,17 +38,14 @@ type Config struct {
|
|||||||
|
|
||||||
// Diffusers
|
// Diffusers
|
||||||
Diffusers Diffusers `yaml:"diffusers"`
|
Diffusers Diffusers `yaml:"diffusers"`
|
||||||
Step int `yaml:"step"`
|
|
||||||
|
Step int `yaml:"step"`
|
||||||
|
|
||||||
// GRPC Options
|
// GRPC Options
|
||||||
GRPC GRPC `yaml:"grpc"`
|
GRPC GRPC `yaml:"grpc"`
|
||||||
|
|
||||||
// Vall-e-x
|
// Vall-e-x
|
||||||
VallE VallE `yaml:"vall-e"`
|
VallE VallE `yaml:"vall-e"`
|
||||||
|
|
||||||
// CUDA
|
|
||||||
// Explicitly enable CUDA or not (some backends might need it)
|
|
||||||
CUDA bool `yaml:"cuda"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type VallE struct {
|
type VallE struct {
|
||||||
@@ -70,16 +65,15 @@ type GRPC struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Diffusers struct {
|
type Diffusers struct {
|
||||||
CUDA bool `yaml:"cuda"`
|
|
||||||
PipelineType string `yaml:"pipeline_type"`
|
PipelineType string `yaml:"pipeline_type"`
|
||||||
SchedulerType string `yaml:"scheduler_type"`
|
SchedulerType string `yaml:"scheduler_type"`
|
||||||
|
CUDA bool `yaml:"cuda"`
|
||||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||||
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||||
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
|
||||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||||
ControlNet string `yaml:"control_net"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type LLMConfig struct {
|
type LLMConfig struct {
|
||||||
@@ -266,36 +260,6 @@ func (cm *ConfigLoader) ListConfigs() []string {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
|
||||||
cm.Lock()
|
|
||||||
defer cm.Unlock()
|
|
||||||
|
|
||||||
for i, config := range cm.configs {
|
|
||||||
modelURL := config.PredictionOptions.Model
|
|
||||||
modelURL = utils.ConvertURL(modelURL)
|
|
||||||
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
|
|
||||||
// md5 of model name
|
|
||||||
md5Name := utils.MD5(modelURL)
|
|
||||||
|
|
||||||
// check if file exists
|
|
||||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
|
|
||||||
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
|
|
||||||
log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cc := cm.configs[i]
|
|
||||||
c := &cc
|
|
||||||
c.PredictionOptions.Model = md5Name
|
|
||||||
cm.configs[i] = *c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||||
cm.Lock()
|
cm.Lock()
|
||||||
defer cm.Unlock()
|
defer cm.Unlock()
|
||||||
@@ -313,7 +277,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
|
|||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
// Skip templates, YAML and .keep files
|
// Skip templates, YAML and .keep files
|
||||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
if !strings.Contains(file.Name(), ".yaml") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||||
|
|||||||
@@ -123,12 +123,13 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
model := bm.options.Loader.CheckIsLoaded(backendId)
|
client := bm.options.Loader.CheckIsLoaded(backendId)
|
||||||
if model == "" {
|
|
||||||
|
if client == nil {
|
||||||
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||||
}
|
}
|
||||||
|
|
||||||
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
|
status, rpcErr := client.Status(context.TODO())
|
||||||
if rpcErr != nil {
|
if rpcErr != nil {
|
||||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
if input.ResponseFormat == "json_object" {
|
||||||
input.Grammar = grammar.JSONBNF
|
input.Grammar = grammar.JSONBNF
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,12 +219,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
c.Set("Transfer-Encoding", "chunked")
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
}
|
}
|
||||||
|
|
||||||
templateFile := ""
|
templateFile := config.Model
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||||
templateFile = config.TemplateConfig.Chat
|
templateFile = config.TemplateConfig.Chat
|
||||||
@@ -234,19 +229,18 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
templateFile = config.TemplateConfig.Functions
|
templateFile = config.TemplateConfig.Functions
|
||||||
}
|
}
|
||||||
|
|
||||||
if templateFile != "" {
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
Functions: funcs,
|
Functions: funcs,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
predInput = templatedInput
|
predInput = templatedInput
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
} else {
|
} else {
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
if input.ResponseFormat == "json_object" {
|
||||||
input.Grammar = grammar.JSONBNF
|
input.Grammar = grammar.JSONBNF
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,12 +81,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
c.Set("Transfer-Encoding", "chunked")
|
c.Set("Transfer-Encoding", "chunked")
|
||||||
}
|
}
|
||||||
|
|
||||||
templateFile := ""
|
templateFile := config.Model
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" {
|
if config.TemplateConfig.Completion != "" {
|
||||||
templateFile = config.TemplateConfig.Completion
|
templateFile = config.TemplateConfig.Completion
|
||||||
@@ -99,14 +94,13 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
predInput := config.PromptStrings[0]
|
predInput := config.PromptStrings[0]
|
||||||
|
|
||||||
if templateFile != "" {
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: predInput,
|
Input: predInput,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
predInput = templatedInput
|
predInput = templatedInput
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responses := make(chan schema.OpenAIResponse)
|
responses := make(chan schema.OpenAIResponse)
|
||||||
@@ -151,16 +145,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for k, i := range config.PromptStrings {
|
for k, i := range config.PromptStrings {
|
||||||
if templateFile != "" {
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
SystemPrompt: config.SystemPrompt,
|
||||||
SystemPrompt: config.SystemPrompt,
|
Input: i,
|
||||||
Input: i,
|
})
|
||||||
})
|
if err == nil {
|
||||||
if err == nil {
|
i = templatedInput
|
||||||
i = templatedInput
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(
|
r, tokenUsage, err := ComputeChoices(
|
||||||
|
|||||||
@@ -30,12 +30,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
templateFile := ""
|
templateFile := config.Model
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
||||||
templateFile = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" {
|
if config.TemplateConfig.Edit != "" {
|
||||||
templateFile = config.TemplateConfig.Edit
|
templateFile = config.TemplateConfig.Edit
|
||||||
@@ -45,16 +40,15 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
totalTokenUsage := backend.TokenUsage{}
|
totalTokenUsage := backend.TokenUsage{}
|
||||||
|
|
||||||
for _, i := range config.InputStrings {
|
for _, i := range config.InputStrings {
|
||||||
if templateFile != "" {
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||||
Input: i,
|
Input: i,
|
||||||
Instruction: input.Instruction,
|
Instruction: input.Instruction,
|
||||||
SystemPrompt: config.SystemPrompt,
|
SystemPrompt: config.SystemPrompt,
|
||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
i = templatedInput
|
i = templatedInput
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -24,26 +22,6 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func downloadFile(url string) (string, error) {
|
|
||||||
// Get the data
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Create the file
|
|
||||||
out, err := os.CreateTemp("", "image")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
// Write the body to file
|
|
||||||
_, err = io.Copy(out, resp.Body)
|
|
||||||
return out.Name(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -78,31 +56,12 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
src := ""
|
src := ""
|
||||||
if input.File != "" {
|
if input.File != "" {
|
||||||
|
//base 64 decode the file and write it somewhere
|
||||||
fileData := []byte{}
|
// that we will cleanup
|
||||||
// check if input.File is an URL, if so download it and save it
|
decoded, err := base64.StdEncoding.DecodeString(input.File)
|
||||||
// to a temporary file
|
if err != nil {
|
||||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
return err
|
||||||
out, err := downloadFile(input.File)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed downloading file:%w", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(out)
|
|
||||||
|
|
||||||
fileData, err = os.ReadFile(out)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed reading file:%w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// base 64 decode the file and write it somewhere
|
|
||||||
// that we will cleanup
|
|
||||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -110,7 +69,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
// write the base64 result
|
// write the base64 result
|
||||||
writer := bufio.NewWriter(outputFile)
|
writer := bufio.NewWriter(outputFile)
|
||||||
_, err = writer.Write(fileData)
|
_, err = writer.Write(decoded)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
outputFile.Close()
|
outputFile.Close()
|
||||||
return err
|
return err
|
||||||
@@ -122,12 +81,8 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
switch config.Backend {
|
// XXX: Only stablediffusion is supported for now
|
||||||
case "stablediffusion":
|
if config.Backend == "" {
|
||||||
config.Backend = model.StableDiffusionBackend
|
|
||||||
case "tinydream":
|
|
||||||
config.Backend = model.TinyDreamBackend
|
|
||||||
case "":
|
|
||||||
config.Backend = model.StableDiffusionBackend
|
config.Backend = model.StableDiffusionBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,7 +100,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
|||||||
}
|
}
|
||||||
|
|
||||||
b64JSON := false
|
b64JSON := false
|
||||||
if input.ResponseFormat.Type == "b64_json" {
|
if input.ResponseFormat == "b64_json" {
|
||||||
b64JSON = true
|
b64JSON = true
|
||||||
}
|
}
|
||||||
// src and clip_skip
|
// src and clip_skip
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/metrics"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/metrics"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,13 +36,7 @@ type Option struct {
|
|||||||
|
|
||||||
AutoloadGalleries bool
|
AutoloadGalleries bool
|
||||||
|
|
||||||
SingleBackend bool
|
SingleBackend bool
|
||||||
ParallelBackendRequests bool
|
|
||||||
|
|
||||||
WatchDogIdle bool
|
|
||||||
WatchDogBusy bool
|
|
||||||
WatchDog bool
|
|
||||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*Option)
|
||||||
@@ -69,40 +62,10 @@ func WithCors(b bool) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableWatchDog = func(o *Option) {
|
|
||||||
o.WatchDog = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var EnableWatchDogIdleCheck = func(o *Option) {
|
|
||||||
o.WatchDog = true
|
|
||||||
o.WatchDogIdle = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var EnableWatchDogBusyCheck = func(o *Option) {
|
|
||||||
o.WatchDog = true
|
|
||||||
o.WatchDogBusy = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
|
||||||
return func(o *Option) {
|
|
||||||
o.WatchDogBusyTimeout = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
|
||||||
return func(o *Option) {
|
|
||||||
o.WatchDogIdleTimeout = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var EnableSingleBackend = func(o *Option) {
|
var EnableSingleBackend = func(o *Option) {
|
||||||
o.SingleBackend = true
|
o.SingleBackend = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnableParallelBackendRequests = func(o *Option) {
|
|
||||||
o.ParallelBackendRequests = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var EnableGalleriesAutoload = func(o *Option) {
|
var EnableGalleriesAutoload = func(o *Option) {
|
||||||
o.AutoloadGalleries = true
|
o.AutoloadGalleries = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,12 +83,6 @@ type OpenAIModel struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionResponseFormatType string
|
|
||||||
|
|
||||||
type ChatCompletionResponseFormat struct {
|
|
||||||
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIRequest struct {
|
type OpenAIRequest struct {
|
||||||
config.PredictionOptions
|
config.PredictionOptions
|
||||||
|
|
||||||
@@ -98,7 +92,7 @@ type OpenAIRequest struct {
|
|||||||
// whisper
|
// whisper
|
||||||
File string `json:"file" validate:"required"`
|
File string `json:"file" validate:"required"`
|
||||||
//whisper/image
|
//whisper/image
|
||||||
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
|
ResponseFormat string `json:"response_format"`
|
||||||
// image
|
// image
|
||||||
Size string `json:"size"`
|
Size string `json:"size"`
|
||||||
// Prompt is read only by completion/image API calls
|
// Prompt is read only by completion/image API calls
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ include_directories(${Protobuf_INCLUDE_DIRS})
|
|||||||
message(STATUS "Using protobuf version ${Protobuf_VERSION} | Protobuf_INCLUDE_DIRS: ${Protobuf_INCLUDE_DIRS} | CMAKE_CURRENT_BINARY_DIR: ${CMAKE_CURRENT_BINARY_DIR}")
|
message(STATUS "Using protobuf version ${Protobuf_VERSION} | Protobuf_INCLUDE_DIRS: ${Protobuf_INCLUDE_DIRS} | CMAKE_CURRENT_BINARY_DIR: ${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
|
||||||
# Proto file
|
# Proto file
|
||||||
get_filename_component(hw_proto "../../../../../../backend/backend.proto" ABSOLUTE)
|
get_filename_component(hw_proto "../../../../../../pkg/grpc/proto/backend.proto" ABSOLUTE)
|
||||||
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
||||||
|
|
||||||
# Generated sources
|
# Generated sources
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=
|
LLAMA_VERSION?=d9b33fe95bd257b36c84ee5769cc048230067d6f
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
BUILD_TYPE?=
|
BUILD_TYPE?=
|
||||||
@@ -21,9 +21,6 @@ endif
|
|||||||
|
|
||||||
llama.cpp:
|
llama.cpp:
|
||||||
git clone --recurse-submodules https://github.com/ggerganov/llama.cpp llama.cpp
|
git clone --recurse-submodules https://github.com/ggerganov/llama.cpp llama.cpp
|
||||||
if [ -z "$(LLAMA_VERSION)" ]; then \
|
|
||||||
exit 1; \
|
|
||||||
fi
|
|
||||||
cd llama.cpp && git checkout -b build $(LLAMA_VERSION) && git submodule update --init --recursive --depth 1
|
cd llama.cpp && git checkout -b build $(LLAMA_VERSION) && git submodule update --init --recursive --depth 1
|
||||||
|
|
||||||
llama.cpp/examples/grpc-server:
|
llama.cpp/examples/grpc-server:
|
||||||
|
|||||||
@@ -40,17 +40,8 @@ using backend::HealthMessage;
|
|||||||
|
|
||||||
|
|
||||||
///// LLAMA.CPP server code below
|
///// LLAMA.CPP server code below
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
|
||||||
using json = nlohmann::json;
|
|
||||||
|
|
||||||
struct server_params
|
using json = nlohmann::json;
|
||||||
{
|
|
||||||
std::string hostname = "127.0.0.1";
|
|
||||||
std::string public_path = "examples/server/public";
|
|
||||||
int32_t port = 8080;
|
|
||||||
int32_t read_timeout = 600;
|
|
||||||
int32_t write_timeout = 600;
|
|
||||||
};
|
|
||||||
|
|
||||||
static bool server_verbose = false;
|
static bool server_verbose = false;
|
||||||
|
|
||||||
@@ -71,10 +62,6 @@ static bool server_verbose = false;
|
|||||||
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
|
|
||||||
json oaicompat_completion_params_parse(const json &body);
|
|
||||||
std::string format_chatml(std::vector<json> messages);
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// base64 utils (TODO: move to common in the future)
|
// base64 utils (TODO: move to common in the future)
|
||||||
//
|
//
|
||||||
@@ -165,23 +152,15 @@ struct task_server {
|
|||||||
json data;
|
json data;
|
||||||
bool infill_mode = false;
|
bool infill_mode = false;
|
||||||
bool embedding_mode = false;
|
bool embedding_mode = false;
|
||||||
int multitask_id = -1;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct task_result {
|
struct task_result {
|
||||||
int id;
|
int id;
|
||||||
int multitask_id = -1;
|
|
||||||
bool stop;
|
bool stop;
|
||||||
bool error;
|
bool error;
|
||||||
json result_json;
|
json result_json;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct task_multi {
|
|
||||||
int id;
|
|
||||||
std::set<int> subtasks_remaining{};
|
|
||||||
std::vector<task_result> results{};
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: can become bool if we can't find use of more states
|
// TODO: can become bool if we can't find use of more states
|
||||||
enum slot_state
|
enum slot_state
|
||||||
{
|
{
|
||||||
@@ -386,6 +365,7 @@ struct llama_client_slot
|
|||||||
|
|
||||||
int32_t num_prompt_tokens = 0;
|
int32_t num_prompt_tokens = 0;
|
||||||
int32_t num_prompt_tokens_processed = 0;
|
int32_t num_prompt_tokens_processed = 0;
|
||||||
|
int32_t multibyte_pending = 0;
|
||||||
|
|
||||||
json prompt;
|
json prompt;
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
@@ -401,9 +381,6 @@ struct llama_client_slot
|
|||||||
bool stopped_word = false;
|
bool stopped_word = false;
|
||||||
bool stopped_limit = false;
|
bool stopped_limit = false;
|
||||||
|
|
||||||
bool oaicompat = false;
|
|
||||||
std::string oaicompat_model;
|
|
||||||
|
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
@@ -423,9 +400,6 @@ struct llama_client_slot
|
|||||||
double t_prompt_processing; // ms
|
double t_prompt_processing; // ms
|
||||||
double t_token_generation; // ms
|
double t_token_generation; // ms
|
||||||
|
|
||||||
// multitasks
|
|
||||||
int multitask_id = -1;
|
|
||||||
|
|
||||||
void reset() {
|
void reset() {
|
||||||
num_prompt_tokens = 0;
|
num_prompt_tokens = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
@@ -434,6 +408,7 @@ struct llama_client_slot
|
|||||||
stopped_word = false;
|
stopped_word = false;
|
||||||
stopped_limit = false;
|
stopped_limit = false;
|
||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
|
multibyte_pending = 0;
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
sent_count = 0;
|
sent_count = 0;
|
||||||
sent_token_probs_index = 0;
|
sent_token_probs_index = 0;
|
||||||
@@ -505,7 +480,7 @@ struct llama_client_slot
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_timings() const {
|
void print_timings() {
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
|
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
|
||||||
@@ -529,7 +504,6 @@ struct llama_server_context
|
|||||||
bool multimodal = false;
|
bool multimodal = false;
|
||||||
bool clean_kv_cache = true;
|
bool clean_kv_cache = true;
|
||||||
bool all_slots_are_idle = false;
|
bool all_slots_are_idle = false;
|
||||||
bool add_bos_token = true;
|
|
||||||
|
|
||||||
int32_t id_gen;
|
int32_t id_gen;
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
@@ -548,8 +522,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
std::vector<task_server> queue_tasks;
|
std::vector<task_server> queue_tasks;
|
||||||
std::vector<task_result> queue_results;
|
std::vector<task_result> queue_results;
|
||||||
std::vector<task_multi> queue_multitasks;
|
std::mutex mutex_tasks;
|
||||||
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
|
|
||||||
std::mutex mutex_results;
|
std::mutex mutex_results;
|
||||||
|
|
||||||
~llama_server_context()
|
~llama_server_context()
|
||||||
@@ -603,8 +576,6 @@ struct llama_server_context
|
|||||||
|
|
||||||
n_ctx = llama_n_ctx(ctx);
|
n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
add_bos_token = llama_should_add_bos_token(model);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -638,11 +609,6 @@ struct llama_server_context
|
|||||||
|
|
||||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
|
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
|
||||||
{
|
{
|
||||||
// TODO: currently, we tokenize using special tokens by default
|
|
||||||
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
||||||
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
||||||
const bool TMP_FORCE_SPECIAL = true;
|
|
||||||
|
|
||||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||||
// or the first element of the json_prompt array is a string.
|
// or the first element of the json_prompt array is a string.
|
||||||
std::vector<llama_token> prompt_tokens;
|
std::vector<llama_token> prompt_tokens;
|
||||||
@@ -658,12 +624,12 @@ struct llama_server_context
|
|||||||
std::vector<llama_token> p;
|
std::vector<llama_token> p;
|
||||||
if (first)
|
if (first)
|
||||||
{
|
{
|
||||||
p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
|
p = ::llama_tokenize(ctx, s, add_bos);
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
p = ::llama_tokenize(ctx, s, false);
|
||||||
}
|
}
|
||||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||||
}
|
}
|
||||||
@@ -680,7 +646,7 @@ struct llama_server_context
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
auto s = json_prompt.template get<std::string>();
|
auto s = json_prompt.template get<std::string>();
|
||||||
prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
|
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt_tokens;
|
return prompt_tokens;
|
||||||
@@ -711,20 +677,11 @@ struct llama_server_context
|
|||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
llama_sampling_params default_sparams;
|
llama_sampling_params default_sparams;
|
||||||
|
|
||||||
if (data.count("__oaicompat") != 0) {
|
|
||||||
slot->oaicompat = true;
|
|
||||||
slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
||||||
} else {
|
|
||||||
slot->oaicompat = false;
|
|
||||||
slot->oaicompat_model = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
slot->params.stream = json_value(data, "stream", false);
|
slot->params.stream = json_value(data, "stream", false);
|
||||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||||
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||||
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
||||||
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||||
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
@@ -909,7 +866,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
void update_system_prompt() {
|
void update_system_prompt() {
|
||||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
@@ -1000,36 +957,35 @@ struct llama_server_context
|
|||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
if (slot.multibyte_pending > 0)
|
||||||
bool incomplete = false;
|
|
||||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
|
|
||||||
{
|
{
|
||||||
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
slot.multibyte_pending -= token_str.size();
|
||||||
if ((c & 0xC0) == 0x80)
|
}
|
||||||
{
|
else if (token_str.size() == 1)
|
||||||
// continuation byte: 10xxxxxx
|
{
|
||||||
continue;
|
const char c = token_str[0];
|
||||||
}
|
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||||
if ((c & 0xE0) == 0xC0)
|
if ((c & 0xE0) == 0xC0)
|
||||||
{
|
{
|
||||||
// 2-byte character: 110xxxxx ...
|
slot.multibyte_pending = 1;
|
||||||
incomplete = i < 2;
|
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||||
}
|
}
|
||||||
else if ((c & 0xF0) == 0xE0)
|
else if ((c & 0xF0) == 0xE0)
|
||||||
{
|
{
|
||||||
// 3-byte character: 1110xxxx ...
|
slot.multibyte_pending = 2;
|
||||||
incomplete = i < 3;
|
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||||
}
|
}
|
||||||
else if ((c & 0xF8) == 0xF0)
|
else if ((c & 0xF8) == 0xF0)
|
||||||
{
|
{
|
||||||
// 4-byte character: 11110xxx ...
|
slot.multibyte_pending = 3;
|
||||||
incomplete = i < 4;
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
slot.multibyte_pending = 0;
|
||||||
}
|
}
|
||||||
// else 1-byte character or invalid byte
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!incomplete)
|
if (slot.multibyte_pending == 0)
|
||||||
{
|
{
|
||||||
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
@@ -1064,7 +1020,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (incomplete)
|
if (slot.multibyte_pending > 0 && !slot.has_next_token)
|
||||||
{
|
{
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
}
|
}
|
||||||
@@ -1133,40 +1089,16 @@ struct llama_server_context
|
|||||||
return slot.images.size() > 0;
|
return slot.images.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(task_server& task, std::string error)
|
void send_error(int id, std::string error)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = task.id;
|
res.id = id;
|
||||||
res.multitask_id = task.multitask_id;
|
|
||||||
res.stop = false;
|
|
||||||
res.error = true;
|
res.error = true;
|
||||||
res.result_json = { { "content", error } };
|
res.result_json = { { "content", error } };
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_multi_task(int id, std::vector<int>& sub_ids)
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
||||||
task_multi multi;
|
|
||||||
multi.id = id;
|
|
||||||
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
|
||||||
queue_multitasks.push_back(multi);
|
|
||||||
}
|
|
||||||
|
|
||||||
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
||||||
for (auto& multitask : queue_multitasks)
|
|
||||||
{
|
|
||||||
if (multitask.id == multitask_id)
|
|
||||||
{
|
|
||||||
multitask.subtasks_remaining.erase(subtask_id);
|
|
||||||
multitask.results.push_back(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
json get_model_props()
|
json get_model_props()
|
||||||
{
|
{
|
||||||
return get_formated_generation(slots[0]);
|
return get_formated_generation(slots[0]);
|
||||||
@@ -1184,7 +1116,6 @@ struct llama_server_context
|
|||||||
{"temp", slot.sparams.temp},
|
{"temp", slot.sparams.temp},
|
||||||
{"top_k", slot.sparams.top_k},
|
{"top_k", slot.sparams.top_k},
|
||||||
{"top_p", slot.sparams.top_p},
|
{"top_p", slot.sparams.top_p},
|
||||||
{"min_p", slot.sparams.min_p},
|
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
{"tfs_z", slot.sparams.tfs_z},
|
||||||
{"typical_p", slot.sparams.typical_p},
|
{"typical_p", slot.sparams.typical_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
@@ -1211,7 +1142,6 @@ struct llama_server_context
|
|||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
res.multitask_id = slot.multitask_id;
|
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = false;
|
res.stop = false;
|
||||||
|
|
||||||
@@ -1237,12 +1167,6 @@ struct llama_server_context
|
|||||||
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.oaicompat)
|
|
||||||
{
|
|
||||||
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
|
|
||||||
res.result_json["model"] = slot.oaicompat_model;
|
|
||||||
}
|
|
||||||
|
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1251,7 +1175,6 @@ struct llama_server_context
|
|||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
res.multitask_id = slot.multitask_id;
|
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
@@ -1291,18 +1214,6 @@ struct llama_server_context
|
|||||||
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.oaicompat)
|
|
||||||
{
|
|
||||||
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
|
|
||||||
res.result_json["model"] = slot.oaicompat_model;
|
|
||||||
}
|
|
||||||
|
|
||||||
// parent multitask, if any, needs to be updated
|
|
||||||
if (slot.multitask_id != -1)
|
|
||||||
{
|
|
||||||
update_multi_task(slot.multitask_id, slot.task_id, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1311,7 +1222,6 @@ struct llama_server_context
|
|||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
res.multitask_id = slot.multitask_id;
|
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
@@ -1338,26 +1248,15 @@ struct llama_server_context
|
|||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
int request_completion(json data, bool infill, bool embedding, int multitask_id)
|
int request_completion(json data, bool infill, bool embedding)
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = id_gen++;
|
task.id = id_gen++;
|
||||||
task.target_id = 0;
|
task.data = data;
|
||||||
task.data = std::move(data);
|
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
task.type = COMPLETION_TASK;
|
task.type = COMPLETION_TASK;
|
||||||
task.multitask_id = multitask_id;
|
|
||||||
|
|
||||||
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
||||||
if (task.data.at("prompt").size() > 1)
|
|
||||||
{
|
|
||||||
lock.unlock(); // entering new func scope
|
|
||||||
return split_multiprompt_task(task);
|
|
||||||
}
|
|
||||||
|
|
||||||
// otherwise, it's a single-prompt task, we actually queue it
|
|
||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
return task.id;
|
return task.id;
|
||||||
}
|
}
|
||||||
@@ -1376,17 +1275,8 @@ struct llama_server_context
|
|||||||
|
|
||||||
for (int i = 0; i < (int) queue_results.size(); i++)
|
for (int i = 0; i < (int) queue_results.size(); i++)
|
||||||
{
|
{
|
||||||
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
|
||||||
if (queue_results[i].multitask_id == task_id)
|
|
||||||
{
|
|
||||||
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
|
|
||||||
queue_results.erase(queue_results.begin() + i);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queue_results[i].id == task_id)
|
if (queue_results[i].id == task_id)
|
||||||
{
|
{
|
||||||
assert(queue_results[i].multitask_id == -1);
|
|
||||||
task_result res = queue_results[i];
|
task_result res = queue_results[i];
|
||||||
queue_results.erase(queue_results.begin() + i);
|
queue_results.erase(queue_results.begin() + i);
|
||||||
return res;
|
return res;
|
||||||
@@ -1476,27 +1366,6 @@ struct llama_server_context
|
|||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
int split_multiprompt_task(task_server& multiprompt_task)
|
|
||||||
{
|
|
||||||
int prompt_count = multiprompt_task.data.at("prompt").size();
|
|
||||||
assert(prompt_count > 1);
|
|
||||||
|
|
||||||
int multitask_id = id_gen++;
|
|
||||||
std::vector<int> subtask_ids(prompt_count);
|
|
||||||
for (int i = 0; i < prompt_count; i++)
|
|
||||||
{
|
|
||||||
json subtask_data = multiprompt_task.data;
|
|
||||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
|
||||||
|
|
||||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
|
||||||
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// queue up the multitask so we can track its subtask progression
|
|
||||||
add_multi_task(multitask_id, subtask_ids);
|
|
||||||
return multitask_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
void process_tasks()
|
void process_tasks()
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
@@ -1512,7 +1381,7 @@ struct llama_server_context
|
|||||||
{
|
{
|
||||||
LOG_TEE("slot unavailable\n");
|
LOG_TEE("slot unavailable\n");
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task, "slot unavailable");
|
send_error(task.id, "slot unavailable");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1526,12 +1395,11 @@ struct llama_server_context
|
|||||||
slot->infill = task.infill_mode;
|
slot->infill = task.infill_mode;
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
slot->multitask_id = task.multitask_id;
|
|
||||||
|
|
||||||
if (!launch_slot_with_data(slot, task.data))
|
if (!launch_slot_with_data(slot, task.data))
|
||||||
{
|
{
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task, "internal_error");
|
send_error(task.id, "internal_error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@@ -1547,38 +1415,6 @@ struct llama_server_context
|
|||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
|
||||||
auto queue_iterator = queue_multitasks.begin();
|
|
||||||
while (queue_iterator != queue_multitasks.end())
|
|
||||||
{
|
|
||||||
if (queue_iterator->subtasks_remaining.empty())
|
|
||||||
{
|
|
||||||
// all subtasks done == multitask is done
|
|
||||||
task_result aggregate_result;
|
|
||||||
aggregate_result.id = queue_iterator->id;
|
|
||||||
aggregate_result.stop = true;
|
|
||||||
aggregate_result.error = false;
|
|
||||||
|
|
||||||
// collect json results into one json result
|
|
||||||
std::vector<json> result_jsons;
|
|
||||||
for (auto& subres : queue_iterator->results)
|
|
||||||
{
|
|
||||||
result_jsons.push_back(subres.result_json);
|
|
||||||
aggregate_result.error = aggregate_result.error && subres.error;
|
|
||||||
}
|
|
||||||
aggregate_result.result_json = json{ "results", result_jsons };
|
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
|
||||||
queue_results.push_back(aggregate_result);
|
|
||||||
|
|
||||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
++queue_iterator;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
@@ -1717,40 +1553,11 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
if (slot.params.n_keep < 0)
|
|
||||||
{
|
|
||||||
slot.params.n_keep = slot.num_prompt_tokens;
|
|
||||||
}
|
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
||||||
|
|
||||||
// if input prompt is too big, truncate it
|
|
||||||
if (slot.num_prompt_tokens >= slot.n_ctx)
|
|
||||||
{
|
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
||||||
const int n_block_size = n_left / 2;
|
|
||||||
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
|
||||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
|
||||||
{"n_ctx", slot.n_ctx},
|
|
||||||
{"n_keep", slot.params.n_keep},
|
|
||||||
{"n_left", n_left},
|
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
||||||
});
|
|
||||||
slot.truncated = true;
|
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
|
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
|
||||||
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
@@ -1760,6 +1567,35 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
if (slot.params.n_keep < 0)
|
||||||
|
{
|
||||||
|
slot.params.n_keep = slot.num_prompt_tokens;
|
||||||
|
}
|
||||||
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
|
// if input prompt is too big, truncate it
|
||||||
|
if (slot.num_prompt_tokens >= slot.n_ctx)
|
||||||
|
{
|
||||||
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
const int n_block_size = n_left / 2;
|
||||||
|
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
||||||
|
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||||
|
|
||||||
|
LOG_VERBOSE("input truncated", {
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||||
|
});
|
||||||
|
slot.truncated = true;
|
||||||
|
prompt_tokens = new_tokens;
|
||||||
|
|
||||||
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (auto &token : prompt_tokens)
|
for (auto &token : prompt_tokens)
|
||||||
{
|
{
|
||||||
@@ -1794,7 +1630,7 @@ struct llama_server_context
|
|||||||
const bool has_images = process_images(slot);
|
const bool has_images = process_images(slot);
|
||||||
|
|
||||||
// process the prefix of first image
|
// process the prefix of first image
|
||||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
|
||||||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
||||||
{
|
{
|
||||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
|
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
|
||||||
@@ -1914,231 +1750,6 @@ struct llama_server_context
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
static std::string random_string()
|
|
||||||
{
|
|
||||||
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
|
|
||||||
|
|
||||||
std::random_device rd;
|
|
||||||
std::mt19937 generator(rd());
|
|
||||||
|
|
||||||
std::string result(32, ' ');
|
|
||||||
|
|
||||||
for (int i = 0; i < 32; ++i) {
|
|
||||||
result[i] = str[generator() % str.size()];
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string gen_chatcmplid()
|
|
||||||
{
|
|
||||||
std::stringstream chatcmplid;
|
|
||||||
chatcmplid << "chatcmpl-" << random_string();
|
|
||||||
return chatcmplid.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string format_chatml(std::vector<json> messages)
|
|
||||||
{
|
|
||||||
std::ostringstream chatml_msgs;
|
|
||||||
|
|
||||||
for (auto it = messages.begin(); it != messages.end(); ++it) {
|
|
||||||
chatml_msgs << "<|im_start|>"
|
|
||||||
<< json_value(*it, "role", std::string("user")) << '\n';
|
|
||||||
chatml_msgs << json_value(*it, "content", std::string(""))
|
|
||||||
<< "<|im_end|>\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
chatml_msgs << "<|im_start|>assistant" << '\n';
|
|
||||||
|
|
||||||
return chatml_msgs.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* llama.cpp completion api semantics */
|
|
||||||
json oaicompat_completion_params_parse(
|
|
||||||
const json &body /* openai api json semantics */)
|
|
||||||
{
|
|
||||||
json llama_params;
|
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
|
||||||
|
|
||||||
// Map OpenAI parameters to llama.cpp parameters
|
|
||||||
llama_params["model"] = json_value(body, "model", std::string("uknown"));
|
|
||||||
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
|
||||||
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
|
|
||||||
llama_params["temperature"] = json_value(body, "temperature", 0.8);
|
|
||||||
llama_params["top_k"] = json_value(body, "top_k", 40);
|
|
||||||
llama_params["top_p"] = json_value(body, "top_p", 0.95);
|
|
||||||
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
|
|
||||||
llama_params["logit_bias"] = json_value(body, "logit_bias",json::object());
|
|
||||||
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
|
|
||||||
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
|
|
||||||
llama_params["seed"] = json_value(body, "seed", 0);
|
|
||||||
llama_params["stream"] = json_value(body, "stream", false);
|
|
||||||
llama_params["mirostat"] = json_value(body, "mirostat", false);
|
|
||||||
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", 0.0);
|
|
||||||
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", 0.0);
|
|
||||||
llama_params["penalize_nl"] = json_value(body, "penalize_nl", false);
|
|
||||||
llama_params["typical_p"] = json_value(body, "typical_p", 0.0);
|
|
||||||
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", 0);
|
|
||||||
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
|
|
||||||
llama_params["tfs_z"] = json_value(body, "tfs_z", 0.0);
|
|
||||||
|
|
||||||
if (llama_params.count("grammar") != 0) {
|
|
||||||
llama_params["grammar"] = json_value(body, "grammar", json::object());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle 'stop' field
|
|
||||||
if (body.contains("stop") && body["stop"].is_string()) {
|
|
||||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
|
||||||
} else {
|
|
||||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure there is ChatML-specific end sequence among stop words
|
|
||||||
llama_params["stop"].push_back("<|im_end|>");
|
|
||||||
|
|
||||||
return llama_params;
|
|
||||||
}
|
|
||||||
|
|
||||||
static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
|
|
||||||
{
|
|
||||||
json result = response.result_json;
|
|
||||||
|
|
||||||
bool stopped_word = result.count("stopped_word") != 0;
|
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
|
||||||
if (stopped_word || stopped_eos) {
|
|
||||||
finish_reason = "stop";
|
|
||||||
}
|
|
||||||
|
|
||||||
json choices =
|
|
||||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json::object()}}})
|
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
|
||||||
{"index", 0},
|
|
||||||
{"message", json{{"content", content},
|
|
||||||
{"role", "assistant"}}}}});
|
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
|
||||||
|
|
||||||
json res =
|
|
||||||
json{{"choices", choices},
|
|
||||||
{"created", t},
|
|
||||||
{"model",
|
|
||||||
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
||||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
|
||||||
{"usage",
|
|
||||||
json{{"completion_tokens", num_tokens_predicted},
|
|
||||||
{"prompt_tokens", num_prompt_tokens},
|
|
||||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
|
|
||||||
{"id", gen_chatcmplid()}};
|
|
||||||
|
|
||||||
if (server_verbose) {
|
|
||||||
res["__verbose"] = result;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result.contains("completion_probabilities")) {
|
|
||||||
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// return value is vector as there is one case where we might need to generate two responses
|
|
||||||
static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
|
|
||||||
json result = response.result_json;
|
|
||||||
|
|
||||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
|
||||||
return std::vector<json>({response.result_json});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
|
||||||
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
||||||
|
|
||||||
bool stopped_word = json_value(result, "stopped_word", false);
|
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
||||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
|
||||||
|
|
||||||
std::string finish_reason;
|
|
||||||
if (stopped_word || stopped_eos) {
|
|
||||||
finish_reason = "stop";
|
|
||||||
}
|
|
||||||
if (stopped_limit) {
|
|
||||||
finish_reason = "length";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
|
||||||
|
|
||||||
json choices;
|
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
|
||||||
choices = json::array({json{{"finish_reason", finish_reason},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json::object()}}});
|
|
||||||
} else {
|
|
||||||
if (first) {
|
|
||||||
if (content.empty()) {
|
|
||||||
choices = json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
|
||||||
} else {
|
|
||||||
// We have to send this as two updates to conform to openai behavior
|
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
|
||||||
{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{
|
|
||||||
{"role", "assistant"}
|
|
||||||
}}}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", gen_chatcmplid()},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
json second_ret = json{
|
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{
|
|
||||||
{"content", content}}}
|
|
||||||
}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", gen_chatcmplid()},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Some idiosyncrasy in task processing logic makes several trailing calls
|
|
||||||
// with empty content, we ignore these at the calee site.
|
|
||||||
if (content.empty()) {
|
|
||||||
return std::vector<json>({json::object()});
|
|
||||||
}
|
|
||||||
|
|
||||||
choices = json::array({json{
|
|
||||||
{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta",
|
|
||||||
json{
|
|
||||||
{"content", content},
|
|
||||||
}},
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
json ret = json{{"choices", choices},
|
|
||||||
{"created", t},
|
|
||||||
{"id", gen_chatcmplid()},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
return std::vector<json>({ret});
|
|
||||||
}
|
|
||||||
|
|
||||||
static json format_partial_response(
|
static json format_partial_response(
|
||||||
llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
|
llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
|
||||||
@@ -2171,6 +1782,8 @@ static json format_detokenized_response(std::string content)
|
|||||||
{"content", content}};
|
{"content", content}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct token_translator
|
struct token_translator
|
||||||
{
|
{
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
@@ -2366,7 +1979,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
// params.model_alias ??
|
// params.model_alias ??
|
||||||
params.model_alias = request->modelfile();
|
params.model_alias = request->modelfile();
|
||||||
params.n_ctx = request->contextsize();
|
params.n_ctx = request->contextsize();
|
||||||
//params.memory_f16 = request->f16memory();
|
params.memory_f16 = request->f16memory();
|
||||||
params.n_threads = request->threads();
|
params.n_threads = request->threads();
|
||||||
params.n_gpu_layers = request->ngpulayers();
|
params.n_gpu_layers = request->ngpulayers();
|
||||||
params.n_batch = request->nbatch();
|
params.n_batch = request->nbatch();
|
||||||
@@ -2473,7 +2086,7 @@ public:
|
|||||||
}
|
}
|
||||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||||
json data = parse_options(true, request, llama);
|
json data = parse_options(true, request, llama);
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false);
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
@@ -2509,7 +2122,7 @@ public:
|
|||||||
|
|
||||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||||
json data = parse_options(false, request, llama);
|
json data = parse_options(false, request, llama);
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false);
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
if (!result.error && result.stop) {
|
if (!result.error && result.stop) {
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// This is a wrapper to statisfy the GRPC service interface
|
|
||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
|
||||||
import (
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/tinydream"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
base.SingleThread
|
|
||||||
tinydream *tinydream.TinyDream
|
|
||||||
}
|
|
||||||
|
|
||||||
func (image *Image) Load(opts *pb.ModelOptions) error {
|
|
||||||
var err error
|
|
||||||
// Note: the Model here is a path to a directory containing the model files
|
|
||||||
image.tinydream, err = tinydream.New(opts.ModelFile)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (image *Image) GenerateImage(opts *pb.GenerateImageRequest) error {
|
|
||||||
return image.tinydream.GenerateImage(
|
|
||||||
int(opts.Height),
|
|
||||||
int(opts.Width),
|
|
||||||
int(opts.Step),
|
|
||||||
int(opts.Seed),
|
|
||||||
opts.PositivePrompt,
|
|
||||||
opts.NegativePrompt,
|
|
||||||
opts.Dst)
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &LLM{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &LLM{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &Whisper{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &Piper{}); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,15 +0,0 @@
|
|||||||
.PHONY: ttsbark
|
|
||||||
ttsbark:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running bark..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "bark run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing bark..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "bark tested."
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "ttsbark.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="v2/en_speaker_4"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="v2/en_speaker_4"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the bark server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test.py
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
CONDA_ENV_PATH = "transformers.yml"
|
|
||||||
|
|
||||||
ifeq ($(BUILD_TYPE), cublas)
|
|
||||||
CONDA_ENV_PATH = "transformers-nvidia.yml"
|
|
||||||
endif
|
|
||||||
|
|
||||||
.PHONY: transformers
|
|
||||||
transformers:
|
|
||||||
@echo "Installing $(CONDA_ENV_PATH)..."
|
|
||||||
bash install.sh $(CONDA_ENV_PATH)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# Check if environment exist
|
|
||||||
conda_env_exists(){
|
|
||||||
! conda list --name "${@}" >/dev/null 2>/dev/null
|
|
||||||
}
|
|
||||||
|
|
||||||
if conda_env_exists "transformers" ; then
|
|
||||||
echo "Creating virtual environment..."
|
|
||||||
conda env create --name transformers --file $1
|
|
||||||
echo "Virtual environment created."
|
|
||||||
else
|
|
||||||
echo "Virtual environment already exists."
|
|
||||||
fi
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
name: transformers
|
|
||||||
channels:
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex=0.1=main
|
|
||||||
- _openmp_mutex=5.1=1_gnu
|
|
||||||
- bzip2=1.0.8=h7b6447c_0
|
|
||||||
- ca-certificates=2023.08.22=h06a4308_0
|
|
||||||
- ld_impl_linux-64=2.38=h1181459_1
|
|
||||||
- libffi=3.4.4=h6a678d5_0
|
|
||||||
- libgcc-ng=11.2.0=h1234567_1
|
|
||||||
- libgomp=11.2.0=h1234567_1
|
|
||||||
- libstdcxx-ng=11.2.0=h1234567_1
|
|
||||||
- libuuid=1.41.5=h5eee18b_0
|
|
||||||
- ncurses=6.4=h6a678d5_0
|
|
||||||
- openssl=3.0.11=h7f8727e_2
|
|
||||||
- pip=23.2.1=py311h06a4308_0
|
|
||||||
- python=3.11.5=h955ad1f_0
|
|
||||||
- readline=8.2=h5eee18b_0
|
|
||||||
- setuptools=68.0.0=py311h06a4308_0
|
|
||||||
- sqlite=3.41.2=h5eee18b_0
|
|
||||||
- tk=8.6.12=h1ccaba5_0
|
|
||||||
- wheel=0.41.2=py311h06a4308_0
|
|
||||||
- xz=5.4.2=h5eee18b_0
|
|
||||||
- zlib=1.2.13=h5eee18b_0
|
|
||||||
- pip:
|
|
||||||
- accelerate==0.23.0
|
|
||||||
- aiohttp==3.8.5
|
|
||||||
- aiosignal==1.3.1
|
|
||||||
- async-timeout==4.0.3
|
|
||||||
- attrs==23.1.0
|
|
||||||
- bark==0.1.5
|
|
||||||
- boto3==1.28.61
|
|
||||||
- botocore==1.31.61
|
|
||||||
- certifi==2023.7.22
|
|
||||||
- TTS==0.22.0
|
|
||||||
- charset-normalizer==3.3.0
|
|
||||||
- datasets==2.14.5
|
|
||||||
- sentence-transformers==2.2.2
|
|
||||||
- sentencepiece==0.1.99
|
|
||||||
- dill==0.3.7
|
|
||||||
- einops==0.7.0
|
|
||||||
- encodec==0.1.1
|
|
||||||
- filelock==3.12.4
|
|
||||||
- frozenlist==1.4.0
|
|
||||||
- fsspec==2023.6.0
|
|
||||||
- funcy==2.0
|
|
||||||
- grpcio==1.59.0
|
|
||||||
- huggingface-hub==0.16.4
|
|
||||||
- idna==3.4
|
|
||||||
- jinja2==3.1.2
|
|
||||||
- jmespath==1.0.1
|
|
||||||
- markupsafe==2.1.3
|
|
||||||
- mpmath==1.3.0
|
|
||||||
- multidict==6.0.4
|
|
||||||
- multiprocess==0.70.15
|
|
||||||
- networkx
|
|
||||||
- numpy==1.26.0
|
|
||||||
- packaging==23.2
|
|
||||||
- pandas
|
|
||||||
- peft==0.5.0
|
|
||||||
- git+https://github.com/bigscience-workshop/petals
|
|
||||||
- protobuf==4.24.4
|
|
||||||
- psutil==5.9.5
|
|
||||||
- pyarrow==13.0.0
|
|
||||||
- python-dateutil==2.8.2
|
|
||||||
- pytz==2023.3.post1
|
|
||||||
- pyyaml==6.0.1
|
|
||||||
- regex==2023.10.3
|
|
||||||
- requests==2.31.0
|
|
||||||
- rouge==1.0.1
|
|
||||||
- s3transfer==0.7.0
|
|
||||||
- safetensors==0.3.3
|
|
||||||
- scipy==1.11.3
|
|
||||||
- six==1.16.0
|
|
||||||
- sympy==1.12
|
|
||||||
- tokenizers==0.14.0
|
|
||||||
- torch==2.1.0
|
|
||||||
- torchaudio==2.1.0
|
|
||||||
- tqdm==4.66.1
|
|
||||||
- transformers==4.34.0
|
|
||||||
- triton==2.1.0
|
|
||||||
- typing-extensions==4.8.0
|
|
||||||
- tzdata==2023.3
|
|
||||||
- urllib3==1.26.17
|
|
||||||
- xxhash==3.4.1
|
|
||||||
- yarl==1.9.2
|
|
||||||
- soundfile
|
|
||||||
- langid
|
|
||||||
- wget
|
|
||||||
- unidecode
|
|
||||||
- pyopenjtalk-prebuilt
|
|
||||||
- pypinyin
|
|
||||||
- inflect
|
|
||||||
- cn2an
|
|
||||||
- jieba
|
|
||||||
- eng_to_ipa
|
|
||||||
- openai-whisper
|
|
||||||
- matplotlib
|
|
||||||
- gradio==3.41.2
|
|
||||||
- nltk
|
|
||||||
- sudachipy
|
|
||||||
- sudachidict_core
|
|
||||||
- vocos
|
|
||||||
prefix: /opt/conda/envs/transformers
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
.PHONY: coqui
|
|
||||||
coqui:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running coqui..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "coqui run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing coqui..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "coqui tested."
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
# Creating a separate environment for ttsbark project
|
|
||||||
|
|
||||||
```
|
|
||||||
make coqui
|
|
||||||
```
|
|
||||||
|
|
||||||
# Testing the gRPC server
|
|
||||||
|
|
||||||
```
|
|
||||||
make test
|
|
||||||
```
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,97 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
This is an extra gRPC server of LocalAI for Bark TTS
|
|
||||||
"""
|
|
||||||
from concurrent import futures
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from TTS.api import TTS
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', 'en')
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
BackendServicer is the class that implements the gRPC service
|
|
||||||
"""
|
|
||||||
def Health(self, request, context):
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
|
|
||||||
# Get device
|
|
||||||
device = "cuda" if request.CUDA else "cpu"
|
|
||||||
|
|
||||||
if not torch.cuda.is_available() and request.CUDA:
|
|
||||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
|
||||||
|
|
||||||
# List available 🐸TTS models
|
|
||||||
print(TTS().list_models())
|
|
||||||
if os.path.isabs(request.AudioPath):
|
|
||||||
self.AudioPath = request.AudioPath
|
|
||||||
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
|
||||||
# get base path of modelFile
|
|
||||||
modelFileBase = os.path.dirname(request.ModelFile)
|
|
||||||
# modify LoraAdapter to be relative to modelFileBase
|
|
||||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
|
||||||
|
|
||||||
try:
|
|
||||||
print("Preparing models, please wait", file=sys.stderr)
|
|
||||||
self.tts = TTS(request.Model).to(device)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
# Implement your logic here for the LoadModel service
|
|
||||||
# Replace this with your desired response
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
try:
|
|
||||||
self.tts.tts_to_file(text=request.text, speaker_wav=self.AudioPath, language=COQUI_LANGUAGE, file_path=request.dst)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(success=True)
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "coqui_server.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
|
||||||
print(response)
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,84 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "backend_diffusers.py", "--addr", "localhost:50051"])
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.kill()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test(self):
|
|
||||||
"""
|
|
||||||
This method tests if the backend can generate images
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
|
||||||
print(response.message)
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
|
||||||
re = stub.GenerateImage(image_req)
|
|
||||||
self.assertTrue(re.success)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Image gen service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the diffusers server with conda
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate diffusers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test.py
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,15 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate exllama
|
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
|
||||||
|
|
||||||
|
|
||||||
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd
|
|
||||||
|
|
||||||
cp -rfv $CONDA_PREFIX/exllama/* ./
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
.PHONY: exllama2
|
|
||||||
exllama2:
|
|
||||||
@echo "Creating virtual environment..."
|
|
||||||
@conda env create --name exllama2 --file exllama2.yml
|
|
||||||
@echo "Virtual environment created."
|
|
||||||
bash install.sh
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running exllama2..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "exllama2 run."
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,138 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import grpc
|
|
||||||
from concurrent import futures
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import version as torch_version
|
|
||||||
|
|
||||||
|
|
||||||
from exllamav2.generator import (
|
|
||||||
ExLlamaV2BaseGenerator,
|
|
||||||
ExLlamaV2Sampler
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from exllamav2 import (
|
|
||||||
ExLlamaV2,
|
|
||||||
ExLlamaV2Config,
|
|
||||||
ExLlamaV2Cache,
|
|
||||||
ExLlamaV2Cache_8bit,
|
|
||||||
ExLlamaV2Tokenizer,
|
|
||||||
model_init,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
def Health(self, request, context):
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
try:
|
|
||||||
model_directory = request.ModelFile
|
|
||||||
|
|
||||||
config = ExLlamaV2Config()
|
|
||||||
config.model_dir = model_directory
|
|
||||||
config.prepare()
|
|
||||||
|
|
||||||
model = ExLlamaV2(config)
|
|
||||||
|
|
||||||
cache = ExLlamaV2Cache(model, lazy=True)
|
|
||||||
model.load_autosplit(cache)
|
|
||||||
|
|
||||||
tokenizer = ExLlamaV2Tokenizer(config)
|
|
||||||
|
|
||||||
# Initialize generator
|
|
||||||
|
|
||||||
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
|
||||||
|
|
||||||
self.generator = generator
|
|
||||||
|
|
||||||
generator.warmup()
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.cache = cache
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
|
|
||||||
penalty = 1.15
|
|
||||||
if request.Penalty != 0.0:
|
|
||||||
penalty = request.Penalty
|
|
||||||
|
|
||||||
settings = ExLlamaV2Sampler.Settings()
|
|
||||||
settings.temperature = request.Temperature
|
|
||||||
settings.top_k = request.TopK
|
|
||||||
settings.top_p = request.TopP
|
|
||||||
settings.token_repetition_penalty = penalty
|
|
||||||
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
|
||||||
tokens = 512
|
|
||||||
|
|
||||||
if request.Tokens != 0:
|
|
||||||
tokens = request.Tokens
|
|
||||||
output = self.generator.generate_simple(
|
|
||||||
request.Prompt, settings, tokens)
|
|
||||||
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in output:
|
|
||||||
output = output.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Result(message=bytes(output, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
# Implement PredictStream RPC
|
|
||||||
# for reply in some_data_generator():
|
|
||||||
# yield reply
|
|
||||||
# Not implemented yet
|
|
||||||
return self.Predict(request, context)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate exllama2
|
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
|
||||||
|
|
||||||
git clone https://github.com/turboderp/exllamav2 $CONDA_PREFIX/exllamav2 && pushd $CONDA_PREFIX/exllamav2 && pip install -r requirements.txt && popd
|
|
||||||
|
|
||||||
cp -rfv $CONDA_PREFIX/exllamav2/* ./
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the exllama server with conda
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate exllama2
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python $DIR/exllama2_backend.py $@
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
.PHONY: petals
|
|
||||||
petals:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running petals..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "petals run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing petals..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "petals tested."
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,140 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
from concurrent import futures
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from petals import AutoDistributedModelForCausalLM
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
|
||||||
"""
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""
|
|
||||||
Returns a health check message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The health check request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Reply: The health check reply.
|
|
||||||
"""
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""
|
|
||||||
Loads a language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The load model request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Result: The load model result.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=False, add_bos_token=False)
|
|
||||||
self.model = AutoDistributedModelForCausalLM.from_pretrained(request.Model)
|
|
||||||
self.cuda = False
|
|
||||||
if request.CUDA:
|
|
||||||
self.model = self.model.cuda()
|
|
||||||
self.cuda = True
|
|
||||||
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""
|
|
||||||
Generates text based on the given prompt and sampling parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The predict request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Result: The predict result.
|
|
||||||
"""
|
|
||||||
|
|
||||||
inputs = self.tokenizer(request.Prompt, return_tensors="pt")["input_ids"]
|
|
||||||
if self.cuda:
|
|
||||||
inputs = inputs.cuda()
|
|
||||||
|
|
||||||
if request.Tokens == 0:
|
|
||||||
# Max to max value if tokens are not specified
|
|
||||||
request.Tokens = 8192
|
|
||||||
|
|
||||||
# TODO: kwargs and map all parameters
|
|
||||||
outputs = self.model.generate(inputs, max_new_tokens=request.Tokens)
|
|
||||||
|
|
||||||
generated_text = self.tokenizer.decode(outputs[0])
|
|
||||||
# Remove prompt from response if present
|
|
||||||
if request.Prompt in generated_text:
|
|
||||||
generated_text = generated_text.replace(request.Prompt, "")
|
|
||||||
|
|
||||||
return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""
|
|
||||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The predict stream request.
|
|
||||||
context: The gRPC context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
backend_pb2.Result: The predict stream result.
|
|
||||||
"""
|
|
||||||
# Implement PredictStream RPC
|
|
||||||
#for reply in some_data_generator():
|
|
||||||
# yield reply
|
|
||||||
# Not implemented yet
|
|
||||||
return self.Predict(request, context)
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
name: petals
|
|
||||||
channels:
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
# - _libgcc_mutex=0.1=main
|
|
||||||
# - _openmp_mutex=5.1=1_gnu
|
|
||||||
# - bzip2=1.0.8=h7b6447c_0
|
|
||||||
# - ca-certificates=2023.08.22=h06a4308_0
|
|
||||||
# - ld_impl_linux-64=2.38=h1181459_1
|
|
||||||
# - libffi=3.4.4=h6a678d5_0
|
|
||||||
# - libgcc-ng=11.2.0=h1234567_1
|
|
||||||
# - libgomp=11.2.0=h1234567_1
|
|
||||||
# - libstdcxx-ng=11.2.0=h1234567_1
|
|
||||||
# - libuuid=1.41.5=h5eee18b_0
|
|
||||||
# - ncurses=6.4=h6a678d5_0
|
|
||||||
# - openssl=3.0.11=h7f8727e_2
|
|
||||||
# - pip=23.2.1=py311h06a4308_0
|
|
||||||
- python=3.11.5=h955ad1f_0
|
|
||||||
# - readline=8.2=h5eee18b_0
|
|
||||||
# - setuptools=68.0.0=py311h06a4308_0
|
|
||||||
# - sqlite=3.41.2=h5eee18b_0
|
|
||||||
# - tk=8.6.12=h1ccaba5_0
|
|
||||||
# - tzdata=2023c=h04d1e81_0
|
|
||||||
# - wheel=0.41.2=py311h06a4308_0
|
|
||||||
# - xz=5.4.2=h5eee18b_0
|
|
||||||
# - zlib=1.2.13=h5eee18b_0
|
|
||||||
- pip:
|
|
||||||
- torch==2.1.0
|
|
||||||
- git+https://github.com/bigscience-workshop/petals
|
|
||||||
prefix: /opt/conda/envs/petals
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the exllama server with conda
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
# if source is available use it, or use conda
|
|
||||||
#
|
|
||||||
if [ -f /opt/conda/bin/activate ]; then
|
|
||||||
source activate transformers
|
|
||||||
else
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
conda activate transformers
|
|
||||||
fi
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python $DIR/backend_petals.py $@
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test_petals.py
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
.PHONY: sentencetransformers
|
|
||||||
sentencetransformers:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running sentencetransformers..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "sentencetransformers run."
|
|
||||||
|
|
||||||
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing sentencetransformers..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "sentencetransformers tested."
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# Creating a separate environment for the sentencetransformers project
|
|
||||||
|
|
||||||
```
|
|
||||||
make sentencetransformers
|
|
||||||
```
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,363 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import backend_pb2 as backend__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class BackendStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Health = channel.unary_unary(
|
|
||||||
'/backend.Backend/Health',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Predict = channel.unary_unary(
|
|
||||||
'/backend.Backend/Predict',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.LoadModel = channel.unary_unary(
|
|
||||||
'/backend.Backend/LoadModel',
|
|
||||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.PredictStream = channel.unary_stream(
|
|
||||||
'/backend.Backend/PredictStream',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Embedding = channel.unary_unary(
|
|
||||||
'/backend.Backend/Embedding',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
|
||||||
)
|
|
||||||
self.GenerateImage = channel.unary_unary(
|
|
||||||
'/backend.Backend/GenerateImage',
|
|
||||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.AudioTranscription = channel.unary_unary(
|
|
||||||
'/backend.Backend/AudioTranscription',
|
|
||||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
|
||||||
)
|
|
||||||
self.TTS = channel.unary_unary(
|
|
||||||
'/backend.Backend/TTS',
|
|
||||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.TokenizeString = channel.unary_unary(
|
|
||||||
'/backend.Backend/TokenizeString',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
|
||||||
)
|
|
||||||
self.Status = channel.unary_unary(
|
|
||||||
'/backend.Backend/Status',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackendServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GenerateImage(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def AudioTranscription(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TokenizeString(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_BackendServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Health': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Health,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Predict,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.LoadModel,
|
|
||||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.PredictStream,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Embedding,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GenerateImage,
|
|
||||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.AudioTranscription,
|
|
||||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TTS,
|
|
||||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TokenizeString,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'Status': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Status,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'backend.Backend', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class Backend(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Health(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Predict(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LoadModel(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
|
||||||
backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PredictStream(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Embedding(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.EmbeddingResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GenerateImage(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
|
||||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def AudioTranscription(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
|
||||||
backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
backend__pb2.TranscriptResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TTS(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
|
||||||
backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TokenizeString(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.TokenizationResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Status(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.StatusResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the sentencetransformers server with conda
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python $DIR/sentencetransformers.py $@
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the sentencetransformers server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test_sentencetransformers.py
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
|
|
||||||
.PHONY: transformers-musicgen
|
|
||||||
transformers-musicgen:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running transformers..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "transformers run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing transformers..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "transformers tested."
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# Creating a separate environment for the transformers project
|
|
||||||
|
|
||||||
```
|
|
||||||
make transformers-musicgen
|
|
||||||
```
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,363 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import backend_pb2 as backend__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class BackendStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Health = channel.unary_unary(
|
|
||||||
'/backend.Backend/Health',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Predict = channel.unary_unary(
|
|
||||||
'/backend.Backend/Predict',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.LoadModel = channel.unary_unary(
|
|
||||||
'/backend.Backend/LoadModel',
|
|
||||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.PredictStream = channel.unary_stream(
|
|
||||||
'/backend.Backend/PredictStream',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Embedding = channel.unary_unary(
|
|
||||||
'/backend.Backend/Embedding',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
|
||||||
)
|
|
||||||
self.GenerateImage = channel.unary_unary(
|
|
||||||
'/backend.Backend/GenerateImage',
|
|
||||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.AudioTranscription = channel.unary_unary(
|
|
||||||
'/backend.Backend/AudioTranscription',
|
|
||||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
|
||||||
)
|
|
||||||
self.TTS = channel.unary_unary(
|
|
||||||
'/backend.Backend/TTS',
|
|
||||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.TokenizeString = channel.unary_unary(
|
|
||||||
'/backend.Backend/TokenizeString',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
|
||||||
)
|
|
||||||
self.Status = channel.unary_unary(
|
|
||||||
'/backend.Backend/Status',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackendServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GenerateImage(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def AudioTranscription(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TokenizeString(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_BackendServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Health': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Health,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Predict,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.LoadModel,
|
|
||||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.PredictStream,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Embedding,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GenerateImage,
|
|
||||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.AudioTranscription,
|
|
||||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TTS,
|
|
||||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TokenizeString,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'Status': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Status,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'backend.Backend', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class Backend(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Health(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Predict(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LoadModel(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
|
||||||
backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PredictStream(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Embedding(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.EmbeddingResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GenerateImage(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
|
||||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def AudioTranscription(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
|
||||||
backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
backend__pb2.TranscriptResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TTS(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
|
||||||
backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TokenizeString(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.TokenizationResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Status(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.StatusResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers-musicgen server with conda
|
|
||||||
|
|
||||||
echo "Launching gRPC server for transformers-musicgen"
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers-musicgen
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python $DIR/transformers_server.py $@
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test_transformers.py
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "transformers_server.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Extra gRPC server for MusicgenForConditionalGeneration models.
|
|
||||||
"""
|
|
||||||
from concurrent import futures
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
from scipy.io.wavfile import write as write_wav
|
|
||||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
A gRPC servicer for the backend service.
|
|
||||||
|
|
||||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
|
||||||
"""
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that returns the health status of the backend service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A HealthRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Reply object that contains the health status of the backend service.
|
|
||||||
"""
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that loads a model into memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A LoadModelRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object that contains the result of the LoadModel operation.
|
|
||||||
"""
|
|
||||||
model_name = request.Model
|
|
||||||
try:
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
model_name = request.model
|
|
||||||
if model_name == "":
|
|
||||||
return backend_pb2.Result(success=False, message="request.model is required")
|
|
||||||
try:
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
|
||||||
inputs = self.processor(
|
|
||||||
text=[request.text],
|
|
||||||
padding=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
tokens = 256
|
|
||||||
# TODO get tokens from request?
|
|
||||||
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
|
|
||||||
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
|
|
||||||
sampling_rate = self.model.config.audio_encoder.sampling_rate
|
|
||||||
write_wav(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
|
|
||||||
print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr)
|
|
||||||
print("[transformers-musicgen] TTS for", file=sys.stderr)
|
|
||||||
print(request, file=sys.stderr)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
return backend_pb2.Result(success=True)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("[transformers-musicgen] Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("[transformers-musicgen] Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
print(f"[transformers-musicgen] startup: {args}", file=sys.stderr)
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
.PHONY: transformers
|
|
||||||
transformers:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running transformers..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "transformers run."
|
|
||||||
|
|
||||||
# It is not working well by using command line. It only6 works with IDE like VSCode.
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing transformers..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "transformers tested."
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# Creating a separate environment for the transformers project
|
|
||||||
|
|
||||||
```
|
|
||||||
make transformers
|
|
||||||
```
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,363 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import backend_pb2 as backend__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class BackendStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Health = channel.unary_unary(
|
|
||||||
'/backend.Backend/Health',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Predict = channel.unary_unary(
|
|
||||||
'/backend.Backend/Predict',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.LoadModel = channel.unary_unary(
|
|
||||||
'/backend.Backend/LoadModel',
|
|
||||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.PredictStream = channel.unary_stream(
|
|
||||||
'/backend.Backend/PredictStream',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Embedding = channel.unary_unary(
|
|
||||||
'/backend.Backend/Embedding',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
|
||||||
)
|
|
||||||
self.GenerateImage = channel.unary_unary(
|
|
||||||
'/backend.Backend/GenerateImage',
|
|
||||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.AudioTranscription = channel.unary_unary(
|
|
||||||
'/backend.Backend/AudioTranscription',
|
|
||||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
|
||||||
)
|
|
||||||
self.TTS = channel.unary_unary(
|
|
||||||
'/backend.Backend/TTS',
|
|
||||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.TokenizeString = channel.unary_unary(
|
|
||||||
'/backend.Backend/TokenizeString',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
|
||||||
)
|
|
||||||
self.Status = channel.unary_unary(
|
|
||||||
'/backend.Backend/Status',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackendServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GenerateImage(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def AudioTranscription(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TokenizeString(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_BackendServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Health': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Health,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Predict,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.LoadModel,
|
|
||||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.PredictStream,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Embedding,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GenerateImage,
|
|
||||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.AudioTranscription,
|
|
||||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TTS,
|
|
||||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TokenizeString,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'Status': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Status,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'backend.Backend', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class Backend(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Health(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Predict(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LoadModel(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
|
||||||
backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PredictStream(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Embedding(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.EmbeddingResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GenerateImage(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
|
||||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def AudioTranscription(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
|
||||||
backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
backend__pb2.TranscriptResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TTS(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
|
||||||
backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TokenizeString(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.TokenizationResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Status(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.StatusResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
|
||||||
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python $DIR/transformers_server.py $@
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test_transformers_server.py
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "transformers_server.py", "--addr", "localhost:50051"])
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.kill()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_embedding(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
time.sleep(10)
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased"))
|
|
||||||
print(response.message)
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
|
|
||||||
embedding_response = stub.Embedding(embedding_request)
|
|
||||||
self.assertIsNotNone(embedding_response.embeddings)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Embedding service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Extra gRPC server for HuggingFace AutoModel models.
|
|
||||||
"""
|
|
||||||
from concurrent import futures
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
||||||
|
|
||||||
|
|
||||||
def mean_pooling(model_output, attention_mask):
|
|
||||||
"""
|
|
||||||
Mean pooling to get sentence embeddings. See:
|
|
||||||
https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1
|
|
||||||
"""
|
|
||||||
token_embeddings = model_output[0]
|
|
||||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
||||||
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns
|
|
||||||
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
||||||
return sum_embeddings / sum_mask
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
||||||
"""
|
|
||||||
A gRPC servicer for the backend service.
|
|
||||||
|
|
||||||
This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding.
|
|
||||||
"""
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that returns the health status of the backend service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A HealthRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Reply object that contains the health status of the backend service.
|
|
||||||
"""
|
|
||||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that loads a model into memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A LoadModelRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Result object that contains the result of the LoadModel operation.
|
|
||||||
"""
|
|
||||||
model_name = request.Model
|
|
||||||
try:
|
|
||||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
|
|
||||||
if request.CUDA:
|
|
||||||
try:
|
|
||||||
# TODO: also tensorflow, make configurable
|
|
||||||
import torch.cuda
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
|
|
||||||
self.model = self.model.to("cuda")
|
|
||||||
except Exception as err:
|
|
||||||
print("Not using CUDA:", err, file=sys.stderr)
|
|
||||||
except Exception as err:
|
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
||||||
# Implement your logic here for the LoadModel service
|
|
||||||
# Replace this with your desired response
|
|
||||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""
|
|
||||||
A gRPC method that calculates embeddings for a given sentence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: An EmbeddingRequest object that contains the request parameters.
|
|
||||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An EmbeddingResult object that contains the calculated embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Tokenize input
|
|
||||||
max_length = 512
|
|
||||||
if request.Tokens != 0:
|
|
||||||
max_length = request.Tokens
|
|
||||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
|
||||||
|
|
||||||
# Create word embeddings
|
|
||||||
model_output = self.model(**encoded_input)
|
|
||||||
|
|
||||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
|
||||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
|
|
||||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
|
||||||
print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
|
||||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(address):
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
|
||||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
||||||
server.add_insecure_port(address)
|
|
||||||
server.start()
|
|
||||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
||||||
|
|
||||||
# Define the signal handler function
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
print("Received termination signal. Shutting down...")
|
|
||||||
server.stop(0)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Set the signal handlers for SIGINT and SIGTERM
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
server.stop(0)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
serve(args.addr)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
.PHONY: ttsvalle
|
|
||||||
ttsvalle:
|
|
||||||
$(MAKE) -C ../common-env/transformers
|
|
||||||
bash install.sh
|
|
||||||
|
|
||||||
.PHONY: run
|
|
||||||
run:
|
|
||||||
@echo "Running ttsvalle..."
|
|
||||||
bash run.sh
|
|
||||||
@echo "ttsvalle run."
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "Testing valle..."
|
|
||||||
bash test.sh
|
|
||||||
@echo "valle tested."
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,363 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import backend_pb2 as backend__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class BackendStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Health = channel.unary_unary(
|
|
||||||
'/backend.Backend/Health',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Predict = channel.unary_unary(
|
|
||||||
'/backend.Backend/Predict',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.LoadModel = channel.unary_unary(
|
|
||||||
'/backend.Backend/LoadModel',
|
|
||||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.PredictStream = channel.unary_stream(
|
|
||||||
'/backend.Backend/PredictStream',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Embedding = channel.unary_unary(
|
|
||||||
'/backend.Backend/Embedding',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
|
||||||
)
|
|
||||||
self.GenerateImage = channel.unary_unary(
|
|
||||||
'/backend.Backend/GenerateImage',
|
|
||||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.AudioTranscription = channel.unary_unary(
|
|
||||||
'/backend.Backend/AudioTranscription',
|
|
||||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
|
||||||
)
|
|
||||||
self.TTS = channel.unary_unary(
|
|
||||||
'/backend.Backend/TTS',
|
|
||||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.TokenizeString = channel.unary_unary(
|
|
||||||
'/backend.Backend/TokenizeString',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
|
||||||
)
|
|
||||||
self.Status = channel.unary_unary(
|
|
||||||
'/backend.Backend/Status',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackendServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GenerateImage(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def AudioTranscription(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TokenizeString(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_BackendServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Health': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Health,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Predict,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.LoadModel,
|
|
||||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.PredictStream,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Embedding,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GenerateImage,
|
|
||||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.AudioTranscription,
|
|
||||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TTS,
|
|
||||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TokenizeString,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'Status': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Status,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'backend.Backend', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class Backend(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Health(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Predict(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LoadModel(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
|
||||||
backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PredictStream(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Embedding(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.EmbeddingResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GenerateImage(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
|
||||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def AudioTranscription(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
|
||||||
backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
backend__pb2.TranscriptResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TTS(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
|
||||||
backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TokenizeString(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.TokenizationResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Status(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.StatusResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
##
|
|
||||||
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
|
|
||||||
export PATH=$PATH:/opt/conda/bin
|
|
||||||
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
echo $CONDA_PREFIX
|
|
||||||
|
|
||||||
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && pip install -r requirements.txt && popd
|
|
||||||
|
|
||||||
cp -rfv $CONDA_PREFIX/vall-e-x/* ./
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
A test script to test the gRPC service
|
|
||||||
"""
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
This method sets up the gRPC service by starting the server
|
|
||||||
"""
|
|
||||||
self.service = subprocess.Popen(["python3", "ttsvalle.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
"""
|
|
||||||
This method tears down the gRPC service by terminating the server
|
|
||||||
"""
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
"""
|
|
||||||
This method tests if the server starts up successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="dingzhen"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_tts(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="dingzhen"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
|
||||||
tts_response = stub.TTS(tts_request)
|
|
||||||
self.assertIsNotNone(tts_response)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("TTS service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the ttsvalle server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate transformers
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test.py
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,363 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import backend_pb2 as backend__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class BackendStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Health = channel.unary_unary(
|
|
||||||
'/backend.Backend/Health',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Predict = channel.unary_unary(
|
|
||||||
'/backend.Backend/Predict',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.LoadModel = channel.unary_unary(
|
|
||||||
'/backend.Backend/LoadModel',
|
|
||||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.PredictStream = channel.unary_stream(
|
|
||||||
'/backend.Backend/PredictStream',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Reply.FromString,
|
|
||||||
)
|
|
||||||
self.Embedding = channel.unary_unary(
|
|
||||||
'/backend.Backend/Embedding',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
|
||||||
)
|
|
||||||
self.GenerateImage = channel.unary_unary(
|
|
||||||
'/backend.Backend/GenerateImage',
|
|
||||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.AudioTranscription = channel.unary_unary(
|
|
||||||
'/backend.Backend/AudioTranscription',
|
|
||||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
|
||||||
)
|
|
||||||
self.TTS = channel.unary_unary(
|
|
||||||
'/backend.Backend/TTS',
|
|
||||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.Result.FromString,
|
|
||||||
)
|
|
||||||
self.TokenizeString = channel.unary_unary(
|
|
||||||
'/backend.Backend/TokenizeString',
|
|
||||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.TokenizationResponse.FromString,
|
|
||||||
)
|
|
||||||
self.Status = channel.unary_unary(
|
|
||||||
'/backend.Backend/Status',
|
|
||||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
response_deserializer=backend__pb2.StatusResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackendServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Predict(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def LoadModel(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def PredictStream(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Embedding(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GenerateImage(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def AudioTranscription(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TTS(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def TokenizeString(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Status(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_BackendServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Health': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Health,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Predict,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.LoadModel,
|
|
||||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
|
||||||
servicer.PredictStream,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
|
||||||
),
|
|
||||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Embedding,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GenerateImage,
|
|
||||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.AudioTranscription,
|
|
||||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
|
||||||
),
|
|
||||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TTS,
|
|
||||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
|
||||||
response_serializer=backend__pb2.Result.SerializeToString,
|
|
||||||
),
|
|
||||||
'TokenizeString': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.TokenizeString,
|
|
||||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
|
||||||
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'Status': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Status,
|
|
||||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
|
||||||
response_serializer=backend__pb2.StatusResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'backend.Backend', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class Backend(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Health(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Predict(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def LoadModel(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
|
||||||
backend__pb2.ModelOptions.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PredictStream(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.Reply.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Embedding(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.EmbeddingResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GenerateImage(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
|
||||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def AudioTranscription(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
|
||||||
backend__pb2.TranscriptRequest.SerializeToString,
|
|
||||||
backend__pb2.TranscriptResult.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TTS(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
|
||||||
backend__pb2.TTSRequest.SerializeToString,
|
|
||||||
backend__pb2.Result.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def TokenizeString(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
|
|
||||||
backend__pb2.PredictOptions.SerializeToString,
|
|
||||||
backend__pb2.TokenizationResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Status(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
|
|
||||||
backend__pb2.HealthMessage.SerializeToString,
|
|
||||||
backend__pb2.StatusResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
##
|
|
||||||
## A bash script wrapper that runs the transformers server with conda
|
|
||||||
|
|
||||||
# Activate conda environment
|
|
||||||
source activate vllm
|
|
||||||
|
|
||||||
# get the directory where the bash script is located
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
|
|
||||||
python -m unittest $DIR/test_backend_vllm.py
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import backend_pb2
|
|
||||||
import backend_pb2_grpc
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import grpc
|
|
||||||
import backend_pb2_grpc
|
|
||||||
import backend_pb2
|
|
||||||
|
|
||||||
class TestBackendServicer(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
TestBackendServicer is the class that tests the gRPC service.
|
|
||||||
|
|
||||||
This class contains methods to test the startup and shutdown of the gRPC service.
|
|
||||||
"""
|
|
||||||
def setUp(self):
|
|
||||||
self.service = subprocess.Popen(["python", "backend_vllm.py", "--addr", "localhost:50051"])
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
self.service.terminate()
|
|
||||||
self.service.wait()
|
|
||||||
|
|
||||||
def test_server_startup(self):
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.Health(backend_pb2.HealthMessage())
|
|
||||||
self.assertEqual(response.message, b'OK')
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("Server failed to start")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
def test_load_model(self):
|
|
||||||
"""
|
|
||||||
This method tests if the model is loaded successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
self.assertEqual(response.message, "Model loaded successfully")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("LoadModel service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
|
|
||||||
def test_text(self):
|
|
||||||
"""
|
|
||||||
This method tests if the embeddings are generated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.setUp()
|
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
|
||||||
stub = backend_pb2_grpc.BackendStub(channel)
|
|
||||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
||||||
self.assertTrue(response.success)
|
|
||||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
|
||||||
resp = stub.Predict(req)
|
|
||||||
self.assertIsNotNone(resp.message)
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
self.fail("text service failed")
|
|
||||||
finally:
|
|
||||||
self.tearDown()
|
|
||||||
@@ -5,6 +5,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
|
||||||
|
bert "github.com/go-skynet/LocalAI/pkg/backend/llm/bert"
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ var (
|
|||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if err := grpc.StartServer(*addr, &Image{}); err != nil {
|
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5,7 +5,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
|
||||||
transformers "github.com/go-skynet/LocalAI/backend/go/llm/transformers"
|
transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers"
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
)
|
)
|
||||||
@@ -5,7 +5,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
|
||||||
transformers "github.com/go-skynet/LocalAI/backend/go/llm/transformers"
|
transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers"
|
||||||
|
|
||||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
)
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user