Compare commits

..

4 Commits

Author SHA1 Message Date
Ettore Di Giacinto
d6ea1a67cf Merge branch 'master' into ci/public-runner 2025-02-08 11:00:45 +01:00
Ettore Di Giacinto
4c145b037b Merge branch 'master' into ci/public-runner 2025-01-23 15:40:25 +01:00
Ettore Di Giacinto
96c080cc64 Merge branch 'master' into ci/public-runner 2025-01-18 18:36:31 +01:00
Ettore Di Giacinto
97ab9b4d92 chore(ci): try to run some jobs on public runners
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-01-18 09:18:45 +01:00
154 changed files with 3406 additions and 245448 deletions

View File

@@ -0,0 +1,23 @@
meta {
name: musicgen
type: http
seq: 1
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/v1/sound-generation
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model_id": "facebook/musicgen-small",
"text": "Exciting 80s Newscast Interstitial",
"duration_seconds": 8
}
}

View File

@@ -0,0 +1,17 @@
meta {
name: backend monitor
type: http
seq: 4
}
get {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/backend/monitor
body: json
auth: none
}
body:json {
{
"model": "{{DEFAULT_MODEL}}"
}
}

View File

@@ -0,0 +1,21 @@
meta {
name: backend-shutdown
type: http
seq: 3
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/backend/shutdown
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}"
}
}

View File

@@ -0,0 +1,5 @@
{
"version": "1",
"name": "LocalAI Test Requests",
"type": "collection"
}

View File

@@ -0,0 +1,6 @@
vars {
HOST: localhost
PORT: 8080
DEFAULT_MODEL: gpt-3.5-turbo
PROTOCOL: http://
}

View File

@@ -0,0 +1,11 @@
meta {
name: get models list
type: http
seq: 2
}
get {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models
body: none
auth: none
}

View File

@@ -0,0 +1,25 @@
meta {
name: Generate image
type: http
seq: 1
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/v1/images/generations
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"prompt": "<positive prompt>|<negative prompt>",
"model": "model-name",
"step": 51,
"size": "1024x1024",
"image": ""
}
}

View File

@@ -0,0 +1,24 @@
meta {
name: -completions
type: http
seq: 4
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/completions
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"prompt": "function downloadFile(string url, string outputPath) {",
"max_tokens": 256,
"temperature": 0.5
}
}

View File

@@ -0,0 +1,23 @@
meta {
name: -edits
type: http
seq: 5
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/edits
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"input": "What day of the wek is it?",
"instruction": "Fix the spelling mistakes"
}
}

View File

@@ -0,0 +1,22 @@
meta {
name: -embeddings
type: http
seq: 6
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/embeddings
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
}
}

View File

@@ -0,0 +1,30 @@
meta {
name: chat completion -simple- 1 message-
type: http
seq: 4
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/chat/completions
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"messages": [
{
"role": "user",
"content": "How could one use friction to cook an egg?"
}
],
"max_tokens": 256,
"temperature": 0.2,
"grammar": ""
}
}

View File

@@ -0,0 +1,29 @@
meta {
name: chat-completions -long-
type: http
seq: 5
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/chat/completions
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"messages": [{"role": "system", "content": "You are a helpful, intelligent pirate assistant. Always answer as helpfully as possible, while being a pirate and using pirate language. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
{"role": "user", "content": "How could one use electricity to cook an egg?"},
{"role": "assistant",
"content": "Shiver me timbers! Using electricity to cook an egg? Well, matey, I reckon that be a right curious idea! *adjusts spectacles* Now, I ain't sure if this be possible, but I'll do me best to help ye out.\n\nFirst things first, ye gotta understand that electricity be a powerful force, and it can't just cook an egg on its own. Ye see, electricity be like a mighty wind that can make things happen, but it needs somethin' to work with. So, if ye want to use electricity to cook an egg, ye gotta find a way to harness that power and make it do the cookin'.\n\nNow, I know what ye might be thinkin': \"How do I harness the power of electricity to cook an egg?\" Well, matey, there be a few ways to do it. One way be to use a special device called an \"electric frying pan.\" This be a pan that has a built-in heating element that gets hot when ye plug it into a wall socket. When the element gets hot, ye can crack an egg into the pan and watch as it cook"
},
{"role": "user", "content": "I don't have one of those, just a raw wire and plenty of power! How do we get it done?"}],
"max_tokens": 1024,
"temperature": 0.5
}
}

View File

@@ -0,0 +1,25 @@
meta {
name: chat-completions -stream-
type: http
seq: 6
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/chat/completions
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"messages": [{"role": "user", "content": "Explain how I can set sail on the ocean using only power generated by seagulls?"}],
"max_tokens": 256,
"temperature": 0.9,
"stream": true
}
}

View File

@@ -0,0 +1,22 @@
meta {
name: add model gallery
type: http
seq: 10
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/galleries
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"url": "file:///home/dave/projects/model-gallery/huggingface/TheBloke__CodeLlama-7B-Instruct-GGML.yaml",
"name": "test"
}
}

View File

@@ -0,0 +1,21 @@
meta {
name: delete model gallery
type: http
seq: 11
}
delete {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/galleries
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"name": "test"
}
}

View File

@@ -0,0 +1,11 @@
meta {
name: list MODELS in galleries
type: http
seq: 7
}
get {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/available
body: none
auth: none
}

View File

@@ -0,0 +1,11 @@
meta {
name: list model GALLERIES
type: http
seq: 8
}
get {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/galleries
body: none
auth: none
}

View File

@@ -0,0 +1,11 @@
meta {
name: model delete
type: http
seq: 7
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/galleries
body: none
auth: none
}

View File

@@ -0,0 +1,21 @@
meta {
name: model gallery apply -gist-
type: http
seq: 12
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/apply
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"id": "TheBloke__CodeLlama-7B-Instruct-GGML__codellama-7b-instruct.ggmlv3.Q2_K.bin"
}
}

View File

@@ -0,0 +1,22 @@
meta {
name: model gallery apply
type: http
seq: 9
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/models/apply
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"id": "dave@TheBloke__CodeLlama-7B-Instruct-GGML__codellama-7b-instruct.ggmlv3.Q3_K_S.bin",
"name": "codellama7b"
}
}

View File

Binary file not shown.

View File

@@ -0,0 +1,16 @@
meta {
name: transcribe
type: http
seq: 1
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/v1/audio/transcriptions
body: multipartForm
auth: none
}
body:multipart-form {
file: @file(transcription/gb1.ogg)
model: whisper-1
}

View File

