mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 03:02:38 -05:00
Compare commits
170 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6c621ef7f | ||
|
|
328289099a | ||
|
|
22ffd5f490 | ||
|
|
81708bb1e6 | ||
|
|
c81e9d8d1f | ||
|
|
ff3ab5fcca | ||
|
|
1d1cae8e4d | ||
|
|
8c781a6a44 | ||
|
|
93a4bec06b | ||
|
|
c93f57efd6 | ||
|
|
0e4f93c5cf | ||
|
|
5b3fedebfe | ||
|
|
219751bb21 | ||
|
|
bb7772a364 | ||
|
|
3c8fc37c56 | ||
|
|
39805b09e5 | ||
|
|
63b01199fe | ||
|
|
b09bae3443 | ||
|
|
de6fb98bed | ||
|
|
433605e282 | ||
|
|
a843e64fc2 | ||
|
|
71611d2dec | ||
|
|
abf48e8a5d | ||
|
|
ac5ea0cd4d | ||
|
|
a46fcacedd | ||
|
|
df947fc933 | ||
|
|
91d49cfe9f | ||
|
|
19d15f83db | ||
|
|
cde61cc518 | ||
|
|
acd829a7a0 | ||
|
|
4aa5dac768 | ||
|
|
08b59b5cc5 | ||
|
|
6b900e28cd | ||
|
|
5ca21ee398 | ||
|
|
953e30814a | ||
|
|
a65344cf25 | ||
|
|
7fb8b4191f | ||
|
|
fc8aec7324 | ||
|
|
c309aac8f5 | ||
|
|
1e37ec727d | ||
|
|
ae36bae59d | ||
|
|
e663beebf0 | ||
|
|
9d0292e9e1 | ||
|
|
fe27bb7982 | ||
|
|
d603a9cbb5 | ||
|
|
c1fc22e746 | ||
|
|
85d3710924 | ||
|
|
a0324245f1 | ||
|
|
ce8e9dc690 | ||
|
|
32ca7efbeb | ||
|
|
27520eb169 | ||
|
|
9843adb4f1 | ||
|
|
8e8d474ae8 | ||
|
|
6151ea1c4d | ||
|
|
d969025f87 | ||
|
|
18e1cb9c92 | ||
|
|
e7ceb9e8f5 | ||
|
|
3a4675c8c3 | ||
|
|
5ce0f216cf | ||
|
|
688f150463 | ||
|
|
00ccb8d4f1 | ||
|
|
e70b91aaef | ||
|
|
8b90ac2b1a | ||
|
|
f085baa77d | ||
|
|
fa4de05c14 | ||
|
|
dde12b492b | ||
|
|
096d98c3d9 | ||
|
|
147cae9ed8 | ||
|
|
c63709014b | ||
|
|
9b307799ce | ||
|
|
78e36779cf | ||
|
|
90ae35e2e4 | ||
|
|
b96e30e66c | ||
|
|
0af0df7423 | ||
|
|
0883d324d9 | ||
|
|
77597e6a16 | ||
|
|
eae6b36d03 | ||
|
|
c4bc7c41b1 | ||
|
|
c79ddd6fc4 | ||
|
|
ae58fb8821 | ||
|
|
569c1d1163 | ||
|
|
12fe0932c4 | ||
|
|
72e3e236de | ||
|
|
ab59b238b3 | ||
|
|
bed9570e48 | ||
|
|
c6bf67f446 | ||
|
|
5ee186b8e5 | ||
|
|
94817b557c | ||
|
|
26e1496075 | ||
|
|
92fca8ae74 | ||
|
|
7fa5b8401d | ||
|
|
0eac0402e1 | ||
|
|
c71c729bc2 | ||
|
|
e459f114cd | ||
|
|
982a7e86a8 | ||
|
|
94916749c5 | ||
|
|
5ce5f87a26 | ||
|
|
1d2ae46ddc | ||
|
|
71ac331f90 | ||
|
|
47cc95fc9f | ||
|
|
3feb632eb4 | ||
|
|
236497e331 | ||
|
|
a38dc497b2 | ||
|
|
28ed52fa94 | ||
|
|
e995b95c94 | ||
|
|
8379cce209 | ||
|
|
3c6b798522 | ||
|
|
c18770a61a | ||
|
|
6352448b72 | ||
|
|
fb6cce487f | ||
|
|
3079cc4167 | ||
|
|
27ef8b1eb7 | ||
|
|
c00435d72b | ||
|
|
d0e67cce75 | ||
|
|
6ec315e540 | ||
|
|
cf4e6f909c | ||
|
|
b3a99166fd | ||
|
|
107008331e | ||
|
|
accd9f9044 | ||
|
|
17294ae5e5 | ||
|
|
3c3a9b765a | ||
|
|
526c5bcdad | ||
|
|
a1bbe75d43 | ||
|
|
572a311639 | ||
|
|
cb5d6f6e3a | ||
|
|
e3cabb555d | ||
|
|
f193f56564 | ||
|
|
c0a91ab548 | ||
|
|
26e510bf28 | ||
|
|
98e73ed67a | ||
|
|
7f3de3ca4a | ||
|
|
189cb3a7be | ||
|
|
1d0ed95a54 | ||
|
|
5dcfdbe51d | ||
|
|
f2f1d7fe72 | ||
|
|
ae533cadef | ||
|
|
58f6aab637 | ||
|
|
b816009db0 | ||
|
|
a84dee1be1 | ||
|
|
30e4ddbf10 | ||
|
|
296a5b6707 | ||
|
|
b0520dcb59 | ||
|
|
f42967ed86 | ||
|
|
966675c8e3 | ||
|
|
f68df1624b | ||
|
|
42cade808b | ||
|
|
d59211982b | ||
|
|
7aaa10680d | ||
|
|
dcf35dd25f | ||
|
|
e70322676c | ||
|
|
b3f43ab938 | ||
|
|
bbc4468908 | ||
|
|
4de7f55f2f | ||
|
|
def23e4ee2 | ||
|
|
55befe396a | ||
|
|
483fddccf9 | ||
|
|
c4495ad8f2 | ||
|
|
05aed255db | ||
|
|
0f1326b2bd | ||
|
|
1668489b00 | ||
|
|
7dd292cbb3 | ||
|
|
c0578031b5 | ||
|
|
a5b64b6a41 | ||
|
|
b722e7eb7e | ||
|
|
6d19a8bdb5 | ||
|
|
f09ddd2983 | ||
|
|
a6839fd238 | ||
|
|
f3063f98d3 | ||
|
|
70674d3c58 | ||
|
|
3829aba869 |
@@ -1,4 +1,3 @@
|
||||
.git
|
||||
.idea
|
||||
models
|
||||
examples/chatbot-ui/models
|
||||
|
||||
7
.env
7
.env
@@ -24,10 +24,13 @@ MODELS_PATH=/models
|
||||
# DEBUG=true
|
||||
|
||||
## Specify a build type. Available: cublas, openblas, clblas.
|
||||
## cuBLAS: This is a GPU-accelerated version of the complete standard BLAS (Basic Linear Algebra Subprograms) library. It's provided by Nvidia and is part of their CUDA toolkit.
|
||||
## OpenBLAS: This is an open-source implementation of the BLAS library that aims to provide highly optimized code for various platforms. It includes support for multi-threading and can be compiled to use hardware-specific features for additional performance. OpenBLAS can run on many kinds of hardware, including CPUs from Intel, AMD, and ARM.
|
||||
## clBLAS: This is an open-source implementation of the BLAS library that uses OpenCL, a framework for writing programs that execute across heterogeneous platforms consisting of CPUs, GPUs, and other processors. clBLAS is designed to take advantage of the parallel computing power of GPUs but can also run on any hardware that supports OpenCL. This includes hardware from different vendors like Nvidia, AMD, and Intel.
|
||||
# BUILD_TYPE=openblas
|
||||
|
||||
## Uncomment and set to false to disable rebuilding from source
|
||||
# REBUILD=false
|
||||
## Uncomment and set to true to enable rebuilding from source
|
||||
# REBUILD=true
|
||||
|
||||
## Enable go tags, available: stablediffusion, tts
|
||||
## stablediffusion: image generation with stablediffusion
|
||||
|
||||
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.sh text eol=lf
|
||||
5
.github/FUNDING.yml
vendored
Normal file
5
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [mudler]
|
||||
custom:
|
||||
- https://www.buymeacoffee.com/mudler
|
||||
9
.github/workflows/bump_deps.yaml
vendored
9
.github/workflows/bump_deps.yaml
vendored
@@ -30,6 +30,15 @@ jobs:
|
||||
- repository: "nomic-ai/gpt4all"
|
||||
variable: "GPT4ALL_VERSION"
|
||||
branch: "main"
|
||||
- repository: "mudler/go-ggllm.cpp"
|
||||
variable: "GOGGLLM_VERSION"
|
||||
branch: "master"
|
||||
- repository: "mudler/go-stable-diffusion"
|
||||
variable: "STABLEDIFFUSION_VERSION"
|
||||
branch: "master"
|
||||
- repository: "mudler/go-piper"
|
||||
variable: "PIPER_VERSION"
|
||||
branch: "master"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
32
.github/workflows/image.yml
vendored
32
.github/workflows/image.yml
vendored
@@ -59,6 +59,38 @@ jobs:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Release space from worker
|
||||
run: |
|
||||
echo "Listing top largest packages"
|
||||
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||
head -n 30 <<< "${pkgs}"
|
||||
echo
|
||||
df -h
|
||||
echo
|
||||
sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
|
||||
sudo apt-get remove --auto-remove android-sdk-platform-tools || true
|
||||
sudo apt-get purge --auto-remove android-sdk-platform-tools || true
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo apt-get remove -y '^mono-.*' || true
|
||||
sudo apt-get remove -y '^ghc-.*' || true
|
||||
sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
|
||||
sudo apt-get remove -y 'php.*' || true
|
||||
sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
|
||||
sudo apt-get remove -y '^google-.*' || true
|
||||
sudo apt-get remove -y azure-cli || true
|
||||
sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
|
||||
sudo apt-get remove -y '^gfortran-.*' || true
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get clean
|
||||
echo
|
||||
echo "Listing top largest packages"
|
||||
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||
head -n 30 <<< "${pkgs}"
|
||||
echo
|
||||
sudo rm -rfv build || true
|
||||
df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
|
||||
22
.github/workflows/test.yml
vendored
22
.github/workflows/test.yml
vendored
@@ -26,9 +26,29 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential ffmpeg
|
||||
|
||||
sudo apt-get install -y ca-certificates cmake curl patch
|
||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||
sudo pip install -r extra/requirements.txt
|
||||
|
||||
sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \
|
||||
curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \
|
||||
tar -xzvf - && \
|
||||
mkdir -p "spdlog-1.11.0/build" && \
|
||||
cd "spdlog-1.11.0/build" && \
|
||||
cmake .. && \
|
||||
make -j8 && \
|
||||
sudo cmake --install . --prefix /usr && mkdir -p "lib/Linux-$(uname -m)" && \
|
||||
cd /build && \
|
||||
mkdir -p "lib/Linux-$(uname -m)/piper_phonemize" && \
|
||||
curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v1.0.0/libpiper_phonemize-amd64.tar.gz" | \
|
||||
tar -C "lib/Linux-$(uname -m)/piper_phonemize" -xzvf - && ls -liah /build/lib/Linux-$(uname -m)/piper_phonemize/ && \
|
||||
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \
|
||||
sudo ln -s /usr/lib/libpiper_phonemize.so /usr/lib/libpiper_phonemize.so.1 && \
|
||||
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/
|
||||
- name: Test
|
||||
run: |
|
||||
make test
|
||||
ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test
|
||||
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -1,12 +1,20 @@
|
||||
# go-llama build artifacts
|
||||
go-llama
|
||||
gpt4all
|
||||
/gpt4all
|
||||
go-stable-diffusion
|
||||
go-piper
|
||||
/go-bert
|
||||
go-ggllm
|
||||
/piper
|
||||
__pycache__/
|
||||
*.a
|
||||
get-sources
|
||||
|
||||
go-ggml-transformers
|
||||
go-gpt2
|
||||
go-rwkv
|
||||
whisper.cpp
|
||||
bloomz
|
||||
/bloomz
|
||||
go-bert
|
||||
|
||||
# LocalAI build binary
|
||||
@@ -28,5 +36,5 @@ release/
|
||||
|
||||
# Generated during build
|
||||
backend-assets/
|
||||
|
||||
/ggml-metal.metal
|
||||
prepare
|
||||
/ggml-metal.metal
|
||||
|
||||
40
Dockerfile
40
Dockerfile
@@ -11,10 +11,16 @@ ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/huggingface.py,autogptq:/build/extra/grpc/autogptq/autogptq.py,bark:/build/extra/grpc/bark/ttsbark.py,diffusers:/build/extra/grpc/diffusers/backend_diffusers.py,exllama:/build/extra/grpc/exllama/exllama.py"
|
||||
ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
||||
ARG GO_TAGS="stablediffusion tts"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y ca-certificates cmake curl patch
|
||||
apt-get install -y ca-certificates cmake curl patch pip
|
||||
|
||||
# Use the variables in subsequent instructions
|
||||
RUN echo "Target Architecture: $TARGETARCH"
|
||||
RUN echo "Target Variant: $TARGETVARIANT"
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
|
||||
@@ -24,10 +30,23 @@ RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
rm -f cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
|
||||
; fi
|
||||
ENV PATH /usr/local/cuda/bin:${PATH}
|
||||
|
||||
# Extras requirements
|
||||
COPY extra/requirements.txt /build/extra/requirements.txt
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
RUN pip install --upgrade pip
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
RUN if [ "${TARGETARCH}" = "amd64" ]; then \
|
||||
pip install git+https://github.com/suno-ai/bark.git diffusers invisible_watermark transformers accelerate safetensors;\
|
||||
fi
|
||||
RUN if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "amd64" ]; then \
|
||||
pip install torch && pip install auto-gptq https://github.com/jllllll/exllama/releases/download/0.0.10/exllama-0.0.10+cu${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}-cp39-cp39-linux_x86_64.whl;\
|
||||
fi
|
||||
RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# OpenBLAS requirements
|
||||
@@ -37,9 +56,6 @@ RUN apt-get install -y libopenblas-dev
|
||||
RUN apt-get install -y libopencv-dev && \
|
||||
ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||
|
||||
# Use the variables in subsequent instructions
|
||||
RUN echo "Target Architecture: $TARGETARCH"
|
||||
RUN echo "Target Variant: $TARGETVARIANT"
|
||||
|
||||
# piper requirements
|
||||
# Use pre-compiled Piper phonemization library (includes onnxruntime)
|
||||
@@ -58,8 +74,8 @@ RUN curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v${SPDLOG_VERSIO
|
||||
mkdir -p "lib/Linux-$(uname -m)/piper_phonemize" && \
|
||||
curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v${PIPER_PHONEMIZE_VERSION}/libpiper_phonemize-${TARGETARCH:-$(go env GOARCH)}${TARGETVARIANT}.tar.gz" | \
|
||||
tar -C "lib/Linux-$(uname -m)/piper_phonemize" -xzvf - && ls -liah /build/lib/Linux-$(uname -m)/piper_phonemize/ && \
|
||||
cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \
|
||||
cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \
|
||||
ln -s /usr/lib/libpiper_phonemize.so /usr/lib/libpiper_phonemize.so.1 && \
|
||||
cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/
|
||||
# \
|
||||
# ; fi
|
||||
@@ -83,6 +99,8 @@ RUN make get-sources
|
||||
COPY go.mod .
|
||||
RUN make prepare
|
||||
COPY . .
|
||||
COPY .git .
|
||||
|
||||
RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build
|
||||
|
||||
###################################
|
||||
@@ -91,8 +109,11 @@ RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data
|
||||
FROM requirements
|
||||
|
||||
ARG FFMPEG
|
||||
ARG BUILD_TYPE
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV REBUILD=true
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ENV REBUILD=false
|
||||
ENV HEALTHCHECK_ENDPOINT=http://localhost:8080/readyz
|
||||
|
||||
# Add FFmpeg
|
||||
@@ -109,7 +130,10 @@ WORKDIR /build
|
||||
COPY . .
|
||||
RUN make prepare-sources
|
||||
COPY --from=builder /build/local-ai ./
|
||||
|
||||
# To resolve exllama import error
|
||||
RUN if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH:-$(go env GOARCH)}" = "amd64" ]; then \
|
||||
cp -rfv /usr/local/lib/python3.9/dist-packages/exllama extra/grpc/exllama/;\
|
||||
fi
|
||||
# Define the health check command
|
||||
HEALTHCHECK --interval=1m --timeout=10m --retries=10 \
|
||||
CMD curl -f $HEALTHCHECK_ENDPOINT || exit 1
|
||||
|
||||
263
Makefile
263
Makefile
@@ -3,24 +3,45 @@ GOTEST=$(GOCMD) test
|
||||
GOVET=$(GOCMD) vet
|
||||
BINARY_NAME=local-ai
|
||||
|
||||
GOLLAMA_VERSION?=42ba448383692c11ca8f04f2b87e87f3f9bdac30
|
||||
# llama.cpp versions
|
||||
GOLLAMA_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
|
||||
|
||||
# gpt4all version
|
||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||
GPT4ALL_VERSION?=a67f8132e1657974d2abe4abeb82df9be3d42bbd
|
||||
GOGGMLTRANSFORMERS_VERSION?=8e31841dcddca16468c11b2e7809f279fa76a832
|
||||
GPT4ALL_VERSION?=0f2bb506a8ee752afc06cbb832773bf85b97eef3
|
||||
|
||||
# go-ggml-transformers version
|
||||
GOGGMLTRANSFORMERS_VERSION?=ffb09d7dd71e2cbc6c5d7d05357d230eea6f369a
|
||||
|
||||
# go-rwkv version
|
||||
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
|
||||
RWKV_VERSION?=f5a8c45396741470583f59b916a2a7641e63bcd0
|
||||
RWKV_VERSION?=c898cd0f62df8f2a7830e53d1d513bef4f6f792b
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_CPP_VERSION?=85ed71aaec8e0612a84c0b67804bde75aa75a273
|
||||
BERT_VERSION?=6069103f54b9969c02e789d0fb12a23bd614285f
|
||||
|
||||
# bert.cpp version
|
||||
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
||||
|
||||
# go-piper version
|
||||
PIPER_VERSION?=56b8a81b4760a6fbee1a82e62f007ae7e8f010a7
|
||||
|
||||
# go-bloomz version
|
||||
BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f
|
||||
|
||||
# stablediffusion version
|
||||
STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632
|
||||
|
||||
# Go-ggllm
|
||||
GOGGLLM_VERSION?=862477d16eefb0805261c19c9b0d053e3b2b684b
|
||||
|
||||
export BUILD_TYPE?=
|
||||
CGO_LDFLAGS?=
|
||||
CUDA_LIBPATH?=/usr/local/cuda/lib64/
|
||||
STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632
|
||||
GO_TAGS?=
|
||||
BUILD_ID?=git
|
||||
|
||||
VERSION?=$(shell git describe --always --tags --dirty || echo "dev" )
|
||||
VERSION?=$(shell git describe --always --tags || echo "dev" )
|
||||
# go tool nm ./local-ai | grep Commit
|
||||
LD_FLAGS?=
|
||||
override LD_FLAGS += -X "github.com/go-skynet/LocalAI/internal.Version=$(VERSION)"
|
||||
@@ -37,8 +58,14 @@ WHITE := $(shell tput -Txterm setaf 7)
|
||||
CYAN := $(shell tput -Txterm setaf 6)
|
||||
RESET := $(shell tput -Txterm sgr0)
|
||||
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
# workaround for rwkv.cpp
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CGO_LDFLAGS += -lcblas -framework Accelerate
|
||||
endif
|
||||
|
||||
ifeq ($(BUILD_TYPE),openblas)
|
||||
CGO_LDFLAGS+=-lopenblas
|
||||
@@ -64,12 +91,14 @@ ifeq ($(STATIC),true)
|
||||
endif
|
||||
|
||||
ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion)
|
||||
OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a
|
||||
# OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a
|
||||
OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion
|
||||
endif
|
||||
|
||||
ifeq ($(findstring tts,$(GO_TAGS)),tts)
|
||||
OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
|
||||
OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
|
||||
# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
|
||||
# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
|
||||
OPTIONAL_GRPC+=backend-assets/grpc/piper
|
||||
endif
|
||||
|
||||
.PHONY: all test build vendor
|
||||
@@ -80,20 +109,14 @@ all: help
|
||||
gpt4all:
|
||||
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
|
||||
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
|
||||
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
|
||||
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.m" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} +
|
||||
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} +
|
||||
@find ./gpt4all/gpt4all-backend -type f -name "llama_util.h" -execdir mv {} "llama_gpt4all_util.h" \;
|
||||
@find ./gpt4all -type f -name "*.cmake" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} +
|
||||
@find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} +
|
||||
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
|
||||
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
|
||||
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
|
||||
|
||||
## go-ggllm
|
||||
go-ggllm:
|
||||
git clone --recurse-submodules https://github.com/mudler/go-ggllm.cpp go-ggllm
|
||||
cd go-ggllm && git checkout -b build $(GOGGLLM_VERSION) && git submodule update --init --recursive --depth 1
|
||||
|
||||
go-ggllm/libggllm.a: go-ggllm
|
||||
$(MAKE) -C go-ggllm BUILD_TYPE=$(BUILD_TYPE) libggllm.a
|
||||
|
||||
## go-piper
|
||||
go-piper:
|
||||
@@ -104,9 +127,6 @@ go-piper:
|
||||
go-bert:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert
|
||||
cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
|
||||
@find ./go-bert -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
|
||||
@find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
|
||||
@find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
|
||||
|
||||
## stable diffusion
|
||||
go-stable-diffusion:
|
||||
@@ -120,9 +140,6 @@ go-stable-diffusion/libstablediffusion.a:
|
||||
go-rwkv:
|
||||
git clone --recurse-submodules $(RWKV_REPO) go-rwkv
|
||||
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
|
||||
@find ./go-rwkv -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
|
||||
@find ./go-rwkv -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
|
||||
@find ./go-rwkv -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
|
||||
|
||||
go-rwkv/librwkv.a: go-rwkv
|
||||
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
|
||||
@@ -130,13 +147,7 @@ go-rwkv/librwkv.a: go-rwkv
|
||||
## bloomz
|
||||
bloomz:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz
|
||||
@find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
|
||||
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
|
||||
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
|
||||
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
|
||||
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
|
||||
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_bloomz_replace/g' {} +
|
||||
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_bloomz_replace/g' {} +
|
||||
cd bloomz && git checkout -b build $(BLOOMZ_VERSION) && git submodule update --init --recursive --depth 1
|
||||
|
||||
bloomz/libbloomz.a: bloomz
|
||||
cd bloomz && make libbloomz.a
|
||||
@@ -155,6 +166,7 @@ backend-assets/espeak-ng-data:
|
||||
ifdef ESPEAK_DATA
|
||||
@cp -rf $(ESPEAK_DATA)/. backend-assets/espeak-ng-data
|
||||
else
|
||||
@echo "ESPEAK_DATA not set, skipping tts. Note that this will break the tts functionality."
|
||||
@touch backend-assets/espeak-ng-data/keep
|
||||
endif
|
||||
|
||||
@@ -165,27 +177,13 @@ gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
|
||||
go-ggml-transformers:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp go-ggml-transformers
|
||||
cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
||||
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
|
||||
@find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} +
|
||||
@find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} +
|
||||
|
||||
go-ggml-transformers/libtransformers.a: go-ggml-transformers
|
||||
$(MAKE) -C go-ggml-transformers libtransformers.a
|
||||
$(MAKE) -C go-ggml-transformers BUILD_TYPE=$(BUILD_TYPE) libtransformers.a
|
||||
|
||||
whisper.cpp:
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
|
||||
@find ./whisper.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
|
||||
@find ./whisper.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
|
||||
@find ./whisper.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
|
||||
|
||||
whisper.cpp/libwhisper.a: whisper.cpp
|
||||
cd whisper.cpp && make libwhisper.a
|
||||
@@ -200,7 +198,7 @@ go-llama/libbinding.a: go-llama
|
||||
go-piper/libpiper_binding.a:
|
||||
$(MAKE) -C go-piper libpiper_binding.a example/main
|
||||
|
||||
get-sources: go-llama go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion
|
||||
get-sources: go-llama go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion
|
||||
touch $@
|
||||
|
||||
replace:
|
||||
@@ -213,12 +211,14 @@ replace:
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz
|
||||
$(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/go-stable-diffusion
|
||||
$(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/go-piper
|
||||
$(GOCMD) mod edit -replace github.com/mudler/go-ggllm.cpp=$(shell pwd)/go-ggllm
|
||||
|
||||
prepare-sources: get-sources replace
|
||||
$(GOCMD) mod download
|
||||
|
||||
## GENERIC
|
||||
rebuild: ## Rebuilds the project
|
||||
$(GOCMD) clean -cache
|
||||
$(MAKE) -C go-llama clean
|
||||
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean
|
||||
$(MAKE) -C go-ggml-transformers clean
|
||||
@@ -228,13 +228,16 @@ rebuild: ## Rebuilds the project
|
||||
$(MAKE) -C go-bert clean
|
||||
$(MAKE) -C bloomz clean
|
||||
$(MAKE) -C go-piper clean
|
||||
$(MAKE) -C go-ggllm clean
|
||||
$(MAKE) build
|
||||
|
||||
prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
|
||||
prepare: prepare-sources $(OPTIONAL_TARGETS)
|
||||
touch $@
|
||||
|
||||
clean: ## Remove build related file
|
||||
rm -fr ./go-llama
|
||||
$(GOCMD) clean -cache
|
||||
rm -f prepare
|
||||
rm -rf ./go-llama
|
||||
rm -rf ./gpt4all
|
||||
rm -rf ./go-gpt2
|
||||
rm -rf ./go-stable-diffusion
|
||||
@@ -245,32 +248,27 @@ clean: ## Remove build related file
|
||||
rm -rf ./bloomz
|
||||
rm -rf ./whisper.cpp
|
||||
rm -rf ./go-piper
|
||||
rm -rf ./go-ggllm
|
||||
rm -rf $(BINARY_NAME)
|
||||
rm -rf release/
|
||||
|
||||
## Build:
|
||||
|
||||
build: prepare ## Build the project
|
||||
build: grpcs prepare ## Build the project
|
||||
$(info ${GREEN}I local-ai build info:${RESET})
|
||||
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
|
||||
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
|
||||
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
|
||||
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
|
||||
ifeq ($(BUILD_TYPE),metal)
|
||||
cp go-llama/build/bin/ggml-metal.metal .
|
||||
endif
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
|
||||
|
||||
dist: build
|
||||
mkdir -p release
|
||||
cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH)
|
||||
|
||||
generic-build: ## Build the project using generic
|
||||
BUILD_TYPE="generic" $(MAKE) build
|
||||
|
||||
## Run
|
||||
run: prepare ## run local-ai
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||
|
||||
test-models/testmodel:
|
||||
mkdir test-models
|
||||
@@ -283,12 +281,40 @@ test-models/testmodel:
|
||||
wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json
|
||||
cp tests/models_fixtures/* test-models
|
||||
|
||||
test: prepare test-models/testmodel
|
||||
cp -r backend-assets api
|
||||
prepare-test: grpcs
|
||||
cp -rf backend-assets api
|
||||
cp tests/models_fixtures/* test-models
|
||||
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
|
||||
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
|
||||
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
|
||||
|
||||
test: prepare test-models/testmodel grpcs
|
||||
@echo 'Running tests'
|
||||
export GO_TAGS="tts stablediffusion"
|
||||
$(MAKE) prepare-test
|
||||
HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/huggingface.py TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
|
||||
$(MAKE) test-gpt4all
|
||||
$(MAKE) test-llama
|
||||
$(MAKE) test-tts
|
||||
$(MAKE) test-stablediffusion
|
||||
|
||||
test-gpt4all: prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
|
||||
|
||||
test-llama: prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
|
||||
|
||||
test-tts: prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg
|
||||
|
||||
test-stablediffusion: prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
|
||||
|
||||
test-container:
|
||||
docker build --target requirements -t local-ai-test-container .
|
||||
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
|
||||
|
||||
## Help:
|
||||
help: ## Show this help.
|
||||
@@ -301,3 +327,98 @@ help: ## Show this help.
|
||||
if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \
|
||||
else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \
|
||||
}' $(MAKEFILE_LIST)
|
||||
|
||||
protogen: protogen-go protogen-python
|
||||
|
||||
protogen-go:
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
pkg/grpc/proto/backend.proto
|
||||
|
||||
protogen-python:
|
||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto
|
||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/autogptq/ --grpc_python_out=extra/grpc/autogptq/ pkg/grpc/proto/backend.proto
|
||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/exllama/ --grpc_python_out=extra/grpc/exllama/ pkg/grpc/proto/backend.proto
|
||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/bark/ --grpc_python_out=extra/grpc/bark/ pkg/grpc/proto/backend.proto
|
||||
python3 -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/diffusers/ --grpc_python_out=extra/grpc/diffusers/ pkg/grpc/proto/backend.proto
|
||||
|
||||
## GRPC
|
||||
|
||||
backend-assets/grpc:
|
||||
mkdir -p backend-assets/grpc
|
||||
|
||||
backend-assets/grpc/falcon: backend-assets/grpc go-ggllm/libggllm.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/
|
||||
|
||||
backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/
|
||||
# TODO: every binary should have its own folder instead, so can have different metal implementations
|
||||
ifeq ($(BUILD_TYPE),metal)
|
||||
cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/
|
||||
endif
|
||||
|
||||
backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/
|
||||
|
||||
backend-assets/grpc/dolly: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/
|
||||
|
||||
backend-assets/grpc/gpt2: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/
|
||||
|
||||
backend-assets/grpc/gptj: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/
|
||||
|
||||
backend-assets/grpc/gptneox: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/
|
||||
|
||||
backend-assets/grpc/mpt: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/
|
||||
|
||||
backend-assets/grpc/replit: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/
|
||||
|
||||
backend-assets/grpc/falcon-ggml: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./cmd/grpc/falcon-ggml/
|
||||
|
||||
backend-assets/grpc/starcoder: backend-assets/grpc go-ggml-transformers/libtransformers.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/
|
||||
|
||||
backend-assets/grpc/rwkv: backend-assets/grpc go-rwkv/librwkv.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-rwkv LIBRARY_PATH=$(shell pwd)/go-rwkv \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./cmd/grpc/rwkv/
|
||||
|
||||
backend-assets/grpc/bloomz: backend-assets/grpc bloomz/libbloomz.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/bloomz LIBRARY_PATH=$(shell pwd)/bloomz \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bloomz ./cmd/grpc/bloomz/
|
||||
|
||||
backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-bert LIBRARY_PATH=$(shell pwd)/go-bert \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./cmd/grpc/bert-embeddings/
|
||||
|
||||
backend-assets/grpc/langchain-huggingface: backend-assets/grpc
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/
|
||||
|
||||
backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/
|
||||
|
||||
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./cmd/grpc/piper/
|
||||
|
||||
backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/
|
||||
|
||||
grpcs: prepare backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
|
||||
276
README.md
276
README.md
@@ -1,220 +1,143 @@
|
||||
<h1 align="center">
|
||||
<br>
|
||||
<img height="300" src="https://user-images.githubusercontent.com/2420543/233147843-88697415-6dbf-4368-a862-ab217f9f7342.jpeg"> <br>
|
||||
<img height="300" src="https://github.com/go-skynet/LocalAI/assets/2420543/0966aa2a-166e-4f99-a3e5-6c915fc997dd"> <br>
|
||||
LocalAI
|
||||
<br>
|
||||
</h1>
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml) [](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)
|
||||
<p align="center">
|
||||
<a href="https://github.com/go-skynet/LocalAI/fork" target="blank">
|
||||
<img src="https://img.shields.io/github/forks/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI forks"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/stargazers" target="blank">
|
||||
<img src="https://img.shields.io/github/stars/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI stars"/>
|
||||
</a>
|
||||
<a href="https://github.com/go-skynet/LocalAI/pulls" target="blank">
|
||||
<img src="https://img.shields.io/github/issues-pr/go-skynet/LocalAI?style=for-the-badge" alt="LocalAI pull-requests"/>
|
||||
</a>
|
||||
<a href='https://github.com/go-skynet/LocalAI/releases'>
|
||||
<img src='https://img.shields.io/github/release/go-skynet/LocalAI?&label=Latest&style=for-the-badge'>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
[](https://discord.gg/uJAeKSAGDy)
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [📣 News](https://localai.io/basics/news/) [ 🛫 Examples ](https://github.com/go-skynet/LocalAI/tree/master/examples/) [ 🖼️ Models ](https://localai.io/models/)
|
||||
|
||||
[Documentation website](https://localai.io/)
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
|
||||
**LocalAI** is a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run LLMs (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families that are compatible with the ggml format. Does not require GPU.
|
||||
|
||||
For a list of the supported model families, please see [the model compatibility table](https://localai.io/model-compatibility/index.html#model-compatibility-table).
|
||||
<p align="center"><b>Follow LocalAI </b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://twitter.com/LocalAI_API" target="blank">
|
||||
<img src="https://img.shields.io/twitter/follow/LocalAI_API?label=Follow: LocalAI_API&style=social" alt="Follow LocalAI_API"/>
|
||||
</a>
|
||||
<a href="https://discord.gg/uJAeKSAGDy" target="blank">
|
||||
<img src="https://dcbadge.vercel.app/api/server/uJAeKSAGDy?style=flat-square&theme=default-inverted" alt="Join LocalAI Discord Community"/>
|
||||
</a>
|
||||
|
||||
<p align="center"><b>Connect with the Creator </b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://twitter.com/mudler_it" target="blank">
|
||||
<img src="https://img.shields.io/twitter/follow/mudler_it?label=Follow: mudler_it&style=social" alt="Follow mudler_it"/>
|
||||
</a>
|
||||
<a href='https://github.com/mudler'>
|
||||
<img alt="Follow on Github" src="https://img.shields.io/badge/Follow-mudler-black?logo=github&link=https%3A%2F%2Fgithub.com%2Fmudler">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center"><b>Share LocalAI Repository</b></p>
|
||||
|
||||
<p align="center">
|
||||
|
||||
<a href="https://twitter.com/intent/tweet?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.&url=https://github.com/go-skynet/LocalAI&hashtags=LocalAI,AI" target="blank">
|
||||
<img src="https://img.shields.io/twitter/follow/_LocalAI?label=Share Repo on Twitter&style=social" alt="Follow _LocalAI"/></a>
|
||||
<a href="https://t.me/share/url?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.&url=https://github.com/go-skynet/LocalAI" target="_blank"><img src="https://img.shields.io/twitter/url?label=Telegram&logo=Telegram&style=social&url=https://github.com/go-skynet/LocalAI" alt="Share on Telegram"/></a>
|
||||
<a href="https://api.whatsapp.com/send?text=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.%20https://github.com/go-skynet/LocalAI"><img src="https://img.shields.io/twitter/url?label=whatsapp&logo=whatsapp&style=social&url=https://github.com/go-skynet/LocalAI" /></a> <a href="https://www.reddit.com/submit?url=https://github.com/go-skynet/LocalAI&title=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.
|
||||
" target="blank">
|
||||
<img src="https://img.shields.io/twitter/url?label=Reddit&logo=Reddit&style=social&url=https://github.com/go-skynet/LocalAI" alt="Share on Reddit"/>
|
||||
</a> <a href="mailto:?subject=Check%20this%20GitHub%20repository%20out.%20LocalAI%20-%20Let%27s%20you%20easily%20run%20LLM%20locally.%3A%0Ahttps://github.com/go-skynet/LocalAI" target="_blank"><img src="https://img.shields.io/twitter/url?label=Gmail&logo=Gmail&style=social&url=https://github.com/go-skynet/LocalAI"/></a> <a href="https://www.buymeacoffee.com/mudler" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/default-orange.png" alt="Buy Me A Coffee" height="23" width="100" style="border-radius:1px"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<hr>
|
||||
|
||||
In a nutshell:
|
||||
|
||||
- Local, OpenAI drop-in alternative REST API. You own your data.
|
||||
- NO GPU required. NO Internet access is required either
|
||||
- Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. See also the [build section](https://localai.io/basics/build/index.html).
|
||||
- Supports multiple models:
|
||||
- 📖 Text generation with GPTs (`llama.cpp`, `gpt4all.cpp`, ... and more)
|
||||
- 🗣 Text to Audio 🎺🆕
|
||||
- 🔈 Audio to Text (Audio transcription with `whisper.cpp`)
|
||||
- 🎨 Image generation with stable diffusion
|
||||
- Supports multiple models
|
||||
- 🏃 Once loaded the first time, it keep models loaded in memory for faster inference
|
||||
- ⚡ Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
|
||||
- ⚡ Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
|
||||
|
||||
LocalAI was created by [Ettore Di Giacinto](https://github.com/mudler/) and is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome!
|
||||
|
||||
See the [Getting started](https://localai.io/basics/getting_started/index.html) and [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/) sections to learn how to use LocalAI. For a list of curated models check out the [model gallery](https://localai.io/models/).
|
||||
Note that this started just as a [fun weekend project](https://localai.io/#backstory) in order to try to create the necessary pieces for a full AI assistant like `ChatGPT`: the community is growing fast and we are working hard to make it better and more stable. If you want to help, please consider contributing (see below)!
|
||||
|
||||
## 🔥🔥 [Hot topics / Roadmap](https://localai.io/#-hot-topics--roadmap)
|
||||
|
||||
## 🚀 [Features](https://localai.io/features/)
|
||||
|
||||
- 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `gpt4all.cpp`, ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table))
|
||||
- 🗣 [Text to Audio](https://localai.io/features/text-to-audio/)
|
||||
- 🔈 [Audio to Text](https://localai.io/features/audio-to-text/) (Audio transcription with `whisper.cpp`)
|
||||
- 🎨 [Image generation with stable diffusion](https://localai.io/features/image-generation)
|
||||
- 🔥 [OpenAI functions](https://localai.io/features/openai-functions/) 🆕
|
||||
- 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/)
|
||||
- ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/)
|
||||
- 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/)
|
||||
|
||||
|
||||
| [ChatGPT OSS alternative](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui) | [Image generation](https://localai.io/api-endpoints/index.html#image-generation) |
|
||||
|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|
|
||||
|  |  |
|
||||
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
|
||||
|
||||
| [Telegram bot](https://github.com/go-skynet/LocalAI/tree/master/examples/telegram-bot) | [Flowise](https://github.com/go-skynet/LocalAI/tree/master/examples/flowise) |
|
||||
|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|
|
||||
 | | |
|
||||
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
|
||||
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
|
||||
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)
|
||||
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65)
|
||||
|
||||
## News
|
||||
## 💻 Usage
|
||||
|
||||
- 🔥🔥🔥 28-06-2023: **v1.20.0**: Added text to audio and gallery huggingface repositories! [Release notes](https://localai.io/basics/news/index.html#-28-06-2023-__v1200__-) [Changelog](https://github.com/go-skynet/LocalAI/releases/tag/v1.20.0)
|
||||
- 🔥🔥🔥 19-06-2023: **v1.19.0**: CUDA support! [Release notes](https://localai.io/basics/news/index.html#-19-06-2023-__v1190__-) [Changelog](https://github.com/go-skynet/LocalAI/releases/tag/v1.19.0)
|
||||
- 🔥🔥🔥 06-06-2023: **v1.18.0**: Many updates, new features, and much more 🚀, check out the [Release notes](https://localai.io/basics/news/index.html#-06-06-2023-__v1180__-)!
|
||||
- 29-05-2023: LocalAI now has a website, [https://localai.io](https://localai.io)! check the news in the [dedicated section](https://localai.io/basics/news/index.html)!
|
||||
Check out the [Getting started](https://localai.io/basics/getting_started/index.html) section in our documentation.
|
||||
|
||||
For latest news, follow also on Twitter [@LocalAI_API](https://twitter.com/LocalAI_API) and [@mudler_it](https://twitter.com/mudler_it)
|
||||
### 💡 Example: Use GPT4ALL-J model
|
||||
|
||||
## Contribute and help
|
||||
See the [documentation](https://localai.io/basics/getting_started/#example-use-gpt4all-j-model-with-docker-compose)
|
||||
|
||||
To help the project you can:
|
||||
### 🔗 Resources
|
||||
|
||||
- [Hacker news post](https://news.ycombinator.com/item?id=35726934) - help us out by voting if you like this project.
|
||||
- [How to build locally](https://localai.io/basics/build/index.html)
|
||||
- [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes)
|
||||
- [Projects integrating LocalAI](https://localai.io/integrations/)
|
||||
|
||||
- If you have technological skills and want to contribute to development, have a look at the open issues. If you are new you can have a look at the [good-first-issue](https://github.com/go-skynet/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) and [help-wanted](https://github.com/go-skynet/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) labels.
|
||||
## ❤️ Sponsors
|
||||
|
||||
- If you don't have technological skills you can still help improving documentation or add examples or share your user-stories with our community, any help and contribution is welcome!
|
||||
> Do you find LocalAI useful?
|
||||
|
||||
## Usage
|
||||
Support the project by becoming [a backer or sponsor](https://github.com/sponsors/mudler). Your logo will show up here with a link to your website.
|
||||
|
||||
Check out the [Getting started](https://localai.io/basics/getting_started/index.html) section. Here below you will find generic, quick instructions to get ready and use LocalAI.
|
||||
A huge thank you to our generous sponsors who support this project:
|
||||
|
||||
The easiest way to run LocalAI is by using `docker-compose` (to build locally, see [building LocalAI](https://localai.io/basics/build/index.html)):
|
||||
|  |
|
||||
|:-----------------------------------------------:|
|
||||
| [Spectro Cloud](https://www.spectrocloud.com/) |
|
||||
| Spectro Cloud kindly supports LocalAI by providing GPU and computing resources to run tests on lamdalabs! |
|
||||
|
||||
```bash
|
||||
|
||||
git clone https://github.com/go-skynet/LocalAI
|
||||
|
||||
cd LocalAI
|
||||
|
||||
# (optional) Checkout a specific LocalAI tag
|
||||
# git checkout -b build <TAG>
|
||||
|
||||
# copy your models to models/
|
||||
cp your-model.bin models/
|
||||
|
||||
# (optional) Edit the .env file to set things like context size and threads
|
||||
# vim .env
|
||||
|
||||
# start with docker-compose
|
||||
docker-compose up -d --pull always
|
||||
# or you can build the images with:
|
||||
# docker-compose up -d --build
|
||||
|
||||
# Now API is accessible at localhost:8080
|
||||
curl http://localhost:8080/v1/models
|
||||
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
|
||||
|
||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "your-model.bin",
|
||||
"prompt": "A long time ago in a galaxy far, far away",
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Use GPT4ALL-J model
|
||||
|
||||
<details>
|
||||
|
||||
```bash
|
||||
# Clone LocalAI
|
||||
git clone https://github.com/go-skynet/LocalAI
|
||||
|
||||
cd LocalAI
|
||||
|
||||
# (optional) Checkout a specific LocalAI tag
|
||||
# git checkout -b build <TAG>
|
||||
|
||||
# Download gpt4all-j to models/
|
||||
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
|
||||
|
||||
# Use a template from the examples
|
||||
cp -rf prompt-templates/ggml-gpt4all-j.tmpl models/
|
||||
|
||||
# (optional) Edit the .env file to set things like context size and threads
|
||||
# vim .env
|
||||
|
||||
# start with docker-compose
|
||||
docker-compose up -d --pull always
|
||||
# or you can build the images with:
|
||||
# docker-compose up -d --build
|
||||
# Now API is accessible at localhost:8080
|
||||
curl http://localhost:8080/v1/models
|
||||
# {"object":"list","data":[{"id":"ggml-gpt4all-j","object":"model"}]}
|
||||
|
||||
curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "ggml-gpt4all-j",
|
||||
"messages": [{"role": "user", "content": "How are you?"}],
|
||||
"temperature": 0.9
|
||||
}'
|
||||
|
||||
# {"model":"ggml-gpt4all-j","choices":[{"message":{"role":"assistant","content":"I'm doing well, thanks. How about you?"}}]}
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
### Build locally
|
||||
|
||||
<details>
|
||||
|
||||
In order to build the `LocalAI` container image locally you can use `docker`:
|
||||
|
||||
```
|
||||
# build the image
|
||||
docker build -t localai .
|
||||
docker run localai
|
||||
```
|
||||
|
||||
Or you can build the binary with `make`:
|
||||
|
||||
```
|
||||
make build
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
See the [build section](https://localai.io/basics/build/index.html) in our documentation for detailed instructions.
|
||||
|
||||
### Run LocalAI in Kubernetes
|
||||
|
||||
LocalAI can be installed inside Kubernetes with helm. See [installation instructions](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes).
|
||||
|
||||
## Supported API endpoints
|
||||
|
||||
See the [list of the supported API endpoints](https://localai.io/api-endpoints/index.html) and how to configure image generation and audio transcription.
|
||||
|
||||
## Frequently asked questions
|
||||
|
||||
See [the FAQ](https://localai.io/faq/index.html) section for a list of common questions.
|
||||
|
||||
## Projects already using LocalAI to run local models
|
||||
|
||||
Feel free to open up a PR to get your project listed!
|
||||
|
||||
- [Kairos](https://github.com/kairos-io/kairos)
|
||||
- [k8sgpt](https://github.com/k8sgpt-ai/k8sgpt#running-local-models)
|
||||
- [Spark](https://github.com/cedriking/spark)
|
||||
- [autogpt4all](https://github.com/aorumbayev/autogpt4all)
|
||||
- [Mods](https://github.com/charmbracelet/mods)
|
||||
- [Flowise](https://github.com/FlowiseAI/Flowise)
|
||||
|
||||
## Short-term roadmap
|
||||
|
||||
- [x] Mimic OpenAI API (https://github.com/go-skynet/LocalAI/issues/10)
|
||||
- [x] Binary releases (https://github.com/go-skynet/LocalAI/issues/6)
|
||||
- [ ] Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351)
|
||||
- [x] Upstream [gpt4all](https://github.com/go-skynet/LocalAI/issues/85) bindings
|
||||
- [x] Multi-model support
|
||||
- [x] Have a webUI!
|
||||
- [x] Allow configuration of defaults for models.
|
||||
- [x] Support for embeddings
|
||||
- [x] Support for audio transcription with https://github.com/ggerganov/whisper.cpp
|
||||
- [x] GPU/CUDA support ( https://github.com/go-skynet/LocalAI/issues/69 )
|
||||
- [X] Enable automatic downloading of models from a curated gallery
|
||||
- [ ] Enable automatic downloading of models from HuggingFace
|
||||
- [ ] Enable gallery management directly from the webui.
|
||||
- [ ] 🔥 OpenAI functions: https://github.com/go-skynet/LocalAI/issues/588
|
||||
|
||||
## Star history
|
||||
## 🌟 Star history
|
||||
|
||||
[](https://star-history.com/#go-skynet/LocalAI&Date)
|
||||
|
||||
## License
|
||||
## 📖 License
|
||||
|
||||
LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/).
|
||||
|
||||
MIT
|
||||
MIT - Author Ettore Di Giacinto
|
||||
|
||||
## Author
|
||||
|
||||
Ettore Di Giacinto and others
|
||||
|
||||
## Acknowledgements
|
||||
## 🙇 Acknowledgements
|
||||
|
||||
LocalAI couldn't have been built without the help of great software already available from the community. Thank you!
|
||||
|
||||
@@ -225,9 +148,12 @@ LocalAI couldn't have been built without the help of great software already avai
|
||||
- https://github.com/EdVince/Stable-Diffusion-NCNN
|
||||
- https://github.com/ggerganov/whisper.cpp
|
||||
- https://github.com/saharNooby/rwkv.cpp
|
||||
- https://github.com/rhasspy/piper
|
||||
- https://github.com/cmp-nct/ggllm.cpp
|
||||
|
||||
## Contributors
|
||||
## 🤗 Contributors
|
||||
|
||||
This is a community project, a special thanks to our contributors! 🤗
|
||||
<a href="https://github.com/go-skynet/LocalAI/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=go-skynet/LocalAI" />
|
||||
</a>
|
||||
|
||||
144
api/api.go
144
api/api.go
@@ -2,9 +2,15 @@ package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/localai"
|
||||
"github.com/go-skynet/LocalAI/api/openai"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
@@ -13,18 +19,18 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func App(opts ...AppOption) (*fiber.App, error) {
|
||||
options := newOptions(opts...)
|
||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
options := options.NewOptions(opts...)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if options.debug {
|
||||
if options.Debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.disableMessage,
|
||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.DisableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
@@ -38,40 +44,44 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
ErrorResponse{
|
||||
Error: &APIError{Message: err.Error(), Code: code},
|
||||
openai.ErrorResponse{
|
||||
Error: &openai.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
},
|
||||
})
|
||||
|
||||
if options.debug {
|
||||
if options.Debug {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||
}))
|
||||
}
|
||||
|
||||
cm := NewConfigMerger()
|
||||
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
cm := config.NewConfigLoader()
|
||||
if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.configFile != "" {
|
||||
if err := cm.LoadConfigFile(options.configFile); err != nil {
|
||||
if options.ConfigFile != "" {
|
||||
if err := cm.LoadConfigFile(options.ConfigFile); err != nil {
|
||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if options.debug {
|
||||
if options.Debug {
|
||||
for _, v := range cm.ListConfigs() {
|
||||
cfg, _ := cm.GetConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
if options.assetsDestination != "" {
|
||||
if options.AssetsDestination != "" {
|
||||
// Extract files from the embedded FS
|
||||
err := assets.ExtractFiles(options.backendAssets, options.assetsDestination)
|
||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||
}
|
||||
@@ -80,75 +90,102 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
|
||||
if options.preloadJSONModels != "" {
|
||||
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil {
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(options.ApiKeys) > 0 {
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||
}
|
||||
authHeaderParts := strings.Split(authHeader, " ")
|
||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
validApiKey := false
|
||||
for _, key := range options.ApiKeys {
|
||||
if apiKey == key {
|
||||
validApiKey = true
|
||||
}
|
||||
}
|
||||
if !validApiKey {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.preloadModelsFromPath != "" {
|
||||
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil {
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.cors {
|
||||
if options.corsAllowOrigins == "" {
|
||||
app.Use(cors.New())
|
||||
if options.CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if options.CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: options.corsAllowOrigins,
|
||||
}))
|
||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
}
|
||||
|
||||
// LocalAI API endpoints
|
||||
applier := newGalleryApplier(options.loader.ModelPath)
|
||||
applier.start(options.context, cm)
|
||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
||||
galleryService.Start(options.Context, cm)
|
||||
|
||||
app.Get("/version", func(c *fiber.Ctx) error {
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries))
|
||||
app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath))
|
||||
app.Get("/models/jobs/:uuid", getOpStatus(applier))
|
||||
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
|
||||
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
|
||||
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
|
||||
app.Post("/chat/completions", chatEndpoint(cm, options))
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", editEndpoint(cm, options))
|
||||
app.Post("/edits", editEndpoint(cm, options))
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cm, options))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", completionEndpoint(cm, options))
|
||||
app.Post("/completions", completionEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options))
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
|
||||
app.Post("/embeddings", embeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
|
||||
app.Post("/tts", ttsEndpoint(cm, options))
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cm, options))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", imageEndpoint(cm, options))
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options))
|
||||
|
||||
if options.imageDir != "" {
|
||||
app.Static("/generated-images", options.imageDir)
|
||||
if options.ImageDir != "" {
|
||||
app.Static("/generated-images", options.ImageDir)
|
||||
}
|
||||
|
||||
if options.audioDir != "" {
|
||||
app.Static("/generated-audio", options.audioDir)
|
||||
if options.AudioDir != "" {
|
||||
app.Static("/generated-audio", options.AudioDir)
|
||||
}
|
||||
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
@@ -160,8 +197,15 @@ func App(opts ...AppOption) (*fiber.App, error) {
|
||||
app.Get("/readyz", ok)
|
||||
|
||||
// models
|
||||
app.Get("/v1/models", listModels(options.loader, cm))
|
||||
app.Get("/models", listModels(options.loader, cm))
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
options.Loader.StopGRPC()
|
||||
}()
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
339
api/api_test.go
339
api/api_test.go
@@ -5,14 +5,16 @@ import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
@@ -23,13 +25,14 @@ import (
|
||||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type modelApplyRequest struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
Overrides map[string]string `json:"overrides"`
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
Overrides map[string]interface{} `json:"overrides"`
|
||||
}
|
||||
|
||||
func getModelStatus(url string) (response map[string]interface{}) {
|
||||
@@ -41,7 +44,7 @@ func getModelStatus(url string) (response map[string]interface{}) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading response body:", err)
|
||||
return
|
||||
@@ -93,7 +96,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading response body:", err)
|
||||
return
|
||||
@@ -121,6 +124,11 @@ var _ = Describe("API test", func() {
|
||||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
|
||||
commonOpts := []options.AppOption{
|
||||
options.WithDebug(true),
|
||||
options.WithDisableMessage(true),
|
||||
}
|
||||
|
||||
Context("API with ephemeral models", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
@@ -139,12 +147,12 @@ var _ = Describe("API test", func() {
|
||||
Name: "bert2",
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||
Overrides: map[string]interface{}{"foo": "bar"},
|
||||
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
|
||||
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
|
||||
},
|
||||
}
|
||||
out, err := yaml.Marshal(g)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
|
||||
err = os.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
galleries := []gallery.Gallery{
|
||||
@@ -154,9 +162,11 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
app, err = App(WithContext(c),
|
||||
WithGalleries(galleries),
|
||||
WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir))
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -201,7 +211,7 @@ var _ = Describe("API test", func() {
|
||||
fmt.Println(response)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
|
||||
@@ -232,7 +242,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||
Name: "bert",
|
||||
Overrides: map[string]string{
|
||||
Overrides: map[string]interface{}{
|
||||
"backend": "llama",
|
||||
},
|
||||
})
|
||||
@@ -243,9 +253,8 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -259,7 +268,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||
Name: "bert",
|
||||
Overrides: map[string]string{},
|
||||
Overrides: map[string]interface{}{},
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
@@ -268,9 +277,8 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -288,7 +296,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
|
||||
Name: "openllama_3b",
|
||||
Overrides: map[string]string{},
|
||||
Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true, "context_size": 128},
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
@@ -297,14 +305,58 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
By("testing completion")
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
||||
|
||||
By("testing functions")
|
||||
resp2, err := client.CreateChatCompletion(
|
||||
context.TODO(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: "openllama_3b",
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What is the weather like in San Francisco (celsius)?",
|
||||
},
|
||||
},
|
||||
Functions: []openai.FunctionDefinition{
|
||||
openai.FunctionDefinition{
|
||||
Name: "get_current_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"location": {
|
||||
Type: jsonschema.String,
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: jsonschema.String,
|
||||
Enum: []string{"celcius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp2.Choices)).To(Equal(1))
|
||||
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
|
||||
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
|
||||
|
||||
var res map[string]string
|
||||
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
|
||||
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
|
||||
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
|
||||
})
|
||||
|
||||
It("runs gpt4all", Label("gpt4all"), func() {
|
||||
@@ -313,9 +365,8 @@ var _ = Describe("API test", func() {
|
||||
}
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
|
||||
Name: "gpt4all-j",
|
||||
Overrides: map[string]string{},
|
||||
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
|
||||
Name: "gpt4all-j",
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
@@ -324,15 +375,130 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s").Should(Equal(true))
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("Model gallery", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
galleries := []gallery.Gallery{
|
||||
{
|
||||
Name: "model-gallery",
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
|
||||
client2 = openaigo.NewClient("")
|
||||
client2.BaseURL = defaultConfig.BaseURL
|
||||
|
||||
// Wait for API to be ready
|
||||
client = openai.NewClientWithConfig(defaultConfig)
|
||||
Eventually(func() error {
|
||||
_, err := client.ListModels(context.TODO())
|
||||
return err
|
||||
}, "2m").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
os.RemoveAll(tmpdir)
|
||||
})
|
||||
It("installs and is capable to run tts", Label("tts"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
ID: "model-gallery@voice-en-us-kathleen-low",
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
// An HTTP Post to the /tts endpoint should return a wav audio file
|
||||
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`)))
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
|
||||
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
|
||||
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
|
||||
})
|
||||
It("installs and is capable to generate images", Label("stablediffusion"), func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
ID: "model-gallery@stablediffusion",
|
||||
Overrides: map[string]interface{}{
|
||||
"parameters": map[string]interface{}{"model": "stablediffusion_assets"},
|
||||
},
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
|
||||
uuid := response["uuid"].(string)
|
||||
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
fmt.Println(response)
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
resp, err := http.Post(
|
||||
"http://127.0.0.1:9090/v1/images/generations",
|
||||
"application/json",
|
||||
bytes.NewBuffer([]byte(`{
|
||||
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
|
||||
"mode": 2, "seed":9000,
|
||||
"size": "256x256", "n":2}`)))
|
||||
// The response should contain an URL
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred(), string(dat))
|
||||
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
|
||||
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
@@ -342,7 +508,12 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
app, err = App(WithContext(c), WithModelLoader(modelLoader))
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -366,7 +537,7 @@ var _ = Describe("API test", func() {
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(10))
|
||||
Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8?
|
||||
})
|
||||
It("can generate completions", func() {
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
|
||||
@@ -397,9 +568,10 @@ var _ = Describe("API test", func() {
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:"))
|
||||
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
|
||||
})
|
||||
It("transcribes audio", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
@@ -443,15 +615,98 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||
})
|
||||
|
||||
Context("External gRPC calls", func() {
|
||||
It("calculate embeddings with huggingface", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaCodeSearchCode,
|
||||
Input: []string{"sun", "cat"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
|
||||
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
|
||||
|
||||
sunEmbedding := resp.Data[0].Embedding
|
||||
resp2, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaCodeSearchCode,
|
||||
Input: []string{"sun"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
|
||||
})
|
||||
})
|
||||
|
||||
Context("backends", func() {
|
||||
It("runs rwkv", func() {
|
||||
It("runs rwkv completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(resp.Choices[0].Text).To(Equal(" five."))
|
||||
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
|
||||
|
||||
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
|
||||
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Text
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(text).To(ContainSubstring("five"))
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
It("runs rwkv chat completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
||||
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices) > 0).To(BeTrue())
|
||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer stream.Close()
|
||||
|
||||
tokens := 0
|
||||
text := ""
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
text += response.Choices[0].Delta.Content
|
||||
tokens++
|
||||
}
|
||||
Expect(text).ToNot(BeEmpty())
|
||||
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
|
||||
|
||||
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -462,7 +717,12 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE")))
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -481,19 +741,14 @@ var _ = Describe("API test", func() {
|
||||
cancel()
|
||||
app.Shutdown()
|
||||
})
|
||||
It("can generate chat completions from config file", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(12))
|
||||
})
|
||||
It("can generate chat completions from config file", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
|
||||
It("can generate chat completions from config file (list1)", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
It("can generate chat completions from config file", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
|
||||
It("can generate chat completions from config file (list2)", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Choices)).To(Equal(1))
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
|
||||
109
api/backend/embeddings.go
Normal file
109
api/backend/embeddings.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
||||
if !c.Embeddings {
|
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||
}
|
||||
|
||||
modelFile := c.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fn func() ([]float32, error)
|
||||
switch model := inferenceModel.(type) {
|
||||
case *grpc.Client:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
embeds := []int32{}
|
||||
|
||||
for _, t := range tokens {
|
||||
embeds = append(embeds, int32(t))
|
||||
}
|
||||
predictOptions.EmbeddingTokens = embeds
|
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Embeddings, nil
|
||||
}
|
||||
predictOptions.Embeddings = s
|
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Embeddings, nil
|
||||
}
|
||||
default:
|
||||
fn = func() ([]float32, error) {
|
||||
return nil, fmt.Errorf("embeddings not supported by the backend")
|
||||
}
|
||||
}
|
||||
|
||||
return func() ([]float32, error) {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[modelFile]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[modelFile] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
embeds, err := fn()
|
||||
if err != nil {
|
||||
return embeds, err
|
||||
}
|
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- {
|
||||
if embeds[i] == 0.0 {
|
||||
embeds = embeds[:i]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return embeds, nil
|
||||
}, nil
|
||||
}
|
||||
69
api/backend/image.go
Normal file
69
api/backend/image.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModel(c.Model),
|
||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||
CUDA: c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
}),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
o.Context,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
Mode: int32(mode),
|
||||
Step: int32(step),
|
||||
Seed: int32(seed),
|
||||
PositivePrompt: positive_prompt,
|
||||
NegativePrompt: negative_prompt,
|
||||
Dst: dst,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return func() error {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[c.Backend]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[c.Backend] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return fn()
|
||||
}, nil
|
||||
}
|
||||
125
api/backend/llm.go
Normal file
125
api/backend/llm.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel *grpc.Client
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (string, error) {
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||
opts.Prompt = s
|
||||
if tokenCallback != nil {
|
||||
ss := ""
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(s []byte) {
|
||||
tokenCallback(string(s))
|
||||
ss += string(s)
|
||||
})
|
||||
return ss, err
|
||||
} else {
|
||||
reply, err := inferenceModel.Predict(ctx, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.Message), err
|
||||
}
|
||||
}
|
||||
|
||||
return func() (string, error) {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[modelFile]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[modelFile] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return fn()
|
||||
}, nil
|
||||
}
|
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
|
||||
func Finetune(config config.Config, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
for _, c := range config.Cutstrings {
|
||||
mu.Lock()
|
||||
reg, ok := cutstrings[c]
|
||||
if !ok {
|
||||
cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = cutstrings[c]
|
||||
}
|
||||
mu.Unlock()
|
||||
prediction = reg.ReplaceAllString(prediction, "")
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSpace {
|
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||
}
|
||||
return prediction
|
||||
|
||||
}
|
||||
22
api/backend/lock.go
Normal file
22
api/backend/lock.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package backend
|
||||
|
||||
import "sync"
|
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex
|
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||
|
||||
func Lock(s string) *sync.Mutex {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[s]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[s] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
|
||||
return l
|
||||
}
|
||||
85
api/backend/options.go
Normal file
85
api/backend/options.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
)
|
||||
|
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
b = c.Batch
|
||||
}
|
||||
return &pb.ModelOptions{
|
||||
ContextSize: int32(c.ContextSize),
|
||||
Seed: int32(c.Seed),
|
||||
NBatch: int32(b),
|
||||
NGQA: c.NGQA,
|
||||
|
||||
RMSNormEps: c.RMSNormEps,
|
||||
F16Memory: c.F16,
|
||||
MLock: c.MMlock,
|
||||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NUMA: c.NUMA,
|
||||
Embeddings: c.Embeddings,
|
||||
LowVRAM: c.LowVRAM,
|
||||
NGPULayers: int32(c.NGPULayers),
|
||||
MMap: c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
Threads: int32(c.Threads),
|
||||
TensorSplit: c.TensorSplit,
|
||||
// AutoGPTQ
|
||||
ModelBaseName: c.AutoGPTQ.ModelBaseName,
|
||||
Device: c.AutoGPTQ.Device,
|
||||
UseTriton: c.AutoGPTQ.Triton,
|
||||
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
|
||||
}
|
||||
}
|
||||
|
||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath := ""
|
||||
if c.PromptCachePath != "" {
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
os.MkdirAll(filepath.Dir(p), 0755)
|
||||
promptCachePath = p
|
||||
}
|
||||
return &pb.PredictOptions{
|
||||
Temperature: float32(c.Temperature),
|
||||
TopP: float32(c.TopP),
|
||||
TopK: int32(c.TopK),
|
||||
Tokens: int32(c.Maxtokens),
|
||||
Threads: int32(c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: c.F16,
|
||||
DebugMode: c.Debug,
|
||||
Grammar: c.Grammar,
|
||||
NegativePromptScale: c.NegativePromptScale,
|
||||
RopeFreqBase: c.RopeFreqBase,
|
||||
RopeFreqScale: c.RopeFreqScale,
|
||||
NegativePrompt: c.NegativePrompt,
|
||||
Mirostat: int32(c.LLMConfig.Mirostat),
|
||||
MirostatETA: float32(c.LLMConfig.MirostatETA),
|
||||
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
|
||||
Debug: c.Debug,
|
||||
StopPrompts: c.StopWords,
|
||||
Repeat: int32(c.RepeatPenalty),
|
||||
NKeep: int32(c.Keep),
|
||||
Batch: int32(c.Batch),
|
||||
IgnoreEOS: c.IgnoreEOS,
|
||||
Seed: int32(c.Seed),
|
||||
FrequencyPenalty: float32(c.FrequencyPenalty),
|
||||
MLock: c.MMlock,
|
||||
MMap: c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
TensorSplit: c.TensorSplit,
|
||||
TailFreeSamplingZ: float32(c.TFZ),
|
||||
TypicalP: float32(c.TypicalP),
|
||||
}
|
||||
}
|
||||
42
api/backend/transcript.go
Normal file
42
api/backend/transcript.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return nil, fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Threads: uint32(c.Threads),
|
||||
})
|
||||
}
|
||||
79
api/backend/tts.go
Normal file
79
api/backend/tts.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
||||
bb := backend
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
}
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return "", nil, fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.AudioDir, fileName)
|
||||
|
||||
// If the model file is not empty, we pass it joined with the model path
|
||||
modelPath := ""
|
||||
if modelFile != "" {
|
||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Dst: filePath,
|
||||
})
|
||||
|
||||
return filePath, res, err
|
||||
}
|
||||
368
api/config.go
368
api/config.go
@@ -1,368 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
OpenAIRequest `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
ContextSize int `yaml:"context_size"`
|
||||
F16 bool `yaml:"f16"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
Threads int `yaml:"threads"`
|
||||
Debug bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||
Mirostat int `yaml:"mirostat"`
|
||||
NGPULayers int `yaml:"gpu_layers"`
|
||||
MMap bool `yaml:"mmap"`
|
||||
MMlock bool `yaml:"mmlock"`
|
||||
LowVRAM bool `yaml:"low_vram"`
|
||||
|
||||
TensorSplit string `yaml:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu"`
|
||||
ImageGenerationAssets string `yaml:"asset_dir"`
|
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||
|
||||
PromptStrings, InputStrings []string
|
||||
InputToken [][]int
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Completion string `yaml:"completion"`
|
||||
Chat string `yaml:"chat"`
|
||||
Edit string `yaml:"edit"`
|
||||
}
|
||||
|
||||
type ConfigMerger struct {
|
||||
configs map[string]Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func defaultConfig(modelFile string) *Config {
|
||||
return &Config{
|
||||
OpenAIRequest: defaultRequest(modelFile),
|
||||
}
|
||||
}
|
||||
|
||||
func NewConfigMerger() *ConfigMerger {
|
||||
return &ConfigMerger{
|
||||
configs: make(map[string]Config),
|
||||
}
|
||||
}
|
||||
func ReadConfigFile(file string) ([]*Config, error) {
|
||||
c := &[]*Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadConfig(file string) (*Config, error) {
|
||||
c := &Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (cm *ConfigMerger) LoadConfigFile(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
cm.configs[cc.Name] = *cc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigMerger) LoadConfig(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfig(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
|
||||
cm.configs[c.Name] = *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
v, exists := cm.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cm *ConfigMerger) ListConfigs() []string {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []string
|
||||
for k := range cm.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigMerger) LoadConfigs(path string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files := make([]fs.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, info)
|
||||
}
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") {
|
||||
continue
|
||||
}
|
||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||
if err == nil {
|
||||
cm.configs[c.Name] = *c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateConfig(config *Config, input *OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
modelFile := input.Model
|
||||
|
||||
if c.Params("model") != "" {
|
||||
modelFile = c.Params("model")
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelFile = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelFile = bearer
|
||||
}
|
||||
return modelFile, input, nil
|
||||
}
|
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
|
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
||||
|
||||
var config *Config
|
||||
|
||||
defaults := func() {
|
||||
config = defaultConfig(modelFile)
|
||||
config.ContextSize = ctx
|
||||
config.Threads = threads
|
||||
config.F16 = f16
|
||||
config.Debug = debug
|
||||
}
|
||||
|
||||
cfg, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfg, exists = cm.GetConfig(modelFile)
|
||||
if exists {
|
||||
config = &cfg
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
config = &cfg
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(config, input)
|
||||
|
||||
// Don't allow 0 as setting
|
||||
if config.Threads == 0 {
|
||||
if threads != 0 {
|
||||
config.Threads = threads
|
||||
} else {
|
||||
config.Threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug {
|
||||
config.Debug = true
|
||||
}
|
||||
|
||||
return config, input, nil
|
||||
}
|
||||
247
api/config/config.go
Normal file
247
api/config/config.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package api_config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PredictionOptions `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
|
||||
F16 bool `yaml:"f16"`
|
||||
Threads int `yaml:"threads"`
|
||||
Debug bool `yaml:"debug"`
|
||||
Roles map[string]string `yaml:"roles"`
|
||||
Embeddings bool `yaml:"embeddings"`
|
||||
Backend string `yaml:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-"`
|
||||
InputToken [][]int `yaml:"-"`
|
||||
functionCallString, functionCallNameString string `yaml:"-"`
|
||||
|
||||
FunctionsConfig Functions `yaml:"function"`
|
||||
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline"`
|
||||
|
||||
// AutoGPTQ specifics
|
||||
AutoGPTQ AutoGPTQ `yaml:"autogptq"`
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers"`
|
||||
|
||||
Step int `yaml:"step"`
|
||||
}
|
||||
|
||||
type Diffusers struct {
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
CUDA bool `yaml:"cuda"`
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||
MirostatETA float64 `yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `yaml:"mirostat_tau"`
|
||||
Mirostat int `yaml:"mirostat"`
|
||||
NGPULayers int `yaml:"gpu_layers"`
|
||||
MMap bool `yaml:"mmap"`
|
||||
MMlock bool `yaml:"mmlock"`
|
||||
LowVRAM bool `yaml:"low_vram"`
|
||||
Grammar string `yaml:"grammar"`
|
||||
StopWords []string `yaml:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings"`
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
ContextSize int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
}
|
||||
|
||||
type AutoGPTQ struct {
|
||||
ModelBaseName string `yaml:"model_base_name"`
|
||||
Device string `yaml:"device"`
|
||||
Triton bool `yaml:"triton"`
|
||||
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
|
||||
}
|
||||
|
||||
type Functions struct {
|
||||
DisableNoAction bool `yaml:"disable_no_action"`
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Chat string `yaml:"chat"`
|
||||
ChatMessage string `yaml:"chat_message"`
|
||||
Completion string `yaml:"completion"`
|
||||
Edit string `yaml:"edit"`
|
||||
Functions string `yaml:"function"`
|
||||
}
|
||||
|
||||
type ConfigLoader struct {
|
||||
configs map[string]Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *Config) SetFunctionCallString(s string) {
|
||||
c.functionCallString = s
|
||||
}
|
||||
|
||||
func (c *Config) SetFunctionCallNameString(s string) {
|
||||
c.functionCallNameString = s
|
||||
}
|
||||
|
||||
func (c *Config) ShouldUseFunctions() bool {
|
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
|
||||
}
|
||||
|
||||
func (c *Config) ShouldCallSpecificFunction() bool {
|
||||
return len(c.functionCallNameString) > 0
|
||||
}
|
||||
|
||||
func (c *Config) FunctionToCall() string {
|
||||
return c.functionCallNameString
|
||||
}
|
||||
|
||||
func defaultPredictOptions(modelFile string) PredictionOptions {
|
||||
return PredictionOptions{
|
||||
TopP: 0.7,
|
||||
TopK: 80,
|
||||
Maxtokens: 512,
|
||||
Temperature: 0.9,
|
||||
Model: modelFile,
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultConfig(modelFile string) *Config {
|
||||
return &Config{
|
||||
PredictionOptions: defaultPredictOptions(modelFile),
|
||||
}
|
||||
}
|
||||
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
configs: make(map[string]Config),
|
||||
}
|
||||
}
|
||||
func ReadConfigFile(file string) ([]*Config, error) {
|
||||
c := &[]*Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadConfig(file string) (*Config, error) {
|
||||
c := &Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
if err := yaml.Unmarshal(f, c); err != nil {
|
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
cm.configs[cc.Name] = *cc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfig(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
|
||||
cm.configs[c.Name] = *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
v, exists := cm.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetAllConfigs() []Config {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []Config
|
||||
for _, v := range cm.configs {
|
||||
res = append(res, v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []string
|
||||
for k := range cm.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files := make([]fs.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, info)
|
||||
}
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") {
|
||||
continue
|
||||
}
|
||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||
if err == nil {
|
||||
cm.configs[c.Name] = *c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package api
|
||||
package api_config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -26,29 +28,29 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
})
|
||||
|
||||
It("Test LoadConfigs", func() {
|
||||
cm := NewConfigMerger()
|
||||
options := newOptions()
|
||||
cm := NewConfigLoader()
|
||||
opts := options.NewOptions()
|
||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
WithModelLoader(modelLoader)(options)
|
||||
options.WithModelLoader(modelLoader)(opts)
|
||||
|
||||
err := cm.LoadConfigs(options.loader.ModelPath)
|
||||
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(cm.configs).ToNot(BeNil())
|
||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
||||
|
||||
// config should includes gpt4all models's api.config
|
||||
Expect(cm.configs).To(HaveKey("gpt4all"))
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
|
||||
|
||||
// config should includes gpt2 models's api.config
|
||||
Expect(cm.configs).To(HaveKey("gpt4all-2"))
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
|
||||
|
||||
// config should includes text-embedding-ada-002 models's api.config
|
||||
Expect(cm.configs).To(HaveKey("text-embedding-ada-002"))
|
||||
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||
|
||||
// config should includes rwkv_test models's api.config
|
||||
Expect(cm.configs).To(HaveKey("rwkv_test"))
|
||||
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
|
||||
|
||||
// config should includes whisper-1 models's api.config
|
||||
Expect(cm.configs).To(HaveKey("whisper-1"))
|
||||
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
44
api/config/prediction.go
Normal file
44
api/config/prediction.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package api_config
|
||||
|
||||
type PredictionOptions struct {
|
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Model string `json:"model" yaml:"model"`
|
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Language string `json:"language"`
|
||||
|
||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||
N int `json:"n"`
|
||||
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
||||
TopK int `json:"top_k" yaml:"top_k"`
|
||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
||||
Echo bool `json:"echo"`
|
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"`
|
||||
F16 bool `json:"f16" yaml:"f16"`
|
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||
Seed int `json:"seed" yaml:"seed"`
|
||||
|
||||
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
|
||||
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`
|
||||
RopeFreqScale float32 `json:"rope_freq_scale" yaml:"rope_freq_scale"`
|
||||
NegativePromptScale float32 `json:"negative_prompt_scale" yaml:"negative_prompt_scale"`
|
||||
// AutoGPTQ
|
||||
UseFastTokenizer bool `json:"use_fast_tokenizer" yaml:"use_fast_tokenizer"`
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/tts"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type TTSRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Input string `json:"input" yaml:"input"`
|
||||
}
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(TTSRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
piperModel, err := o.loader.BackendLoader(model.PiperBackend, input.Model, []llama.ModelOption{}, uint32(0), o.assetsDestination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
w, ok := piperModel.(*tts.Piper)
|
||||
if !ok {
|
||||
return fmt.Errorf("loader returned non-piper object %+v", w)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.audioDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.audioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.audioDir, fileName)
|
||||
|
||||
modelPath := filepath.Join(o.loader.ModelPath, input.Model)
|
||||
|
||||
if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.TTS(input.Input, modelPath, filePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,19 @@
|
||||
package api
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -38,7 +42,7 @@ type galleryApplier struct {
|
||||
statuses map[string]*galleryOpStatus
|
||||
}
|
||||
|
||||
func newGalleryApplier(modelPath string) *galleryApplier {
|
||||
func NewGalleryService(modelPath string) *galleryApplier {
|
||||
return &galleryApplier{
|
||||
modelPath: modelPath,
|
||||
C: make(chan galleryOp),
|
||||
@@ -47,7 +51,7 @@ func newGalleryApplier(modelPath string) *galleryApplier {
|
||||
}
|
||||
|
||||
// prepareModel applies a
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error {
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||
if err != nil {
|
||||
@@ -72,13 +76,15 @@ func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
||||
return g.statuses[s]
|
||||
}
|
||||
|
||||
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.C:
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// updates the status with an error
|
||||
@@ -89,13 +95,17 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
displayDownload(fileName, current, total, percentage)
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
var err error
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.galleryName != "" {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
if strings.Contains(op.galleryName, "@") {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
}
|
||||
} else {
|
||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
||||
}
|
||||
@@ -118,63 +128,57 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
|
||||
}()
|
||||
}
|
||||
|
||||
var lastProgress time.Time = time.Now()
|
||||
var startTime time.Time = time.Now()
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func displayDownload(fileName string, current string, total string, percentage float64) {
|
||||
currentTime := time.Now()
|
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||
|
||||
lastProgress = currentTime
|
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration
|
||||
if percentage > 0 {
|
||||
elapsed := currentTime.Sub(startTime)
|
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||
}
|
||||
|
||||
if total != "" {
|
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
log.Debug().Msgf("Downloading: %s", current)
|
||||
if strings.Contains(r.ID, "@") {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
||||
var requests []galleryModel
|
||||
|
||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error {
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
var requests []galleryModel
|
||||
err := json.Unmarshal([]byte(s), &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
|
||||
/// Endpoints
|
||||
|
||||
func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
status := g.getStatus(c.Params("uuid"))
|
||||
@@ -191,7 +195,7 @@ type GalleryModel struct {
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
||||
func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
@@ -216,7 +220,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal
|
||||
}
|
||||
}
|
||||
|
||||
func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
|
||||
func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
|
||||
|
||||
32
api/localai/localai.go
Normal file
32
api/localai/localai.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type TTSRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Input string `json:"input" yaml:"input"`
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(TTSRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
||||
772
api/openai.go
772
api/openai.go
@@ -1,772 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct {
|
||||
Code any `json:"code,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Param *string `json:"param,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object,omitempty"`
|
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Created int `json:"created,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
Data []Item `json:"data,omitempty"`
|
||||
|
||||
Usage OpenAIUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty" yaml:"role"`
|
||||
Content string `json:"content,omitempty" yaml:"content"`
|
||||
}
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"`
|
||||
Language string `json:"language"`
|
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"`
|
||||
// image
|
||||
Size string `json:"size"`
|
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"`
|
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"`
|
||||
Input interface{} `json:"input" yaml:"input"`
|
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"`
|
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"`
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
Echo bool `json:"echo"`
|
||||
// Common options between all the API calls
|
||||
TopP float64 `json:"top_p" yaml:"top_p"`
|
||||
TopK int `json:"top_k" yaml:"top_k"`
|
||||
Temperature float64 `json:"temperature" yaml:"temperature"`
|
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
|
||||
|
||||
N int `json:"n"`
|
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"`
|
||||
F16 bool `json:"f16" yaml:"f16"`
|
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
|
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
|
||||
Keep int `json:"n_keep" yaml:"n_keep"`
|
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
|
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
|
||||
Mirostat int `json:"mirostat" yaml:"mirostat"`
|
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
|
||||
TFZ float64 `json:"tfz" yaml:"tfz"`
|
||||
|
||||
Seed int `json:"seed" yaml:"seed"`
|
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"`
|
||||
Step int `json:"step"`
|
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||
}
|
||||
|
||||
func defaultRequest(modelFile string) OpenAIRequest {
|
||||
return OpenAIRequest{
|
||||
TopP: 0.7,
|
||||
TopK: 80,
|
||||
Maxtokens: 512,
|
||||
Temperature: 0.9,
|
||||
Model: modelFile,
|
||||
}
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
templateFile = config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing")
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{Input: predInput})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []Choice
|
||||
for _, i := range config.PromptStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{Input: i})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
initialMessage := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
mess := []string{}
|
||||
for _, i := range input.Messages {
|
||||
var content string
|
||||
r := config.Roles[i.Role]
|
||||
if r != "" {
|
||||
content = fmt.Sprint(r, " ", i.Content)
|
||||
} else {
|
||||
content = i.Content
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Chat != "" {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{Input: predInput})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Delta: &Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
templateFile = config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
var result []Choice
|
||||
for _, i := range config.InputStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
Instruction string
|
||||
}{Input: i})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
*
|
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "512x512"
|
||||
}'
|
||||
|
||||
*
|
||||
*/
|
||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o.loader, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if m == "" {
|
||||
m = model.StableDiffusionBackend
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", m)
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" {
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(input.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if input.ResponseFormat == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
|
||||
var result []Item
|
||||
for _, i := range config.PromptStrings {
|
||||
n := input.N
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := 15
|
||||
|
||||
if input.Mode != 0 {
|
||||
mode = input.Mode
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
step = input.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = o.imageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := &Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Data: result,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o.loader, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads), o.assetsDestination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
w, ok := whisperModel.(whisper.Model)
|
||||
if !ok {
|
||||
return fmt.Errorf("loader returned non-whisper object")
|
||||
}
|
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr})
|
||||
}
|
||||
}
|
||||
|
||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []OpenAIModel{}
|
||||
for _, m := range models {
|
||||
mm[m] = nil
|
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
|
||||
for _, k := range cm.ListConfigs() {
|
||||
if _, exists := mm[k]; !exists {
|
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIModel `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
}
|
||||
}
|
||||
115
api/openai/api.go
Normal file
115
api/openai/api.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
)
|
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct {
|
||||
Code any `json:"code,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Param *string `json:"param,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error *APIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object,omitempty"`
|
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Created int `json:"created,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
Data []Item `json:"data,omitempty"`
|
||||
|
||||
Usage OpenAIUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"`
|
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"`
|
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
config.PredictionOptions
|
||||
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"`
|
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"`
|
||||
// image
|
||||
Size string `json:"size"`
|
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"`
|
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"`
|
||||
Input interface{} `json:"input" yaml:"input"`
|
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"`
|
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"`
|
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"`
|
||||
Step int `json:"step"`
|
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"`
|
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"`
|
||||
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
|
||||
// AutoGPTQ
|
||||
ModelBaseName string `json:"model_base_name" yaml:"model_base_name"`
|
||||
}
|
||||
359
api/openai/chat.go
Normal file
359
api/openai/chat.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
emptyMessage := ""
|
||||
|
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
initialMessage := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: &emptyMessage}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
noActionDescription := "use this action to answer without performing any action"
|
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
||||
}
|
||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("")
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
||||
}
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions
|
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && *i.Content != ""
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: *i.Content,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, " ", *i.Content)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(*i.Content)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
if toStream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
|
||||
if toStream {
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.Loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
break
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Delta: &Message{Content: &emptyMessage},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||
if processFunctions {
|
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s = utils.EscapeNewLines(s)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"]
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args)
|
||||
|
||||
ss["arguments"] = string(d)
|
||||
ss["name"] = func_name
|
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(d), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = backend.Finetune(*config, predInput, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = ""
|
||||
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction)
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}})
|
||||
} else {
|
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &Message{Role: "assistant", FunctionCall: ss},
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
*c = append(*c, Choice{FinishReason: "stop", Index: 0, Message: &Message{Role: "assistant", Content: &s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
156
api/openai/completion.go
Normal file
156
api/openai/completion.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
templateFile = config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.Loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []Choice
|
||||
for k, i := range config.PromptStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
69
api/openai/edit.go
Normal file
69
api/openai/edit.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
templateFile := config.Model
|
||||
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
templateFile = config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
var result []Choice
|
||||
for _, i := range config.InputStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
70
api/openai/embeddings.go
Normal file
70
api/openai/embeddings.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
160
api/openai/image.go
Normal file
160
api/openai/image.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
*
|
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "512x512"
|
||||
}'
|
||||
|
||||
*
|
||||
*/
|
||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if m == "" {
|
||||
m = model.StableDiffusionBackend
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", m)
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" {
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(input.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if input.ResponseFormat == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
|
||||
var result []Item
|
||||
for _, i := range config.PromptStrings {
|
||||
n := input.N
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := config.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if input.Mode != 0 {
|
||||
mode = input.Mode
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
step = input.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = o.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := &Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
resp := &OpenAIResponse{
|
||||
Data: result,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
37
api/openai/inference.go
Normal file
37
api/openai/inference.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
||||
n := req.N
|
||||
result := []Choice{}
|
||||
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction)
|
||||
cb(prediction, &result)
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
68
api/openai/list.go
Normal file
68
api/openai/list.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []OpenAIModel{}
|
||||
|
||||
var filterFn func(name string) bool
|
||||
filter := c.Query("filter")
|
||||
|
||||
// If filter is not specified, do not filter the list by model name
|
||||
if filter == "" {
|
||||
filterFn = func(_ string) bool { return true }
|
||||
} else {
|
||||
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
|
||||
rxp, err := regexp.Compile(filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filterFn = func(name string) bool {
|
||||
return rxp.MatchString(name)
|
||||
}
|
||||
}
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||
|
||||
// Start with the known configurations
|
||||
for _, c := range cm.GetAllConfigs() {
|
||||
if excludeConfigured {
|
||||
mm[c.Model] = nil
|
||||
}
|
||||
|
||||
if filterFn(c.Name) {
|
||||
dataModels = append(dataModels, OpenAIModel{ID: c.Name, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
// Then iterate through the loose files:
|
||||
for _, m := range models {
|
||||
// And only adds them if they shouldn't be skipped.
|
||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIModel `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
}
|
||||
}
|
||||
268
api/openai/request.go
Normal file
268
api/openai/request.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
options "github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) {
|
||||
loader := o.Loader
|
||||
input := new(OpenAIRequest)
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
input.Context = ctx
|
||||
input.Cancel = cancel
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
modelFile := input.Model
|
||||
|
||||
if c.Params("model") != "" {
|
||||
modelFile = c.Params("model")
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelFile = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelFile = bearer
|
||||
}
|
||||
return modelFile, input, nil
|
||||
}
|
||||
|
||||
func updateConfig(config *config.Config, input *OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.LLMConfig.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) {
|
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
||||
|
||||
var cfg *config.Config
|
||||
|
||||
defaults := func() {
|
||||
cfg = config.DefaultConfig(modelFile)
|
||||
cfg.ContextSize = ctx
|
||||
cfg.Threads = threads
|
||||
cfg.F16 = f16
|
||||
cfg.Debug = debug
|
||||
}
|
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = cm.GetConfig(modelFile)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(cfg, input)
|
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 {
|
||||
if threads != 0 {
|
||||
cfg.Threads = threads
|
||||
} else {
|
||||
cfg.Threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug {
|
||||
cfg.Debug = true
|
||||
}
|
||||
|
||||
return cfg, input, nil
|
||||
}
|
||||
71
api/openai/transcription.go
Normal file
71
api/openai/transcription.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
}
|
||||
}
|
||||
153
api/options.go
153
api/options.go
@@ -1,153 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
context context.Context
|
||||
configFile string
|
||||
loader *model.ModelLoader
|
||||
uploadLimitMB, threads, ctxSize int
|
||||
f16 bool
|
||||
debug, disableMessage bool
|
||||
imageDir string
|
||||
audioDir string
|
||||
cors bool
|
||||
preloadJSONModels string
|
||||
preloadModelsFromPath string
|
||||
corsAllowOrigins string
|
||||
|
||||
galleries []gallery.Gallery
|
||||
|
||||
backendAssets embed.FS
|
||||
assetsDestination string
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
|
||||
func newOptions(o ...AppOption) *Option {
|
||||
opt := &Option{
|
||||
context: context.Background(),
|
||||
uploadLimitMB: 15,
|
||||
threads: 1,
|
||||
ctxSize: 512,
|
||||
debug: true,
|
||||
disableMessage: true,
|
||||
}
|
||||
for _, oo := range o {
|
||||
oo(opt)
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.cors = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.corsAllowOrigins = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssetsOutput(out string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.assetsDestination = out
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssets(f embed.FS) AppOption {
|
||||
return func(o *Option) {
|
||||
o.backendAssets = f
|
||||
}
|
||||
}
|
||||
|
||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||
return func(o *Option) {
|
||||
o.galleries = append(o.galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) AppOption {
|
||||
return func(o *Option) {
|
||||
o.context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.preloadModelsFromPath = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.preloadJSONModels = configFile
|
||||
}
|
||||
}
|
||||
func WithConfigFile(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.configFile = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
||||
return func(o *Option) {
|
||||
o.loader = loader
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadLimitMB(limit int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.uploadLimitMB = limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithContextSize(ctxSize int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ctxSize = ctxSize
|
||||
}
|
||||
}
|
||||
|
||||
func WithF16(f16 bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.f16 = f16
|
||||
}
|
||||
}
|
||||
|
||||
func WithDebug(debug bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.debug = debug
|
||||
}
|
||||
}
|
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.disableMessage = disableMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudioDir(audioDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.audioDir = audioDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithImageDir(imageDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.imageDir = imageDir
|
||||
}
|
||||
}
|
||||
193
api/options/options.go
Normal file
193
api/options/options.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package options
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
Loader *model.ModelLoader
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
Debug, DisableMessage bool
|
||||
ImageDir string
|
||||
AudioDir string
|
||||
CORS bool
|
||||
PreloadJSONModels string
|
||||
PreloadModelsFromPath string
|
||||
CORSAllowOrigins string
|
||||
ApiKeys []string
|
||||
|
||||
Galleries []gallery.Gallery
|
||||
|
||||
BackendAssets embed.FS
|
||||
AssetsDestination string
|
||||
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries bool
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
|
||||
func NewOptions(o ...AppOption) *Option {
|
||||
opt := &Option{
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Threads: 1,
|
||||
ContextSize: 512,
|
||||
Debug: true,
|
||||
DisableMessage: true,
|
||||
}
|
||||
for _, oo := range o {
|
||||
oo(opt)
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.CORS = b
|
||||
}
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *Option) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
o.ExternalGRPCBackends = make(map[string]string)
|
||||
}
|
||||
o.ExternalGRPCBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.CORSAllowOrigins = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssetsOutput(out string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.AssetsDestination = out
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssets(f embed.FS) AppOption {
|
||||
return func(o *Option) {
|
||||
o.BackendAssets = f
|
||||
}
|
||||
}
|
||||
|
||||
func WithStringGalleries(galls string) AppOption {
|
||||
return func(o *Option) {
|
||||
if galls == "" {
|
||||
log.Debug().Msgf("no galleries to load")
|
||||
return
|
||||
}
|
||||
var galleries []gallery.Gallery
|
||||
if err := json.Unmarshal([]byte(galls), &galleries); err != nil {
|
||||
log.Error().Msgf("failed loading galleries: %s", err.Error())
|
||||
}
|
||||
o.Galleries = append(o.Galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Galleries = append(o.Galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.PreloadModelsFromPath = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.PreloadJSONModels = configFile
|
||||
}
|
||||
}
|
||||
func WithConfigFile(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ConfigFile = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Loader = loader
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadLimitMB(limit int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.UploadLimitMB = limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithContextSize(ctxSize int) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ContextSize = ctxSize
|
||||
}
|
||||
}
|
||||
|
||||
func WithF16(f16 bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.F16 = f16
|
||||
}
|
||||
}
|
||||
|
||||
func WithDebug(debug bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Debug = debug
|
||||
}
|
||||
}
|
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption {
|
||||
return func(o *Option) {
|
||||
o.DisableMessage = disableMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudioDir(audioDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.AudioDir = audioDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithImageDir(imageDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ImageDir = imageDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithApiKeys(apiKeys []string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.ApiKeys = apiKeys
|
||||
}
|
||||
}
|
||||
@@ -1,647 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/donomii/go-rwkv.cpp"
|
||||
"github.com/go-skynet/LocalAI/pkg/langchain"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
||||
"github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex
|
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
||||
|
||||
func defaultLLamaOpts(c Config) []llama.ModelOption {
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
if c.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
|
||||
}
|
||||
if c.F16 {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if c.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
|
||||
if c.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit))
|
||||
if c.Batch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if c.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if c.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
return llamaOpts
|
||||
}
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) {
|
||||
if c.Backend != model.StableDiffusionBackend {
|
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||
}
|
||||
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads), o.assetsDestination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fn func() error
|
||||
switch model := inferenceModel.(type) {
|
||||
case *stablediffusion.StableDiffusion:
|
||||
fn = func() error {
|
||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst)
|
||||
}
|
||||
|
||||
default:
|
||||
fn = func() error {
|
||||
return fmt.Errorf("creation of images not supported by the backend")
|
||||
}
|
||||
}
|
||||
|
||||
return func() error {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[c.Backend]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[c.Backend] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return fn()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) {
|
||||
if !c.Embeddings {
|
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||
}
|
||||
|
||||
modelFile := c.Model
|
||||
|
||||
llamaOpts := defaultLLamaOpts(c)
|
||||
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fn func() ([]float32, error)
|
||||
switch model := inferenceModel.(type) {
|
||||
case *llama.LLama:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
return model.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
return model.Embeddings(s, predictOptions...)
|
||||
}
|
||||
// bert embeddings
|
||||
case *bert.Bert:
|
||||
fn = func() ([]float32, error) {
|
||||
if len(tokens) > 0 {
|
||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
|
||||
}
|
||||
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
||||
}
|
||||
default:
|
||||
fn = func() ([]float32, error) {
|
||||
return nil, fmt.Errorf("embeddings not supported by the backend")
|
||||
}
|
||||
}
|
||||
|
||||
return func() ([]float32, error) {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[modelFile]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[modelFile] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
embeds, err := fn()
|
||||
if err != nil {
|
||||
return embeds, err
|
||||
}
|
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- {
|
||||
if embeds[i] == 0.0 {
|
||||
embeds = embeds[:i]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return embeds, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(c.Temperature),
|
||||
llama.SetTopP(c.TopP),
|
||||
llama.SetTopK(c.TopK),
|
||||
llama.SetTokens(c.Maxtokens),
|
||||
llama.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if c.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
if c.PromptCachePath != "" {
|
||||
// Create parent directory
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
os.MkdirAll(filepath.Dir(p), 0755)
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
|
||||
}
|
||||
|
||||
if c.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
|
||||
}
|
||||
|
||||
if c.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA))
|
||||
}
|
||||
|
||||
if c.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU))
|
||||
}
|
||||
|
||||
if c.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...))
|
||||
|
||||
if c.RepeatPenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty))
|
||||
}
|
||||
|
||||
if c.Keep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep))
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.F16 {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if c.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP))
|
||||
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
supportStreams := false
|
||||
modelFile := c.Model
|
||||
|
||||
llamaOpts := defaultLLamaOpts(c)
|
||||
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fn func() (string, error)
|
||||
|
||||
switch model := inferenceModel.(type) {
|
||||
case *rwkv.RwkvState:
|
||||
supportStreams = true
|
||||
|
||||
fn = func() (string, error) {
|
||||
stopWord := "\n"
|
||||
if len(c.StopWords) > 0 {
|
||||
stopWord = c.StopWords[0]
|
||||
}
|
||||
|
||||
if err := model.ProcessInput(s); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
case *transformers.GPTNeoX:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.Replit:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.Starcoder:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.MPT:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *bloomz.Bloomz:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{
|
||||
bloomz.SetTemperature(c.Temperature),
|
||||
bloomz.SetTopP(c.TopP),
|
||||
bloomz.SetTopK(c.TopK),
|
||||
bloomz.SetTokens(c.Maxtokens),
|
||||
bloomz.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.Falcon:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.GPTJ:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.Dolly:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *transformers.GPT2:
|
||||
fn = func() (string, error) {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{
|
||||
transformers.SetTemperature(c.Temperature),
|
||||
transformers.SetTopP(c.TopP),
|
||||
transformers.SetTopK(c.TopK),
|
||||
transformers.SetTokens(c.Maxtokens),
|
||||
transformers.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
}
|
||||
case *gpt4all.Model:
|
||||
supportStreams = true
|
||||
|
||||
fn = func() (string, error) {
|
||||
if tokenCallback != nil {
|
||||
model.SetTokenCallback(tokenCallback)
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt4all.PredictOption{
|
||||
gpt4all.SetTemperature(c.Temperature),
|
||||
gpt4all.SetTopP(c.TopP),
|
||||
gpt4all.SetTopK(c.TopK),
|
||||
gpt4all.SetTokens(c.Maxtokens),
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
str, er := model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil)
|
||||
return str, er
|
||||
}
|
||||
case *llama.LLama:
|
||||
supportStreams = true
|
||||
fn = func() (string, error) {
|
||||
|
||||
if tokenCallback != nil {
|
||||
model.SetTokenCallback(tokenCallback)
|
||||
}
|
||||
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
||||
|
||||
str, er := model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil)
|
||||
return str, er
|
||||
}
|
||||
case *langchain.HuggingFace:
|
||||
fn = func() (string, error) {
|
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []langchain.PredictOption{
|
||||
langchain.SetModel(c.Model),
|
||||
langchain.SetMaxTokens(c.Maxtokens),
|
||||
langchain.SetTemperature(c.Temperature),
|
||||
langchain.SetStopWords(c.StopWords),
|
||||
}
|
||||
|
||||
pred, er := model.PredictHuggingFace(s, predictOptions...)
|
||||
if er != nil {
|
||||
return "", er
|
||||
}
|
||||
return pred.Completion, nil
|
||||
}
|
||||
}
|
||||
|
||||
return func() (string, error) {
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[modelFile]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[modelFile] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
res, err := fn()
|
||||
if tokenCallback != nil && !supportStreams {
|
||||
tokenCallback(res)
|
||||
}
|
||||
return res, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
||||
result := []Choice{}
|
||||
|
||||
n := input.N
|
||||
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
prediction = Finetune(*config, predInput, prediction)
|
||||
cb(prediction, &result)
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
|
||||
func Finetune(config Config, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
for _, c := range config.Cutstrings {
|
||||
mu.Lock()
|
||||
reg, ok := cutstrings[c]
|
||||
if !ok {
|
||||
cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = cutstrings[c]
|
||||
}
|
||||
mu.Unlock()
|
||||
prediction = reg.ReplaceAllString(prediction, "")
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSpace {
|
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||
}
|
||||
return prediction
|
||||
|
||||
}
|
||||
22
cmd/grpc/bert-embeddings/main.go
Normal file
22
cmd/grpc/bert-embeddings/main.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/bloomz/main.go
Normal file
23
cmd/grpc/bloomz/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/dolly/main.go
Normal file
23
cmd/grpc/dolly/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/falcon-ggml/main.go
Normal file
23
cmd/grpc/falcon-ggml/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
25
cmd/grpc/falcon/main.go
Normal file
25
cmd/grpc/falcon/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/gpt2/main.go
Normal file
23
cmd/grpc/gpt2/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/gpt4all/main.go
Normal file
23
cmd/grpc/gpt4all/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/gptj/main.go
Normal file
23
cmd/grpc/gptj/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/gptneox/main.go
Normal file
23
cmd/grpc/gptneox/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/langchain-huggingface/main.go
Normal file
23
cmd/grpc/langchain-huggingface/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
25
cmd/grpc/llama/main.go
Normal file
25
cmd/grpc/llama/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/mpt/main.go
Normal file
23
cmd/grpc/mpt/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/piper/main.go
Normal file
23
cmd/grpc/piper/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
tts "github.com/go-skynet/LocalAI/pkg/grpc/tts"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/replit/main.go
Normal file
23
cmd/grpc/replit/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/rwkv/main.go
Normal file
23
cmd/grpc/rwkv/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/stablediffusion/main.go
Normal file
23
cmd/grpc/stablediffusion/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
image "github.com/go-skynet/LocalAI/pkg/grpc/image"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/starcoder/main.go
Normal file
23
cmd/grpc/starcoder/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
23
cmd/grpc/whisper/main.go
Normal file
23
cmd/grpc/whisper/main.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,35 @@ cd /build
|
||||
if [ "$REBUILD" != "false" ]; then
|
||||
rm -rf ./local-ai
|
||||
ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build -j${BUILD_PARALLELISM:-1}
|
||||
else
|
||||
echo "@@@@@"
|
||||
echo "Skipping rebuild"
|
||||
echo "@@@@@"
|
||||
echo "If you are experiencing issues with the pre-compiled builds, try setting REBUILD=true"
|
||||
echo "If you are still experiencing issues with the build, try setting CMAKE_ARGS and disable the instructions set as needed:"
|
||||
echo 'CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF"'
|
||||
echo "see the documentation at: https://localai.io/basics/build/index.html"
|
||||
echo "Note: See also https://github.com/go-skynet/LocalAI/issues/288"
|
||||
echo "@@@@@"
|
||||
echo "CPU info:"
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
else
|
||||
echo "CPU: no AVX found"
|
||||
fi
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
else
|
||||
echo "CPU: no AVX2 found"
|
||||
fi
|
||||
if grep -q -e "\savx512" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512 found OK"
|
||||
else
|
||||
echo "CPU: no AVX512 found"
|
||||
fi
|
||||
echo "@@@@@"
|
||||
fi
|
||||
|
||||
./local-ai "$@"
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
# Examples
|
||||
|
||||
| [ChatGPT OSS alternative](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui) | [Image generation](https://localai.io/api-endpoints/index.html#image-generation) |
|
||||
|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|
|
||||
|  |  |
|
||||
|
||||
| [Telegram bot](https://github.com/go-skynet/LocalAI/tree/master/examples/telegram-bot) | [Flowise](https://github.com/go-skynet/LocalAI/tree/master/examples/flowise) |
|
||||
|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|
|
||||
 | | |
|
||||
|
||||
Here is a list of projects that can easily be integrated with the LocalAI backend.
|
||||
|
||||
|
||||
### Projects
|
||||
|
||||
### AutoGPT
|
||||
@@ -64,6 +73,14 @@ A ready to use example to show e2e how to integrate LocalAI with langchain
|
||||
|
||||
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-python/)
|
||||
|
||||
### LocalAI functions
|
||||
|
||||
_by [@mudler](https://github.com/mudler)_
|
||||
|
||||
A ready to use example to show how to use OpenAI functions with LocalAI
|
||||
|
||||
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/functions/)
|
||||
|
||||
### LocalAI WebUI
|
||||
|
||||
_by [@dhruvgera](https://github.com/dhruvgera)_
|
||||
|
||||
9
examples/functions/.env
Normal file
9
examples/functions/.env
Normal file
@@ -0,0 +1,9 @@
|
||||
OPENAI_API_KEY=sk---anystringhere
|
||||
OPENAI_API_BASE=http://api:8080/v1
|
||||
# Models to preload at start
|
||||
# Here we configure gpt4all as gpt-3.5-turbo and bert as embeddings
|
||||
PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/openllama-7b-open-instruct.yaml", "name": "gpt-3.5-turbo"}]
|
||||
|
||||
## Change the default number of threads
|
||||
#THREADS=14
|
||||
|
||||
5
examples/functions/Dockerfile
Normal file
5
examples/functions/Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
||||
FROM python:3.10-bullseye
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
ENTRYPOINT [ "python", "./functions-openai.py" ];
|
||||
18
examples/functions/README.md
Normal file
18
examples/functions/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# LocalAI functions
|
||||
|
||||
Example of using LocalAI functions, see the [OpenAI](https://openai.com/blog/function-calling-and-other-api-updates) blog post.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
# Clone LocalAI
|
||||
git clone https://github.com/go-skynet/LocalAI
|
||||
|
||||
cd LocalAI/examples/functions
|
||||
|
||||
docker-compose run --rm functions
|
||||
```
|
||||
|
||||
Note: The example automatically downloads the `openllama` model as it is under a permissive license.
|
||||
|
||||
See the `.env` configuration file to set a different model with the [model-gallery](https://github.com/go-skynet/model-gallery) by editing `PRELOAD_MODELS`.
|
||||
23
examples/functions/docker-compose.yaml
Normal file
23
examples/functions/docker-compose.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
api:
|
||||
image: quay.io/go-skynet/local-ai:master
|
||||
ports:
|
||||
- 8080:8080
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- DEBUG=true
|
||||
- MODELS_PATH=/models
|
||||
volumes:
|
||||
- ./models:/models:cached
|
||||
command: ["/usr/bin/local-ai" ]
|
||||
functions:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
depends_on:
|
||||
api:
|
||||
condition: service_healthy
|
||||
env_file:
|
||||
- .env
|
||||
76
examples/functions/functions-openai.py
Normal file
76
examples/functions/functions-openai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import openai
|
||||
import json
|
||||
|
||||
# Example dummy function hard coded to return the same weather
|
||||
# In production, this could be your backend API or an external API
|
||||
def get_current_weather(location, unit="fahrenheit"):
|
||||
"""Get the current weather in a given location"""
|
||||
weather_info = {
|
||||
"location": location,
|
||||
"temperature": "72",
|
||||
"unit": unit,
|
||||
"forecast": ["sunny", "windy"],
|
||||
}
|
||||
return json.dumps(weather_info)
|
||||
|
||||
|
||||
def run_conversation():
|
||||
# Step 1: send the conversation and available functions to GPT
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
functions = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call="auto", # auto is default, but we'll be explicit
|
||||
)
|
||||
response_message = response["choices"][0]["message"]
|
||||
|
||||
# Step 2: check if GPT wanted to call a function
|
||||
if response_message.get("function_call"):
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
available_functions = {
|
||||
"get_current_weather": get_current_weather,
|
||||
} # only one function in this example, but you can have multiple
|
||||
function_name = response_message["function_call"]["name"]
|
||||
fuction_to_call = available_functions[function_name]
|
||||
function_args = json.loads(response_message["function_call"]["arguments"])
|
||||
function_response = fuction_to_call(
|
||||
location=function_args.get("location"),
|
||||
unit=function_args.get("unit"),
|
||||
)
|
||||
|
||||
# Step 4: send the info on the function call and function response to GPT
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
second_response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
) # get a new response from GPT where it can see the function response
|
||||
return second_response
|
||||
|
||||
|
||||
print(run_conversation())
|
||||
2
examples/functions/requirements.txt
Normal file
2
examples/functions/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
langchain==0.0.234
|
||||
openai==0.27.8
|
||||
@@ -38,7 +38,7 @@ helm install local-ai go-skynet/local-ai --create-namespace --namespace local-ai
|
||||
# Install k8sgpt
|
||||
helm repo add k8sgpt https://charts.k8sgpt.ai/
|
||||
helm repo update
|
||||
helm install release k8sgpt/k8sgpt-operator -n k8sgpt-operator-system --create-namespace
|
||||
helm install release k8sgpt/k8sgpt-operator -n k8sgpt-operator-system --create-namespace --version 0.0.17
|
||||
```
|
||||
|
||||
Apply the k8sgpt-operator configuration:
|
||||
@@ -55,7 +55,6 @@ spec:
|
||||
baseUrl: http://local-ai.local-ai.svc.cluster.local:8080/v1
|
||||
noCache: false
|
||||
model: gpt-3.5-turbo
|
||||
noCache: false
|
||||
version: v0.3.0
|
||||
enableAI: true
|
||||
EOF
|
||||
@@ -67,4 +66,7 @@ Apply a broken pod:
|
||||
|
||||
```
|
||||
kubectl apply -f broken-pod.yaml
|
||||
```
|
||||
```
|
||||
|
||||
## ArgoCD Deployment Example
|
||||
[Deploy K8sgpt + localai with Argocd](https://github.com/tyler-harpool/gitops/tree/main/infra/k8gpt)
|
||||
|
||||
@@ -2,12 +2,13 @@ replicaCount: 1
|
||||
|
||||
deployment:
|
||||
# https://quay.io/repository/go-skynet/local-ai?tab=tags
|
||||
image: quay.io/go-skynet/local-ai:latest
|
||||
image: quay.io/go-skynet/local-ai:v1.23.0
|
||||
env:
|
||||
threads: 4
|
||||
debug: "true"
|
||||
context_size: 512
|
||||
preload_models: '[{ "url": "github:go-skynet/model-gallery/wizard.yaml", "name": "gpt-3.5-turbo", "overrides": { "parameters": { "model": "WizardLM-7B-uncensored.ggmlv3.q5_1" }},"files": [ { "uri": "https://huggingface.co//WizardLM-7B-uncensored-GGML/resolve/main/WizardLM-7B-uncensored.ggmlv3.q5_1.bin", "sha256": "d92a509d83a8ea5e08ba4c2dbaf08f29015932dc2accd627ce0665ac72c2bb2b", "filename": "WizardLM-7B-uncensored.ggmlv3.q5_1" }]}]'
|
||||
galleries: '[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]'
|
||||
preload_models: '[{ "id": "huggingface@thebloke__open-llama-13b-open-instruct-ggml__open-llama-13b-open-instruct.ggmlv3.q3_k_m.bin", "name": "gpt-3.5-turbo", "overrides": { "f16": true, "mmap": true }}]'
|
||||
modelsPath: "/models"
|
||||
|
||||
resources:
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.vectorstores.base import VectorStoreRetriever
|
||||
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
|
||||
|
||||
# Load and process the text
|
||||
embedding = OpenAIEmbeddings()
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_base=base_path)
|
||||
persist_directory = 'db'
|
||||
|
||||
# Now we can load the persisted database from disk, and use it as normal.
|
||||
|
||||
@@ -18,8 +18,8 @@ texts = text_splitter.split_documents(documents)
|
||||
# Supplying a persist_directory will store the embeddings on disk
|
||||
persist_directory = 'db'
|
||||
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
embedding = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_base=base_path)
|
||||
vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
|
||||
|
||||
vectordb.persist()
|
||||
vectordb = None
|
||||
vectordb = None
|
||||
|
||||
@@ -22,7 +22,7 @@ services:
|
||||
- 'PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}, {"url": "github:go-skynet/model-gallery/stablediffusion.yaml"}, {"url": "github:go-skynet/model-gallery/whisper-base.yaml", "name": "whisper-1"}]'
|
||||
volumes:
|
||||
- ./models:/models:cached
|
||||
command: ["/usr/bin/local-ai" ]
|
||||
command: ["/usr/bin/local-ai"]
|
||||
chatgpt_telegram_bot:
|
||||
container_name: chatgpt_telegram_bot
|
||||
command: python3 bot/bot.py
|
||||
|
||||
109
extra/grpc/autogptq/autogptq.py
Executable file
109
extra/grpc/autogptq/autogptq.py
Executable file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import TextGenerationPipeline
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
device = "cuda:0"
|
||||
if request.Device != "":
|
||||
device = request.Device
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=request.UseFastTokenizer)
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(request.Model,
|
||||
model_basename=request.ModelBaseName,
|
||||
use_safetensors=True,
|
||||
trust_remote_code=True,
|
||||
device=device,
|
||||
use_triton=request.UseTriton,
|
||||
quantize_config=None)
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Predict(self, request, context):
|
||||
penalty = 1.0
|
||||
if request.Penalty != 0.0:
|
||||
penalty = request.Penalty
|
||||
tokens = 512
|
||||
if request.Tokens != 0:
|
||||
tokens = request.Tokens
|
||||
top_p = 0.95
|
||||
if request.TopP != 0.0:
|
||||
top_p = request.TopP
|
||||
|
||||
# Implement Predict RPC
|
||||
pipeline = TextGenerationPipeline(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
max_new_tokens=tokens,
|
||||
temperature=request.Temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=penalty,
|
||||
)
|
||||
t = pipeline(request.Prompt)[0]["generated_text"]
|
||||
# Remove prompt from response if present
|
||||
if request.Prompt in t:
|
||||
t = t.replace(request.Prompt, "")
|
||||
|
||||
return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
# Implement PredictStream RPC
|
||||
#for reply in some_data_generator():
|
||||
# yield reply
|
||||
# Not implemented yet
|
||||
return self.Predict(request, context)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
49
extra/grpc/autogptq/backend_pb2.py
Normal file
49
extra/grpc/autogptq/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x86\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x9d\x04\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=818
|
||||
_globals['_REPLY']._serialized_start=820
|
||||
_globals['_REPLY']._serialized_end=844
|
||||
_globals['_MODELOPTIONS']._serialized_start=847
|
||||
_globals['_MODELOPTIONS']._serialized_end=1388
|
||||
_globals['_RESULT']._serialized_start=1390
|
||||
_globals['_RESULT']._serialized_end=1432
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1434
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1471
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1473
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1540
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1542
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1620
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1622
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1711
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1714
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1872
|
||||
_globals['_TTSREQUEST']._serialized_start=1874
|
||||
_globals['_TTSREQUEST']._serialized_end=1928
|
||||
_globals['_BACKEND']._serialized_start=1931
|
||||
_globals['_BACKEND']._serialized_end=2422
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/autogptq/backend_pb2_grpc.py
Normal file
297
extra/grpc/autogptq/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
49
extra/grpc/bark/backend_pb2.py
Normal file
49
extra/grpc/bark/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x86\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x9d\x04\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=818
|
||||
_globals['_REPLY']._serialized_start=820
|
||||
_globals['_REPLY']._serialized_end=844
|
||||
_globals['_MODELOPTIONS']._serialized_start=847
|
||||
_globals['_MODELOPTIONS']._serialized_end=1388
|
||||
_globals['_RESULT']._serialized_start=1390
|
||||
_globals['_RESULT']._serialized_end=1432
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1434
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1471
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1473
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1540
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1542
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1620
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1622
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1711
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1714
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1872
|
||||
_globals['_TTSREQUEST']._serialized_start=1874
|
||||
_globals['_TTSREQUEST']._serialized_end=1928
|
||||
_globals['_BACKEND']._serialized_start=1931
|
||||
_globals['_BACKEND']._serialized_end=2422
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/bark/backend_pb2_grpc.py
Normal file
297
extra/grpc/bark/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
83
extra/grpc/bark/ttsbark.py
Normal file
83
extra/grpc/bark/ttsbark.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from pathlib import Path
|
||||
from bark import SAMPLE_RATE, generate_audio, preload_models
|
||||
from scipy.io.wavfile import write as write_wav
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
model_name = request.Model
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
# download and load all models
|
||||
preload_models()
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
model = request.model
|
||||
print(request, file=sys.stderr)
|
||||
try:
|
||||
audio_array = None
|
||||
if model != "":
|
||||
audio_array = generate_audio(request.text, history_prompt=model)
|
||||
else:
|
||||
audio_array = generate_audio(request.text)
|
||||
print("saving to", request.dst, file=sys.stderr)
|
||||
# save audio to disk
|
||||
write_wav(request.dst, SAMPLE_RATE, audio_array)
|
||||
print("saved to", request.dst, file=sys.stderr)
|
||||
print("tts for", file=sys.stderr)
|
||||
print(request, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
114
extra/grpc/diffusers/backend_diffusers.py
Executable file
114
extra/grpc/diffusers/backend_diffusers.py
Executable file
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
|
||||
# import diffusers
|
||||
import torch
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
print(f"Request {request}", file=sys.stderr)
|
||||
torchType = torch.float32
|
||||
if request.F16Memory:
|
||||
torchType = torch.float16
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
)
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
if request.SchedulerType == "EulerAncestralDiscreteScheduler":
|
||||
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
|
||||
if request.SchedulerType == "DPMSolverMultistepScheduler":
|
||||
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
def GenerateImage(self, request, context):
|
||||
|
||||
prompt = request.positive_prompt
|
||||
negative_prompt = request.negative_prompt
|
||||
|
||||
image = self.pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
# guidance_scale=12,
|
||||
target_size=(request.width,request.height),
|
||||
original_size=(4096,4096),
|
||||
num_inference_steps=request.step
|
||||
).images[0]
|
||||
|
||||
image.save(request.dst)
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
49
extra/grpc/diffusers/backend_pb2.py
Normal file
49
extra/grpc/diffusers/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x86\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x9d\x04\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=818
|
||||
_globals['_REPLY']._serialized_start=820
|
||||
_globals['_REPLY']._serialized_end=844
|
||||
_globals['_MODELOPTIONS']._serialized_start=847
|
||||
_globals['_MODELOPTIONS']._serialized_end=1388
|
||||
_globals['_RESULT']._serialized_start=1390
|
||||
_globals['_RESULT']._serialized_end=1432
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1434
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1471
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1473
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1540
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1542
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1620
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1622
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1711
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1714
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1872
|
||||
_globals['_TTSREQUEST']._serialized_start=1874
|
||||
_globals['_TTSREQUEST']._serialized_end=1928
|
||||
_globals['_BACKEND']._serialized_start=1931
|
||||
_globals['_BACKEND']._serialized_end=2422
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/diffusers/backend_pb2_grpc.py
Normal file
297
extra/grpc/diffusers/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
49
extra/grpc/exllama/backend_pb2.py
Normal file
49
extra/grpc/exllama/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x86\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x9d\x04\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=818
|
||||
_globals['_REPLY']._serialized_start=820
|
||||
_globals['_REPLY']._serialized_end=844
|
||||
_globals['_MODELOPTIONS']._serialized_start=847
|
||||
_globals['_MODELOPTIONS']._serialized_end=1388
|
||||
_globals['_RESULT']._serialized_start=1390
|
||||
_globals['_RESULT']._serialized_end=1432
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1434
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1471
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1473
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1540
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1542
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1620
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1622
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1711
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1714
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1872
|
||||
_globals['_TTSREQUEST']._serialized_start=1874
|
||||
_globals['_TTSREQUEST']._serialized_end=1928
|
||||
_globals['_BACKEND']._serialized_start=1931
|
||||
_globals['_BACKEND']._serialized_end=2422
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/exllama/backend_pb2_grpc.py
Normal file
297
extra/grpc/exllama/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
142
extra/grpc/exllama/exllama.py
Executable file
142
extra/grpc/exllama/exllama.py
Executable file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os, glob
|
||||
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import version as torch_version
|
||||
from exllama.generator import ExLlamaGenerator
|
||||
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
||||
from exllama.tokenizer import ExLlamaTokenizer
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def generate(self,prompt, max_new_tokens):
|
||||
self.generator.end_beam_search()
|
||||
|
||||
# Tokenizing the input
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
decoded_text = ''
|
||||
for i in range(max_new_tokens):
|
||||
token = self.generator.gen_single_token()
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
if token.item() == self.generator.tokenizer.eos_token_id:
|
||||
break
|
||||
return decoded_text
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
# https://github.com/turboderp/exllama/blob/master/example_cfg.py
|
||||
model_directory = request.ModelFile
|
||||
|
||||
# Locate files we need within that directory
|
||||
tokenizer_path = os.path.join(model_directory, "tokenizer.model")
|
||||
model_config_path = os.path.join(model_directory, "config.json")
|
||||
st_pattern = os.path.join(model_directory, "*.safetensors")
|
||||
model_path = glob.glob(st_pattern)[0]
|
||||
|
||||
# Create config, model, tokenizer and generator
|
||||
|
||||
config = ExLlamaConfig(model_config_path) # create config from config.json
|
||||
config.model_path = model_path # supply path to model weights file
|
||||
|
||||
model = ExLlama(config) # create ExLlama instance and load the weights
|
||||
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
|
||||
|
||||
cache = ExLlamaCache(model, batch_size = 2) # create cache for inference
|
||||
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator
|
||||
|
||||
self.generator= generator
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.cache = cache
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def Predict(self, request, context):
|
||||
penalty = 1.15
|
||||
if request.Penalty != 0.0:
|
||||
penalty = request.Penalty
|
||||
self.generator.settings.token_repetition_penalty_max = penalty
|
||||
self.generator.settings.temperature = request.Temperature
|
||||
self.generator.settings.top_k = request.TopK
|
||||
self.generator.settings.top_p = request.TopP
|
||||
|
||||
tokens = 512
|
||||
if request.Tokens != 0:
|
||||
tokens = request.Tokens
|
||||
|
||||
if self.cache.batch_size == 1:
|
||||
del self.cache
|
||||
self.cache = ExLlamaCache(self.model, batch_size=2)
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
|
||||
t = self.generate(request.Prompt, tokens)
|
||||
|
||||
# Remove prompt from response if present
|
||||
if request.Prompt in t:
|
||||
t = t.replace(request.Prompt, "")
|
||||
|
||||
return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
# Implement PredictStream RPC
|
||||
#for reply in some_data_generator():
|
||||
# yield reply
|
||||
# Not implemented yet
|
||||
return self.Predict(request, context)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
49
extra/grpc/huggingface/backend_pb2.py
Normal file
49
extra/grpc/huggingface/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x86\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x9d\x04\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=818
|
||||
_globals['_REPLY']._serialized_start=820
|
||||
_globals['_REPLY']._serialized_end=844
|
||||
_globals['_MODELOPTIONS']._serialized_start=847
|
||||
_globals['_MODELOPTIONS']._serialized_end=1388
|
||||
_globals['_RESULT']._serialized_start=1390
|
||||
_globals['_RESULT']._serialized_end=1432
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1434
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1471
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1473
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1540
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1542
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1620
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1622
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1711
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1714
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1872
|
||||
_globals['_TTSREQUEST']._serialized_start=1874
|
||||
_globals['_TTSREQUEST']._serialized_end=1928
|
||||
_globals['_BACKEND']._serialized_start=1931
|
||||
_globals['_BACKEND']._serialized_end=2422
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/huggingface/backend_pb2_grpc.py
Normal file
297
extra/grpc/huggingface/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
66
extra/grpc/huggingface/huggingface.py
Executable file
66
extra/grpc/huggingface/huggingface.py
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
model_name = request.Model
|
||||
try:
|
||||
self.model = SentenceTransformer(model_name)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
def Embedding(self, request, context):
|
||||
# Implement your logic here for the Embedding service
|
||||
# Replace this with your desired response
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
sentence_embeddings = self.model.encode(request.Embeddings)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
4
extra/requirements.txt
Normal file
4
extra/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
sentence_transformers
|
||||
grpcio
|
||||
google
|
||||
protobuf
|
||||
56
go.mod
56
go.mod
@@ -3,30 +3,35 @@ module github.com/go-skynet/LocalAI
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369
|
||||
github.com/gofiber/fiber/v2 v2.47.0
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066
|
||||
github.com/gofiber/fiber/v2 v2.48.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hpcloud/tail v1.0.0
|
||||
github.com/imdario/mergo v0.3.16
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a
|
||||
github.com/onsi/ginkgo/v2 v2.11.0
|
||||
github.com/onsi/gomega v1.27.8
|
||||
github.com/otiai10/openaigo v1.4.0
|
||||
github.com/rs/zerolog v1.29.1
|
||||
github.com/sashabaranov/go-openai v1.11.3
|
||||
github.com/swaggo/swag v1.16.1
|
||||
github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9
|
||||
github.com/onsi/gomega v1.27.10
|
||||
github.com/otiai10/openaigo v1.5.2
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5
|
||||
github.com/rs/zerolog v1.30.0
|
||||
github.com/sashabaranov/go-openai v1.14.1
|
||||
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537
|
||||
github.com/urfave/cli/v2 v2.25.7
|
||||
github.com/valyala/fasthttp v1.48.0
|
||||
google.golang.org/grpc v1.57.0
|
||||
google.golang.org/protobuf v1.31.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -34,8 +39,10 @@ require (
|
||||
require (
|
||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golang/snappy v0.0.2 // indirect
|
||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/nwaples/rardecode v1.1.0 // indirect
|
||||
@@ -43,44 +50,35 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
||||
github.com/ulikunitz/xz v0.5.9 // indirect
|
||||
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/fsnotify.v1 v1.4.7 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/PuerkitoBio/purell v1.1.1 // indirect
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
|
||||
github.com/andybalholm/brotli v1.0.5 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||
github.com/go-openapi/jsonreference v0.19.6 // indirect
|
||||
github.com/go-openapi/spec v0.20.4 // indirect
|
||||
github.com/go-openapi/swag v0.22.3 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/klauspost/compress v1.16.3 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760
|
||||
github.com/otiai10/mint v1.6.1 // indirect
|
||||
github.com/philhofer/fwd v1.1.2 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
||||
github.com/tinylib/msgp v1.1.8 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/net v0.12.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
golang.org/x/tools v0.9.3 // indirect
|
||||
)
|
||||
|
||||
254
go.sum
254
go.sum
@@ -1,9 +1,3 @@
|
||||
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
|
||||
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
|
||||
github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI=
|
||||
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
|
||||
github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
|
||||
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
@@ -21,11 +15,14 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674 h1:G70Yf/QOCEL1v24idWnGd6rJsbqiGkJAJnMaWaolzEg=
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df h1:qVcBEZlvp5A1gGWNJj02xyDtbsUI2hohlQMSB1fgER4=
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
|
||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY=
|
||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 h1:boeMTUUBtnLU8JElZJHXrsUzROJar9/t6vGOFjkrhhI=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
@@ -36,43 +33,44 @@ github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs=
|
||||
github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns=
|
||||
github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M=
|
||||
github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM=
|
||||
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
|
||||
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
||||
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ=
|
||||
github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA=
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9 h1:wRGbDwNwPmSzoXVw/HLzXY4blpRvPWg7QW2OA0WKezA=
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9/go.mod h1:pXKCpYYXujMeAvgJHU6WoMfvYbr84563+J8+Ebkyr5U=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1 h1:jVGgzDSfpjD/0jl/ChpGI+O4EHSAeeU6DK7IyhH8PK8=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792 h1:rozZ9gWGzq0ZhBsNCWqfLTRCebaxwTsxLMnflwe6rDU=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc h1:SrNxH4U8W6cqurbxpXxm9rzifeDsCgecRT73kT0BRq0=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070 h1:T771FjB1yQw8j4P5x4ayFrUPNTglzxRIqDjaNkMVIME=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8 h1:Knh5QUvI/68erb/yWtrVa/3hvoQdENF2dH0hL2HNPrI=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210 h1:9bm+vsiR3UI7xlU0G0cMU2Swq78RysoFVkSONvrujF8=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369 h1:lSX1NWzRvRS2MlACvyvVVUnqXhKiuMAoN3DO5TbCe8M=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y=
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0=
|
||||
github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76 h1:NRdxo2MKi8qhWZXxu6CIZOkdH+LBERFz1kk22U1FD3k=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230709163512-6c97625cca76/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119 h1:FeUSk5yMHT7J7jeCQKAOs4x5LRNSYH0SR6djM/i1jcc=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230724222459-562d2b5a7119/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965 h1:2MO/rABKpkXnnKQ3Ar90aqhnlMEejE9gnKG6bafv+ow=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230727163958-6ba16de8e965/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7 h1:1uBwholTaJ8Lva8ySJjT4jNaCDAh+MJXtsbZBbQq9lA=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230729200103-8c51308e42d7/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066 h1:v4Js+yEdgY9IV7n35M+5MELLxlOMp3qC5whZm5YTLjI=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230802220037-50cee7712066/go.mod h1:fiJBto+Le1XLtD/cID5SAKs8cKE7wFXJKfTT3wvPQRA=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gofiber/fiber/v2 v2.47.0 h1:EN5lHVCc+Pyqh5OEsk8fzRiifgwpbrP0rulQ4iNf3fs=
|
||||
github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5lmcNksUoU=
|
||||
github.com/gofiber/fiber/v2 v2.48.0 h1:cRVMCb9aUJDsyHxGFLwz/sGzDggdailZZyptU9F9cU0=
|
||||
github.com/gofiber/fiber/v2 v2.48.0/go.mod h1:xqJgfqrc23FJuqGOW6DVgi3HyZEm2Mn9pRqUb2kHSX8=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw=
|
||||
github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
@@ -85,11 +83,11 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
|
||||
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
||||
@@ -99,16 +97,12 @@ github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQs
|
||||
github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
||||
github.com/klauspost/pgzip v1.2.5 h1:qnWYvvKqedOF2ulHpMG72XQol4ILEJ8k2wwRl/Km8oE=
|
||||
github.com/klauspost/pgzip v1.2.5/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA=
|
||||
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
@@ -124,32 +118,51 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef h1:OJZtJ5vYhlkTJI0RHIl62kOkhiINQEhZgsXlwmmNDhM=
|
||||
github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef/go.mod h1:00giAi/vwF8LX29JBjkPQhtASsivPnGNzB6sdmk8JGE=
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9vPAAU+s35cDMCx1IyeMn+4OYfCBPqi08Q8vXDg=
|
||||
github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc=
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks=
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c h1:axNtjd5k6Xs4Ck7B7VRRQu6q5lQzTsjdWmaJkDADopU=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165 h1:zcnIdoSeLueTDxUD2A1qnyaSp8uh0Ay7OgHeBwpxSeg=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5 h1:bmQnxyKiqCu8i2y/N/Sf0coWoG2/Ed12YGQeb7lTnjo=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230714185456-cfd70b69fcf5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9 h1:/oRwZhulKTU8LpPD2fXi2o2kdlTutQjYWDVMkrv14po=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230725212419-9100b2ef6fb9/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816 h1:hRi7hpDUuaO0dB4NZ8eyaeD2fRar6CPyNAARsO5DhzA=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230727161923-39acbc837816/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586 h1:WVEMSZMyHFe68PN204c3Fdk5g2lZouPvbU9/2zkPpWc=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230731161838-cbdcde8b7586/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de h1:E5EGczxEAcbaO8yqj074MQxU609QbtB6in3qTOW1EFo=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230802145814-c449b71b56de/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee h1:Y/j+GNytyncmDnAEuDZwzkYC9nzUPvXJPF+nntQG0VU=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230807175413-0f2bb506a8ee/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a h1:bX26Zfwh72ug2aZTEwFISTMEJ56Wa/4KqboidD+g92A=
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230811181453-4d855afe973a/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
|
||||
github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ=
|
||||
github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
|
||||
github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc=
|
||||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
|
||||
github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
|
||||
github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ=
|
||||
github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks=
|
||||
github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM=
|
||||
github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM=
|
||||
github.com/otiai10/openaigo v1.2.0 h1:Whq+uvgqw8NdIsVdixtBKCAI6OdfCJiGPlhUnYJQ6Ag=
|
||||
github.com/otiai10/openaigo v1.2.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg=
|
||||
github.com/otiai10/openaigo v1.4.0 h1:BeacKb2Q5bVejjOKHFJxL2WFYal3QxwkrKtKuoU5LNU=
|
||||
github.com/otiai10/openaigo v1.4.0/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI=
|
||||
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
|
||||
github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw=
|
||||
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
|
||||
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
|
||||
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
|
||||
github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg=
|
||||
github.com/otiai10/openaigo v1.5.2 h1:YnNDisZmA4syArF3IxMCIrfgZOq30PLV219gPY7n2z8=
|
||||
github.com/otiai10/openaigo v1.5.2/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI=
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI=
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
|
||||
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -160,37 +173,34 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
|
||||
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
|
||||
github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c=
|
||||
github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
|
||||
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
|
||||
github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4=
|
||||
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8=
|
||||
github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
|
||||
github.com/sashabaranov/go-openai v1.14.0 h1:D1yAB+DHElgbJFdYyjxfTWMFzhddn+PwZmkQ039L7mQ=
|
||||
github.com/sashabaranov/go-openai v1.14.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA=
|
||||
github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/swaggo/swag v1.16.1 h1:fTNRhKstPKxcnoKsytm4sahr8FaYzUcT7i1/3nd/fBg=
|
||||
github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKikGxto=
|
||||
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
|
||||
github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0=
|
||||
github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=
|
||||
github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944 h1:EE9fvNENTdRc/yI/1zAs7VFbmDk6JZ7EbBIFl+TsCm0=
|
||||
github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI=
|
||||
github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba h1:NpAI9C0y9T4jwP7XFShwYJKGf/ggyCgZEtL/7lLRPwE=
|
||||
github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g=
|
||||
github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39 h1:SpOEFXx5xXLypFnwNRQj7yOC3rMvSylGA5BQW/FAwYc=
|
||||
github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g=
|
||||
github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b h1:xUxtya/3KRDn1rcCVZucp2KhjdqSZat9j0hOshSVh2Q=
|
||||
github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM=
|
||||
github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9 h1:BooyHg3f058lrPcTLdfC7HTfjO5OGZAgwciQJ5e85l0=
|
||||
github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM=
|
||||
github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8 h1:wdJigYmmIRCuXhCkADDr53Oa1fp/WlxCPoVXR2r7GrU=
|
||||
github.com/tmc/langchaingo v0.0.0-20230713201705-dcf7ecdc8ac8/go.mod h1:mTzgQfAGwmBz2hhQELZfu2bwsbHwyKHA6IHOa+9LDFg=
|
||||
github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a h1:I/2JSuYXkWaVVLSZmrPfrgbvvvPR0IaulZcB0Iu8oVI=
|
||||
github.com/tmc/langchaingo v0.0.0-20230726025230-7d5f9fd5e90a/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
|
||||
github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe h1:+XVrCjh3rPibfISkUFG2Ck5NLKODQ9cFdmraFye1bGA=
|
||||
github.com/tmc/langchaingo v0.0.0-20230729232647-7df4fe5fb8fe/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
|
||||
github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600 h1:SABuIthjhIXEsxnokuA16CZOxxdW9XohIHQqd/go8Nc=
|
||||
github.com/tmc/langchaingo v0.0.0-20230731024823-8f101609f600/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
|
||||
github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5 h1:js7vYDJGzUGVSt0YlIusUc5BXYVECu3LUI/asby5Ggo=
|
||||
github.com/tmc/langchaingo v0.0.0-20230802030916-271e9bd7e7c5/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
|
||||
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537 h1:vkeNjlW+0Xiw2XizMHoQuLG8pg6AN1hU8zJuMV9GQBc=
|
||||
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
|
||||
github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/ulikunitz/xz v0.5.9 h1:RsKRIA2MO8x56wkkcd3LbtcE/uMszhb6DpRf+3uwa3I=
|
||||
github.com/ulikunitz/xz v0.5.9/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
@@ -207,74 +217,92 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMx
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
|
||||
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
|
||||
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A=
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54 h1:9NWlQfY2ePejTmfwUH1OWwmznFa+0kKcHGPDvcPza9M=
|
||||
google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
|
||||
google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI=
|
||||
google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
|
||||
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
|
||||
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473/go.mod h1:N1eN2tsCx0Ydtgjl4cqmbRCsY4/+z4cYDeqwZTk6zog=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -6,5 +6,5 @@ var Version = ""
|
||||
var Commit = ""
|
||||
|
||||
func PrintableVersion() string {
|
||||
return fmt.Sprintf("LocalAI %s (%s)", Version, Commit)
|
||||
return fmt.Sprintf("%s (%s)", Version, Commit)
|
||||
}
|
||||
|
||||
90
main.go
90
main.go
@@ -1,14 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
api "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -17,6 +18,13 @@ import (
|
||||
|
||||
func main() {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
path, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -33,6 +41,10 @@ func main() {
|
||||
Name: "f16",
|
||||
EnvVars: []string{"F16"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "autoload-galleries",
|
||||
EnvVars: []string{"AUTOLOAD_GALLERIES"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "debug",
|
||||
EnvVars: []string{"DEBUG"},
|
||||
@@ -101,6 +113,11 @@ func main() {
|
||||
EnvVars: []string{"BACKEND_ASSETS_PATH"},
|
||||
Value: "/tmp/localai/backend_data",
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "external-grpc-backends",
|
||||
Usage: "A list of external grpc backends",
|
||||
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
Usage: "Default context size of the model",
|
||||
@@ -113,6 +130,11 @@ func main() {
|
||||
EnvVars: []string{"UPLOAD_LIMIT"},
|
||||
Value: 15,
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "api-keys",
|
||||
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
|
||||
EnvVars: []string{"API_KEY"},
|
||||
},
|
||||
},
|
||||
Description: `
|
||||
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
||||
@@ -126,34 +148,46 @@ Some of the models compatible are:
|
||||
- Alpaca
|
||||
- StableLM (ggml quantized)
|
||||
|
||||
It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
|
||||
For a list of compatible model, check out: https://localai.io/model-compatibility/index.html
|
||||
`,
|
||||
UsageText: `local-ai [options]`,
|
||||
Copyright: "go-skynet authors",
|
||||
Copyright: "Ettore Di Giacinto",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
|
||||
galls := ctx.String("galleries")
|
||||
var galleries []gallery.Gallery
|
||||
err := json.Unmarshal([]byte(galls), &galleries)
|
||||
fmt.Println(err)
|
||||
app, err := api.App(
|
||||
api.WithConfigFile(ctx.String("config-file")),
|
||||
api.WithGalleries(galleries),
|
||||
api.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
api.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
||||
api.WithContextSize(ctx.Int("context-size")),
|
||||
api.WithDebug(ctx.Bool("debug")),
|
||||
api.WithImageDir(ctx.String("image-path")),
|
||||
api.WithAudioDir(ctx.String("audio-path")),
|
||||
api.WithF16(ctx.Bool("f16")),
|
||||
api.WithDisableMessage(false),
|
||||
api.WithCors(ctx.Bool("cors")),
|
||||
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||
api.WithThreads(ctx.Int("threads")),
|
||||
api.WithBackendAssets(backendAssets),
|
||||
api.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
api.WithUploadLimitMB(ctx.Int("upload-limit")))
|
||||
|
||||
opts := []options.AppOption{
|
||||
options.WithConfigFile(ctx.String("config-file")),
|
||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
||||
options.WithContextSize(ctx.Int("context-size")),
|
||||
options.WithDebug(ctx.Bool("debug")),
|
||||
options.WithImageDir(ctx.String("image-path")),
|
||||
options.WithAudioDir(ctx.String("audio-path")),
|
||||
options.WithF16(ctx.Bool("f16")),
|
||||
options.WithStringGalleries(ctx.String("galleries")),
|
||||
options.WithDisableMessage(false),
|
||||
options.WithCors(ctx.Bool("cors")),
|
||||
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||
options.WithThreads(ctx.Int("threads")),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
options.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range externalgRPC {
|
||||
backend := v[:strings.IndexByte(v, ':')]
|
||||
uri := v[strings.IndexByte(v, ':')+1:]
|
||||
opts = append(opts, options.WithExternalBackend(backend, uri))
|
||||
}
|
||||
|
||||
if ctx.Bool("autoload-galleries") {
|
||||
opts = append(opts, options.EnableGalleriesAutoload)
|
||||
}
|
||||
|
||||
app, err := api.App(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/imdario/mergo"
|
||||
@@ -17,19 +18,17 @@ type Gallery struct {
|
||||
|
||||
// Installs a model from the gallery (galleryname@modelname)
|
||||
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
installName := model.Name
|
||||
if req.Name != "" {
|
||||
model.Name = req.Name
|
||||
installName = req.Name
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
@@ -40,20 +39,62 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
return err
|
||||
}
|
||||
|
||||
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
|
||||
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
||||
return applyModel(model)
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model, err := FindGallery(models, name)
|
||||
if err != nil {
|
||||
var err2 error
|
||||
model, err2 = FindGallery(models, strings.ToLower(name))
|
||||
if err2 != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
for _, model := range models {
|
||||
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no gallery found with name %q", name)
|
||||
}
|
||||
|
||||
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
|
||||
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
var model *GalleryModel
|
||||
for _, m := range models {
|
||||
if name == m.Name || m.Name == strings.ToLower(name) {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
}
|
||||
|
||||
return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
|
||||
}
|
||||
|
||||
// List available models
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
@@ -50,7 +49,7 @@ var _ = Describe("Model test", func() {
|
||||
}}
|
||||
out, err := yaml.Marshal(gallery)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = ioutil.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
|
||||
err = os.WriteFile(filepath.Join(tempdir, "gallery_simple.yaml"), out, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
galleries := []Gallery{
|
||||
|
||||
58
pkg/grammar/functions.go
Normal file
58
pkg/grammar/functions.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
type Functions []Function
|
||||
|
||||
func (f Functions) ToJSONStructure() JSONFunctionStructure {
|
||||
js := JSONFunctionStructure{}
|
||||
for _, function := range f {
|
||||
// t := function.Parameters["type"]
|
||||
//tt := t.(string)
|
||||
|
||||
properties := function.Parameters["properties"]
|
||||
defs := function.Parameters["$defs"]
|
||||
dat, _ := json.Marshal(properties)
|
||||
dat2, _ := json.Marshal(defs)
|
||||
prop := map[string]interface{}{}
|
||||
defsD := map[string]interface{}{}
|
||||
|
||||
json.Unmarshal(dat, &prop)
|
||||
json.Unmarshal(dat2, &defsD)
|
||||
if js.Defs == nil {
|
||||
js.Defs = defsD
|
||||
}
|
||||
js.OneOf = append(js.OneOf, Item{
|
||||
Type: "object",
|
||||
Properties: Properties{
|
||||
Function: FunctionName{Const: function.Name},
|
||||
Arguments: Argument{
|
||||
Type: "object",
|
||||
Properties: prop,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return js
|
||||
}
|
||||
|
||||
// Select returns a list of functions containing the function with the given name
|
||||
func (f Functions) Select(name string) Functions {
|
||||
var funcs Functions
|
||||
|
||||
for _, f := range f {
|
||||
if f.Name == name {
|
||||
funcs = []Function{f}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return funcs
|
||||
}
|
||||
63
pkg/grammar/functions_test.go
Normal file
63
pkg/grammar/functions_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package grammar_test
|
||||
|
||||
import (
|
||||
. "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LocalAI grammar functions", func() {
|
||||
Describe("ToJSONStructure()", func() {
|
||||
It("converts a list of functions to a JSON structure that can be parsed to a grammar", func() {
|
||||
var functions Functions = []Function{
|
||||
{
|
||||
Name: "create_event",
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"event_name": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
"event_date": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "search",
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
js := functions.ToJSONStructure()
|
||||
Expect(len(js.OneOf)).To(Equal(2))
|
||||
Expect(js.OneOf[0].Properties.Function.Const).To(Equal("create_event"))
|
||||
Expect(js.OneOf[0].Properties.Arguments.Properties["event_name"].(map[string]interface{})["type"]).To(Equal("string"))
|
||||
Expect(js.OneOf[0].Properties.Arguments.Properties["event_date"].(map[string]interface{})["type"]).To(Equal("string"))
|
||||
Expect(js.OneOf[1].Properties.Function.Const).To(Equal("search"))
|
||||
Expect(js.OneOf[1].Properties.Arguments.Properties["query"].(map[string]interface{})["type"]).To(Equal("string"))
|
||||
})
|
||||
})
|
||||
Context("Select()", func() {
|
||||
It("selects one of the functions and returns a list containing only the selected one", func() {
|
||||
var functions Functions = []Function{
|
||||
{
|
||||
Name: "create_event",
|
||||
},
|
||||
{
|
||||
Name: "search",
|
||||
},
|
||||
}
|
||||
|
||||
functions = functions.Select("create_event")
|
||||
Expect(len(functions)).To(Equal(1))
|
||||
Expect(functions[0].Name).To(Equal("create_event"))
|
||||
})
|
||||
})
|
||||
})
|
||||
13
pkg/grammar/grammar_suite_test.go
Normal file
13
pkg/grammar/grammar_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestGrammar(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Grammar test suite")
|
||||
}
|
||||
254
pkg/grammar/json_schema.go
Normal file
254
pkg/grammar/json_schema.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package grammar
|
||||
|
||||
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
SPACE_RULE = `" "?`
|
||||
|
||||
PRIMITIVE_RULES = map[string]string{
|
||||
"boolean": `("true" | "false") space`,
|
||||
"number": `("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space`,
|
||||
"integer": `("-"? ([0-9] | [1-9] [0-9]*)) space`,
|
||||
"string": `"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space`,
|
||||
"null": `"null" space`,
|
||||
}
|
||||
|
||||
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`)
|
||||
GRAMMAR_LITERAL_ESCAPES = map[string]string{
|
||||
"\r": `\r`,
|
||||
"\n": `\n`,
|
||||
`"`: `\"`,
|
||||
}
|
||||
)
|
||||
|
||||
type JSONSchemaConverter struct {
|
||||
propOrder map[string]int
|
||||
rules map[string]string
|
||||
}
|
||||
|
||||
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
|
||||
propOrderSlice := strings.Split(propOrder, ",")
|
||||
propOrderMap := make(map[string]int)
|
||||
for idx, name := range propOrderSlice {
|
||||
propOrderMap[name] = idx
|
||||
}
|
||||
|
||||
rules := make(map[string]string)
|
||||
rules["space"] = SPACE_RULE
|
||||
|
||||
return &JSONSchemaConverter{
|
||||
propOrder: propOrderMap,
|
||||
rules: rules,
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string {
|
||||
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string {
|
||||
return GRAMMAR_LITERAL_ESCAPES[match]
|
||||
})
|
||||
return fmt.Sprintf(`"%s"`, escaped)
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
|
||||
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
|
||||
key := escName
|
||||
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
|
||||
i := 0
|
||||
for {
|
||||
key = fmt.Sprintf("%s%d", escName, i)
|
||||
if _, ok := sc.rules[key]; !ok {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
sc.rules[key] = rule
|
||||
return key
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) formatGrammar() string {
|
||||
var lines []string
|
||||
for name, rule := range sc.rules {
|
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) string {
|
||||
st, existType := schema["type"]
|
||||
var schemaType string
|
||||
if existType {
|
||||
schemaType = st.(string)
|
||||
}
|
||||
ruleName := name
|
||||
if name == "" {
|
||||
ruleName = "root"
|
||||
}
|
||||
_, oneOfExists := schema["oneOf"]
|
||||
_, anyOfExists := schema["anyOf"]
|
||||
if oneOfExists || anyOfExists {
|
||||
var alternatives []string
|
||||
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
|
||||
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
|
||||
|
||||
if oneOfExists {
|
||||
for i, altSchema := range oneOfSchemas {
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
} else if anyOfExists {
|
||||
for i, altSchema := range anyOfSchemas {
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
}
|
||||
|
||||
rule := strings.Join(alternatives, " | ")
|
||||
return sc.addRule(ruleName, rule)
|
||||
} else if ref, exists := schema["$ref"].(string); exists {
|
||||
referencedSchema := sc.resolveReference(ref, rootSchema)
|
||||
return sc.visit(referencedSchema, name, rootSchema)
|
||||
} else if constVal, exists := schema["const"]; exists {
|
||||
return sc.addRule(ruleName, sc.formatLiteral(constVal))
|
||||
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||||
var enumRules []string
|
||||
for _, enumVal := range enumVals {
|
||||
enumRule := sc.formatLiteral(enumVal)
|
||||
enumRules = append(enumRules, enumRule)
|
||||
}
|
||||
rule := strings.Join(enumRules, " | ")
|
||||
return sc.addRule(ruleName, rule)
|
||||
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
|
||||
propOrder := sc.propOrder
|
||||
var propPairs []struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}
|
||||
|
||||
for propName, propSchema := range properties {
|
||||
propPairs = append(propPairs, struct {
|
||||
propName string
|
||||
propSchema map[string]interface{}
|
||||
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
|
||||
}
|
||||
|
||||
sort.Slice(propPairs, func(i, j int) bool {
|
||||
iOrder := propOrder[propPairs[i].propName]
|
||||
jOrder := propOrder[propPairs[j].propName]
|
||||
if iOrder != 0 && jOrder != 0 {
|
||||
return iOrder < jOrder
|
||||
}
|
||||
return propPairs[i].propName < propPairs[j].propName
|
||||
})
|
||||
|
||||
var rule strings.Builder
|
||||
rule.WriteString(`"{" space`)
|
||||
|
||||
for i, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
propSchema := propPair.propSchema
|
||||
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||||
|
||||
if i > 0 {
|
||||
rule.WriteString(` "," space`)
|
||||
}
|
||||
|
||||
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName))
|
||||
}
|
||||
|
||||
rule.WriteString(` "}" space`)
|
||||
return sc.addRule(ruleName, rule.String())
|
||||
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||||
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
|
||||
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||||
return sc.addRule(ruleName, rule)
|
||||
} else {
|
||||
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
|
||||
if !exists {
|
||||
panic(fmt.Sprintf("Unrecognized schema: %v", schema))
|
||||
}
|
||||
if ruleName == "root" {
|
||||
schemaType = "root"
|
||||
}
|
||||
return sc.addRule(schemaType, primitiveRule)
|
||||
}
|
||||
}
|
||||
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
|
||||
if !strings.HasPrefix(ref, "#/$defs/") {
|
||||
panic(fmt.Sprintf("Invalid reference format: %s", ref))
|
||||
}
|
||||
|
||||
defKey := strings.TrimPrefix(ref, "#/$defs/")
|
||||
definitions, exists := rootSchema["$defs"].(map[string]interface{})
|
||||
if !exists {
|
||||
fmt.Println(rootSchema)
|
||||
|
||||
panic("No definitions found in the schema")
|
||||
}
|
||||
|
||||
def, exists := definitions[defKey].(map[string]interface{})
|
||||
if !exists {
|
||||
fmt.Println(definitions)
|
||||
|
||||
panic(fmt.Sprintf("Definition not found: %s", defKey))
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
|
||||
sc.visit(schema, "", schema)
|
||||
return sc.formatGrammar()
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
|
||||
var schema map[string]interface{}
|
||||
_ = json.Unmarshal(b, &schema)
|
||||
return sc.Grammar(schema)
|
||||
}
|
||||
|
||||
func jsonString(v interface{}) string {
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
type FunctionName struct {
|
||||
Const string `json:"const"`
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
Function FunctionName `json:"function"`
|
||||
Arguments Argument `json:"arguments"`
|
||||
}
|
||||
|
||||
type Argument struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]interface{} `json:"properties"`
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Type string `json:"type"`
|
||||
Properties Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type JSONFunctionStructure struct {
|
||||
OneOf []Item `json:"oneOf,omitempty"`
|
||||
AnyOf []Item `json:"anyOf,omitempty"`
|
||||
Defs map[string]interface{} `json:"$defs,omitempty"`
|
||||
}
|
||||
|
||||
func (j JSONFunctionStructure) Grammar(propOrder string) string {
|
||||
dat, _ := json.Marshal(j)
|
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
|
||||
}
|
||||
116
pkg/grammar/json_schema_test.go
Normal file
116
pkg/grammar/json_schema_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package grammar_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
testInput1 = `
|
||||
{
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "create_event"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"time": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function": {"const": "search"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
inputResult1 = `root-0-function ::= "\"create_event\""
|
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
|
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||
root ::= root-0 | root-1
|
||||
space ::= " "?
|
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
root-1-function ::= "\"search\""`
|
||||
)
|
||||
|
||||
var _ = Describe("JSON schema grammar tests", func() {
|
||||
Context("JSON", func() {
|
||||
It("generates a valid grammar from JSON schema", func() {
|
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
It("generates a valid grammar from JSON Objects", func() {
|
||||
|
||||
structuredGrammar := JSONFunctionStructure{
|
||||
OneOf: []Item{
|
||||
{
|
||||
Type: "object",
|
||||
Properties: Properties{
|
||||
Function: FunctionName{
|
||||
Const: "create_event",
|
||||
},
|
||||
Arguments: Argument{ // this is OpenAI's parameter
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{
|
||||
"title": map[string]string{"type": "string"},
|
||||
"date": map[string]string{"type": "string"},
|
||||
"time": map[string]string{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "object",
|
||||
Properties: Properties{
|
||||
Function: FunctionName{
|
||||
Const: "search",
|
||||
},
|
||||
Arguments: Argument{
|
||||
Type: "object",
|
||||
Properties: map[string]interface{}{
|
||||
"query": map[string]string{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
grammar := structuredGrammar.Grammar("")
|
||||
results := strings.Split(inputResult1, "\n")
|
||||
for _, r := range results {
|
||||
if r != "" {
|
||||
Expect(grammar).To(ContainSubstring(r))
|
||||
}
|
||||
}
|
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||
})
|
||||
})
|
||||
})
|
||||
42
pkg/grpc/base/base.go
Normal file
42
pkg/grpc/base/base.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package base
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
}
|
||||
|
||||
func (llm *Base) Load(opts *pb.ModelOptions) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
|
||||
}
|
||||
|
||||
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return "", fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return []float32{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) {
|
||||
return api.Result{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user