Compare commits

..

1 Commits

Author SHA1 Message Date
Ettore Di Giacinto
83110891fd fix(go-grpc-server): always close resultChan
By not closing the channel, if a server not implementing PredictStream
receives a client call would hang indefinetly as would wait for
resultChan to be consumed.

If the prediction stream returns we close the channel now and we wait
for the goroutine to finish.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-10-05 00:07:58 +02:00
128 changed files with 371 additions and 1615 deletions

View File

@@ -6,7 +6,6 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"github.com/microcosm-cc/bluemonday"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -280,12 +279,6 @@ func main() {
return return
} }
// Ensure that all arbitrary text content is sanitized before display
for i, m := range models {
models[i].Name = bluemonday.StrictPolicy().Sanitize(m.Name)
models[i].Description = bluemonday.StrictPolicy().Sanitize(m.Description)
}
// render the template // render the template
data := struct { data := struct {
Models []*GalleryModel Models []*GalleryModel

View File

@@ -9,8 +9,6 @@ updates:
directory: "/" directory: "/"
schedule: schedule:
interval: "weekly" interval: "weekly"
ignore:
- dependency-name: "github.com/mudler/LocalAI/pkg/grpc/proto"
- package-ecosystem: "github-actions" - package-ecosystem: "github-actions"
# Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.) # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.)
directory: "/" directory: "/"

View File

@@ -33,7 +33,7 @@ jobs:
run: | run: |
CGO_ENABLED=0 make build-api CGO_ENABLED=0 make build-api
- name: rm - name: rm
uses: appleboy/ssh-action@v1.1.0 uses: appleboy/ssh-action@v1.0.3
with: with:
host: ${{ secrets.EXPLORER_SSH_HOST }} host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }} username: ${{ secrets.EXPLORER_SSH_USERNAME }}
@@ -53,7 +53,7 @@ jobs:
rm: true rm: true
target: ./local-ai target: ./local-ai
- name: restarting - name: restarting
uses: appleboy/ssh-action@v1.1.0 uses: appleboy/ssh-action@v1.0.3
with: with:
host: ${{ secrets.EXPLORER_SSH_HOST }} host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }} username: ${{ secrets.EXPLORER_SSH_USERNAME }}

View File

@@ -79,7 +79,7 @@ jobs:
args: ${{ steps.summarize.outputs.message }} args: ${{ steps.summarize.outputs.message }}
- name: Setup tmate session if fails - name: Setup tmate session if fails
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -161,7 +161,7 @@ jobs:
TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }} TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
- name: Setup tmate session if fails - name: Setup tmate session if fails
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180

View File