@@ -0,0 +1,22 @@
meta {
name: -tts
type: http
seq: 2
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/tts
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
}
}

View File

@@ -0,0 +1,23 @@
meta {
name: musicgen
type: http
seq: 2
}
post {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/tts
body: json
auth: none
}
headers {
Content-Type: application/json
}
body:json {
{
"backend": "transformers",
"model": "facebook/musicgen-small",
"input": "80s Synths playing Jazz"
}
}

2
.github/labeler.yml vendored
View File

@@ -1,4 +1,4 @@
enhancement:
enhancements:
- head-branch: ['^feature', 'feature']
dependencies:

View File

@@ -9,7 +9,7 @@ jobs:
fail-fast: false
matrix:
include:
- repository: "ggml-org/llama.cpp"
- repository: "ggerganov/llama.cpp"
variable: "CPPLLAMA_VERSION"
branch: "master"
- repository: "ggerganov/whisper.cpp"

View File

@@ -33,7 +33,7 @@ jobs:
run: |
CGO_ENABLED=0 make build-api
- name: rm
uses: appleboy/ssh-action@v1.2.2
uses: appleboy/ssh-action@v1.2.0
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
@@ -53,7 +53,7 @@ jobs:
rm: true
target: ./local-ai
- name: restarting
uses: appleboy/ssh-action@v1.2.2
uses: appleboy/ssh-action@v1.2.0
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}

View File

@@ -2,10 +2,9 @@ name: 'generate and publish GRPC docker caches'
on:
workflow_dispatch:
schedule:
# daily at midnight
- cron: '0 0 * * *'
push:
branches:
- master
concurrency:
group: grpc-cache-${{ github.head_ref || github.ref }}-${{ github.repository }}
@@ -17,7 +16,7 @@ jobs:
matrix:
include:
- grpc-base-image: ubuntu:22.04
runs-on: 'arc-runner-set'
runs-on: 'ubuntu-latest'
platforms: 'linux/amd64,linux/arm64'
runs-on: ${{matrix.runs-on}}
steps:

View File

@@ -53,7 +53,7 @@ jobs:
tag-suffix: '-cublas-cuda12-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
runs-on: 'arc-runner-set'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
makeflags: "--jobs=3 --output-sync=target"
# - build-type: 'hipblas'

View File

@@ -310,11 +310,6 @@ jobs:
tags: ${{ steps.meta_aio_dockerhub.outputs.tags }}
labels: ${{ steps.meta_aio_dockerhub.outputs.labels }}
- name: Cleanup
run: |
docker builder prune -f
docker system prune --force --volumes --all
- name: Latest tag
# run this on branches, when it is a tag and there is a latest-image defined
if: github.event_name != 'pull_request' && inputs.latest-image != '' && github.ref_type == 'tag'

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2023-2025 Ettore Di Giacinto (mudler@localai.io)
Copyright (c) 2023-2024 Ettore Di Giacinto (mudler@localai.io)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -6,7 +6,7 @@ BINARY_NAME=local-ai
DETECT_LIBS?=true
# llama.cpp versions
CPPLLAMA_VERSION?=10f2e81809bbb69ecfe64fc8b4686285f84b0c07
CPPLLAMA_VERSION?=d2fe216fb2fb7ca8627618c9ea3a2e7886325780
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp
@@ -22,7 +22,7 @@ BARKCPP_VERSION?=v1.0.0
# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=19d876ee300a055629926ff836489901f734f2b7
STABLEDIFFUSION_GGML_VERSION?=d46ed5e184b97c2018dc2e8105925bdb8775e02c
ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64

View File

@@ -212,7 +212,7 @@ A huge thank you to our generous sponsors who support this project covering CI e
<p align="center">
<a href="https://www.spectrocloud.com/" target="blank">
<img height="200" src="https://github.com/user-attachments/assets/72eab1dd-8b93-4fc0-9ade-84db49f24962">
<img height="200" src="https://github.com/go-skynet/LocalAI/assets/2420543/68a6f3cb-8a65-4a4d-99b5-6417a8905512">
</a>
<a href="https://www.premai.io/" target="blank">
<img height="200" src="https://github.com/mudler/LocalAI/assets/2420543/42e4ca83-661e-4f79-8e46-ae43689683d6"> <br>

View File

@@ -1,8 +0,0 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View File

@@ -129,7 +129,7 @@ detect_gpu
detect_gpu_size
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
check_vars

View File

@@ -1,8 +0,0 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View File

@@ -1,8 +0,0 @@
backend: silero-vad
name: silero-vad
parameters:
model: silero-vad.onnx
download_files:
- filename: silero-vad.onnx
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

View File

