Compare commits

..

2 Commits

Author SHA1 Message Date
Ettore Di Giacinto
63c5d843b6 chore(gosec): fix CI
downgrade to latest known version of the gosec action

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-09-13 19:17:27 +02:00
Ettore Di Giacinto
a9b0e264f2 chore(exllama): drop exllama backend
For polishing and cleaning up it makes now sense to drop exllama which
is completely unmaintained, and was only supporting the llamav1
architecture (nowadays it's superseded by llamav1) .

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-09-13 19:09:43 +02:00
143 changed files with 1079 additions and 3412 deletions

View File

@@ -9,7 +9,6 @@
# Param 2: email
#
config_user() {
echo "Configuring git for $1 <$2>"
local gcn=$(git config --global user.name)
if [ -z "${gcn}" ]; then
echo "Setting up git user / remote"
@@ -25,7 +24,6 @@ config_user() {
# Param 2: remote url
#
config_remote() {
echo "Adding git remote and fetching $2 as $1"
local gr=$(git remote -v | grep $1)
if [ -z "${gr}" ]; then
git remote add $1 $2

View File

@@ -29,14 +29,9 @@ def calculate_sha256(file_path):
def manual_safety_check_hf(repo_id):
scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan")
scan = scanResponse.json()
# Check if 'hasUnsafeFile' exists in the response
if 'hasUnsafeFile' in scan:
if scan['hasUnsafeFile']:
return scan
else:
return None
else:
return None
if scan['hasUnsafeFile']:
return scan
return None
download_type, repo_id_or_url = parse_uri(uri)

View File

@@ -6,7 +6,6 @@ import (
"io/ioutil"
"os"
"github.com/microcosm-cc/bluemonday"
"gopkg.in/yaml.v3"
)
@@ -280,12 +279,6 @@ func main() {
return
}
// Ensure that all arbitrary text content is sanitized before display
for i, m := range models {
models[i].Name = bluemonday.StrictPolicy().Sanitize(m.Name)
models[i].Description = bluemonday.StrictPolicy().Sanitize(m.Description)
}
// render the template
data := struct {
Models []*GalleryModel

View File

@@ -33,7 +33,7 @@ jobs:
run: |
CGO_ENABLED=0 make build-api
- name: rm
uses: appleboy/ssh-action@v1.1.0
uses: appleboy/ssh-action@v1.0.3
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
@@ -53,7 +53,7 @@ jobs:
rm: true
target: ./local-ai
- name: restarting
uses: appleboy/ssh-action@v1.1.0
uses: appleboy/ssh-action@v1.0.3
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}

View File

@@ -13,78 +13,6 @@ concurrency:
cancel-in-progress: true
jobs:
hipblas-jobs:
uses: ./.github/workflows/image_build.yml
with:
tag-latest: ${{ matrix.tag-latest }}
tag-suffix: ${{ matrix.tag-suffix }}
ffmpeg: ${{ matrix.ffmpeg }}
image-type: ${{ matrix.image-type }}
build-type: ${{ matrix.build-type }}
cuda-major-version: ${{ matrix.cuda-major-version }}
cuda-minor-version: ${{ matrix.cuda-minor-version }}
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
grpc-base-image: ${{ matrix.grpc-base-image }}
aio: ${{ matrix.aio }}
makeflags: ${{ matrix.makeflags }}
latest-image: ${{ matrix.latest-image }}
latest-image-aio: ${{ matrix.latest-image-aio }}
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
strategy:
# Pushing with all jobs in parallel
# eats the bandwidth of all the nodes
max-parallel: 2
matrix:
include:
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-hipblas-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
aio: "-aio-gpu-hipblas"
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
latest-image: 'latest-gpu-hipblas'
latest-image-aio: 'latest-aio-gpu-hipblas'
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas'
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-core'
ffmpeg: 'false'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
self-hosted-jobs:
uses: ./.github/workflows/image_build.yml
with:
@@ -111,7 +39,7 @@ jobs:
strategy:
# Pushing with all jobs in parallel
# eats the bandwidth of all the nodes
max-parallel: ${{ github.event_name != 'pull_request' && 5 || 8 }}
max-parallel: ${{ github.event_name != 'pull_request' && 6 || 10 }}
matrix:
include:
# Extra images
@@ -194,6 +122,29 @@ jobs:
base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: '-hipblas-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
aio: "-aio-gpu-hipblas"
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
latest-image: 'latest-gpu-hipblas'
latest-image-aio: 'latest-aio-gpu-hipblas'
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas'
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'auto'
@@ -261,6 +212,26 @@ jobs:
image-type: 'core'
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-core'
ffmpeg: 'false'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.1"
grpc-base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
makeflags: "--jobs=3 --output-sync=target"
core-image-build:
uses: ./.github/workflows/image_build.yml

View File

@@ -18,7 +18,7 @@ jobs:
if: ${{ github.actor != 'dependabot[bot]' }}
- name: Run Gosec Security Scanner
if: ${{ github.actor != 'dependabot[bot]' }}
uses: securego/gosec@v2.21.4
uses: securego/gosec@v2.21.0
with:
# we let the report trigger content trigger a failure using the GitHub Security features.
args: '-no-fail -fmt sarif -out results.sarif ./...'

View File

@@ -178,22 +178,13 @@ jobs:
uses: actions/checkout@v4
with:
submodules: true
- name: Dependencies
run: |
# Install protoc
curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \
unzip -j -d /usr/local/bin protoc.zip bin/protoc && \
rm protoc.zip
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af
PATH="$PATH:$HOME/go/bin" make protogen-go
- name: Build images
run: |
docker build --build-arg FFMPEG=true --build-arg IMAGE_TYPE=extras --build-arg EXTRA_BACKENDS=rerankers --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" -t local-ai:tests -f Dockerfile .
BASE_IMAGE=local-ai:tests DOCKER_AIO_IMAGE=local-ai-aio:test make docker-aio
- name: Test
run: |
PATH="$PATH:$HOME/go/bin" LOCALAI_MODELS_DIR=$PWD/models LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio \
LOCALAI_MODELS_DIR=$PWD/models LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio \
make run-e2e-aio
- name: Setup tmate session if tests fail
if: ${{ failure() }}

View File

@@ -15,6 +15,8 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
- [Documentation](#documentation)
- [Community and Communication](#community-and-communication)
## Getting Started
### Prerequisites
@@ -52,7 +54,7 @@ If you find a bug, have a feature request, or encounter any issues, please check
## Coding Guidelines
- No specific coding guidelines at the moment. Please make sure the code can be tested. The most popular lint tools like [`golangci-lint`](https://golangci-lint.run) can help you here.
- No specific coding guidelines at the moment. Please make sure the code can be tested. The most popular lint tools like []`golangci-lint`](https://golangci-lint.run) can help you here.
## Testing
@@ -82,3 +84,5 @@ We are welcome the contribution of the documents, please open new PR or create a
- You can reach out via the Github issue tracker.
- Open a new discussion at [Discussion](https://github.com/go-skynet/LocalAI/discussions)
- Join the Discord channel [Discord](https://discord.gg/uJAeKSAGDy)
---

View File

@@ -9,8 +9,6 @@ FROM ${BASE_IMAGE} AS requirements-core
USER root
ARG GO_VERSION=1.22.6
ARG CMAKE_VERSION=3.26.4
ARG CMAKE_FROM_SOURCE=false
ARG TARGETARCH
ARG TARGETVARIANT
@@ -23,25 +21,13 @@ RUN apt-get update && \
build-essential \
ccache \
ca-certificates \
curl libssl-dev \
cmake \
curl \
git \
unzip upx-ucl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install CMake (the version in 22.04 is too old)
RUN <<EOT bash
if [ "${CMAKE_FROM_SOURCE}}" = "true" ]; then
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
else
apt-get update && \
apt-get install -y \
cmake && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
fi
EOT
# Install Go
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin
@@ -202,8 +188,6 @@ FROM ${GRPC_BASE_IMAGE} AS grpc
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
ARG GRPC_MAKEFLAGS="-j4 -Otarget"
ARG GRPC_VERSION=v1.65.0
ARG CMAKE_FROM_SOURCE=false
ARG CMAKE_VERSION=3.26.4
ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
@@ -212,24 +196,12 @@ WORKDIR /build
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates \
build-essential curl libssl-dev \
build-essential \
cmake \
git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install CMake (the version in 22.04 is too old)
RUN <<EOT bash
if [ "${CMAKE_FROM_SOURCE}}" = "true" ]; then
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
else
apt-get update && \
apt-get install -y \
cmake && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
fi
EOT
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
# and running make install in the target container
@@ -325,10 +297,10 @@ COPY .git .
RUN make prepare
## Build the binary
## If it's CUDA or hipblas, we want to skip some of the llama-compat backends to save space
## We only leave the most CPU-optimized variant and the fallback for the cublas/hipblas build
## (both will use CUDA or hipblas for the actual computation)
RUN if [ "${BUILD_TYPE}" = "cublas" ] || [ "${BUILD_TYPE}" = "hipblas" ]; then \
## If it's CUDA, we want to skip some of the llama-compat backends to save space
## We only leave the most CPU-optimized variant and the fallback for the cublas build
## (both will use CUDA for the actual computation)
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
SKIP_GRPC_BACKEND="backend-assets/grpc/llama-cpp-avx backend-assets/grpc/llama-cpp-avx2" make build; \
else \
make build; \
@@ -366,8 +338,9 @@ RUN if [ "${FFMPEG}" = "true" ]; then \
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ssh less wget
# For the devcontainer, leave apt functional in case additional devtools are needed at runtime.
ssh less && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN go install github.com/go-delve/delve/cmd/dlv@latest

View File

@@ -8,7 +8,7 @@ DETECT_LIBS?=true
# llama.cpp versions
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
CPPLLAMA_VERSION?=96776405a17034dcfd53d3ddf5d142d34bdbb657
CPPLLAMA_VERSION?=e6b7801bd189d102d901d3e72035611a25456ef1
# go-rwkv version
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
@@ -16,7 +16,7 @@ RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
WHISPER_CPP_VERSION?=fdbfb460ed546452a5d53611bba66d10d842e719
WHISPER_CPP_VERSION?=a551933542d956ae84634937acd2942eb40efaaf
# bert.cpp version
BERT_REPO?=https://github.com/go-skynet/go-bert.cpp
@@ -359,9 +359,6 @@ clean-tests:
rm -rf test-dir
rm -rf core/http/backend-assets
clean-dc: clean
cp -r /build/backend-assets /workspace/backend-assets
## Build:
build: prepare backend-assets grpcs ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
@@ -468,15 +465,15 @@ run-e2e-image:
ls -liah $(abspath ./tests/e2e-fixtures)
docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --gpus all --name e2e-tests-$(RANDOM) localai-tests
run-e2e-aio: protogen-go
run-e2e-aio:
@echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio
test-e2e:
@echo 'Running e2e tests'
BUILD_TYPE=$(BUILD_TYPE) \
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e
teardown-e2e:
rm -rf $(TEST_DIR) || true
@@ -484,24 +481,24 @@ teardown-e2e:
test-llama: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
test-llama-gguf: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
test-tts: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stablediffusion: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS)
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stores: backend-assets/grpc/local-store
mkdir -p tests/integration/backend-assets/grpc
cp -f backend-assets/grpc/local-store tests/integration/backend-assets/grpc/
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts 1 -v -r tests/integration
test-container:
docker build --target requirements -t local-ai-test-container .

View File

@@ -68,7 +68,9 @@ docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
[💻 Getting started](https://localai.io/basics/getting_started/index.html)
## 📰 Latest project news
## 🔥🔥 Hot topics / Roadmap
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
- Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io)
- July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723
@@ -81,12 +83,8 @@ docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
- May 2024: Chat, TTS, and Image generation in the WebUI: https://github.com/mudler/LocalAI/pull/2222
- April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121
Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
Hot topics (looking for contributors):
## 🔥🔥 Hot topics (looking for help):
- Multimodal with vLLM and Video understanding: https://github.com/mudler/LocalAI/pull/3729
- Realtime API https://github.com/mudler/LocalAI/issues/3714
- 🔥🔥 Distributed, P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113
- WebUI improvements: https://github.com/mudler/LocalAI/issues/2156
- Backends v2: https://github.com/mudler/LocalAI/issues/1126

View File

@@ -2,7 +2,7 @@ backend: llama-cpp
context_size: 4096
f16: true
mmap: true
name: gpt-4o
name: gpt-4-vision-preview
roles:
user: "USER:"

View File

@@ -2,7 +2,7 @@ backend: llama-cpp
context_size: 4096
f16: true
mmap: true
name: gpt-4o
name: gpt-4-vision-preview
roles:
user: "USER:"

View File

@@ -2,7 +2,7 @@ backend: llama-cpp
context_size: 4096
mmap: false
f16: false
name: gpt-4o
name: gpt-4-vision-preview
roles:
user: "USER:"

View File

@@ -26,19 +26,6 @@ service Backend {
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}
rpc Rerank(RerankRequest) returns (RerankResult) {}
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
}
// Define the empty request
message MetricsRequest {}
message MetricsResponse {
int32 slot_id = 1;
string prompt_json_for_slot = 2; // Stores the prompt as a JSON string.
float tokens_per_second = 3;
int32 tokens_generated = 4;
int32 prompt_tokens_processed = 5;
}
message RerankRequest {
@@ -147,9 +134,6 @@ message PredictOptions {
repeated string Images = 42;
bool UseTokenizerTemplate = 43;
repeated Message Messages = 44;
repeated string Videos = 45;
repeated string Audios = 46;
string CorrelationId = 47;
}
// The response message containing the result

View File

@@ -13,7 +13,6 @@
#include <getopt.h>
#include "clip.h"
#include "llava.h"
#include "log.h"
#include "stb_image.h"
#include "common.h"
#include "json.hpp"
@@ -113,7 +112,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
std::string ret;
for (; begin != end; ++begin)
{
ret += common_token_to_piece(ctx, *begin);
ret += llama_token_to_piece(ctx, *begin);
}
return ret;
}
@@ -121,7 +120,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{
std::string out = token == -1 ? "" : common_token_to_piece(ctx, token);
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character
// (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -203,8 +202,8 @@ struct llama_client_slot
std::string stopping_word;
// sampling
struct common_sampler_params sparams;
common_sampler *ctx_sampling = nullptr;
struct gpt_sampler_params sparams;
gpt_sampler *ctx_sampling = nullptr;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
@@ -257,7 +256,7 @@ struct llama_client_slot
images.clear();
}
bool has_budget(common_params &global_params) {
bool has_budget(gpt_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1)
{
return true; // limitless
@@ -398,7 +397,7 @@ struct llama_server_context
clip_ctx *clp_ctx = nullptr;
common_params params;
gpt_params params;
llama_batch batch;
@@ -441,7 +440,7 @@ struct llama_server_context
}
}
bool load_model(const common_params &params_)
bool load_model(const gpt_params &params_)
{
params = params_;
if (!params.mmproj.empty()) {
@@ -449,7 +448,7 @@ struct llama_server_context
LOG_INFO("Multi Modal Mode Enabled", {});
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) {
LOG_ERR("unable to load clip model: %s", params.mmproj.c_str());
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
return false;
}
@@ -458,12 +457,12 @@ struct llama_server_context
}
}
common_init_result common_init = common_init_from_params(params);
model = common_init.model;
ctx = common_init.context;
llama_init_result llama_init = llama_init_from_gpt_params(params);
model = llama_init.model;
ctx = llama_init.context;
if (model == nullptr)
{
LOG_ERR("unable to load model: %s", params.model.c_str());
LOG_ERROR("unable to load model", {{"model", params.model}});
return false;
}
@@ -471,7 +470,7 @@ struct llama_server_context
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
const int n_embd_llm = llama_n_embd(model);
if (n_embd_clip != n_embd_llm) {
LOG("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm);
llama_free(ctx);
llama_free_model(model);
return false;
@@ -490,21 +489,11 @@ struct llama_server_context
std::vector<char> buf(1);
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
if (res < 0) {
LOG_ERR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", __func__);
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
}
}
llama_client_slot* get_active_slot() {
for (llama_client_slot& slot : slots) {
// Check if the slot is currently processing
if (slot.is_processing()) {
return &slot; // Return the active slot
}
}
return nullptr; // No active slot found
}
void initialize() {
// create slots
all_slots_are_idle = true;
@@ -578,12 +567,12 @@ struct llama_server_context
std::vector<llama_token> p;
if (first)
{
p = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
first = false;
}
else
{
p = common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
}
@@ -600,7 +589,7 @@ struct llama_server_context
else
{
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
}
return prompt_tokens;
@@ -629,7 +618,7 @@ struct llama_server_context
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
slot_params default_params;
common_sampler_params default_sparams;
gpt_sampler_params default_sparams;
slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
@@ -769,7 +758,7 @@ struct llama_server_context
}
else if (el[0].is_string())
{
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks)
{
slot->sparams.logit_bias.push_back({tok, bias});
@@ -801,7 +790,7 @@ struct llama_server_context
sampler_names.emplace_back(name);
}
}
slot->sparams.samplers = common_sampler_types_from_names(sampler_names, false);
slot->sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
}
else
{
@@ -823,11 +812,10 @@ struct llama_server_context
img_sl.img_data = clip_image_u8_init();
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
{
LOG_ERR("%s: failed to load image, slot_id: %d, img_sl_id: %d",
__func__,
slot->id,
img_sl.id
);
LOG_ERROR("failed to load image", {
{"slot_id", slot->id},
{"img_sl_id", img_sl.id}
});
return false;
}
LOG_VERBOSE("image loaded", {
@@ -865,12 +853,12 @@ struct llama_server_context
}
}
if (!found) {
LOG("ERROR: Image with id: %i, not found.\n", img_id);
LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id);
slot->images.clear();
return false;
}
} catch (const std::invalid_argument& e) {
LOG("Invalid image number id in prompt\n");
LOG_TEE("Invalid image number id in prompt\n");
slot->images.clear();
return false;
}
@@ -885,9 +873,9 @@ struct llama_server_context
if (slot->ctx_sampling != nullptr)
{
common_sampler_free(slot->ctx_sampling);
gpt_sampler_free(slot->ctx_sampling);
}
slot->ctx_sampling = common_sampler_init(model, slot->sparams);
slot->ctx_sampling = gpt_sampler_init(model, slot->sparams);
//llama_set_rng_seed(ctx, slot->params.seed);
slot->command = LOAD_PROMPT;
@@ -898,7 +886,7 @@ struct llama_server_context
{"task_id", slot->task_id},
});
// LOG("sampling: \n%s\n", llama_sampling_print(slot->sparams).c_str());
// LOG_TEE("sampling: \n%s\n", llama_sampling_print(slot->sparams).c_str());
return true;
}
@@ -914,13 +902,13 @@ struct llama_server_context
system_tokens.clear();
if (!system_prompt.empty()) {
system_tokens = common_tokenize(ctx, system_prompt, add_bos_token);
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
common_batch_clear(batch);
llama_batch_clear(batch);
for (int i = 0; i < (int)system_tokens.size(); ++i)
{
common_batch_add(batch, system_tokens[i], i, { 0 }, false);
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
@@ -938,7 +926,7 @@ struct llama_server_context
};
if (llama_decode(ctx, batch_view) != 0)
{
LOG("%s: llama_decode() failed\n", __func__);
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
}
@@ -950,7 +938,7 @@ struct llama_server_context
}
}
LOG("system prompt updated\n");
LOG_TEE("system prompt updated\n");
system_need_update = false;
}
@@ -1009,7 +997,7 @@ struct llama_server_context
bool process_token(completion_token_output &result, llama_client_slot &slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok;
// search stop word and delete it
@@ -1132,7 +1120,7 @@ struct llama_server_context
}
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.cpuparams.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
LOG("Error processing the given image");
LOG_TEE("Error processing the given image");
return false;
}
@@ -1144,7 +1132,7 @@ struct llama_server_context
void send_error(task_server& task, const std::string &error)
{
LOG("task %i - error: %s\n", task.id, error.c_str());
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
@@ -1160,7 +1148,7 @@ struct llama_server_context
samplers.reserve(slot.sparams.samplers.size());
for (const auto & sampler : slot.sparams.samplers)
{
samplers.emplace_back(common_sampler_type_to_str(sampler));
samplers.emplace_back(gpt_sampler_type_to_str(sampler));
}
return json {
@@ -1216,7 +1204,7 @@ struct llama_server_context
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size());
size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size());
if (probs_pos < probs_stop_pos)
@@ -1268,7 +1256,7 @@ struct llama_server_context
std::vector<completion_token_output> probs = {};
if (!slot.params.stream && slot.stopped_word)
{
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size());
}
else
@@ -1383,7 +1371,7 @@ struct llama_server_context
};
if (llama_decode(ctx, batch_view))
{
LOG("%s : failed to eval\n", __func__);
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
}
@@ -1401,14 +1389,14 @@ struct llama_server_context
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch_img))
{
LOG("%s : failed to eval image\n", __func__);
LOG_TEE("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
image_idx++;
common_batch_clear(batch);
llama_batch_clear(batch);
// append prefix of next image
const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
@@ -1418,7 +1406,7 @@ struct llama_server_context
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i)
{
common_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1;
}
}
@@ -1550,7 +1538,7 @@ struct llama_server_context
update_system_prompt();
}
common_batch_clear(batch);
llama_batch_clear(batch);
if (all_slots_are_idle)
{
@@ -1584,7 +1572,7 @@ struct llama_server_context
slot.n_past = 0;
slot.truncated = false;
slot.has_next_token = true;
LOG("Context exhausted. Slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
LOG_TEE("Context exhausted. Slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
continue;
// END LOCALAI changes
@@ -1628,7 +1616,7 @@ struct llama_server_context
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
common_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1;
}
@@ -1722,7 +1710,7 @@ struct llama_server_context
if (!slot.params.cache_prompt)
{
common_sampler_reset(slot.ctx_sampling);
gpt_sampler_reset(slot.ctx_sampling);
slot.n_past = 0;
slot.n_past_se = 0;
@@ -1734,7 +1722,7 @@ struct llama_server_context
// push the prompt into the sampling context (do not apply grammar)
for (auto &token : prompt_tokens)
{
common_sampler_accept(slot.ctx_sampling, token, false);
gpt_sampler_accept(slot.ctx_sampling, token, false);
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
@@ -1826,17 +1814,16 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
common_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
slot_npast++;
}
if (has_images && !ingest_images(slot, n_batch))
{
LOG_ERR("%s: failed processing images Slot id : %d, Task id: %d",
__func__,
slot.id,
slot.task_id
);
LOG_ERROR("failed processing images", {
"slot_id", slot.id,
"task_id", slot.task_id,
});
// FIXME @phymbert: to be properly tested
// early returning without changing the slot state will block the slot for ever
// no one at the moment is checking the return value
@@ -1876,10 +1863,10 @@ struct llama_server_context
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
LOG("\n");
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
@@ -1889,7 +1876,7 @@ struct llama_server_context
slot.ga_i += slot.ga_w / slot.ga_n;
LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_se += n_tokens;
}
@@ -1914,11 +1901,11 @@ struct llama_server_context
if (n_batch == 1 || ret < 0)
{
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false;
}
LOG("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
@@ -1943,9 +1930,9 @@ struct llama_server_context
}
completion_token_output result;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
const llama_token id = gpt_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
common_sampler_accept(slot.ctx_sampling, id, true);
gpt_sampler_accept(slot.ctx_sampling, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1)
@@ -1956,7 +1943,7 @@ struct llama_server_context
}
result.tok = id;
const auto * cur_p = common_sampler_get_candidates(slot.ctx_sampling);
const auto * cur_p = gpt_sampler_get_candidates(slot.ctx_sampling);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
@@ -2009,7 +1996,7 @@ static json format_partial_response(
struct token_translator
{
llama_context * ctx;
std::string operator()(llama_token tok) const { return common_token_to_piece(ctx, tok); }
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
};
@@ -2116,9 +2103,6 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();
// Add the correlationid to json data
data["correlation_id"] = predict->correlationid();
// for each image in the request, add the image data
//
for (int i = 0; i < predict->images_size(); i++) {
@@ -2203,7 +2187,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
// }
static void params_parse(const backend::ModelOptions* request,
common_params & params) {
gpt_params & params) {
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
@@ -2311,7 +2295,7 @@ public:
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
// Implement LoadModel RPC
common_params params;
gpt_params params;
params_parse(request, params);
llama_backend_init();
@@ -2357,11 +2341,6 @@ public:
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});
// Send the reply
writer->Write(reply);
@@ -2385,12 +2364,6 @@ public:
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});
completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
@@ -2430,31 +2403,6 @@ public:
return grpc::Status::OK;
}
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
llama_client_slot* active_slot = llama.get_active_slot();
if (active_slot != nullptr) {
// Calculate the tokens per second using existing logic
double tokens_per_second = 1e3 / active_slot->t_token_generation * active_slot->n_decoded;
// Populate the response with metrics
response->set_slot_id(active_slot->id);
response->set_prompt_json_for_slot(active_slot->prompt.dump());
response->set_tokens_per_second(tokens_per_second);
response->set_tokens_generated(active_slot->n_decoded);
response->set_prompt_tokens_processed(active_slot->num_prompt_tokens_processed);
} else {
// Handle case when no active slot exists
response->set_slot_id(0);
response->set_prompt_json_for_slot("");
response->set_tokens_per_second(0);
response->set_tokens_generated(0);
response->set_prompt_tokens_processed(0);
}
return grpc::Status::OK;
}
};
void RunServer(const std::string& server_address) {

View File

@@ -2,4 +2,4 @@
intel-extension-for-pytorch
torch
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,6 +1,6 @@
accelerate
auto-gptq==0.7.1
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
transformers

View File

@@ -3,6 +3,6 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406
transformers
accelerate

View File

@@ -1,4 +1,4 @@
bark==0.1.5
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi

View File

@@ -1,2 +1,2 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf

View File

@@ -3,6 +3,6 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
transformers
accelerate

View File

@@ -1,4 +1,4 @@
coqui-tts
grpcio==1.66.2
TTS==0.22.0
grpcio==1.66.1
protobuf
certifi

View File

@@ -19,7 +19,7 @@ class TestBackendServicer(unittest.TestCase):
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
time.sleep(30)
time.sleep(10)
def tearDown(self) -> None:
"""

View File

@@ -3,7 +3,7 @@ intel-extension-for-pytorch
torch
torchvision
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406
diffusers
opencv-python
transformers

View File

@@ -1,5 +1,5 @@
setuptools
grpcio==1.66.2
grpcio==1.66.1
pillow
protobuf
certifi

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
wheel

View File

@@ -1,3 +1,3 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi

View File

@@ -2,7 +2,7 @@
intel-extension-for-pytorch
torch
optimum[openvino]
grpcio==1.66.2
grpcio==1.66.1
protobuf
librosa==0.9.1
faster-whisper==1.0.3
@@ -18,6 +18,6 @@ python-dotenv
pypinyin==0.50.0
cn2an==0.5.22
jieba==0.42.1
gradio==4.44.1
gradio==4.38.1
langid==1.1.6
git+https://github.com/myshell-ai/MeloTTS.git

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
librosa
faster-whisper

View File

@@ -19,7 +19,7 @@ class TestBackendServicer(unittest.TestCase):
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
time.sleep(30)
time.sleep(10)
def tearDown(self) -> None:
"""

View File

@@ -15,12 +15,5 @@ installRequirements
# https://github.com/descriptinc/audiotools/issues/101
# incompatible protobuf versions.
PYDIR=python3.10
pyenv="${MY_DIR}/venv/lib/${PYDIR}/site-packages/google/protobuf/internal/"
if [ ! -d ${pyenv} ]; then
echo "(parler-tts/install.sh): Error: ${pyenv} does not exist"
exit 1
fi
curl -L https://raw.githubusercontent.com/protocolbuffers/protobuf/main/python/google/protobuf/internal/builder.py -o ${pyenv}/builder.py
PYDIR=$(ls ${MY_DIR}/venv/lib)
curl -L https://raw.githubusercontent.com/protocolbuffers/protobuf/main/python/google/protobuf/internal/builder.py -o ${MY_DIR}/venv/lib/${PYDIR}/site-packages/google/protobuf/internal/builder.py

View File

@@ -3,6 +3,6 @@ intel-extension-for-pytorch
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406
transformers
accelerate

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
llvmlite==0.43.0

View File

@@ -5,4 +5,4 @@ accelerate
torch
rerankers[transformers]
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,3 +1,3 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi

View File

@@ -55,7 +55,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
model_name = request.Model
try:
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
self.model = SentenceTransformer(model_name)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")

View File

@@ -2,5 +2,5 @@ torch
accelerate
transformers
bitsandbytes
sentence-transformers==3.1.1
sentence-transformers==3.0.1
transformers

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch
accelerate
sentence-transformers==3.1.1
sentence-transformers==3.0.1
transformers

View File

@@ -1,4 +1,4 @@
torch
accelerate
sentence-transformers==3.1.1
sentence-transformers==3.0.1
transformers

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0
torch
accelerate
sentence-transformers==3.1.1
sentence-transformers==3.0.1
transformers

View File

@@ -4,5 +4,5 @@ torch
optimum[openvino]
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
accelerate
sentence-transformers==3.1.1
sentence-transformers==3.0.1
transformers

View File

@@ -1,5 +1,3 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
datasets
einops
certifi

View File

@@ -4,4 +4,4 @@ transformers
accelerate
torch
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
scipy==1.14.0
certifi

View File

@@ -72,12 +72,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
A Result object that contains the result of the LoadModel operation.
"""
model_name = request.Model
# Check to see if the Model exists in the filesystem already.
if os.path.exists(request.ModelFile):
model_name = request.ModelFile
compute = torch.float16
if request.F16Memory == True:

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -4,4 +4,4 @@ accelerate
torch
torchaudio
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
setuptools==72.1.0 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,3 +1,3 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi

View File

@@ -5,8 +5,6 @@ import argparse
import signal
import sys
import os
from typing import List
from PIL import Image
import backend_pb2
import backend_pb2_grpc
@@ -17,8 +15,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.multimodal.utils import fetch_image
from vllm.assets.video import VideoAsset
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -109,7 +105,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
try:
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
try:
@@ -122,7 +117,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
print("Model loaded successfully", file=sys.stderr)
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
@@ -201,33 +196,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Seed != 0:
sampling_params.seed = request.Seed
# Extract image paths and process images
prompt = request.Prompt
image_paths = request.Images
image_data = [self.load_image(img_path) for img_path in image_paths]
videos_path = request.Videos
video_data = [self.load_video(video_path) for video_path in videos_path]
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
# If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
# Generate text using the LLM engine
# Generate text
request_id = random_uuid()
print(f"Generating text with request_id: {request_id}", file=sys.stderr)
outputs = self.llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image_data if image_data else None,
"video": video_data if video_data else None,
} if image_data or video_data else None,
},
sampling_params=sampling_params,
request_id=request_id,
)
outputs = self.llm.generate(prompt, sampling_params, request_id)
# Stream the results
generated_text = ""
@@ -250,49 +227,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if streaming:
return
# Remove the image files from /tmp folder
for img_path in image_paths:
try:
os.remove(img_path)
except Exception as e:
print(f"Error removing image file: {img_path}, {e}", file=sys.stderr)
# Sending the final generated text
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def load_image(self, image_path: str):
"""
Load an image from the given file path.
Args:
image_path (str): The path to the image file.
Returns:
Image: The loaded image.
"""
try:
return Image.open(image_path)
except Exception as e:
print(f"Error loading image {image_path}: {e}", file=sys.stderr)
return self.load_video(image_path)
def load_video(self, video_path: str):
"""
Load a video from the given file path.
Args:
video_path (str): The path to the image file.
Returns:
Video: The loaded video.
"""
try:
video = VideoAsset(name=video_path).np_ndarrays
return video
except Exception as e:
print(f"Error loading video {image_path}: {e}", file=sys.stderr)
return None
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))

View File

@@ -13,20 +13,4 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
fi
# We don't embed this into the images as it is a large dependency and not always needed.
# Besides, the speed inference are not actually usable in the current state for production use-cases.
if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE}" == "xtrue" ]; then
ensureVenv
# https://docs.vllm.ai/en/v0.6.1/getting_started/cpu-installation.html
if [ ! -d vllm ]; then
git clone https://github.com/vllm-project/vllm
fi
pushd vllm
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.66.2 protobuf bitsandbytes
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
VLLM_TARGET_DEVICE=cpu python setup.py install
popd
rm -rf vllm
else
installRequirements
fi
installRequirements

View File

@@ -1,5 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118
accelerate
torch
transformers
bitsandbytes
transformers

View File

@@ -1,4 +1,3 @@
accelerate
torch
transformers
bitsandbytes
transformers

View File

@@ -1,5 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0
accelerate
torch
transformers
bitsandbytes
transformers

View File

@@ -4,5 +4,4 @@ accelerate
torch
transformers
optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
bitsandbytes
setuptools==70.3.0 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,4 +1,4 @@
grpcio==1.66.2
grpcio==1.66.1
protobuf
certifi
setuptools

View File

@@ -10,11 +10,20 @@ import (
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)

View File

@@ -8,8 +8,19 @@ import (
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
threads := backendConfig.Threads
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(*threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
})
inferenceModel, err := loader.BackendLoader(
opts...,

View File

@@ -31,13 +31,24 @@ type TokenUsage struct {
Completion int
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := gRPCModelOpts(c)
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(c, o, []model.Option{})
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
})
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
@@ -90,8 +101,6 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
opts.Images = images
opts.Videos = videos
opts.Audios = audios
tokenUsage := TokenUsage{}

View File

@@ -11,65 +11,32 @@ import (
"github.com/rs/zerolog/log"
)
func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
name := c.Name
if name == "" {
name = c.Model
}
defOpts := []model.Option{
model.WithBackendString(c.Backend),
model.WithModel(c.Model),
model.WithAssetDir(so.AssetsDestination),
model.WithContext(so.Context),
model.WithModelID(name),
}
threads := 1
if c.Threads != nil {
threads = *c.Threads
}
if so.Threads != 0 {
threads = so.Threads
}
c.Threads = &threads
grpcOpts := grpcModelOpts(c)
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
if so.SingleBackend {
defOpts = append(defOpts, model.WithSingleActiveBackend())
opts = append(opts, model.WithSingleActiveBackend())
}
if so.ParallelBackendRequests {
defOpts = append(defOpts, model.EnableParallelRequests)
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range so.ExternalGRPCBackends {
defOpts = append(defOpts, model.WithExternalBackend(k, v))
opts = append(opts, model.WithExternalBackend(k, v))
}
return append(defOpts, opts...)
return opts
}
func getSeed(c config.BackendConfig) int32 {
var seed int32 = config.RAND_SEED
if c.Seed != nil {
seed = int32(*c.Seed)
}
seed := int32(*c.Seed)
if seed == config.RAND_SEED {
seed = rand.Int31()
}
@@ -77,47 +44,11 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}
func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
f16 := false
if c.F16 != nil {
f16 = *c.F16
}
embeddings := false
if c.Embeddings != nil {
embeddings = *c.Embeddings
}
lowVRAM := false
if c.LowVRAM != nil {
lowVRAM = *c.LowVRAM
}
mmap := false
if c.MMap != nil {
mmap = *c.MMap
}
ctxSize := 1024
if c.ContextSize != nil {
ctxSize = *c.ContextSize
}
mmlock := false
if c.MMlock != nil {
mmlock = *c.MMlock
}
nGPULayers := 9999999
if c.NGPULayers != nil {
nGPULayers = *c.NGPULayers
}
return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
@@ -125,14 +56,14 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
F16Memory: f16,
F16Memory: *c.F16,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(ctxSize),
ContextSize: int32(*c.ContextSize),
Seed: getSeed(c),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
@@ -154,16 +85,16 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
YarnBetaSlow: c.YarnBetaSlow,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
MLock: mmlock,
MLock: *c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: embeddings,
LowVRAM: lowVRAM,
NGPULayers: int32(nGPULayers),
MMap: mmap,
Embeddings: *c.Embeddings,
LowVRAM: *c.LowVRAM,
NGPULayers: int32(*c.NGPULayers),
MMap: *c.MMap,
MainGPU: c.MainGPU,
Threads: int32(*c.Threads),
TensorSplit: c.TensorSplit,

View File

@@ -9,9 +9,21 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
bb := backend
if bb == "" {
return nil, fmt.Errorf("backend is required")
}
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
rerankModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err

View File

@@ -13,6 +13,7 @@ import (
)
func SoundGeneration(
backend string,
modelFile string,
text string,
duration *float32,
@@ -24,8 +25,18 @@ func SoundGeneration(
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
if backend == "" {
return "", nil, fmt.Errorf("backend is a required parameter")
}
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
soundGenModel, err := loader.BackendLoader(opts...)
if err != nil {

View File

@@ -1,33 +0,0 @@
package backend
import (
"context"
"fmt"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
)
func TokenMetrics(
modelFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*proto.MetricsResponse, error) {
opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
})
model, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if model == nil {
return nil, fmt.Errorf("could not loadmodel model")
}
res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{})
return res, err
}

View File

@@ -1,44 +0,0 @@
package backend
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc"
model "github.com/mudler/LocalAI/pkg/model"
)
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
return schema.TokenizeResponse{}, err
}
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
predictOptions.Prompt = s
// tokenize the string
resp, err := inferenceModel.TokenizeString(appConfig.Context, predictOptions)
if err != nil {
return schema.TokenizeResponse{}, err
}
return schema.TokenizeResponse{
Tokens: resp.Tokens,
}, nil
}

View File

@@ -14,11 +14,13 @@ import (
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if backendConfig.Backend == "" {
backendConfig.Backend = model.WhisperBackend
}
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})
transcriptionModel, err := ml.BackendLoader(opts...)
if err != nil {

View File

@@ -28,9 +28,14 @@ func ModelTTS(
bb = model.PiperBackend
}
opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {

View File

@@ -41,35 +41,31 @@ type RunCMD struct {
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"`
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
LibraryPath string `env:"LOCALAI_LIBRARY_PATH,LIBRARY_PATH" help:"Path to the library directory (for e.g. external libraries used by backends)" default:"/usr/share/local-ai/libs" group:"backends"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"`
EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"`
WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"`
EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"`
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
}
func (r *RunCMD) Run(ctx *cliContext.Context) error {
@@ -101,11 +97,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithModelsURL(append(r.Models, r.ModelArgs...)...),
config.WithOpaqueErrors(r.OpaqueErrors),
config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan),
config.WithSubtleKeyComparison(r.UseSubtleKeyComparison),
config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
}
token := ""

View File

@@ -85,14 +85,13 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
var inputFile *string
if t.InputFile != "" {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(t.Model, text,
filePath, _, err := backend.SoundGeneration(t.Backend, t.Model, text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View File

@@ -15,9 +15,8 @@ import (
)
type UtilCMD struct {
GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"`
HFScan HFScanCMD `cmd:"" name:"hf-scan" help:"Checks installed models for known security issues. WARNING: this is a best-effort feature and may not catch everything!"`
UsecaseHeuristic UsecaseHeuristicCMD `cmd:"" name:"usecase-heuristic" help:"Checks a specific model config and prints what usecase LocalAI will offer for it."`
GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"`
HFScan HFScanCMD `cmd:"" name:"hf-scan" help:"Checks installed models for known security issues. WARNING: this is a best-effort feature and may not catch everything!"`
}
type GGUFInfoCMD struct {
@@ -31,11 +30,6 @@ type HFScanCMD struct {
ToScan []string `arg:""`
}
type UsecaseHeuristicCMD struct {
ConfigName string `name:"The config file to check"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
}
func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error {
if u.Args == nil || len(u.Args) == 0 {
return fmt.Errorf("no GGUF file provided")
@@ -105,31 +99,3 @@ func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error {
return nil
}
}
func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error {
if len(uhcmd.ConfigName) == 0 {
log.Error().Msg("ConfigName is a required parameter")
return fmt.Errorf("config name is a required parameter")
}
if len(uhcmd.ModelsPath) == 0 {
log.Error().Msg("ModelsPath is a required parameter")
return fmt.Errorf("model path is a required parameter")
}
bcl := config.NewBackendConfigLoader(uhcmd.ModelsPath)
err := bcl.LoadBackendConfig(uhcmd.ConfigName)
if err != nil {
log.Error().Err(err).Str("ConfigName", uhcmd.ConfigName).Msg("error while loading backend")
return err
}
bc, exists := bcl.GetBackendConfig(uhcmd.ConfigName)
if !exists {
log.Error().Str("ConfigName", uhcmd.ConfigName).Msg("ConfigName not found")
}
for name, uc := range config.GetAllBackendConfigUsecases() {
if bc.HasUsecases(uc) {
log.Info().Str("Usecase", name)
}
}
log.Info().Msg("---")
return nil
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"embed"
"encoding/json"
"regexp"
"time"
"github.com/mudler/LocalAI/pkg/xsysinfo"
@@ -17,6 +16,7 @@ type ApplicationConfig struct {
ModelPath string
LibPath string
UploadLimitMB, Threads, ContextSize int
DisableWebUI bool
F16 bool
Debug bool
ImageDir string
@@ -31,18 +31,11 @@ type ApplicationConfig struct {
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
EnforcePredownloadScans bool
OpaqueErrors bool
P2PToken string
P2PNetworkID string
DisableWebUI bool
EnforcePredownloadScans bool
OpaqueErrors bool
UseSubtleKeyComparison bool
DisableApiKeyRequirementForHttpGet bool
HttpGetExemptedEndpoints []*regexp.Regexp
DisableGalleryEndpoint bool
LoadToMemory []string
ModelLibraryURL string
Galleries []Gallery
@@ -64,6 +57,8 @@ type ApplicationConfig struct {
ModelsURL []string
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
DisableGalleryEndpoint bool
}
type AppOption func(*ApplicationConfig)
@@ -332,38 +327,6 @@ func WithOpaqueErrors(opaque bool) AppOption {
}
}
func WithLoadToMemory(models []string) AppOption {
return func(o *ApplicationConfig) {
o.LoadToMemory = models
}
}
func WithSubtleKeyComparison(subtle bool) AppOption {
return func(o *ApplicationConfig) {
o.UseSubtleKeyComparison = subtle
}
}
func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
return func(o *ApplicationConfig) {
o.DisableApiKeyRequirementForHttpGet = required
}
}
func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
return func(o *ApplicationConfig) {
o.HttpGetExemptedEndpoints = []*regexp.Regexp{}
for _, epr := range endpoints {
r, err := regexp.Compile(epr)
if err == nil && r != nil {
o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r)
} else {
log.Warn().Err(err).Str("regex", epr).Msg("Error while compiling HTTP Get Exemption regex, skipping this entry.")
}
}
}
}
// ToConfigLoaderOptions returns a slice of ConfigLoader Option.
// Some options defined at the application level are going to be passed as defaults for
// all the configuration for the models.

View File

@@ -3,13 +3,11 @@ package config
import (
"os"
"regexp"
"slices"
"strings"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"gopkg.in/yaml.v3"
)
const (
@@ -29,15 +27,13 @@ type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases"`
KnownUsecases *BackendConfigUsecases `yaml:"-"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings *bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
@@ -196,21 +192,6 @@ type TemplateConfig struct {
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
Video string `yaml:"video"`
Image string `yaml:"image"`
Audio string `yaml:"audio"`
}
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
type BCAlias BackendConfig
var aux BCAlias
if err := value.Decode(&aux); err != nil {
return err
}
*c = BackendConfig(aux)
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
return nil
}
func (c *BackendConfig) SetFunctionCallString(s string) {
@@ -429,121 +410,3 @@ func (c *BackendConfig) Validate() bool {
func (c *BackendConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
}
type BackendConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
)
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
return map[string]BackendConfigUsecases{
"FLAG_ANY": FLAG_ANY,
"FLAG_CHAT": FLAG_CHAT,
"FLAG_COMPLETION": FLAG_COMPLETION,
"FLAG_EDIT": FLAG_EDIT,
"FLAG_EMBEDDINGS": FLAG_EMBEDDINGS,
"FLAG_RERANK": FLAG_RERANK,
"FLAG_IMAGE": FLAG_IMAGE,
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_LLM": FLAG_LLM,
}
}
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllBackendConfigUsecases()
for _, str := range input {
flag, exists := flags["FLAG_"+strings.ToUpper(str)]
if exists {
result |= flag
}
}
return &result
}
// HasUsecases examines a BackendConfig and determines which endpoints have a chance of success.
func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool {
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
return true
}
return c.GuessUsecases(u)
}
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
return false
}
}
if (u & FLAG_COMPLETION) == FLAG_COMPLETION {
if c.TemplateConfig.Completion == "" {
return false
}
}
if (u & FLAG_EDIT) == FLAG_EDIT {
if c.TemplateConfig.Edit == "" {
return false
}
}
if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS {
if c.Embeddings == nil || !*c.Embeddings {
return false
}
}
if (u & FLAG_IMAGE) == FLAG_IMAGE {
imageBackends := []string{"diffusers", "tinydream", "stablediffusion"}
if !slices.Contains(imageBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
}
if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" {
return false
}
}
if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT {
if c.Backend != "whisper" {
return false
}
}
if (u & FLAG_TTS) == FLAG_TTS {
ttsBackends := []string{"piper", "transformers-musicgen", "parler-tts"}
if !slices.Contains(ttsBackends, c.Backend) {
return false
}
}
if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION {
if c.Backend != "transformers-musicgen" {
return false
}
}
return true
}

View File

@@ -1,35 +0,0 @@
package config
import "regexp"
type BackendConfigFilterFn func(string, *BackendConfig) bool
func NoFilterFn(_ string, _ *BackendConfig) bool { return true }
func BuildNameFilterFn(filter string) (BackendConfigFilterFn, error) {
if filter == "" {
return NoFilterFn, nil
}
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
return func(name string, config *BackendConfig) bool {
if config != nil {
return rxp.MatchString(config.Name)
}
return rxp.MatchString(name)
}, nil
}
func BuildUsecaseFilterFn(usecases BackendConfigUsecases) BackendConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
return func(name string, config *BackendConfig) bool {
if config == nil {
return false // TODO: Potentially make this a param, for now, no known usecase to include
}
return config.HasUsecases(usecases)
}
}

View File

@@ -201,26 +201,6 @@ func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
return res
}
func (bcl *BackendConfigLoader) GetBackendConfigsByFilter(filter BackendConfigFilterFn) []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
if filter == nil {
filter = NoFilterFn
}
for n, v := range bcl.configs {
if filter(n, &v) {
res = append(res, v)
}
}
// TODO: I don't think this one needs to Sort on name... but we'll see what breaks.
return res
}
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
bcl.Lock()
defer bcl.Unlock()

View File

@@ -19,17 +19,12 @@ var _ = Describe("Test cases for config related functions", func() {
`backend: "../foo-bar"
name: "foo"
parameters:
model: "foo-bar"
known_usecases:
- chat
- COMPLETION
`)
model: "foo-bar"`)
Expect(err).ToNot(HaveOccurred())
config, err := readBackendConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
Expect(config.Validate()).To(BeFalse())
Expect(config.KnownUsecases).ToNot(BeNil())
})
It("Test Validate", func() {
tmp, err := os.CreateTemp("", "config.yaml")
@@ -66,99 +61,4 @@ parameters:
Expect(config.Validate()).To(BeTrue())
})
})
It("Properly handles backend usecase matching", func() {
a := BackendConfig{
Name: "a",
}
Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially.
b := BackendConfig{
Name: "b",
Backend: "stablediffusion",
}
Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue())
Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse())
c := BackendConfig{
Name: "c",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Chat: "chat",
},
}
Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue())
d := BackendConfig{
Name: "d",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Chat: "chat",
Completion: "completion",
},
}
Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue())
trueValue := true
e := BackendConfig{
Name: "e",
Backend: "llama-cpp",
TemplateConfig: TemplateConfig{
Completion: "completion",
},
Embeddings: &trueValue,
}
Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse())
Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse())
Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue())
f := BackendConfig{
Name: "f",
Backend: "piper",
}
Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue())
Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse())
g := BackendConfig{
Name: "g",
Backend: "whisper",
}
Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse())
h := BackendConfig{
Name: "h",
Backend: "transformers-musicgen",
}
Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse())
Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue())
Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue())
knownUsecases := FLAG_CHAT | FLAG_COMPLETION
i := BackendConfig{
Name: "i",
Backend: "whisper",
// Earlier test checks parsing, this just needs to set final values
KnownUsecases: &knownUsecases,
}
Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue())
Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue())
Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse())
Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue())
Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue())
})
})

View File

@@ -132,7 +132,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
uri := downloader.URI(url)
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
@@ -156,7 +156,7 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
}
uri := downloader.URI(gallery.URL)
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {

View File

@@ -69,7 +69,7 @@ type PromptTemplate struct {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
uri := downloader.URI(url)
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {

View File

@@ -3,15 +3,13 @@ package http
import (
"embed"
"errors"
"fmt"
"net/http"
"strings"
"github.com/dave-gray101/v2keyauth"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/config"
@@ -31,6 +29,24 @@ import (
"github.com/rs/zerolog/log"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
// Embed a directory
//
//go:embed static/*
@@ -121,17 +137,37 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
})
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(app)
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
kaConfig, err := middleware.GetKeyAuthConfig(appConfig)
if err != nil || kaConfig == nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
app.Use(v2keyauth.New(*kaConfig))
if appConfig.CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
@@ -156,13 +192,13 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
galleryService := services.NewGalleryService(appConfig)
galleryService.Start(appConfig.Context, cl)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig)
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
httpFS := http.FS(embedDirStatic)

View File

@@ -12,7 +12,6 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
@@ -32,9 +31,6 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)
const apiKey = "joshua"
const bearerKey = "Bearer " + apiKey
const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can.
@@ -54,19 +50,11 @@ type modelApplyRequest struct {
func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request
req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
resp, err := http.Get(url)
if err != nil {
fmt.Println("Error creating request:", err)
return
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error sending request:", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
@@ -84,15 +72,14 @@ func getModelStatus(url string) (response map[string]interface{}) {
return
}
func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
func getModels(url string) (response []gallery.GalleryModel) {
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return response, err
return
}
func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
@@ -114,7 +101,6 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
// Make the request
client := &http.Client{}
@@ -154,7 +140,6 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
client := &http.Client{}
resp, err := client.Do(req)
@@ -190,7 +175,6 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
client := &http.Client{}
resp, err := client.Do(req)
@@ -211,35 +195,6 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
return json.Unmarshal(body, respJson)
}
func postInvalidRequest(url string) (error, int) {
req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
if err != nil {
return err, -1
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode
}
return nil, resp.StatusCode
}
//go:embed backend-assets/*
var backendAssets embed.FS
@@ -305,7 +260,6 @@ var _ = Describe("API test", func() {
config.WithContext(c),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithApiKeys([]string{apiKey}),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
@@ -315,7 +269,7 @@ var _ = Describe("API test", func() {
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
@@ -341,19 +295,10 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred())
})
Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(403))
})
})
Context("Applying models", func() {
It("applies models from a gallery", func() {
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
models := getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
@@ -386,8 +331,7 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
models = getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
@@ -951,7 +895,7 @@ var _ = Describe("API test", func() {
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
@@ -970,7 +914,7 @@ var _ = Describe("API test", func() {
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})

View File

@@ -19,16 +19,14 @@ func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *mo
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
if ctx.Query("model") != "" {
modelInput = ctx.Query("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
models, _ := services.ListModels(cl, loader, "", true)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)

View File

@@ -6,7 +6,6 @@ import (
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services"
@@ -42,7 +41,7 @@ func DoneProgress(galleryID, text string, showDelete bool) string {
"tabindex": "-1",
"autofocus": "",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
elem.Text(text),
),
elem.If(showDelete, deleteButton(galleryID, modelName), reInstallButton(galleryID)),
).Render()
@@ -58,7 +57,7 @@ func ErrorProgress(err, galleryName string) string {
"tabindex": "-1",
"autofocus": "",
},
elem.Text("Error "+bluemonday.StrictPolicy().Sanitize(err)),
elem.Text("Error "+err),
),
installButton(galleryName),
).Render()
@@ -171,7 +170,7 @@ func P2PNodeBoxes(nodes []p2p.NodeData) string {
attrs.Props{
"class": "text-gray-200 font-semibold ml-2 mr-1",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(n.ID)),
elem.Text(n.ID),
),
elem.Text("Status: "),
elem.If(
@@ -228,7 +227,7 @@ func StartProgressBar(uid, progress, text string) string {
"tabindex": "-1",
"autofocus": "",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive
elem.Text(text),
elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid,
"hx-trigger": "every 600ms",
@@ -250,7 +249,9 @@ func cardSpan(text, icon string) elem.Node {
"class": icon + " pr-2",
}),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
elem.Text(text),
//elem.Text(text),
)
}
@@ -284,9 +285,11 @@ func searchableElement(text, icon string) elem.Node {
elem.I(attrs.Props{
"class": icon + " pr-2",
}),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
elem.Text(text),
),
),
//elem.Text(text),
)
}
@@ -300,7 +303,7 @@ func link(text, url string) elem.Node {
elem.I(attrs.Props{
"class": "fas fa-link pr-2",
}),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
elem.Text(text),
)
}
func installButton(galleryName string) elem.Node {
@@ -384,13 +387,13 @@ func ListModels(models []*gallery.GalleryModel, processTracker ProcessTracker, g
attrs.Props{
"class": "mb-2 text-xl font-bold leading-tight",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)),
elem.Text(m.Name),
),
elem.P(
attrs.Props{
"class": "mb-4 text-sm [&:not(:hover)]:truncate text-base",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
elem.Text(m.Description),
),
)
}

View File

@@ -55,7 +55,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
}
// TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
filePath, _, err := backend.SoundGeneration(cfg.Backend, modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -45,13 +45,13 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
@@ -64,7 +64,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
Documents: req.Documents,
}
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -1,60 +0,0 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model"
)
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
//
// @Summary Get TokenMetrics for Active Slot.
// @Accept json
// @Produce audio/x-wav
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenMetricsRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Token Metrics for model: %s", modelFile)
response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.JSON(response)
}
}

View File

@@ -17,14 +17,12 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
if err != nil {
return err
}
loadedModels := ml.ListModels()
for b := range appConfig.ExternalGRPCBackends {
availableBackends = append(availableBackends, b)
}
return c.JSON(
schema.SystemInformationResponse{
Backends: availableBackends,
Models: loadedModels,
},
)
}

View File

@@ -1,58 +0,0 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// TokenizeEndpoint exposes a REST API to tokenize the content
// @Summary Tokenize the input.
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenizeRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
c.JSON(tokenResponse)
return nil
}
}

View File

@@ -13,10 +13,15 @@ import (
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, _ := services.ListModels(cl, ml, "", true)
backendConfigs := cl.GetAllBackendConfigs()
galleryConfigs := map[string]*gallery.Config{}
modelsWithBackendConfig := map[string]interface{}{}
for _, m := range backendConfigs {
modelsWithBackendConfig[m.Name] = nil
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil {
continue
@@ -24,11 +29,17 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
galleryConfigs[m.Name] = cfg
}
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
// Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := modelStatus()
modelsWithoutConfig := []string{}
for _, m := range models {
if _, ok := modelsWithBackendConfig[m]; !ok {
modelsWithoutConfig = append(modelsWithoutConfig, m)
}
}
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
@@ -84,7 +83,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
if !modelExists(cl, ml, request.Model) {
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Model %q not found", request.Model)))
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
}
if request.Tools == nil {
@@ -148,7 +147,7 @@ func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoade
// Convert string limit to integer
limit, err := strconv.Atoi(limitQuery)
if err != nil {
return c.Status(http.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Invalid limit query value: %s", limitQuery)))
return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery))
}
// Sort assistants
@@ -226,7 +225,7 @@ func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
func modelExists(cl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string) (found bool) {
found = false
models, err := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
models, err := services.ListModels(cl, ml, "", true)
if err != nil {
return
}
@@ -289,7 +288,7 @@ func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
}
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
}
}
@@ -338,11 +337,11 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
}
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find file_id: %s", request.FileID)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID))
}
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find %q", assistantID)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID))
}
}
@@ -443,7 +442,7 @@ func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
return c.Status(fiber.StatusOK).JSON(newAssistant)
}
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
}
}
@@ -514,9 +513,9 @@ func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoa
if assistantFile.ID == fileId {
return c.Status(fiber.StatusOK).JSON(assistantFile)
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId))
}
}
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)))
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID))
}
}

View File

@@ -161,12 +161,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
@@ -450,7 +444,6 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse)
@@ -647,16 +640,8 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
videos := []string{}
for _, m := range input.Messages {
videos = append(videos, m.StringVideos...)
}
audios := []string{}
for _, m := range input.Messages {
audios = append(audios, m.StringAudios...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View File

@@ -57,8 +57,6 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}
return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)

View File

@@ -8,7 +8,6 @@ import (
"sync/atomic"
"time"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
@@ -50,7 +49,7 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
err = c.SaveFile(file, savePath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
}
f := schema.File{
@@ -122,7 +121,7 @@ func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Applicat
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(file)
@@ -144,14 +143,14 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
// If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err)))
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
}
}
@@ -181,12 +180,12 @@ func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.Send(fileContents)

View File

@@ -27,17 +27,9 @@ func ComputeChoices(
for _, m := range req.Messages {
images = append(images, m.StringImages...)
}
videos := []string{}
for _, m := range req.Messages {
videos = append(videos, m.StringVideos...)
}
audios := []string{}
for _, m := range req.Messages {
audios = append(audios, m.StringAudios...)
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback)
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}

View File

@@ -18,32 +18,32 @@ func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader)
filter := c.Query("filter")
// By default, exclude any loose files that are already referenced by a configuration file.
var policy services.LooseFilePolicy
if c.QueryBool("excludeConfigured", true) {
policy = services.SKIP_IF_CONFIGURED
} else {
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
}
excludeConfigured := c.QueryBool("excludeConfigured", true)
filterFn, err := config.BuildNameFilterFn(filter)
dataModels, err := modelList(bcl, ml, filter, excludeConfigured)
if err != nil {
return err
}
modelNames, err := services.ListModels(bcl, ml, filterFn, policy)
if err != nil {
return err
}
// Map from a slice of names to a slice of OpenAIModel response objects
dataModels := []schema.OpenAIModel{}
for _, m := range modelNames {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(schema.ModelsDataResponse{
Object: "list",
Data: dataModels,
})
}
}
func modelList(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := services.ListModels(bcl, ml, filter, excludeConfigured)
if err != nil {
return nil, err
}
dataModels := []schema.OpenAIModel{}
// Then iterate through the loose files:
for _, m := range models {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return dataModels, nil
}

View File

@@ -6,22 +6,15 @@ import (
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
@@ -31,14 +24,9 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
}
received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())
ctx, cancel := context.WithCancel(o.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
@@ -147,7 +135,7 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
@@ -156,58 +144,20 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
t := "[vid-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Video != "" {
t = config.TemplateConfig.Video
}
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
vidIndex++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
t := "[audio-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Audio != "" {
t = config.TemplateConfig.Audio
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
audioIndex++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
t := "[img-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Image != "" {
t = config.TemplateConfig.Image
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
imgIndex++
}
}
}

View File

@@ -1,95 +0,0 @@
package middleware
import (
"crypto/subtle"
"errors"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config"
)
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme)
if err != nil {
return nil, err
}
return &v2keyauth.Config{
CustomKeyLookup: customLookup,
Next: getApiKeyRequiredFilterFunction(applicationConfig),
Validator: getApiKeyValidationFunction(applicationConfig),
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
}, nil
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
return func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return ctx.Next() // if no keys are set up, any error we get here is not an error.
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403)
}
return ctx.Status(403).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if apiKey == validKey {
return true, nil
}
}
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
if applicationConfig.DisableApiKeyRequirementForHttpGet {
return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
if rx.MatchString(c.Path()) {
return true
}
}
return false
}
}
return func(c *fiber.Ctx) bool { return false }
}

View File

@@ -10,11 +10,12 @@ import (
func RegisterElevenLabsRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", auth, elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
}

View File

@@ -1,13 +0,0 @@
package routes
import "github.com/gofiber/fiber/v2"
func HealthRoutes(app *fiber.App) {
// Service health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
app.Get("/healthz", ok)
app.Get("/readyz", ok)
}

View File

@@ -11,7 +11,8 @@ import (
func RegisterJINARoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))

View File

@@ -15,55 +15,61 @@ func RegisterLocalAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService) {
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
app.Get("/swagger/*", swagger.HandlerDefault) // default
// LocalAI API endpoints
if !appConfig.DisableGalleryEndpoint {
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
app.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
}
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
// Kubernetes health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
app.Get("/healthz", ok)
app.Get("/readyz", ok)
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
if p2p.IsP2PEnabled() {
app.Get("/api/p2p", localai.ShowP2PNodes(appConfig))
app.Get("/api/p2p/token", localai.ShowP2PToken(appConfig))
app.Get("/api/p2p", auth, localai.ShowP2PNodes(appConfig))
app.Get("/api/p2p/token", auth, localai.ShowP2PToken(appConfig))
}
app.Get("/version", func(c *fiber.Ctx) error {
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
app.Get("/system", localai.SystemInformations(ml, appConfig))
// misc
app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
app.Get("/system", auth, localai.SystemInformations(ml, appConfig))
}

View File

@@ -11,65 +11,66 @@ import (
func RegisterOpenAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
// edit
app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig))
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
// assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
// files
app.Post("/v1/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig))
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
// completion
app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
// images
app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig))
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
@@ -80,6 +81,6 @@ func RegisterOpenAIRoutes(app *fiber.App,
}
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(cl, ml))
app.Get("/models", openai.ListModelsEndpoint(cl, ml))
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
}

View File

@@ -6,7 +6,6 @@ import (
"sort"
"strings"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/elements"
@@ -60,7 +59,8 @@ func RegisterUIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService) {
galleryService *services.GalleryService,
auth func(*fiber.Ctx) error) {
// keeps the state of models that are being installed from the UI
var processingModels = NewModelOpCache()
@@ -85,10 +85,10 @@ func RegisterUIRoutes(app *fiber.App,
return processingModelsData, taskTypes
}
app.Get("/", localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml, modelStatus))
if p2p.IsP2PEnabled() {
app.Get("/p2p", func(c *fiber.Ctx) error {
app.Get("/p2p", auth, func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI - P2P dashboard",
"Version": internal.PrintableVersion(),
@@ -104,17 +104,17 @@ func RegisterUIRoutes(app *fiber.App,
})
/* show nodes live! */
app.Get("/p2p/ui/workers", func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers", auth, func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
})
app.Get("/p2p/ui/workers-federation", func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-federation", auth, func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeBoxes(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
})
app.Get("/p2p/ui/workers-stats", func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-stats", auth, func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID))))
})
app.Get("/p2p/ui/workers-federation-stats", func(c *fiber.Ctx) error {
app.Get("/p2p/ui/workers-federation-stats", auth, func(c *fiber.Ctx) error {
return c.SendString(elements.P2PNodeStats(p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID))))
})
}
@@ -122,7 +122,7 @@ func RegisterUIRoutes(app *fiber.App,
if !appConfig.DisableGalleryEndpoint {
// Show the Models page (all models)
app.Get("/browse", func(c *fiber.Ctx) error {
app.Get("/browse", auth, func(c *fiber.Ctx) error {
term := c.Query("term")
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
@@ -167,12 +167,12 @@ func RegisterUIRoutes(app *fiber.App,
// Show the models, filtered from the user input
// https://htmx.org/examples/active-search/
app.Post("/browse/search/models", func(c *fiber.Ctx) error {
app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error {
form := struct {
Search string `form:"search"`
}{}
if err := c.BodyParser(&form); err != nil {
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
@@ -188,7 +188,7 @@ func RegisterUIRoutes(app *fiber.App,
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/install/model/:id", func(c *fiber.Ctx) error {
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
log.Debug().Msgf("UI job submitted to install : %+v\n", galleryID)
@@ -215,7 +215,7 @@ func RegisterUIRoutes(app *fiber.App,
// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/delete/model/:id", func(c *fiber.Ctx) error {
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!
log.Debug().Msgf("UI job submitted to delete : %+v\n", galleryID)
var galleryName = galleryID
@@ -255,7 +255,7 @@ func RegisterUIRoutes(app *fiber.App,
// Display the job current progress status
// If the job is done, we trigger the /browse/job/:uid route
// https://htmx.org/examples/progress-bar/
app.Get("/browse/job/progress/:uid", func(c *fiber.Ctx) error {
app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
status := galleryService.GetStatus(jobUID)
@@ -279,7 +279,7 @@ func RegisterUIRoutes(app *fiber.App,
// this route is hit when the job is done, and we display the
// final state (for now just displays "Installation completed")
app.Get("/browse/job/:uid", func(c *fiber.Ctx) error {
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {
jobUID := strings.Clone(c.Params("uid")) // note: strings.Clone is required for multiple requests!
status := galleryService.GetStatus(jobUID)
@@ -303,8 +303,8 @@ func RegisterUIRoutes(app *fiber.App,
}
// Show the Chat page
app.Get("/chat/:model", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
app.Get("/chat/:model", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, "", true)
summary := fiber.Map{
"Title": "LocalAI - Chat with " + c.Params("model"),
@@ -318,8 +318,8 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/chat", summary)
})
app.Get("/talk/", func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
app.Get("/talk/", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, "", true)
if len(backendConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
@@ -338,9 +338,9 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/talk", summary)
})
app.Get("/chat/", func(c *fiber.Ctx) error {
app.Get("/chat/", auth, func(c *fiber.Ctx) error {
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
backendConfigs, _ := services.ListModels(cl, ml, "", true)
if len(backendConfigs) == 0 {
// If no model is available redirect to the index which suggests how to install models
@@ -359,7 +359,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/chat", summary)
})
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
app.Get("/text2image/:model", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
@@ -374,7 +374,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/text2image", summary)
})
app.Get("/text2image/", func(c *fiber.Ctx) error {
app.Get("/text2image/", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
@@ -395,7 +395,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/text2image", summary)
})
app.Get("/tts/:model", func(c *fiber.Ctx) error {
app.Get("/tts/:model", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()
summary := fiber.Map{
@@ -410,7 +410,7 @@ func RegisterUIRoutes(app *fiber.App,
return c.Render("views/tts", summary)
})
app.Get("/tts/", func(c *fiber.Ctx) error {
app.Get("/tts/", auth, func(c *fiber.Ctx) error {
backendConfigs := cl.GetAllBackendConfigs()

Some files were not shown because too many files have changed in this diff Show More