@@ -123,7 +123,7 @@ jobs:
release/* release/*
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -232,7 +232,7 @@ jobs:
release/* release/*
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -308,7 +308,7 @@ jobs:
release/* release/*
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -350,7 +350,7 @@ jobs:
release/* release/*
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180

View File

@@ -105,14 +105,6 @@ jobs:
tests-parler-tts: tests-parler-tts:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Force Install GIT latest
run: |
sudo apt-get update \
&& sudo apt-get install -y software-properties-common \
&& sudo apt-get update \
&& sudo add-apt-repository -y ppa:git-core/ppa \
&& sudo apt-get update \
&& sudo apt-get install -y git
- name: Clone - name: Clone
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:

View File

@@ -133,7 +133,7 @@ jobs:
PATH="$PATH:/root/go/bin" GO_TAGS="stablediffusion tts" make --jobs 5 --output-sync=target test PATH="$PATH:/root/go/bin" GO_TAGS="stablediffusion tts" make --jobs 5 --output-sync=target test
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -197,7 +197,7 @@ jobs:
make run-e2e-aio make run-e2e-aio
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180
@@ -235,7 +235,7 @@ jobs:
BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
- name: Setup tmate session if tests fail - name: Setup tmate session if tests fail
if: ${{ failure() }} if: ${{ failure() }}
uses: mxschmitt/action-tmate@v3.19 uses: mxschmitt/action-tmate@v3.18
with: with:
detached: true detached: true
connect-timeout-seconds: 180 connect-timeout-seconds: 180

View File

@@ -9,8 +9,6 @@ FROM ${BASE_IMAGE} AS requirements-core
USER root USER root
ARG GO_VERSION=1.22.6 ARG GO_VERSION=1.22.6
ARG CMAKE_VERSION=3.26.4
ARG CMAKE_FROM_SOURCE=false
ARG TARGETARCH ARG TARGETARCH
ARG TARGETVARIANT ARG TARGETVARIANT
@@ -23,25 +21,13 @@ RUN apt-get update && \
build-essential \ build-essential \
ccache \ ccache \
ca-certificates \ ca-certificates \
curl libssl-dev \ cmake \
curl \
git \ git \
unzip upx-ucl && \ unzip upx-ucl && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install CMake (the version in 22.04 is too old)
RUN <<EOT bash
if [ "${CMAKE_FROM_SOURCE}}" = "true" ]; then
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
else
apt-get update && \
apt-get install -y \
cmake && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
fi
EOT
# Install Go # Install Go
RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz RUN curl -L -s https://go.dev/dl/go${GO_VERSION}.linux-${TARGETARCH}.tar.gz | tar -C /usr/local -xz
ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin ENV PATH=$PATH:/root/go/bin:/usr/local/go/bin
@@ -202,8 +188,6 @@ FROM ${GRPC_BASE_IMAGE} AS grpc
# This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI # This is a bit of a hack, but it's required in order to be able to effectively cache this layer in CI
ARG GRPC_MAKEFLAGS="-j4 -Otarget" ARG GRPC_MAKEFLAGS="-j4 -Otarget"
ARG GRPC_VERSION=v1.65.0 ARG GRPC_VERSION=v1.65.0
ARG CMAKE_FROM_SOURCE=false
ARG CMAKE_VERSION=3.26.4
ENV MAKEFLAGS=${GRPC_MAKEFLAGS} ENV MAKEFLAGS=${GRPC_MAKEFLAGS}
@@ -212,24 +196,12 @@ WORKDIR /build
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
ca-certificates \ ca-certificates \
build-essential curl libssl-dev \ build-essential \
cmake \
git && \ git && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install CMake (the version in 22.04 is too old)
RUN <<EOT bash
if [ "${CMAKE_FROM_SOURCE}}" = "true" ]; then
curl -L -s https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && tar xvf cmake.tar.gz && cd cmake-${CMAKE_VERSION} && ./configure && make && make install
else
apt-get update && \
apt-get install -y \
cmake && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
fi
EOT
# We install GRPC to a different prefix here so that we can copy in only the build artifacts later # We install GRPC to a different prefix here so that we can copy in only the build artifacts later
# saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree # saves several hundred MB on the final docker image size vs copying in the entire GRPC source tree
# and running make install in the target container # and running make install in the target container

View File

@@ -8,7 +8,7 @@ DETECT_LIBS?=true
# llama.cpp versions # llama.cpp versions
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
CPPLLAMA_VERSION?=0a1c750c80147687df267114c81956757cc14382 CPPLLAMA_VERSION?=d5ed2b929d85bbd7dbeecb690880f07d9d7a6077
# go-rwkv version # go-rwkv version
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
@@ -16,7 +16,7 @@ RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
# whisper.cpp version # whisper.cpp version
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
WHISPER_CPP_VERSION?=0fbaac9c891055796456df7b9122a70c220f9ca1 WHISPER_CPP_VERSION?=ccc2547210e09e3a1785817383ab770389bb442b
# bert.cpp version # bert.cpp version
BERT_REPO?=https://github.com/go-skynet/go-bert.cpp BERT_REPO?=https://github.com/go-skynet/go-bert.cpp
@@ -470,13 +470,13 @@ run-e2e-image:
run-e2e-aio: protogen-go run-e2e-aio: protogen-go
@echo 'Running e2e AIO tests' @echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio
test-e2e: test-e2e:
@echo 'Running e2e tests' @echo 'Running e2e tests'
BUILD_TYPE=$(BUILD_TYPE) \ BUILD_TYPE=$(BUILD_TYPE) \
LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e
teardown-e2e: teardown-e2e:
rm -rf $(TEST_DIR) || true rm -rf $(TEST_DIR) || true
@@ -484,24 +484,24 @@ teardown-e2e:
test-llama: prepare-test test-llama: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
test-llama-gguf: prepare-test test-llama-gguf: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
test-tts: prepare-test test-tts: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stablediffusion: prepare-test test-stablediffusion: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stores: backend-assets/grpc/local-store test-stores: backend-assets/grpc/local-store
mkdir -p tests/integration/backend-assets/grpc mkdir -p tests/integration/backend-assets/grpc
cp -f backend-assets/grpc/local-store tests/integration/backend-assets/grpc/ cp -f backend-assets/grpc/local-store tests/integration/backend-assets/grpc/
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts 1 -v -r tests/integration
test-container: test-container:
docker build --target requirements -t local-ai-test-container . docker build --target requirements -t local-ai-test-container .

View File

@@ -66,21 +66,6 @@ docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-aio-cpu
# docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12 # docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12
``` ```
To load models:
```bash
# From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io)
local-ai run llama-3.2-1b-instruct:q4_k_m
# Start LocalAI with the phi-2 model directly from huggingface
local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf
# Install and run a model from the Ollama OCI registry
local-ai run ollama://gemma:2b
# Run a model from a configuration file
local-ai run https://gist.githubusercontent.com/.../phi-2.yaml
# Install and run a model from a standard OCI registry (e.g., Docker Hub)
local-ai run oci://localai/phi-2:latest
```
[💻 Getting started](https://localai.io/basics/getting_started/index.html) [💻 Getting started](https://localai.io/basics/getting_started/index.html)
## 📰 Latest project news ## 📰 Latest project news

View File

@@ -219,7 +219,6 @@ message ModelOptions {
int32 SwapSpace = 53; int32 SwapSpace = 53;
int32 MaxModelLen = 54; int32 MaxModelLen = 54;
int32 TensorParallelSize = 55; int32 TensorParallelSize = 55;
string LoadFormat = 58;
string MMProj = 41; string MMProj = 41;

View File

@@ -113,7 +113,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
std::string ret; std::string ret;
for (; begin != end; ++begin) for (; begin != end; ++begin)
{ {
ret += common_token_to_piece(ctx, *begin); ret += llama_token_to_piece(ctx, *begin);
} }
return ret; return ret;
} }
@@ -121,7 +121,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
// format incomplete utf-8 multibyte character for output // format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{ {
std::string out = token == -1 ? "" : common_token_to_piece(ctx, token); std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character // if the size is 1 and first bit is 1, meaning it's a partial character
// (size > 1 meaning it's already a known token) // (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80) if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -203,8 +203,8 @@ struct llama_client_slot
std::string stopping_word; std::string stopping_word;
// sampling // sampling
struct common_sampler_params sparams; struct gpt_sampler_params sparams;
common_sampler *ctx_sampling = nullptr; gpt_sampler *ctx_sampling = nullptr;
int32_t ga_i = 0; // group-attention state int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor int32_t ga_n = 1; // group-attention factor
@@ -257,7 +257,7 @@ struct llama_client_slot
images.clear(); images.clear();
} }
bool has_budget(common_params &global_params) { bool has_budget(gpt_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) if (params.n_predict == -1 && global_params.n_predict == -1)
{ {
return true; // limitless return true; // limitless
@@ -391,39 +391,6 @@ struct llama_metrics {
} }
}; };
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
struct llama_server_context struct llama_server_context
{ {
llama_model *model = nullptr; llama_model *model = nullptr;
@@ -431,7 +398,7 @@ struct llama_server_context
clip_ctx *clp_ctx = nullptr; clip_ctx *clp_ctx = nullptr;
common_params params; gpt_params params;
llama_batch batch; llama_batch batch;
@@ -474,7 +441,7 @@ struct llama_server_context
} }
} }
bool load_model(const common_params &params_) bool load_model(const gpt_params &params_)
{ {
params = params_; params = params_;
if (!params.mmproj.empty()) { if (!params.mmproj.empty()) {
@@ -491,9 +458,9 @@ struct llama_server_context
} }
} }
common_init_result common_init = common_init_from_params(params); llama_init_result llama_init = llama_init_from_gpt_params(params);
model = common_init.model; model = llama_init.model;
ctx = common_init.context; ctx = llama_init.context;
if (model == nullptr) if (model == nullptr)
{ {
LOG_ERR("unable to load model: %s", params.model.c_str()); LOG_ERR("unable to load model: %s", params.model.c_str());
@@ -611,12 +578,12 @@ struct llama_server_context
std::vector<llama_token> p; std::vector<llama_token> p;
if (first) if (first)
{ {
p = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
first = false; first = false;
} }
else else
{ {
p = common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
} }
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} }
@@ -633,7 +600,7 @@ struct llama_server_context
else else
{ {
auto s = json_prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
} }
return prompt_tokens; return prompt_tokens;
@@ -662,7 +629,7 @@ struct llama_server_context
bool launch_slot_with_data(llama_client_slot* &slot, json data) { bool launch_slot_with_data(llama_client_slot* &slot, json data) {
slot_params default_params; slot_params default_params;
common_sampler_params default_sparams; gpt_sampler_params default_sparams;
slot->params.stream = json_value(data, "stream", false); slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false); slot->params.cache_prompt = json_value(data, "cache_prompt", false);
@@ -802,7 +769,7 @@ struct llama_server_context
} }
else if (el[0].is_string()) else if (el[0].is_string())
{ {
auto toks = common_tokenize(model, el[0].get<std::string>(), false); auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) for (auto tok : toks)
{ {
slot->sparams.logit_bias.push_back({tok, bias}); slot->sparams.logit_bias.push_back({tok, bias});
@@ -834,7 +801,7 @@ struct llama_server_context
sampler_names.emplace_back(name); sampler_names.emplace_back(name);
} }
} }
slot->sparams.samplers = common_sampler_types_from_names(sampler_names, false); slot->sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
} }
else else
{ {
@@ -918,9 +885,9 @@ struct llama_server_context
if (slot->ctx_sampling != nullptr) if (slot->ctx_sampling != nullptr)
{ {
common_sampler_free(slot->ctx_sampling); gpt_sampler_free(slot->ctx_sampling);
} }
slot->ctx_sampling = common_sampler_init(model, slot->sparams); slot->ctx_sampling = gpt_sampler_init(model, slot->sparams);
//llama_set_rng_seed(ctx, slot->params.seed); //llama_set_rng_seed(ctx, slot->params.seed);
slot->command = LOAD_PROMPT; slot->command = LOAD_PROMPT;
@@ -947,13 +914,13 @@ struct llama_server_context
system_tokens.clear(); system_tokens.clear();
if (!system_prompt.empty()) { if (!system_prompt.empty()) {
system_tokens = common_tokenize(ctx, system_prompt, add_bos_token); system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
common_batch_clear(batch); llama_batch_clear(batch);
for (int i = 0; i < (int)system_tokens.size(); ++i) for (int i = 0; i < (int)system_tokens.size(); ++i)
{ {
common_batch_add(batch, system_tokens[i], i, { 0 }, false); llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
} }
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
@@ -967,6 +934,7 @@ struct llama_server_context
batch.n_seq_id + i, batch.n_seq_id + i,
batch.seq_id + i, batch.seq_id + i,
batch.logits + i, batch.logits + i,
0, 0, 0, // unused
}; };
if (llama_decode(ctx, batch_view) != 0) if (llama_decode(ctx, batch_view) != 0)
{ {
@@ -1041,7 +1009,7 @@ struct llama_server_context
bool process_token(completion_token_output &result, llama_client_slot &slot) { bool process_token(completion_token_output &result, llama_client_slot &slot) {
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok); const std::string token_str = llama_token_to_piece(ctx, result.tok);
slot.sampled = result.tok; slot.sampled = result.tok;
// search stop word and delete it // search stop word and delete it
@@ -1192,7 +1160,7 @@ struct llama_server_context
samplers.reserve(slot.sparams.samplers.size()); samplers.reserve(slot.sparams.samplers.size());
for (const auto & sampler : slot.sparams.samplers) for (const auto & sampler : slot.sparams.samplers)
{ {
samplers.emplace_back(common_sampler_type_to_str(sampler)); samplers.emplace_back(gpt_sampler_type_to_str(sampler));
} }
return json { return json {
@@ -1248,7 +1216,7 @@ struct llama_server_context
if (slot.sparams.n_probs > 0) if (slot.sparams.n_probs > 0)
{ {
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size()); size_t probs_pos = std::min(slot.sent_token_probs_index, slot.generated_token_probs.size());
size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size()); size_t probs_stop_pos = std::min(slot.sent_token_probs_index + to_send_toks.size(), slot.generated_token_probs.size());
if (probs_pos < probs_stop_pos) if (probs_pos < probs_stop_pos)
@@ -1300,7 +1268,7 @@ struct llama_server_context
std::vector<completion_token_output> probs = {}; std::vector<completion_token_output> probs = {};
if (!slot.params.stream && slot.stopped_word) if (!slot.params.stream && slot.stopped_word)
{ {
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size());
} }
else else
@@ -1411,6 +1379,7 @@ struct llama_server_context
batch.n_seq_id + i, batch.n_seq_id + i,
batch.seq_id + i, batch.seq_id + i,
batch.logits + i, batch.logits + i,
0, 0, 0, // unused
}; };
if (llama_decode(ctx, batch_view)) if (llama_decode(ctx, batch_view))
{ {
@@ -1429,9 +1398,8 @@ struct llama_server_context
} }
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
float * embd = img.image_embedding + i * n_embd; llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, slot.n_past, 0); if (llama_decode(ctx, batch_img))
if (llama_decode(ctx, llava_batch.batch))
{ {
LOG("%s : failed to eval image\n", __func__); LOG("%s : failed to eval image\n", __func__);
return false; return false;
@@ -1440,7 +1408,7 @@ struct llama_server_context
} }
image_idx++; image_idx++;
common_batch_clear(batch); llama_batch_clear(batch);
// append prefix of next image // append prefix of next image
const auto json_prompt = (image_idx >= (int) slot.images.size()) ? const auto json_prompt = (image_idx >= (int) slot.images.size()) ?
@@ -1450,7 +1418,7 @@ struct llama_server_context
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i) for (int i = 0; i < (int) append_tokens.size(); ++i)
{ {
common_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
} }
} }
@@ -1582,7 +1550,7 @@ struct llama_server_context
update_system_prompt(); update_system_prompt();
} }
common_batch_clear(batch); llama_batch_clear(batch);
if (all_slots_are_idle) if (all_slots_are_idle)
{ {
@@ -1660,7 +1628,7 @@ struct llama_server_context
// TODO: we always have to take into account the "system_tokens" // TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow // this is not great and needs to be improved somehow
common_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
} }
@@ -1754,7 +1722,7 @@ struct llama_server_context
if (!slot.params.cache_prompt) if (!slot.params.cache_prompt)
{ {
common_sampler_reset(slot.ctx_sampling); gpt_sampler_reset(slot.ctx_sampling);
slot.n_past = 0; slot.n_past = 0;
slot.n_past_se = 0; slot.n_past_se = 0;
@@ -1766,7 +1734,7 @@ struct llama_server_context
// push the prompt into the sampling context (do not apply grammar) // push the prompt into the sampling context (do not apply grammar)
for (auto &token : prompt_tokens) for (auto &token : prompt_tokens)
{ {
common_sampler_accept(slot.ctx_sampling, token, false); gpt_sampler_accept(slot.ctx_sampling, token, false);
} }
slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
@@ -1858,7 +1826,7 @@ struct llama_server_context
ga_i += ga_w/ga_n; ga_i += ga_w/ga_n;
} }
} }
common_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
slot_npast++; slot_npast++;
} }
@@ -1936,6 +1904,7 @@ struct llama_server_context
batch.n_seq_id + i, batch.n_seq_id + i,
batch.seq_id + i, batch.seq_id + i,
batch.logits + i, batch.logits + i,
0, 0, 0, // unused
}; };
const int ret = llama_decode(ctx, batch_view); const int ret = llama_decode(ctx, batch_view);
@@ -1974,9 +1943,9 @@ struct llama_server_context
} }
completion_token_output result; completion_token_output result;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i); const llama_token id = gpt_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
common_sampler_accept(slot.ctx_sampling, id, true); gpt_sampler_accept(slot.ctx_sampling, id, true);
slot.n_decoded += 1; slot.n_decoded += 1;
if (slot.n_decoded == 1) if (slot.n_decoded == 1)
@@ -1987,7 +1956,7 @@ struct llama_server_context
} }
result.tok = id; result.tok = id;
const auto * cur_p = common_sampler_get_candidates(slot.ctx_sampling); const auto * cur_p = gpt_sampler_get_candidates(slot.ctx_sampling);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({ result.probs.push_back({
@@ -2040,7 +2009,7 @@ static json format_partial_response(
struct token_translator struct token_translator
{ {
llama_context * ctx; llama_context * ctx;
std::string operator()(llama_token tok) const { return common_token_to_piece(ctx, tok); } std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
}; };
@@ -2234,7 +2203,7 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
// } // }
static void params_parse(const backend::ModelOptions* request, static void params_parse(const backend::ModelOptions* request,
common_params & params) { gpt_params & params) {
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
@@ -2342,7 +2311,7 @@ public:
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
// Implement LoadModel RPC // Implement LoadModel RPC
common_params params; gpt_params params;
params_parse(request, params); params_parse(request, params);
llama_backend_init(); llama_backend_init();

View File

@@ -1,2 +1,2 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch

View File

@@ -1 +1 @@
torch==2.4.1 torch

View File

@@ -1,2 +1,2 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch

View File

@@ -1,6 +1,6 @@
accelerate accelerate
auto-gptq==0.7.1 auto-gptq==0.7.1
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
transformers transformers

View File

@@ -1,4 +1,4 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
torchaudio==2.4.1+cu118 torchaudio
transformers transformers
accelerate accelerate

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio
transformers transformers
accelerate accelerate

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch
torchaudio==2.4.1+rocm6.0 torchaudio
transformers transformers
accelerate accelerate

View File

@@ -1,4 +1,4 @@
bark==0.1.5 bark==0.1.5
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi

View File

@@ -1,2 +1,2 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf

View File

@@ -1,4 +1,3 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch
coqui-tts

View File

@@ -1,6 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
torchaudio==2.4.1+cu118 torchaudio
transformers transformers
accelerate accelerate
coqui-tts

View File

@@ -1,5 +1,4 @@
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio
transformers transformers
accelerate accelerate
coqui-tts

View File

@@ -1,6 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch
torchaudio==2.4.1+rocm6.0 torchaudio
transformers transformers
accelerate accelerate
coqui-tts

View File

@@ -6,4 +6,3 @@ optimum[openvino]
setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406 setuptools==75.1.0 # https://github.com/mudler/LocalAI/issues/2406
transformers transformers
accelerate accelerate
coqui-tts

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 coqui-tts
grpcio==1.66.2
protobuf protobuf
certifi certifi
packaging==24.1

View File

@@ -19,7 +19,7 @@ class TestBackendServicer(unittest.TestCase):
This method sets up the gRPC service by starting the server This method sets up the gRPC service by starting the server
""" """
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
time.sleep(30) time.sleep(10)
def tearDown(self) -> None: def tearDown(self) -> None:
""" """

View File

@@ -5,5 +5,5 @@ accelerate
compel compel
peft peft
sentencepiece sentencepiece
torch==2.4.1 torch
optimum-quanto optimum-quanto

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
diffusers diffusers
opencv-python opencv-python
transformers transformers

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
diffusers diffusers
opencv-python opencv-python
transformers transformers

View File

@@ -1,5 +1,5 @@
setuptools setuptools
grpcio==1.67.0 grpcio==1.66.2
pillow pillow
protobuf protobuf
certifi certifi

View File

@@ -1,3 +1,3 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch

View File

@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
transformers transformers
accelerate accelerate

View File

@@ -1,3 +1,3 @@
torch==2.4.1 torch
transformers transformers
accelerate accelerate

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
wheel wheel

View File

@@ -1,2 +1,2 @@
torch==2.4.1 torch
transformers transformers

View File

@@ -1,3 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
transformers transformers

View File

@@ -1,2 +1,2 @@
torch==2.4.1 torch
transformers transformers

View File

@@ -1,3 +1,3 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi

View File

@@ -1,3 +1 @@
torch==2.4.1 torch
git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -1,4 +1,2 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -1,3 +1 @@
torch==2.4.1 torch
git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -1,4 +1,2 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch
git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -2,22 +2,22 @@
intel-extension-for-pytorch intel-extension-for-pytorch
torch torch
optimum[openvino] optimum[openvino]
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
librosa==0.9.1 librosa==0.9.1
faster-whisper==0.9.0 faster-whisper==1.0.3
pydub==0.25.1 pydub==0.25.1
wavmark==0.0.3 wavmark==0.0.3
numpy==1.22.0 numpy==1.26.4
eng_to_ipa==0.0.2 eng_to_ipa==0.0.2
inflect==7.0.0 inflect==7.0.0
unidecode==1.3.7 unidecode==1.3.7
whisper-timestamped==1.14.2 whisper-timestamped==1.15.4
openai openai
python-dotenv python-dotenv
pypinyin==0.50.0 pypinyin==0.50.0
cn2an==0.5.22 cn2an==0.5.22
jieba==0.42.1 jieba==0.42.1
gradio==4.44.1
langid==1.1.6 langid==1.1.6
git+https://github.com/myshell-ai/MeloTTS.git git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -1,10 +1,10 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
librosa librosa
faster-whisper faster-whisper
pydub==0.25.1 pydub==0.25.1
wavmark==0.0.3 wavmark==0.0.3
numpy==1.22.0 numpy
eng_to_ipa==0.0.2 eng_to_ipa==0.0.2
inflect inflect
unidecode unidecode
@@ -13,8 +13,8 @@ openai
python-dotenv python-dotenv
pypinyin pypinyin
cn2an==0.5.22 cn2an==0.5.22
networkx==2.8.8
jieba==0.42.1 jieba==0.42.1
gradio==3.48.0 gradio
langid==1.1.6 langid==1.1.6
llvmlite==0.43.0 git+https://github.com/myshell-ai/MeloTTS.git
git+https://github.com/myshell-ai/OpenVoice.git

View File

@@ -19,7 +19,7 @@ class TestBackendServicer(unittest.TestCase):
This method sets up the gRPC service by starting the server This method sets up the gRPC service by starting the server
""" """
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
time.sleep(30) time.sleep(10)
def tearDown(self) -> None: def tearDown(self) -> None:
""" """

View File

@@ -15,12 +15,12 @@ installRequirements
# https://github.com/descriptinc/audiotools/issues/101 # https://github.com/descriptinc/audiotools/issues/101
# incompatible protobuf versions. # incompatible protobuf versions.
# PYDIR=python3.10 PYDIR=python3.10
# pyenv="${MY_DIR}/venv/lib/${PYDIR}/site-packages/google/protobuf/internal/" pyenv="${MY_DIR}/venv/lib/${PYDIR}/site-packages/google/protobuf/internal/"
# if [ ! -d ${pyenv} ]; then if [ ! -d ${pyenv} ]; then
# echo "(parler-tts/install.sh): Error: ${pyenv} does not exist" echo "(parler-tts/install.sh): Error: ${pyenv} does not exist"
# exit 1 exit 1
# fi fi
# curl -L https://raw.githubusercontent.com/protocolbuffers/protobuf/main/python/google/protobuf/internal/builder.py -o ${pyenv}/builder.py curl -L https://raw.githubusercontent.com/protocolbuffers/protobuf/main/python/google/protobuf/internal/builder.py -o ${pyenv}/builder.py

View File

@@ -1,4 +1,3 @@
git+https://github.com/huggingface/parler-tts.git@8e465f1b5fcd223478e07175cb40494d19ffbe17 git+https://github.com/huggingface/parler-tts.git@8e465f1b5fcd223478e07175cb40494d19ffbe17
llvmlite==0.43.0 llvmlite==0.43.0
numba==0.60.0 numba==0.60.0
git+https://github.com/descriptinc/audiotools

View File

@@ -1,3 +1,3 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
torchaudio==2.4.1+cu118 torchaudio
transformers transformers
accelerate accelerate

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio
transformers transformers
accelerate accelerate

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
llvmlite==0.43.0 llvmlite==0.43.0

View File

@@ -1,4 +1,4 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch
rerankers[transformers] rerankers[transformers]

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
transformers transformers
accelerate accelerate
torch==2.4.1+cu118 torch
rerankers[transformers] rerankers[transformers]

View File

@@ -1,4 +1,4 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch
rerankers[transformers] rerankers[transformers]

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
transformers transformers
accelerate accelerate
torch==2.4.1+rocm6.0 torch
rerankers[transformers] rerankers[transformers]

View File

@@ -1,3 +1,3 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi

View File

@@ -1,6 +1,6 @@
torch==2.4.1 torch
accelerate accelerate
transformers transformers
bitsandbytes bitsandbytes
sentence-transformers==3.2.0 sentence-transformers==3.1.1
transformers transformers

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
accelerate accelerate
sentence-transformers==3.2.0 sentence-transformers==3.1.1
transformers transformers

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
accelerate accelerate
sentence-transformers==3.2.0 sentence-transformers==3.1.1
transformers transformers

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch
accelerate accelerate
sentence-transformers==3.2.0 sentence-transformers==3.1.1
transformers transformers

View File

@@ -4,5 +4,5 @@ torch
optimum[openvino] optimum[openvino]
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406 setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406
accelerate accelerate
sentence-transformers==3.2.0 sentence-transformers==3.1.1
transformers transformers

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
datasets datasets

View File

@@ -1,3 +1,3 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch

View File

@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
transformers transformers
accelerate accelerate
torch==2.4.1+cu118 torch

View File

@@ -1,3 +1,3 @@
transformers transformers
accelerate accelerate
torch==2.4.1 torch

View File

@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
transformers transformers
accelerate accelerate
torch==2.4.1+rocm6.0 torch

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
scipy==1.14.0 scipy==1.14.0
certifi certifi

View File

@@ -72,13 +72,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns: Returns:
A Result object that contains the result of the LoadModel operation. A Result object that contains the result of the LoadModel operation.
""" """
model_name = request.Model model_name = request.Model
# Check to see if the Model exists in the filesystem already.
if os.path.exists(request.ModelFile):
model_name = request.ModelFile
compute = torch.float16 compute = torch.float16
if request.F16Memory == True: if request.F16Memory == True:
compute=torch.bfloat16 compute=torch.bfloat16

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
accelerate accelerate
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118 torch
accelerate accelerate
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,4 +1,4 @@
torch==2.4.1 torch
accelerate accelerate
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0 torch
accelerate accelerate
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406 setuptools==69.5.1 # https://github.com/mudler/LocalAI/issues/2406

View File

@@ -1,3 +1,3 @@
accelerate accelerate
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio

View File

@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
accelerate accelerate
torch==2.4.1+cu118 torch
torchaudio==2.4.1+cu118 torchaudio

View File

@@ -1,3 +1,3 @@
accelerate accelerate
torch==2.4.1 torch
torchaudio==2.4.1 torchaudio

View File

@@ -1,3 +1,3 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi

View File

@@ -19,8 +19,6 @@ from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
import base64
import io
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -95,8 +93,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.Quantization != "": if request.Quantization != "":
engine_args.quantization = request.Quantization engine_args.quantization = request.Quantization
if request.LoadFormat != "":
engine_args.load_format = request.LoadFormat
if request.GPUMemoryUtilization != 0: if request.GPUMemoryUtilization != 0:
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
if request.TrustRemoteCode: if request.TrustRemoteCode:
@@ -221,15 +217,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Generate text using the LLM engine # Generate text using the LLM engine
request_id = random_uuid() request_id = random_uuid()
print(f"Generating text with request_id: {request_id}", file=sys.stderr) print(f"Generating text with request_id: {request_id}", file=sys.stderr)
multi_modal_data = {}
if image_data:
multi_modal_data["image"] = image_data
if video_data:
multi_modal_data["video"] = video_data
outputs = self.llm.generate( outputs = self.llm.generate(
{ {
"prompt": prompt, "prompt": prompt,
"multi_modal_data": multi_modal_data if multi_modal_data else None, "multi_modal_data": {
"image": image_data if image_data else None,
"video": video_data if video_data else None,
} if image_data or video_data else None,
}, },
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
@@ -268,22 +262,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def load_image(self, image_path: str): def load_image(self, image_path: str):
""" """
Load an image from the given file path or base64 encoded data. Load an image from the given file path.
Args: Args:
image_path (str): The path to the image file or base64 encoded data. image_path (str): The path to the image file.
Returns: Returns:
Image: The loaded image. Image: The loaded image.
""" """
try: try:
return Image.open(image_path)
image_data = base64.b64decode(image_path)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e: except Exception as e:
print(f"Error loading image {image_path}: {e}", file=sys.stderr) print(f"Error loading image {image_path}: {e}", file=sys.stderr)
return None return self.load_video(image_path)
def load_video(self, video_path: str): def load_video(self, video_path: str):
""" """
@@ -296,15 +287,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Video: The loaded video. Video: The loaded video.
""" """
try: try:
timestamp = str(int(time.time() * 1000)) # Generate timestamp video = VideoAsset(name=video_path).np_ndarrays
p = f"/tmp/vl-{timestamp}.data" # Use timestamp in filename
with open(p, "wb") as f:
f.write(base64.b64decode(video_path))
video = VideoAsset(name=p).np_ndarrays
os.remove(p)
return video return video
except Exception as e: except Exception as e:
print(f"Error loading video {video_path}: {e}", file=sys.stderr) print(f"Error loading video {image_path}: {e}", file=sys.stderr)
return None return None
async def serve(address): async def serve(address):

View File

@@ -13,16 +13,14 @@ if [ "x${BUILD_PROFILE}" == "xintel" ]; then
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
fi fi
# We don't embed this into the images as it is a large dependency and not always needed. if [ "x${BUILD_TYPE}" == "x" ]; then
# Besides, the speed inference are not actually usable in the current state for production use-cases.
if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE}" == "xtrue" ]; then
ensureVenv ensureVenv
# https://docs.vllm.ai/en/v0.6.1/getting_started/cpu-installation.html # https://docs.vllm.ai/en/v0.6.1/getting_started/cpu-installation.html
if [ ! -d vllm ]; then if [ ! -d vllm ]; then
git clone https://github.com/vllm-project/vllm git clone https://github.com/vllm-project/vllm
fi fi
pushd vllm pushd vllm
uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.67.0 protobuf bitsandbytes uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.66.2 protobuf bitsandbytes
uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
VLLM_TARGET_DEVICE=cpu python setup.py install VLLM_TARGET_DEVICE=cpu python setup.py install
popd popd

View File

@@ -1,3 +1,3 @@
accelerate accelerate
torch==2.4.1 torch
transformers transformers

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
accelerate accelerate
torch==2.4.1+cu118 torch
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,4 +1,4 @@
accelerate accelerate
torch==2.4.1 torch
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0 --extra-index-url https://download.pytorch.org/whl/rocm6.0
accelerate accelerate
torch==2.4.1+rocm6.0 torch
transformers transformers
bitsandbytes bitsandbytes

View File

@@ -1,4 +1,4 @@
grpcio==1.67.0 grpcio==1.66.2
protobuf protobuf
certifi certifi
setuptools setuptools

View File

@@ -2,7 +2,6 @@ package backend
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
@@ -78,16 +77,6 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
switch ct := message.Content.(type) { switch ct := message.Content.(type) {
case string: case string:
protoMessages[i].Content = ct protoMessages[i].Content = ct
case []interface{}:
// If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here
data, _ := json.Marshal(ct)
resultData := []struct {
Text string `json:"text"`
}{}
json.Unmarshal(data, &resultData)
for _, r := range resultData {
protoMessages[i].Content += r.Text
}
default: default:
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
} }

View File

@@ -139,7 +139,6 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
DraftModel: c.DraftModel, DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath, AudioPath: c.VallE.AudioPath,
Quantization: c.Quantization, Quantization: c.Quantization,
LoadFormat: c.LoadFormat,
GPUMemoryUtilization: c.GPUMemoryUtilization, GPUMemoryUtilization: c.GPUMemoryUtilization,
TrustRemoteCode: c.TrustRemoteCode, TrustRemoteCode: c.TrustRemoteCode,
EnforceEager: c.EnforceEager, EnforceEager: c.EnforceEager,

View File

@@ -53,7 +53,6 @@ type RunCMD struct {
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"` UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"` DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"`
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"` HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/browse/?$,^/talk/?$,^/p2p/?$,^/chat/?$,^/text2image/?$,^/tts/?$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"` Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"`
@@ -109,10 +108,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithLoadToMemory(r.LoadToMemory), config.WithLoadToMemory(r.LoadToMemory),
} }
if r.DisableMetricsEndpoint {
opts = append(opts, config.DisableMetricsEndpoint)
}
token := "" token := ""
if r.Peer2Peer || r.Peer2PeerToken != "" { if r.Peer2Peer || r.Peer2PeerToken != "" {
log.Info().Msg("P2P mode enabled") log.Info().Msg("P2P mode enabled")

View File

@@ -39,7 +39,6 @@ type ApplicationConfig struct {
OpaqueErrors bool OpaqueErrors bool
UseSubtleKeyComparison bool UseSubtleKeyComparison bool
DisableApiKeyRequirementForHttpGet bool DisableApiKeyRequirementForHttpGet bool
DisableMetrics bool
HttpGetExemptedEndpoints []*regexp.Regexp HttpGetExemptedEndpoints []*regexp.Regexp
DisableGalleryEndpoint bool DisableGalleryEndpoint bool
LoadToMemory []string LoadToMemory []string
@@ -351,10 +350,6 @@ func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
} }
} }
var DisableMetricsEndpoint AppOption = func(o *ApplicationConfig) {
o.DisableMetrics = true
}
func WithHttpGetExemptedEndpoints(endpoints []string) AppOption { func WithHttpGetExemptedEndpoints(endpoints []string) AppOption {
return func(o *ApplicationConfig) { return func(o *ApplicationConfig) {
o.HttpGetExemptedEndpoints = []*regexp.Regexp{} o.HttpGetExemptedEndpoints = []*regexp.Regexp{}

View File

@@ -143,7 +143,6 @@ type LLMConfig struct {
DraftModel string `yaml:"draft_model"` DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"` NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"` Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM EnforceEager bool `yaml:"enforce_eager"` // vLLM
@@ -198,7 +197,9 @@ type TemplateConfig struct {
// It defaults to \n // It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
Multimodal string `yaml:"multimodal"` Video string `yaml:"video"`
Image string `yaml:"image"`
Audio string `yaml:"audio"`
} }
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {

View File

@@ -109,21 +109,19 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Use(recover.New()) app.Use(recover.New())
} }
if !appConfig.DisableMetrics { metricsService, err := services.NewLocalAIMetricsService()
metricsService, err := services.NewLocalAIMetricsService() if err != nil {
if err != nil { return nil, err
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
} }
// Health Checks should always be exempt from auth, so register these first
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(app) routes.HealthRoutes(app)
kaConfig, err := middleware.GetKeyAuthConfig(appConfig) kaConfig, err := middleware.GetKeyAuthConfig(appConfig)

View File

@@ -12,7 +12,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http" . "github.com/mudler/LocalAI/core/http"
@@ -951,7 +950,7 @@ var _ = Describe("API test", func() {
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue()) Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"))) Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -970,7 +969,7 @@ var _ = Describe("API test", func() {
tokens++ tokens++
} }
Expect(text).ToNot(BeEmpty()) Expect(text).ToNot(BeEmpty())
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five"))) Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0))) Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
}) })

View File

@@ -6,7 +6,6 @@ import (
"github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/attrs"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
@@ -42,7 +41,7 @@ func DoneProgress(galleryID, text string, showDelete bool) string {
"tabindex": "-1", "tabindex": "-1",
"autofocus": "", "autofocus": "",
}, },
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), elem.Text(text),
), ),
elem.If(showDelete, deleteButton(galleryID, modelName), reInstallButton(galleryID)), elem.If(showDelete, deleteButton(galleryID, modelName), reInstallButton(galleryID)),
).Render() ).Render()
@@ -58,7 +57,7 @@ func ErrorProgress(err, galleryName string) string {
"tabindex": "-1", "tabindex": "-1",
"autofocus": "", "autofocus": "",
}, },
elem.Text("Error "+bluemonday.StrictPolicy().Sanitize(err)), elem.Text("Error "+err),
), ),
installButton(galleryName), installButton(galleryName),
).Render() ).Render()
@@ -171,7 +170,7 @@ func P2PNodeBoxes(nodes []p2p.NodeData) string {
attrs.Props{ attrs.Props{
"class": "text-gray-200 font-semibold ml-2 mr-1", "class": "text-gray-200 font-semibold ml-2 mr-1",
}, },
elem.Text(bluemonday.StrictPolicy().Sanitize(n.ID)), elem.Text(n.ID),
), ),
elem.Text("Status: "), elem.Text("Status: "),
elem.If( elem.If(
@@ -228,7 +227,7 @@ func StartProgressBar(uid, progress, text string) string {
"tabindex": "-1", "tabindex": "-1",
"autofocus": "", "autofocus": "",
}, },
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive elem.Text(text),
elem.Div(attrs.Props{ elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid, "hx-get": "/browse/job/progress/" + uid,
"hx-trigger": "every 600ms", "hx-trigger": "every 600ms",
@@ -250,7 +249,9 @@ func cardSpan(text, icon string) elem.Node {
"class": icon + " pr-2", "class": icon + " pr-2",
}), }),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), elem.Text(text),
//elem.Text(text),
) )
} }
@@ -284,9 +285,11 @@ func searchableElement(text, icon string) elem.Node {
elem.I(attrs.Props{ elem.I(attrs.Props{
"class": icon + " pr-2", "class": icon + " pr-2",
}), }),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), elem.Text(text),
), ),
), ),
//elem.Text(text),
) )
} }
@@ -300,7 +303,7 @@ func link(text, url string) elem.Node {
elem.I(attrs.Props{ elem.I(attrs.Props{
"class": "fas fa-link pr-2", "class": "fas fa-link pr-2",
}), }),
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), elem.Text(text),
) )
} }
func installButton(galleryName string) elem.Node { func installButton(galleryName string) elem.Node {
@@ -384,13 +387,13 @@ func ListModels(models []*gallery.GalleryModel, processTracker ProcessTracker, g
attrs.Props{ attrs.Props{
"class": "mb-2 text-xl font-bold leading-tight", "class": "mb-2 text-xl font-bold leading-tight",
}, },
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)), elem.Text(m.Name),
), ),
elem.P( elem.P(
attrs.Props{ attrs.Props{
"class": "mb-4 text-sm [&:not(:hover)]:truncate text-base", "class": "mb-4 text-sm [&:not(:hover)]:truncate text-base",
}, },
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)), elem.Text(m.Description),
), ),
) )
} }

View File

@@ -13,10 +13,15 @@ import (
func WelcomeEndpoint(appConfig *config.ApplicationConfig, func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error { cl *config.BackendConfigLoader, ml *model.ModelLoader, modelStatus func() (map[string]string, map[string]string)) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
models, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
backendConfigs := cl.GetAllBackendConfigs() backendConfigs := cl.GetAllBackendConfigs()
galleryConfigs := map[string]*gallery.Config{} galleryConfigs := map[string]*gallery.Config{}
modelsWithBackendConfig := map[string]interface{}{}
for _, m := range backendConfigs { for _, m := range backendConfigs {
modelsWithBackendConfig[m.Name] = nil
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name) cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil { if err != nil {
continue continue
@@ -24,15 +29,13 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
galleryConfigs[m.Name] = cfg galleryConfigs[m.Name] = cfg
} }
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
// Get model statuses to display in the UI the operation in progress // Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := modelStatus() processingModels, taskTypes := modelStatus()
summary := fiber.Map{ summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(), "Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(), "Version": internal.PrintableVersion(),
"Models": modelsWithoutConfig, "Models": models,
"ModelsConfig": backendConfigs, "ModelsConfig": backendConfigs,
"GalleryConfig": galleryConfigs, "GalleryConfig": galleryConfigs,
"IsP2PEnabled": p2p.IsP2PEnabled(), "IsP2PEnabled": p2p.IsP2PEnabled(),

View File

@@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
@@ -84,7 +83,7 @@ func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
if !modelExists(cl, ml, request.Model) { if !modelExists(cl, ml, request.Model) {
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model) log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Model %q not found", request.Model))) return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
} }
if request.Tools == nil { if request.Tools == nil {
@@ -148,7 +147,7 @@ func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoade
// Convert string limit to integer // Convert string limit to integer
limit, err := strconv.Atoi(limitQuery) limit, err := strconv.Atoi(limitQuery)
if err != nil { if err != nil {
return c.Status(http.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Invalid limit query value: %s", limitQuery))) return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery))
} }
// Sort assistants // Sort assistants
@@ -289,7 +288,7 @@ func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
} }
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
} }
} }
@@ -338,11 +337,11 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model
} }
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find file_id: %s", request.FileID))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID))
} }
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find %q", assistantID))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find %q", assistantID))
} }
} }
@@ -443,7 +442,7 @@ func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
return c.Status(fiber.StatusOK).JSON(newAssistant) return c.Status(fiber.StatusOK).JSON(newAssistant)
} }
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
} }
} }
@@ -514,9 +513,9 @@ func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoa
if assistantFile.ID == fileId { if assistantFile.ID == fileId {
return c.Status(fiber.StatusOK).JSON(assistantFile) return c.Status(fiber.StatusOK).JSON(assistantFile)
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId))
} }
} }
return c.Status(fiber.StatusNotFound).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID))) return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID))
} }
} }

View File

@@ -8,7 +8,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
@@ -50,7 +49,7 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
err = c.SaveFile(file, savePath) err = c.SaveFile(file, savePath)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + bluemonday.StrictPolicy().Sanitize(err.Error())) return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
} }
f := schema.File{ f := schema.File{
@@ -122,7 +121,7 @@ func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Applicat
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c) file, err := getFileFromRequest(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
return c.JSON(file) return c.JSON(file)
@@ -144,14 +143,14 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c) file, err := getFileFromRequest(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename)) err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil { if err != nil {
// If the file doesn't exist then we should just continue to remove it // If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))) return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
} }
} }
@@ -181,12 +180,12 @@ func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c) file, err := getFileFromRequest(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename)) fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
return c.Send(fileContents) return c.Send(fileContents)

View File

@@ -149,10 +149,6 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
// Decode each request's message content // Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0 imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages { for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) { switch content := m.Content.(type) {
case string: case string:
input.Messages[i].StringContent = content input.Messages[i].StringContent = content
@@ -160,16 +156,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
dat, _ := json.Marshal(content) dat, _ := json.Marshal(content)
c := []schema.Content{} c := []schema.Content{}
json.Unmarshal(dat, &c) json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT: CONTENT:
for _, pp := range c { for _, pp := range c {
switch pp.Type { switch pp.Type {
case "text": case "text":
textContent += pp.Text input.Messages[i].StringContent = pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url": case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text // Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
@@ -178,8 +169,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT continue CONTENT
} }
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
t := "[vid-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Video != "" {
t = config.TemplateConfig.Video
}
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
vidIndex++ vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio": case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text // Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
@@ -188,8 +185,13 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT continue CONTENT
} }
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
t := "[audio-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Audio != "" {
t = config.TemplateConfig.Audio
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
audioIndex++ audioIndex++
nrOfAudiosInMessage++
case "image_url", "image": case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text // Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
@@ -198,21 +200,16 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT continue CONTENT
} }
t := "[img-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Image != "" {
t = config.TemplateConfig.Image
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
imgIndex++ imgIndex++
nrOfImgsInMessage++
} }
} }
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/dave-gray101/v2keyauth" "github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth" "github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/microcosm-cc/bluemonday"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
) )
@@ -39,7 +38,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403) return ctx.SendStatus(403)
} }
return ctx.Status(403).SendString(bluemonday.StrictPolicy().Sanitize(err.Error())) return ctx.Status(403).SendString(err.Error())
} }
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500) return ctx.SendStatus(500)

Some files were not shown because too many files have changed in this diff Show More