@@ -165,6 +165,7 @@ message Reply {
message GrammarTrigger {
string word = 1;
bool at_start = 2;
}
message ModelOptions {
@@ -228,11 +229,6 @@ message ModelOptions {
int32 MaxModelLen = 54;
int32 TensorParallelSize = 55;
string LoadFormat = 58;
bool DisableLogStatus = 66;
string DType = 67;
int32 LimitImagePerPrompt = 68;
int32 LimitVideoPerPrompt = 69;
int32 LimitAudioPerPrompt = 70;
string MMProj = 41;

View File

@@ -469,7 +469,7 @@ struct llama_server_context
bool has_eos_token = true;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<common_grammar_trigger> grammar_trigger_words;
int32_t n_ctx; // total context for all clients / slots
@@ -709,7 +709,7 @@ struct llama_server_context
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot->sparams.grammar_triggers = grammar_triggers;
slot->sparams.grammar_trigger_words = grammar_trigger_words;
slot->sparams.grammar_lazy = grammar_lazy;
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
@@ -1155,14 +1155,6 @@ struct llama_server_context
slot.has_next_token = false;
}
if (slot.n_past >= slot.n_ctx) {
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false;
LOG_VERBOSE("stopped due to running out of context capacity", {});
}
if (result.tok == llama_vocab_eos(vocab) || llama_vocab_is_eog(vocab, result.tok))
{
slot.stopped_eos = true;
@@ -1350,7 +1342,7 @@ struct llama_server_context
queue_results.send(res);
}
void send_embedding(llama_client_slot &slot, const llama_batch & batch)
void send_embedding(llama_client_slot &slot)
{
task_result res;
res.id = slot.task_id;
@@ -1372,38 +1364,10 @@ struct llama_server_context
else
{
const float *data = llama_get_embeddings(ctx);
std::vector<float> embd_res(n_embd, 0.0f);
std::vector<std::vector<float>> embedding;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
LOG("failed to get embeddings");
continue;
}
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
embedding.push_back(embd_res);
} else {
embedding.push_back({ embd, embd + n_embd });
}
}
// OAI compat
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding[0] },
{"embedding", embedding },
};
}
queue_results.send(res);
@@ -1663,17 +1627,17 @@ struct llama_server_context
{
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
// START LOCALAI changes
// Temporary disable context-shifting as it can lead to infinite loops (issue: https://github.com/ggerganov/llama.cpp/issues/3969)
// See: https://github.com/mudler/LocalAI/issues/1333
// Context is exhausted, release the slot
slot.release();
send_final_response(slot);
slot.has_next_token = false;
LOG_ERROR("context is exhausted, release the slot", {});
slot.cache_tokens.clear();
slot.n_past = 0;
slot.truncated = false;
slot.has_next_token = true;
LOG("Context exhausted. Slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
continue;
// END LOCALAI changes
@@ -2024,7 +1988,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
send_embedding(slot, batch_view);
send_embedding(slot);
slot.release();
slot.i_batch = -1;
continue;
@@ -2421,12 +2385,12 @@ static void params_parse(const backend::ModelOptions* request,
llama.grammar_lazy = true;
for (int i = 0; i < request->grammartriggers_size(); i++) {
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
trigger.value = request->grammartriggers(i).word();
// trigger.at_start = request->grammartriggers(i).at_start();
llama.grammar_triggers.push_back(trigger);
trigger.word = request->grammartriggers(i).word();
trigger.at_start = request->grammartriggers(i).at_start();
llama.grammar_trigger_words.push_back(trigger);
LOG_INFO("grammar trigger", {
{ "word", trigger.value },
{ "word", trigger.word },
{ "at_start", trigger.at_start }
});
}
}

View File

@@ -1,13 +1,13 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 7f892beb..0517e529 100644
index 3cd0d2fa..6c5e811a 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -2766,7 +2766,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
int patch_offset = ctx->has_class_embedding ? 1 : 0;
@@ -2608,7 +2608,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
- patches_data[i] = i + patch_offset;
+ patches_data[i] = i + 1;
- patches_data[i] = i + 1;
+ patches_data[i] = i;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);

View File

@@ -1,7 +1,5 @@
#!/bin/bash
set -e
## Patches
## Apply patches from the `patches` directory
for patch in $(ls patches); do

View File

@@ -35,8 +35,6 @@ const char* sample_method_str[] = {
"ipndm",
"ipndm_v",
"lcm",
"ddim_trailing",
"tcd",
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
@@ -175,7 +173,6 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
-1, //clip_skip
cfg_scale, // sfg_scale
3.5f,
0, // eta
width,
height,
sample_method,

View File

@@ -1,6 +1,6 @@
accelerate
auto-gptq==0.7.1
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi
transformers

View File

@@ -1,4 +1,4 @@
bark==0.1.5
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi

View File

@@ -1,3 +1,3 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
grpcio-tools

View File

@@ -1,4 +1,4 @@
transformers==4.48.3
transformers
accelerate
torch==2.4.1
coqui-tts

View File

@@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118
torchaudio==2.4.1+cu118
transformers==4.48.3
transformers
accelerate
coqui-tts

View File

@@ -1,5 +1,5 @@
torch==2.4.1
torchaudio==2.4.1
transformers==4.48.3
transformers
accelerate
coqui-tts

View File

@@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0
torchaudio==2.4.1+rocm6.0
transformers==4.48.3
transformers
accelerate
coqui-tts

View File

@@ -5,6 +5,6 @@ torchaudio==2.3.1+cxx11.abi
oneccl_bind_pt==2.3.100+xpu
optimum[openvino]
setuptools
transformers==4.48.3
transformers
accelerate
coqui-tts

View File

@@ -1,4 +1,4 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi
packaging==24.1

View File

@@ -159,18 +159,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torchType = torch.float16
variant = "fp16"
options = request.Options
# empty dict
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when
# generating the images
for opt in options:
key, value = opt.split(":")
self.options[key] = value
local = False
modelFile = request.Model
@@ -453,9 +441,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# create a dictionary of parameters by using the keys from EnableParameters and the values from defaults
kwargs = {key: options.get(key) for key in keys if key in options}
# populate kwargs from self.options.
kwargs.update(self.options)
# Set seed
if request.seed > 0:
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(

View File

@@ -1,5 +1,5 @@
setuptools
grpcio==1.71.0
grpcio==1.70.0
pillow
protobuf
certifi

View File

@@ -1,4 +1,4 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi
wheel

View File

@@ -1,3 +1,3 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
grpcio-tools

View File

@@ -1,4 +1,4 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
phonemizer
scipy

View File

@@ -1,3 +1,3 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi

View File

@@ -1,4 +1,4 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi
setuptools

View File

@@ -109,17 +109,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
engine_args.swap_space = request.SwapSpace
if request.MaxModelLen != 0:
engine_args.max_model_len = request.MaxModelLen
if request.DisableLogStatus:
engine_args.disable_log_status = request.DisableLogStatus
if request.DType != "":
engine_args.dtype = request.DType
if request.LimitImagePerPrompt != 0 or request.LimitVideoPerPrompt != 0 or request.LimitAudioPerPrompt != 0:
# limit-mm-per-prompt defaults to 1 per modality, based on vLLM docs
engine_args.limit_mm_per_prompt = {
"image": max(request.LimitImagePerPrompt, 1),
"video": max(request.LimitVideoPerPrompt, 1),
"audio": max(request.LimitAudioPerPrompt, 1)
}
try:
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
@@ -280,7 +269,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def load_image(self, image_path: str):
"""
Load an image from the given file path or base64 encoded data.
Args:
image_path (str): The path to the image file or base64 encoded data.
@@ -299,7 +288,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
def load_video(self, video_path: str):
"""
Load a video from the given file path.
Args:
video_path (str): The path to the image file.
@@ -346,4 +335,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
asyncio.run(serve(args.addr))
asyncio.run(serve(args.addr))

View File

@@ -1,4 +1,4 @@
grpcio==1.71.0
grpcio==1.70.0
protobuf
certifi
setuptools

View File

@@ -145,7 +145,13 @@ func New(opts ...config.AppOption) (*Application, error) {
if options.LoadToMemory != nil {
for _, m := range options.LoadToMemory {
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
config.LoadOptionDebug(options.Debug),
config.LoadOptionThreads(options.Threads),
config.LoadOptionContextSize(options.ContextSize),
config.LoadOptionF16(options.F16),
config.ModelPath(options.ModelPath),
)
if err != nil {
return nil, err
}

View File

@@ -33,7 +33,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -48,7 +48,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}
}
opts := ModelOptions(*c, o)
opts := ModelOptions(c, o)
inferenceModel, err := loader.Load(opts...)
if err != nil {
return nil, err
@@ -84,7 +84,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(*c, loader.ModelPath)
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
@@ -116,11 +116,6 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
}
if tokenCallback != nil {
if c.TemplateConfig.ReplyPrefix != "" {
tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage)
}
ss := ""
var partialRune []byte
@@ -170,13 +165,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
response := string(reply.Message)
if c.TemplateConfig.ReplyPrefix != "" {
response = c.TemplateConfig.ReplyPrefix + response
}
return LLMResponse{
Response: response,
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}

View File

@@ -122,6 +122,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
triggers = append(triggers, &pb.GrammarTrigger{
Word: t.Word,
AtStart: t.AtStart,
})
}
@@ -158,12 +159,6 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
SwapSpace: int32(c.SwapSpace),
MaxModelLen: int32(c.MaxModelLen),
TensorParallelSize: int32(c.TensorParallelSize),
DisableLogStatus: c.DisableLogStatus,
DType: c.DType,
// LimitMMPerPrompt vLLM
LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
MMProj: c.MMProj,
FlashAttention: c.FlashAttention,
CacheTypeKey: c.CacheTypeK,

View File

@@ -9,10 +9,10 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig)
rerankModel, err := loader.Load(opts...)
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
rerankModel, err := loader.Load(opts...)
if err != nil {
return nil, err
}

View File

@@ -13,6 +13,7 @@ import (
)
func SoundGeneration(
modelFile string,
text string,
duration *float32,
temperature *float32,
@@ -24,9 +25,8 @@ func SoundGeneration(
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
soundGenModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
@@ -44,7 +44,7 @@ func SoundGeneration(
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: backendConfig.Model,
Model: modelFile,
Dst: filePath,
Sample: doSample,
Duration: duration,

View File

@@ -4,17 +4,19 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/model"
model "github.com/mudler/LocalAI/pkg/model"
)
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
modelFile := backendConfig.Model
var inferenceModel grpc.Backend
var err error
opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err = loader.Load(opts...)
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
inferenceModel, err = loader.Load(opts...)
if err != nil {
return schema.TokenizeResponse{}, err
}

View File

@@ -47,7 +47,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.TranscriptionSegment{
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),

View File

@@ -14,22 +14,28 @@ import (
)
func ModelTTS(
backend,
text,
modelFile,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
ttsModel, err := loader.Load(opts...)
bb := backend
if bb == "" {
bb = model.PiperBackend
}
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
ttsModel, err := loader.Load(opts...)
if err != nil {
return "", nil, err
}
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@@ -39,21 +45,22 @@ func ModelTTS(
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
}
modelPath = mp
} else {
modelPath = backendConfig.Model // skip this step if it fails?????
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

View File

@@ -1,38 +0,0 @@
package backend
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
)
func VAD(request *schema.VADRequest,
ctx context.Context,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
opts := ModelOptions(backendConfig, appConfig)
vadModel, err := ml.Load(opts...)
if err != nil {
return nil, err
}
req := proto.VADRequest{
Audio: request.Audio,
}
resp, err := vadModel.VAD(ctx, &req)
if err != nil {
return nil, err
}
segments := []schema.VADSegment{}
for _, s := range resp.Segments {
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
}
return &schema.VADResponse{
Segments: segments,
}, nil
}

View File

@@ -86,14 +86,13 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
var inputFile *string
if t.InputFile != "" {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(text,
filePath, _, err := backend.SoundGeneration(t.Model, text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

View File

@@ -52,10 +52,8 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}

View File

@@ -130,28 +130,25 @@ type LLMConfig struct {
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats"` // vLLM
DType string `yaml:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
LoadFormat string `yaml:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`
FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
@@ -169,13 +166,6 @@ type LLMConfig struct {
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
}
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image"`
LimitVideoPerPrompt int `yaml:"video"`
LimitAudioPerPrompt int `yaml:"audio"`
}
// AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend
type AutoGPTQ struct {
ModelBaseName string `yaml:"model_base_name"`
@@ -213,8 +203,6 @@ type TemplateConfig struct {
Multimodal string `yaml:"multimodal"`
JinjaTemplate bool `yaml:"jinja_template"`
ReplyPrefix string `yaml:"reply_prefix"`
}
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
@@ -224,15 +212,7 @@ func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
return err
}
*c = BackendConfig(aux)
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllBackendConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
return nil
}
@@ -457,21 +437,19 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int
const (
FLAG_ANY BackendConfigUsecases = 0b00000000000
FLAG_CHAT BackendConfigUsecases = 0b00000000001
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
FLAG_EDIT BackendConfigUsecases = 0b00000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
FLAG_RERANK BackendConfigUsecases = 0b00000010000
FLAG_IMAGE BackendConfigUsecases = 0b00000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
FLAG_TTS BackendConfigUsecases = 0b00010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
)
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@@ -486,16 +464,10 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
}
}
func stringToFlag(s string) string {
return "FLAG_" + strings.ToUpper(s)
}
func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
if len(input) == 0 {
return nil
@@ -503,7 +475,7 @@ func GetUsecasesFromYAML(input []string) *BackendConfigUsecases {
result := FLAG_ANY
flags := GetAllBackendConfigUsecases()
for _, str := range input {
flag, exists := flags[stringToFlag(str)]
flag, exists := flags["FLAG_"+strings.ToUpper(str)]
if exists {
result |= flag
}
@@ -577,18 +549,5 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
}
}
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}
if (u & FLAG_VAD) == FLAG_VAD {
if c.Backend != "silero-vad" {
return false
}
}
return true
}

View File

@@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
@@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
@@ -117,9 +117,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
Model: modelName,
},
}
@@ -147,15 +145,6 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil
}
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
@@ -178,7 +167,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
return fmt.Errorf("cannot read config file: %w", err)
}
if c.Validate() {
@@ -335,10 +324,9 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
return fmt.Errorf("cannot read directory '%s': %w", path, err)
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
@@ -356,13 +344,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
}
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil {
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
continue
}
if c.Validate() {
bcl.configs[c.Name] = *c
} else {
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
log.Error().Err(err).Msgf("config is not valid")
}
}

View File

@@ -161,11 +161,10 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
}
// We try to guess only if we don't have a template defined already
guessPath := filepath.Join(modelPath, cfg.ModelFileName())
f, err := gguf.ParseGGUFFile(guessPath)
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName()))
if err != nil {
// Only valid for gguf files
log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file")
return
}

View File

@@ -29,8 +29,6 @@ func InstallModelFromGallery(galleries []config.Gallery, name string, basePath s
if err != nil {
return err
}
config.Description = model.Description
config.License = model.License
} else if len(model.ConfigFile) > 0 {
// TODO: is this worse than using the override method with a blank cfg yaml?
reYamlConfig, err := yaml.Marshal(model.ConfigFile)
@@ -116,7 +114,7 @@ func FindModel(models []*GalleryModel, name string, basePath string) *GalleryMod
// List available models
// Models galleries are a list of yaml files that are hosted on a remote server (for example github).
// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []config.Gallery, basePath string) (GalleryModels, error) {
func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel
// Get models from galleries

View File

@@ -62,15 +62,3 @@ func (gm GalleryModels) FindByName(name string) *GalleryModel {
}
return nil
}
func (gm GalleryModels) Paginate(pageNum int, itemsNum int) GalleryModels {
start := (pageNum - 1) * itemsNum
end := start + itemsNum
if start > len(gm) {
start = len(gm)
}
if end > len(gm) {
end = len(gm)
}
return gm[start:end]
}

View File

@@ -130,6 +130,7 @@ func API(application *application.Application) (*fiber.App, error) {
return metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(router)
@@ -139,28 +140,6 @@ func API(application *application.Application) (*fiber.App, error) {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
httpFS := http.FS(embedDirStatic)
router.Use(favicon.New(favicon.Config{
URL: "/favicon.ico",
FileSystem: httpFS,
File: "static/favicon.ico",
}))
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
if application.ApplicationConfig().ImageDir != "" {
router.Static("/generated-images", application.ApplicationConfig().ImageDir)
}
if application.ApplicationConfig().AudioDir != "" {
router.Static("/generated-audio", application.ApplicationConfig().AudioDir)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
router.Use(v2keyauth.New(*kaConfig))
@@ -188,15 +167,27 @@ func API(application *application.Application) (*fiber.App, error) {
galleryService := services.NewGalleryService(application.ApplicationConfig())
galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
requestExtractor := middleware.NewRequestExtractor(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
routes.RegisterElevenLabsRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
routes.RegisterOpenAIRoutes(router, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIRoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig(), galleryService)
}
routes.RegisterJINARoutes(router, requestExtractor, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterJINARoutes(router, application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig())
httpFS := http.FS(embedDirStatic)
router.Use(favicon.New(favicon.Config{
URL: "/favicon.ico",
FileSystem: httpFS,
File: "static/favicon.ico",
}))
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler
// Note: keep this at the bottom!

47
core/http/ctx/fiber.go Normal file
View File

@@ -0,0 +1,47 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
if ctx.Query("model") != "" {
modelInput = ctx.Query("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View File

@@ -13,7 +13,7 @@ func installButton(galleryName string) elem.Node {
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right inline-flex items-center rounded-lg bg-blue-600 hover:bg-blue-700 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out shadow hover:shadow-lg",
"class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "browse/install/model/" + galleryName,
@@ -52,7 +52,7 @@ func infoButton(m *gallery.GalleryModel) elem.Node {
attrs.Props{
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
"class": "float-left inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"data-modal-target": modalName(m),
"data-modal-toggle": modalName(m),
},

View File

@@ -17,7 +17,7 @@ const (
func cardSpan(text, icon string) elem.Node {
return elem.Span(
attrs.Props{
"class": "inline-flex items-center px-3 py-1 rounded-lg text-xs font-medium bg-gray-700/70 text-gray-300 border border-gray-600/50 mr-2 mb-2",
"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
},
elem.I(attrs.Props{
"class": icon + " pr-2",
@@ -39,20 +39,19 @@ func searchableElement(text, icon string) elem.Node {
),
elem.Span(
attrs.Props{
"class": "inline-flex items-center text-xs px-3 py-1 rounded-full bg-gray-700/60 text-gray-300 border border-gray-600/50 hover:bg-gray-600 hover:text-gray-100 transition duration-200 ease-in-out",
"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2 hover:bg-gray-300 hover:shadow-gray-2",
},
elem.A(
attrs.Props{
// "name": "search",
// "value": text,
//"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
//"href": "#!",
"href": "browse?term=" + text,
//"hx-post": "browse/search/models",
//"hx-target": "#search-results",
"href": "#!",
"hx-post": "browse/search/models",
"hx-target": "#search-results",
// TODO: this doesn't work
// "hx-vals": `{ \"search\": \"` + text + `\" }`,
//"hx-indicator": ".htmx-indicator",
"hx-indicator": ".htmx-indicator",
},
elem.I(attrs.Props{
"class": icon + " pr-2",
@@ -102,7 +101,7 @@ func modalName(m *gallery.GalleryModel) string {
return m.Name + "-modal"
}
func modelModal(m *gallery.GalleryModel) elem.Node {
func modelDescription(m *gallery.GalleryModel) elem.Node {
urls := []elem.Node{}
for _, url := range m.URLs {
urls = append(urls,
@@ -117,125 +116,6 @@ func modelModal(m *gallery.GalleryModel) elem.Node {
)
}
return elem.Div(
attrs.Props{
"id": modalName(m),
"tabindex": "-1",
"aria-hidden": "true",
"class": "hidden overflow-y-auto overflow-x-hidden fixed top-0 right-0 left-0 z-50 justify-center items-center w-full md:inset-0 h-[calc(100%-1rem)] max-h-full",
},
elem.Div(
attrs.Props{
"class": "relative p-4 w-full max-w-2xl max-h-full",
},
elem.Div(
attrs.Props{
"class": "relative p-4 w-full max-w-2xl max-h-full bg-white rounded-lg shadow dark:bg-gray-700",
},
// header
elem.Div(
attrs.Props{
"class": "flex items-center justify-between p-4 md:p-5 border-b rounded-t dark:border-gray-600",
},
elem.H3(
attrs.Props{
"class": "text-xl font-semibold text-gray-900 dark:text-white",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)),
),
elem.Button( // close button
attrs.Props{
"class": "text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm w-8 h-8 ms-auto inline-flex justify-center items-center dark:hover:bg-gray-600 dark:hover:text-white",
"data-modal-hide": modalName(m),
},
elem.Raw(
`<svg class="w-3 h-3" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 14 14">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m1 1 6 6m0 0 6 6M7 7l6-6M7 7l-6 6"/>
</svg>`,
),
elem.Span(
attrs.Props{
"class": "sr-only",
},
elem.Text("Close modal"),
),
),
),
// body
elem.Div(
attrs.Props{
"class": "p-4 md:p-5 space-y-4",
},
elem.Div(
attrs.Props{
"class": "flex justify-center items-center",
},
elem.Img(attrs.Props{
// "class": "rounded-t-lg object-fit object-center h-96",
"class": "lazy rounded-t-lg max-h-48 max-w-96 object-cover mt-3 entered loaded",
"src": m.Icon,
"loading": "lazy",
}),
),
elem.P(
attrs.Props{
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
),
elem.Hr(
attrs.Props{},
),
elem.P(
attrs.Props{
"class": "text-sm font-semibold text-gray-900 dark:text-white",
},
elem.Text("Links"),
),
elem.Ul(
attrs.Props{},
urls...,
),
elem.If(
len(m.Tags) > 0,
elem.Div(
attrs.Props{},
elem.P(
attrs.Props{
"class": "text-sm mb-5 font-semibold text-gray-900 dark:text-white",
},
elem.Text("Tags"),
),
elem.Div(
attrs.Props{
"class": "flex flex-row flex-wrap content-center",
},
tagsNodes...,
),
),
elem.Div(attrs.Props{}),
),
),
// Footer
elem.Div(
attrs.Props{
"class": "flex items-center p-4 md:p-5 border-t border-gray-200 rounded-b dark:border-gray-600",
},
elem.Button(
attrs.Props{
"data-modal-hide": modalName(m),
"class": "py-2.5 px-5 ms-3 text-sm font-medium text-gray-900 focus:outline-none bg-white rounded-lg border border-gray-200 hover:bg-gray-100 hover:text-blue-700 focus:z-10 focus:ring-4 focus:ring-gray-100 dark:focus:ring-gray-700 dark:bg-gray-800 dark:text-gray-400 dark:border-gray-600 dark:hover:text-white dark:hover:bg-gray-700",
},
elem.Text("Close"),
),
),
),
),
)
}
func modelDescription(m *gallery.GalleryModel) elem.Node {
return elem.Div(
attrs.Props{
"class": "p-6 text-surface dark:text-white",
@@ -252,6 +132,122 @@ func modelDescription(m *gallery.GalleryModel) elem.Node {
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
),
elem.Div(
attrs.Props{
"id": modalName(m),
"tabindex": "-1",
"aria-hidden": "true",
"class": "hidden overflow-y-auto overflow-x-hidden fixed top-0 right-0 left-0 z-50 justify-center items-center w-full md:inset-0 h-[calc(100%-1rem)] max-h-full",
},
elem.Div(
attrs.Props{
"class": "relative p-4 w-full max-w-2xl max-h-full",
},
elem.Div(
attrs.Props{
"class": "relative p-4 w-full max-w-2xl max-h-full bg-white rounded-lg shadow dark:bg-gray-700",
},
// header
elem.Div(
attrs.Props{
"class": "flex items-center justify-between p-4 md:p-5 border-b rounded-t dark:border-gray-600",
},
elem.H3(
attrs.Props{
"class": "text-xl font-semibold text-gray-900 dark:text-white",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)),
),
elem.Button( // close button
attrs.Props{
"class": "text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm w-8 h-8 ms-auto inline-flex justify-center items-center dark:hover:bg-gray-600 dark:hover:text-white",
"data-modal-hide": modalName(m),
},
elem.Raw(
`<svg class="w-3 h-3" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 14 14">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m1 1 6 6m0 0 6 6M7 7l6-6M7 7l-6 6"/>
</svg>`,
),
elem.Span(
attrs.Props{
"class": "sr-only",
},
elem.Text("Close modal"),
),
),
),
// body
elem.Div(
attrs.Props{
"class": "p-4 md:p-5 space-y-4",
},
elem.Div(
attrs.Props{
"class": "flex justify-center items-center",
},
elem.Img(attrs.Props{
// "class": "rounded-t-lg object-fit object-center h-96",
"class": "lazy rounded-t-lg max-h-48 max-w-96 object-cover mt-3 entered loaded",
"src": m.Icon,
"loading": "lazy",
}),
),
elem.P(
attrs.Props{
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
),
elem.Hr(
attrs.Props{},
),
elem.P(
attrs.Props{
"class": "text-sm font-semibold text-gray-900 dark:text-white",
},
elem.Text("Links"),
),
elem.Ul(
attrs.Props{},
urls...,
),
elem.If(
len(m.Tags) > 0,
elem.Div(
attrs.Props{},
elem.P(
attrs.Props{
"class": "text-sm mb-5 font-semibold text-gray-900 dark:text-white",
},
elem.Text("Tags"),
),
elem.Div(
attrs.Props{
"class": "flex flex-row flex-wrap content-center",
},
tagsNodes...,
),
),
elem.Div(attrs.Props{}),
),
),
// Footer
elem.Div(
attrs.Props{
"class": "flex items-center p-4 md:p-5 border-t border-gray-200 rounded-b dark:border-gray-600",
},
elem.Button(
attrs.Props{
"data-modal-hide": modalName(m),
"class": "py-2.5 px-5 ms-3 text-sm font-medium text-gray-900 focus:outline-none bg-white rounded-lg border border-gray-200 hover:bg-gray-100 hover:text-blue-700 focus:z-10 focus:ring-4 focus:ring-gray-100 dark:focus:ring-gray-700 dark:bg-gray-800 dark:text-gray-400 dark:border-gray-600 dark:hover:text-white dark:hover:bg-gray-700",
},
elem.Text("Close"),
),
),
),
),
),
)
}
@@ -401,7 +397,7 @@ func ListModels(models []*gallery.GalleryModel, processTracker ProcessTracker, g
modelsElements = append(modelsElements,
elem.Div(
attrs.Props{
"class": " me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface pb-2 bg-gray-800/90 border border-gray-700/50 rounded-xl overflow-hidden transition-all duration-300 hover:shadow-lg hover:shadow-blue-900/20 hover:-translate-y-1 hover:border-blue-700/50",
"class": " me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface pb-2",
},
elem.Div(
attrs.Props{
@@ -410,7 +406,6 @@ func ListModels(models []*gallery.GalleryModel, processTracker ProcessTracker, g
elems...,
),
),
modelModal(m),
)
}

View File

@@ -2,7 +2,6 @@ package elements
import (
"fmt"
"time"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
@@ -19,6 +18,19 @@ func renderElements(n []elem.Node) string {
}
func P2PNodeStats(nodes []p2p.NodeData) string {
/*
<div class="bg-gray-800 p-6 rounded-lg shadow-lg text-left">
<p class="text-xl font-semibold text-gray-200">Total Workers Detected: {{ len .Nodes }}</p>
{{ $online := 0 }}
{{ range .Nodes }}
{{ if .IsOnline }}
{{ $online = add $online 1 }}
{{ end }}
{{ end }}
<p class="text-xl font-semibold text-gray-200">Total Online Workers: {{$online}}</p>
</div>
*/
online := 0
for _, n := range nodes {
if n.IsOnline() {
@@ -26,21 +38,27 @@ func P2PNodeStats(nodes []p2p.NodeData) string {
}
}
class := "text-blue-400"
class := "text-green-500"
if online == 0 {
class = "text-red-400"
class = "text-red-500"
}
/*
<i class="fas fa-circle animate-pulse text-green-500 ml-2 mr-1"></i>
*/
circle := elem.I(attrs.Props{
"class": "fas fa-circle animate-pulse " + class + " ml-2 mr-1",
})
nodesElements := []elem.Node{
elem.Span(
attrs.Props{
"class": class + " font-bold text-xl",
"class": class,
},
circle,
elem.Text(fmt.Sprintf("%d", online)),
),
elem.Span(
attrs.Props{
"class": "text-gray-300 text-xl",
"class": "text-gray-200",
},
elem.Text(fmt.Sprintf("/%d", len(nodes))),
),
@@ -50,73 +68,77 @@ func P2PNodeStats(nodes []p2p.NodeData) string {
}
func P2PNodeBoxes(nodes []p2p.NodeData) string {
/*
<div class="bg-gray-800 p-4 rounded-lg shadow-lg text-left">
<div class="flex items-center mb-2">
<i class="fas fa-desktop text-gray-400 mr-2"></i>
<span class="text-gray-200 font-semibold">{{.ID}}</span>
</div>
<p class="text-sm text-gray-400 mt-2 flex items-center">
Status:
<i class="fas fa-circle {{ if .IsOnline }}text-green-500{{ else }}text-red-500{{ end }} ml-2 mr-1"></i>
<span class="{{ if .IsOnline }}text-green-400{{ else }}text-red-400{{ end }}">
{{ if .IsOnline }}Online{{ else }}Offline{{ end }}
</span>
</p>
</div>
*/
nodesElements := []elem.Node{}
for _, n := range nodes {
nodeID := bluemonday.StrictPolicy().Sanitize(n.ID)
// Define status-specific classes
statusIconClass := "text-green-400"
statusText := "Online"
statusTextClass := "text-green-400"
if !n.IsOnline() {
statusIconClass = "text-red-400"
statusText = "Offline"
statusTextClass = "text-red-400"
}
nodesElements = append(nodesElements,
elem.Div(
attrs.Props{
"class": "bg-gray-800/80 border border-gray-700/50 rounded-xl p-4 shadow-lg transition-all duration-300 hover:shadow-blue-900/20 hover:border-blue-700/50",
"class": "bg-gray-700 p-6 rounded-lg shadow-lg text-left",
},
// Node ID and status indicator in top row
elem.Div(
elem.P(
attrs.Props{
"class": "flex items-center justify-between mb-3",
"class": "text-sm text-gray-400 mt-2 flex",
},
// Node ID with icon
elem.Div(
elem.I(
attrs.Props{
"class": "flex items-center",
"class": "fas fa-desktop text-gray-400 mr-2",
},
),
elem.Text("Name: "),
elem.Span(
attrs.Props{
"class": "text-gray-200 font-semibold ml-2 mr-1",
},
elem.Text(bluemonday.StrictPolicy().Sanitize(n.ID)),
),
elem.Text("Status: "),
elem.If(
n.IsOnline(),
elem.I(
attrs.Props{
"class": "fas fa-server text-blue-400 mr-2",
"class": "fas fa-circle animate-pulse text-green-500 ml-2 mr-1",
},
),
elem.I(
attrs.Props{
"class": "fas fa-circle animate-pulse text-red-500 ml-2 mr-1",
},
),
),
elem.If(
n.IsOnline(),
elem.Span(
attrs.Props{
"class": "text-green-400",
},
elem.Text("Online"),
),
elem.Span(
attrs.Props{
"class": "text-white font-medium",
"class": "text-red-400",
},
elem.Text(nodeID),
elem.Text("Offline"),
),
),
// Status indicator
elem.Div(
attrs.Props{
"class": "flex items-center",
},
elem.I(
attrs.Props{
"class": "fas fa-circle animate-pulse " + statusIconClass + " mr-1.5",
},
),
elem.Span(
attrs.Props{
"class": statusTextClass,
},
elem.Text(statusText),
),
),
),
// Bottom section with timestamp
elem.Div(
attrs.Props{
"class": "text-xs text-gray-400 pt-1 border-t border-gray-700/30",
},
elem.Text("Last updated: "+time.Now().UTC().Format("2006-01-02 15:04:05")),
),
))
}

View File

@@ -4,7 +4,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
@@ -17,21 +17,45 @@ import (
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
input := new(schema.ElevenLabsSoundGenerationRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context")
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID")
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
if input.Duration != nil {
log.Debug().Float32("duration", *input.Duration).Msg("duration set")
}
if input.Temperature != nil {
log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set")
}
// TODO: Support uploading files?
filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -3,7 +3,7 @@ package elevenlabs
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@@ -20,21 +20,39 @@ import (
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return fiber.ErrBadRequest
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved")
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -3,9 +3,9 @@ package jina
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@@ -19,32 +19,58 @@ import (
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
req := new(schema.JINARerankRequest)
if err := c.BodyParser(req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Cannot parse JSON",
})
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved")
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
request := &proto.RerankRequest{
Query: input.Query,
TopN: int32(input.TopN),
Documents: input.Documents,
Query: req.Query,
TopN: int32(req.TopN),
Documents: req.Documents,
}
results, err := backend.Rerank(request, ml, appConfig, *cfg)
results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg)
if err != nil {
return err
}
response := &schema.JINARerankResponse{
Model: input.Model,
Model: req.Model,
}
for _, r := range results.Results {

View File

@@ -4,15 +4,13 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/pkg/model"
)
// TODO: This is not yet in use. Needs middleware rework, since it is not referenced.
// TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID
//
// @Summary Get TokenMetrics for Active Slot.
@@ -31,13 +29,18 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
return err
}
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)

View File

@@ -4,9 +4,10 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
// TokenizeEndpoint exposes a REST API to tokenize the content
@@ -15,21 +16,42 @@ import (
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
return func(c *fiber.Ctx) error {
input := new(schema.TokenizeRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return ctx.JSON(tokenResponse)
return c.JSON(tokenResponse)
}
}

View File

@@ -3,7 +3,7 @@ package localai
import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@@ -24,24 +24,37 @@ import (
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved")
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if cfg.Backend == "" {
if input.Backend != "" {
cfg.Backend = input.Backend
} else {
cfg.Backend = model.PiperBackend
}
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
if input.Language != "" {
@@ -52,7 +65,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Voice = input.Voice
}
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -4,8 +4,9 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -18,20 +19,45 @@ import (
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
input := new(schema.VADRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request recieved")
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}

View File

@@ -5,19 +5,18 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
@@ -175,20 +174,26 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
funcs := input.Functions
shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions()
@@ -538,7 +543,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m
audios = append(audios, m.StringAudios...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View File

@@ -10,13 +10,12 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
@@ -28,9 +27,10 @@ import (
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
@@ -63,18 +63,22 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
}
return func(c *fiber.Ctx) error {
// Handle Correlation
id := c.Get("X-Correlation-ID", uuid.New().String())
// Add Correlation
c.Set("X-Correlation-ID", id)
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
if config.ResponseFormatMap != nil {
@@ -118,7 +122,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
responses := make(chan schema.OpenAIResponse)
go process(id, predInput, input, config, ml, responses, extraUsage)
go process(predInput, input, config, ml, responses, extraUsage)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

View File

@@ -2,17 +2,16 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/rs/zerolog/log"
@@ -26,21 +25,20 @@ import (
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
log.Debug().Msgf("Edit Endpoint Config: %+v", *config)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
var result []schema.Choice
totalTokenUsage := backend.TokenUsage{}

View File

@@ -2,11 +2,11 @@ package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/google/uuid"
@@ -23,14 +23,14 @@ import (
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
model, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)

View File

@@ -15,7 +15,6 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
@@ -67,23 +66,25 @@ func downloadFile(url string) (string, error) {
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return fiber.ErrBadRequest
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return fiber.ErrBadRequest
if m == "" {
m = "stablediffusion"
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
src := ""
if input.File != "" {
fileData := []byte{}
var err error
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {

View File

@@ -37,7 +37,7 @@ func ComputeChoices(
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback)
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}

View File

@@ -1,450 +1,326 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
backendConfigLoader: backendConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
}
if model == "" {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}
package openai
import (
"context"
"encoding/json"
"fmt"
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/templates"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
}
received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())
ctx, cancel := context.WithCancel(o.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
if !cfg.Validate() {
return nil, nil, fmt.Errorf("failed to validate config")
}
return cfg, input, err
}

View File

@@ -1,6 +1,7 @@
package openai
import (
"fmt"
"io"
"net/http"
"os"
@@ -9,8 +10,6 @@ import (
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@@ -26,16 +25,15 @@ import (
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
m, input, err := readRequest(c, cl, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
return fiber.ErrBadRequest
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request: %w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {

View File

@@ -4,26 +4,17 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterElevenLabsRoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
}

View File

@@ -3,22 +3,16 @@ package routes
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/jina"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterJINARoutes(app *fiber.App,
re *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig) {
// POST endpoint to mimic the reranking
app.Post("/v1/rerank",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }),
jina.JINARerankEndpoint(cl, ml, appConfig))
app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig))
}

View File

@@ -5,16 +5,13 @@ import (
"github.com/gofiber/swagger"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
func RegisterLocalAIRoutes(router *fiber.App,
requestExtractor *middleware.RequestExtractor,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -36,18 +33,8 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
}
router.Post("/tts",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(cl, ml, appConfig))
vadChain := []fiber.Handler{
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VAD)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VADRequest) }),
localai.VADEndpoint(cl, ml, appConfig),
}
router.Post("/vad", vadChain...)
router.Post("/v1/vad", vadChain...)
router.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
router.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
// Stores
sl := model.NewModelLoader("")
@@ -60,14 +47,10 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
}
// Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
router.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// The v1/* urls are exactly the same as above - makes local e2e testing easier if they are registered.
router.Get("/v1/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService))
router.Post("/v1/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService))
// p2p
if p2p.IsP2PEnabled() {
@@ -84,9 +67,6 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/system", localai.SystemInformations(ml, appConfig))
// misc
router.Post("/v1/tokenize",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }),
localai.TokenizeEndpoint(cl, ml, appConfig))
router.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig))
}

View File

@@ -3,50 +3,51 @@ package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
)
func RegisterOpenAIRoutes(app *fiber.App,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
// chat
chatChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ChatEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/chat/completions", chatChain...)
app.Post("/chat/completions", chatChain...)
app.Post("/v1/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/chat/completions",
openai.ChatEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// edit
editChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EditEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/edits", editChain...)
app.Post("/edits", editChain...)
app.Post("/v1/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// completion
completionChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.CompletionEndpoint(application.BackendLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig()),
}
app.Post("/v1/completions", completionChain...)
app.Post("/completions", completionChain...)
app.Post("/v1/engines/:model/completions", completionChain...)
app.Post("/edits",
openai.EditEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// assistant
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
@@ -80,37 +81,53 @@ func RegisterOpenAIRoutes(app *fiber.App,
app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(application.BackendLoader(), application.ApplicationConfig()))
// embeddings
embeddingChain := []fiber.Handler{
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
}
app.Post("/v1/embeddings", embeddingChain...)
app.Post("/embeddings", embeddingChain...)
app.Post("/v1/engines/:model/embeddings", embeddingChain...)
// audio
app.Post("/v1/audio/transcriptions",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()),
// completion
app.Post("/v1/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/audio/speech",
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
app.Post("/v1/engines/:model/completions",
openai.CompletionEndpoint(
application.BackendLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
),
)
// embeddings
app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// audio
app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/audio/speech", localai.TTSEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
// images
app.Post("/v1/images/generations",
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
re.SetOpenAIRequest,
openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
app.Post("/v1/images/generations", openai.ImageEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))
if application.ApplicationConfig().ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig().ImageDir)
}
if application.ApplicationConfig().AudioDir != "" {
app.Static("/generated-audio", application.ApplicationConfig().AudioDir)
}
// List models
app.Get("/v1/models", openai.ListModelsEndpoint(application.BackendLoader(), application.ModelLoader(), application.ApplicationConfig()))

Some files were not shown because too many files have changed in this diff Show More