Compare commits

..

86 Commits

Author SHA1 Message Date
Ettore Di Giacinto
495191a54a fix(llama.cpp): fix eos without cache 2024-03-18 12:14:16 +01:00
Ettore Di Giacinto
b790fca180 fix(whisper.cpp): Add stubs and -lcuda 2024-03-18 12:13:39 +01:00
Ettore Di Giacinto
0663f66205 deps(whisper.cpp): update, fix cublas build 2024-03-16 10:38:57 +01:00
LocalAI [bot]
5826fb8e6d ⬆️ Update mudler/go-piper (#1844)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-15 23:51:03 +00:00
Ettore Di Giacinto
89351f1a7d feat(embeddings): do not require to be configured (#1842)
Certain engines requires to know during model loading
if the embedding feature has to be enabled, however, it is impractical
to have to set it to ALL the backends that supports embeddings.

There are transformers and sentencentransformers that seamelessly handle
both cases, without having this settings to be explicitly enabled.

The case sussist only for ggml-based models that needs to enable
featuresets during model loading (and thus settings `embedding` is
required), however most of the other engines does not require this.

This change disables the check done at code side, making easier to use
embeddings by not having to specify explicitly `embeddings: true`.

Part of: https://github.com/mudler/LocalAI/issues/1373
2024-03-15 18:14:23 +01:00
Ettore Di Giacinto
ae2e4fc2fe docs(transformers): add docs section about transformers (#1841) 2024-03-15 18:13:30 +01:00
Dave
db199f61da fix: osx build default.metallib (#1837)
fix: osx build default.metallib (#1837)
* port osx fix from refactor pr to slim pr
* manually bump llama.cpp version to unstick CI?
2024-03-15 08:18:58 +00:00
LocalAI [bot]
44adbd2c75 ⬆️ Update go-skynet/go-llama.cpp (#1835)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-14 23:06:42 +00:00
Ettore Di Giacinto
20136ca8b7 feat(tts): add Elevenlabs and OpenAI TTS compatibility layer (#1834)
* feat(elevenlabs): map elevenlabs API support to TTS

This allows elevenlabs Clients to work automatically with LocalAI by
supporting the elevenlabs API.

The elevenlabs server endpoint is implemented such as it is wired to the
TTS endpoints.

Fixes: https://github.com/mudler/LocalAI/issues/1809

* feat(openai/tts): compat layer with openai tts

Fixes: #1276

* fix: adapt tts CLI
2024-03-14 23:08:34 +01:00
Dave
45d520f913 fix: OSX Build Files for llama.cpp (#1836)
bot ate my changes, seperate branch
2024-03-14 23:07:47 +01:00
fakezeta
3882130911 feat: Add Bitsandbytes quantization for transformer backend enhancement #1775 and fix: Transformer backend error on CUDA #1774 (#1823)
* fixes #1775 and #1774

Add BitsAndBytes Quantization and fixes embedding on CUDA devices

* Manage 4bit and 8 bit quantization

Manage different BitsAndBytes options with the quantization: parameter in yaml

* fix compilation errors on non CUDA environment
2024-03-14 23:06:30 +01:00
cryptk
a6b540737f fix: missing OpenCL libraries from docker containers during clblas docker build (#1830) 2024-03-14 08:40:37 +01:00
LocalAI [bot]
f82065703d ⬆️ Update ggerganov/llama.cpp (#1827)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-14 08:39:39 +01:00
cryptk
b423af001d fix: the correct BUILD_TYPE for OpenCL is clblas (with no t) (#1828) 2024-03-14 08:39:21 +01:00
Ettore Di Giacinto
b9e77d394b feat(model-help): display help text in markdown (#1825)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-03-13 21:50:46 +01:00
Ettore Di Giacinto
57222497ec fix(docker-compose): update docker compose file (#1824)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-03-13 17:57:45 +01:00
LocalAI [bot]
5c5f07c1e7 ⬆️ Update ggerganov/llama.cpp (#1821)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-13 10:05:46 +01:00
Ettore Di Giacinto
f895d06605 fix(config): set better defaults for inferencing (#1822)
* fix(defaults): set better defaults for inferencing

This changeset aim to have better defaults and to properly detect when
no inference settings are provided with the model.

If not specified, we defaults to mirostat sampling, and offload all the
GPU layers (if a GPU is detected).

Related to https://github.com/mudler/LocalAI/issues/1373 and https://github.com/mudler/LocalAI/issues/1723

* Adapt tests

* Also pre-initialize default seed
2024-03-13 10:05:30 +01:00
Ettore Di Giacinto
bc8f648a91 fix(doc/examples): set defaults to mirostat (#1820)
The default sampler on some models don't return enough candidates which
leads to a false sense of randomness. Tracing back the code it looks
that with the temperature sampler there might not be enough
candidates to pick from, and since the seed and "randomness" take effect
while picking a good candidate this yields to the same results over and
over.

Fixes https://github.com/mudler/LocalAI/issues/1723 by updating the
examples and documentation to use mirostat instead.
2024-03-11 19:49:03 +01:00
LocalAI [bot]
8e57f4df31 ⬆️ Update ggerganov/llama.cpp (#1818)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-11 00:02:37 +01:00
LocalAI [bot]
a08cc5adbb ⬆️ Update ggerganov/llama.cpp (#1816)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-10 09:32:09 +01:00
LocalAI [bot]
595a73fce4 ⬆️ Update ggerganov/llama.cpp (#1813)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-09 09:27:06 +01:00
LocalAI [bot]
dc919e08e8 ⬆️ Update ggerganov/llama.cpp (#1811)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-08 08:21:25 +01:00
Ettore Di Giacinto
5d1018495f feat(intel): add diffusers/transformers support (#1746)
* feat(intel): add diffusers support

* try to consume upstream container image

* Debug

* Manually install deps

* Map transformers/hf cache dir to modelpath if not specified

* fix(compel): update initialization, pass by all gRPC options

* fix: add dependencies, implement transformers for xpu

* base it from the oneapi image

* Add pillow

* set threads if specified when launching the API

* Skip conda install if intel

* defaults to non-intel

* ci: add to pipelines

* prepare compel only if enabled

* Skip conda install if intel

* fix cleanup

* Disable compel by default

* Install torch 2.1.0 with Intel

* Skip conda on some setups

* Detect python

* Quiet output

* Do not override system python with conda

* Prefer python3

* Fixups

* exllama2: do not install without conda (overrides pytorch version)

* exllama/exllama2: do not install if not using cuda

* Add missing dataset dependency

* Small fixups, symlink to python, add requirements

* Add neural_speed to the deps

* correctly handle model offloading

* fix: device_map == xpu

* go back at calling python, fixed at dockerfile level

* Exllama2 restricted to only nvidia gpus

* Tokenizer to xpu
2024-03-07 14:37:45 +01:00
LocalAI [bot]
ad6fd7a991 ⬆️ Update ggerganov/llama.cpp (#1805)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-06 23:28:31 +01:00
LocalAI [bot]
e022b5959e ⬆️ Update mudler/go-stable-diffusion (#1802)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-05 23:39:57 +00:00
LocalAI [bot]
db7f4955a1 ⬆️ Update ggerganov/llama.cpp (#1801)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-05 21:50:27 +00:00
Dave
5c69dd155f feat(autogpt/transformers): consume trust_remote_code (#1799)
trusting remote code by default is a danger to our users
2024-03-05 19:47:15 +01:00
TwinFin
504f2e8bf4 Update Backend Dependancies (#1797)
* Update transformers.yml

Signed-off-by: TwinFin <57421631+TwinFinz@users.noreply.github.com>

* Update transformers-rocm.yml

Signed-off-by: TwinFin <57421631+TwinFinz@users.noreply.github.com>

* Update transformers-nvidia.yml

Signed-off-by: TwinFin <57421631+TwinFinz@users.noreply.github.com>

---------

Signed-off-by: TwinFin <57421631+TwinFinz@users.noreply.github.com>
2024-03-05 10:10:00 +00:00
Luna Midori
e586dc2924 Edit links in readme and integrations page (#1796)
* Update integrations.md

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

* Update README.md

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

* Update README.md

Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

* Update README.md

Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

---------

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-03-05 10:14:30 +01:00
Ettore Di Giacinto
333f918005 Update integrations.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-03-05 09:45:54 +01:00
LocalAI [bot]
c8e29033c2 ⬆️ Update ggerganov/llama.cpp (#1794)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-05 08:59:09 +01:00
LocalAI [bot]
d0bd961bde ⬆️ Update ggerganov/llama.cpp (#1791)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-04 09:44:21 +01:00
Ettore Di Giacinto
006511ee25 Revert "feat(assistant): Initial implementation of assistants api" (#1790)
Revert "feat(assistant): Initial implementation of assistants api (#1761)"

This reverts commit 4ab72146cd.
2024-03-03 10:31:06 +01:00
Steven Christou
4ab72146cd feat(assistant): Initial implementation of assistants api (#1761)
Initial implementation of assistants api
2024-03-03 08:50:43 +01:00
LocalAI [bot]
b60a3fc879 ⬆️ Update ggerganov/llama.cpp (#1789)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-03 08:49:23 +01:00
Ettore Di Giacinto
a0eeb74957 Update hot topics/roadmap
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-03-02 09:35:40 +01:00
LocalAI [bot]
daa0b8741c ⬆️ Update ggerganov/llama.cpp (#1785)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-03-01 22:38:24 +00:00
Ludovic Leroux
939411300a Bump vLLM version + more options when loading models in vLLM (#1782)
* Bump vLLM version to 0.3.2

* Add vLLM model loading options

* Remove transformers-exllama

* Fix install exllama
2024-03-01 22:48:53 +01:00
Dave
1c312685aa refactor: move remaining api packages to core (#1731)
* core 1

* api/openai/files fix

* core 2 - core/config

* move over core api.go and tests to the start of core/http

* move over localai specific endpoints to core/http, begin the service/endpoint split there

* refactor big chunk on the plane

* refactor chunk 2 on plane, next step: port and modify changes to request.go

* easy fixes for request.go, major changes not done yet

* lintfix

* json tag lintfix?

* gitignore and .keep files

* strange fix attempt: rename the config dir?
2024-03-01 16:19:53 +01:00
LocalAI [bot]
316de82f51 ⬆️ Update ggerganov/llama.cpp (#1779)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-29 22:33:30 +00:00
Ettore Di Giacinto
9068bc5271 Create SECURITY.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-29 19:53:04 +01:00
Oussama
31a4c9c9d3 Fix Command Injection Vulnerability (#1778)
* Added fix for command injection

* changed function name from sh to runCommand
2024-02-29 18:32:29 +00:00
Ettore Di Giacinto
c1966af2cf ci: reduce stress on self-hosted runners (#1776)
Split jobs by self-hosted and free public runner provided by Github

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-29 11:40:08 +01:00
LocalAI [bot]
c665898652 ⬆️ Update donomii/go-rwkv.cpp (#1771)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-28 23:50:27 +00:00
LocalAI [bot]
f651a660aa ⬆️ Update ggerganov/llama.cpp (#1772)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-28 23:02:30 +01:00
Ettore Di Giacinto
ba672b51da Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-28 16:03:38 +01:00
Ettore Di Giacinto
be498c5dd9 Update openai-functions.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-28 15:58:31 +01:00
Ettore Di Giacinto
6e95beccb9 Update overview.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-28 15:24:08 +01:00
Ettore Di Giacinto
c8be839481 Update openai-functions.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-27 23:24:46 +01:00
LocalAI [bot]
c7e08813a5 ⬆️ Update ggerganov/llama.cpp (#1767)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-27 23:12:51 +01:00
LocalAI [bot]
d21a6b33ab ⬆️ Update ggerganov/llama.cpp (#1756)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-27 18:07:51 +00:00
Joshua Waring
9112cf153e Update integrations.md (#1765)
Added Jetbrains compatible plugin for LocalAI

Signed-off-by: Joshua Waring <Joshhua5@users.noreply.github.com>
2024-02-27 17:35:59 +01:00
Ettore Di Giacinto
3868ac8402 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-27 15:44:15 +01:00
Ettore Di Giacinto
3f09010227 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-27 15:43:15 +01:00
Ettore Di Giacinto
d6cf82aba3 fix(tests): re-enable tests after code move (#1764)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-27 15:04:19 +01:00
Ettore Di Giacinto
dfe54639b1 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-27 10:37:56 +01:00
Ettore Di Giacinto
bc5f5aa538 deps(llama.cpp): update (#1759)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-26 13:18:44 +01:00
Ettore Di Giacinto
05818e0425 fix(functions): handle correctly when there are no results (#1758) 2024-02-26 08:38:23 +01:00
Sertaç Özercan
7f72a61104 ci: add stablediffusion to release (#1757)
Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
2024-02-25 23:06:18 +00:00
LocalAI [bot]
8e45d47740 ⬆️ Update ggerganov/llama.cpp (#1753)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-25 10:03:19 +01:00
LocalAI [bot]
71771d1e9b ⬆️ Update docs version mudler/LocalAI (#1752)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-25 10:02:52 +01:00
Ettore Di Giacinto
aa098e4d0b fix(sse): do not omit empty finish_reason (#1745)
Fixes https://github.com/mudler/LocalAI/issues/1744
2024-02-24 11:51:59 +01:00
Ludovic Leroux
0135e1e3b9 fix: vllm - use AsyncLLMEngine to allow true streaming mode (#1749)
* fix: use vllm AsyncLLMEngine to bring true stream

Current vLLM implementation uses the LLMEngine, which was designed for offline batch inference, which results in the streaming mode outputing all blobs at once at the end of the inference.

This PR reworks the gRPC server to use asyncio and gRPC.aio, in combination with vLLM's AsyncLLMEngine to bring true stream mode.

This PR also passes more parameters to vLLM during inference (presence_penalty, frequency_penalty, stop, ignore_eos, seed, ...).

* Remove unused import
2024-02-24 11:48:45 +01:00
LocalAI [bot]
ff88c390bb ⬆️ Update ggerganov/llama.cpp (#1750)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-24 00:06:46 +01:00
LocalAI [bot]
d825821a22 ⬆️ Update ggerganov/llama.cpp (#1740)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-23 00:07:15 +01:00
Luna Midori
cbed6ab1bb Update README.md (#1739)
* Update README.md

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

* Update README.md

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>

---------

Signed-off-by: Luna Midori <118759930+lunamidori5@users.noreply.github.com>
2024-02-22 16:35:06 +01:00
LocalAI [bot]
6fc122fa1a ⬆️ Update ggerganov/llama.cpp (#1705)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-22 09:33:23 +00:00
Ettore Di Giacinto
feba38be36 examples(mistral-openorca): add stopword
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-22 00:15:08 +01:00
Ettore Di Giacinto
ba85d0bcad feat(upload-api): do not display error if uploadedFiles.json is not present
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-22 00:15:08 +01:00
Ettore Di Giacinto
ad3623dd8d examples(phi-2): strip newline at the end of the prompt template
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-21 23:17:51 +01:00
Ettore Di Giacinto
8292781045 deps(llama.cpp): update, support Gemma models (#1734)
deps(llama.cpp): update

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-21 17:23:38 +01:00
Ettore Di Giacinto
54ec6348fa deps(llama.cpp): update (#1714)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2024-02-21 11:35:44 +01:00
Dave
255748bcba MQTT Startup Refactoring Part 1: core/ packages part 1 (#1728)
This PR specifically introduces a `core` folder and moves the following packages over, without any other changes:

- `api/backend`
- `api/config`
- `api/options`
- `api/schema`

Once this is merged and we confirm there's no regressions, I can migrate over the remaining changes piece by piece to split up application startup, backend services, http, and mqtt as was the goal of the earlier PRs!
2024-02-21 01:21:19 +00:00
Chakib Benziane
594eb468df Add TTS dependency for cuda based builds fixes #1727 (#1730)
Signed-off-by: Chakib Benziane <contact@blob42.xyz>
2024-02-20 21:59:43 +01:00
Ettore Di Giacinto
960d314e4f feat(tools): Parallel function calling (#1726)
feat(tools): support returning multiple tools choices

Fixes: https://github.com/mudler/LocalAI/issues/1275
2024-02-20 21:58:45 +01:00
Ettore Di Giacinto
ed3b50622b Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-20 19:55:36 +01:00
Ettore Di Giacinto
9f2235c208 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-19 19:49:00 +01:00
Ettore Di Giacinto
4ec50bfc41 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-19 19:03:09 +01:00
Ettore Di Giacinto
51b67a247a Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-18 13:37:16 +01:00
Steven Christou
01205fd4c0 Initial implementation of upload files api. (#1703)
* Initial implementation of upload files api.

* Move sanitize method to utils.

* Save uploaded data to uploads folder.

* Avoid loop if we do not have a purpose.

* Minor cleanup of api and fix bug where deleting duplicate filename cause error.

* Revert defer of saving config

* Moved creation of directory to startup.

* Make file names unique when storing on disk.

* Add test for files api.

* Update dependencies.
2024-02-18 10:12:02 +00:00
Ettore Di Giacinto
c72808f18b feat(tools): support Tool calls in the API (#1715)
* feat(tools): support Tools in the API

Co-authored-by: =?UTF-8?q?Stephan=20A=C3=9Fmus?= <stephan.assmus@sap.com>

* feat(tools): support function streaming

* Adhere to new return types when using tools instead of functions

* Keep backward compatibility with function calling

* Evaluate function names in chat templates

* Disable recovery with --debug

* Correctly stream out the entire result

* Detect when llm chooses to reply and to not perform any action in SSE

* Feedback from code review

---------

Co-authored-by: =?UTF-8?q?Stephan=20A=C3=9Fmus?= <stephan.assmus@sap.com>
2024-02-17 10:00:34 +01:00
Ettore Di Giacinto
6b539a2972 Update README.md
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-16 15:22:35 +01:00
LocalAI [bot]
2151d21862 ⬆️ Update docs version mudler/LocalAI (#1718)
* ⬆️ Update docs version mudler/LocalAI

Signed-off-by: GitHub <noreply@github.com>

* Update docs/data/version.json

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: GitHub <noreply@github.com>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2024-02-16 15:11:53 +01:00
fenfir
fb0a4c5d9a Build docker container for ROCm (#1595)
* Dockerfile changes to build for ROCm

* Adjust linker flags for ROCm

* Update conda env for diffusers and transformers to use ROCm pytorch

* Update transformers conda env for ROCm

* ci: build hipblas images

* fixup rebase

* use self-hosted

Signed-off-by: mudler <mudler@localai.io>

* specify LD_LIBRARY_PATH only when BUILD_TYPE=hipblas

---------

Signed-off-by: mudler <mudler@localai.io>
Co-authored-by: mudler <mudler@localai.io>
2024-02-16 15:08:50 +01:00
Ettore Di Giacinto
e690bf387a fix(tts): fix regression when supplying backend from requests (#1713)
fixes #1707

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2024-02-15 17:33:06 +01:00
140 changed files with 5380 additions and 2992 deletions

View File

@@ -3,3 +3,4 @@ models
examples/chatbot-ui/models
examples/rwkv/models
examples/**/models
Dockerfile

2
.env
View File

@@ -18,7 +18,7 @@
## Default path for models
#
MODELS_PATH=/models
# MODELS_PATH=/models
## Enable debug mode
# DEBUG=true

View File

@@ -51,6 +51,22 @@ jobs:
image-type: 'extras'
runs-on: 'arc-runner-set'
base-image: "ubuntu:22.04"
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas'
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
runs-on: 'arc-runner-set'
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: 'sycl-f16-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
runs-on: 'arc-runner-set'
core-image-build:
uses: ./.github/workflows/image_build.yml
with:
@@ -97,4 +113,4 @@ jobs:
ffmpeg: 'true'
image-type: 'core'
runs-on: 'ubuntu-latest'
base-image: "ubuntu:22.04"
base-image: "ubuntu:22.04"

View File

@@ -13,7 +13,7 @@ concurrency:
cancel-in-progress: true
jobs:
extras-image-build:
self-hosted-jobs:
uses: ./.github/workflows/image_build.yml
with:
tag-latest: ${{ matrix.tag-latest }}
@@ -37,6 +37,7 @@ jobs:
max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }}
matrix:
include:
# Extra images
- build-type: ''
#platforms: 'linux/amd64,linux/arm64'
platforms: 'linux/amd64'
@@ -103,35 +104,39 @@ jobs:
image-type: 'extras'
base-image: "ubuntu:22.04"
runs-on: 'arc-runner-set'
core-image-build:
uses: ./.github/workflows/image_build.yml
with:
tag-latest: ${{ matrix.tag-latest }}
tag-suffix: ${{ matrix.tag-suffix }}
ffmpeg: ${{ matrix.ffmpeg }}
image-type: ${{ matrix.image-type }}
build-type: ${{ matrix.build-type }}
cuda-major-version: ${{ matrix.cuda-major-version }}
cuda-minor-version: ${{ matrix.cuda-minor-version }}
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
strategy:
matrix:
include:
- build-type: ''
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-ffmpeg-core'
tag-suffix: '-hipblas-ffmpeg'
ffmpeg: 'true'
image-type: 'core'
base-image: "ubuntu:22.04"
runs-on: 'ubuntu-latest'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
runs-on: 'arc-runner-set'
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas'
ffmpeg: 'false'
image-type: 'extras'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
runs-on: 'arc-runner-set'
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f16-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
runs-on: 'arc-runner-set'
- build-type: 'sycl_f32'
platforms: 'linux/amd64'
tag-latest: 'false'
base-image: "intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04"
tag-suffix: '-sycl-f32-ffmpeg'
ffmpeg: 'true'
image-type: 'extras'
runs-on: 'arc-runner-set'
# Core images
- build-type: 'sycl_f16'
platforms: 'linux/amd64'
tag-latest: 'false'
@@ -164,6 +169,52 @@ jobs:
ffmpeg: 'true'
image-type: 'core'
runs-on: 'arc-runner-set'
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
runs-on: 'arc-runner-set'
- build-type: 'hipblas'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-hipblas-core'
ffmpeg: 'false'
image-type: 'core'
base-image: "rocm/dev-ubuntu-22.04:6.0-complete"
runs-on: 'arc-runner-set'
core-image-build:
uses: ./.github/workflows/image_build.yml
with:
tag-latest: ${{ matrix.tag-latest }}
tag-suffix: ${{ matrix.tag-suffix }}
ffmpeg: ${{ matrix.ffmpeg }}
image-type: ${{ matrix.image-type }}
build-type: ${{ matrix.build-type }}
cuda-major-version: ${{ matrix.cuda-major-version }}
cuda-minor-version: ${{ matrix.cuda-minor-version }}
platforms: ${{ matrix.platforms }}
runs-on: ${{ matrix.runs-on }}
base-image: ${{ matrix.base-image }}
secrets:
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
strategy:
matrix:
include:
- build-type: ''
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-ffmpeg-core'
ffmpeg: 'true'
image-type: 'core'
base-image: "ubuntu:22.04"
runs-on: 'ubuntu-latest'
- build-type: 'cublas'
cuda-major-version: "11"
cuda-minor-version: "7"

View File

@@ -89,6 +89,35 @@ jobs:
files: |
release/*
build-stablediffusion:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v4
with:
submodules: true
- uses: actions/setup-go@v4
with:
go-version: '>=1.21.0'
- name: Dependencies
run: |
sudo apt-get install -y --no-install-recommends libopencv-dev
sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
- name: Build stablediffusion
run: |
make backend-assets/grpc/stablediffusion
mkdir -p release && cp backend-assets/grpc/stablediffusion release
- uses: actions/upload-artifact@v3
with:
name: stablediffusion
path: release/
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: |
release/*
build-macOS:
strategy:
matrix:

4
.gitignore vendored
View File

@@ -21,6 +21,7 @@ local-ai
!charts/*
# prevent above rules from omitting the api/localai folder
!api/localai
!core/**/localai
# Ignore models
models/*
@@ -34,6 +35,7 @@ release/
.idea
# Generated during build
backend-assets/
backend-assets/*
!backend-assets/.keep
prepare
/ggml-metal.metal

View File

@@ -1,10 +1,11 @@
ARG GO_VERSION=1.21
ARG IMAGE_TYPE=extras
ARG BASE_IMAGE=ubuntu:22.04
# extras or core
FROM ${BASE_IMAGE} as requirements-core
USER root
ARG GO_VERSION=1.21.7
ARG BUILD_TYPE
ARG CUDA_MAJOR_VERSION=11
@@ -22,7 +23,7 @@ RUN apt-get update && \
apt-get install -y ca-certificates curl patch pip cmake git && apt-get clean
# Install Go
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -v -C /usr/local -xz
RUN curl -L -s https://go.dev/dl/go$GO_VERSION.linux-$TARGETARCH.tar.gz | tar -C /usr/local -xz
ENV PATH $PATH:/usr/local/go/bin
COPY --chmod=644 custom-ca-certs/* /usr/local/share/ca-certificates/
@@ -42,8 +43,12 @@ RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && apt-get clean \
; fi
# Cuda
ENV PATH /usr/local/cuda/bin:${PATH}
# HipBLAS requirements
ENV PATH /opt/rocm/bin:${PATH}
# OpenBLAS requirements and stable diffusion
RUN apt-get install -y \
libopenblas-dev \
@@ -70,10 +75,16 @@ RUN curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmo
apt-get install -y conda && apt-get clean
ENV PATH="/root/.cargo/bin:${PATH}"
RUN apt-get install -y python3-pip && apt-get clean
RUN pip install --upgrade pip
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
RUN apt-get install -y espeak-ng espeak && apt-get clean
RUN if [ ! -e /usr/bin/python ]; then \
ln -s /usr/bin/python3 /usr/bin/python \
; fi
###################################
###################################
@@ -94,6 +105,13 @@ COPY . .
COPY .git .
RUN make prepare
# If we are building with clblas support, we need the libraries for the builds
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
apt-get update && \
apt-get install -y libclblast-dev && \
apt-get clean \
; fi
# stablediffusion does not tolerate a newer version of abseil, build it first
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
@@ -137,6 +155,13 @@ RUN if [ "${FFMPEG}" = "true" ]; then \
apt-get install -y ffmpeg && apt-get clean \
; fi
# Add OpenCL
RUN if [ "${BUILD_TYPE}" = "clblas" ]; then \
apt-get update && \
apt-get install -y libclblast1 && \
apt-get clean \
; fi
WORKDIR /build
# we start fresh & re-copy all assets because `make build` does not clean up nicely after itself
@@ -161,43 +186,43 @@ COPY --from=builder /build/backend-assets/grpc/stablediffusion ./backend-assets/
## Duplicated from Makefile to avoid having a big layer that's hard to push
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/autogptq \
make -C backend/python/autogptq \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/bark \
make -C backend/python/bark \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/diffusers \
make -C backend/python/diffusers \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
make -C backend/python/vllm \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
make -C backend/python/mamba \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
make -C backend/python/sentencetransformers \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers \
make -C backend/python/transformers \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/vall-e-x \
make -C backend/python/vall-e-x \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama \
make -C backend/python/exllama \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/exllama2 \
make -C backend/python/exllama2 \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/petals \
make -C backend/python/petals \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/transformers-musicgen \
make -C backend/python/transformers-musicgen \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/coqui \
make -C backend/python/coqui \
; fi
# Make sure the models directory exists

View File

@@ -4,11 +4,11 @@ GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
# llama.cpp versions
GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0
GOLLAMA_VERSION?=6a8041ef6b46d4712afc3ae791d1c2d73da0ad1c
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
CPPLLAMA_VERSION?=f026f8120f97090d34a52b3dc023c82e0ede3f7d
CPPLLAMA_VERSION?=4755afd1cbd40d93c017e5b98c39796f52345314
# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
@@ -16,19 +16,19 @@ GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8
# go-rwkv version
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=633c5a3485c403cb2520693dc0991a25dace9f0f
RWKV_VERSION?=661e7ae26d442f5cfebd2a0881b44e8c55949ec6
# whisper.cpp version
WHISPER_CPP_VERSION?=37a709f6558c6d9783199e2b8cbb136e1c41d346
WHISPER_CPP_VERSION?=a56f435fd475afd7edf02bfbf9f8c77f527198c2
# bert.cpp version
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
# go-piper version
PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07
PIPER_VERSION?=9d0100873a7dbb0824dfea40e8cec70a1b110759
# stablediffusion version
STABLEDIFFUSION_VERSION?=d5d2be8e7e395c2d73ceef61e6fe8d240f2cd831
STABLEDIFFUSION_VERSION?=362df9da29f882dbf09ade61972d16a1f53c3485
# tinydream version
TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a
@@ -44,6 +44,8 @@ BUILD_ID?=git
TEST_DIR=/tmp/test
TEST_FLAKES?=5
RANDOM := $(shell bash -c 'echo $$RANDOM')
VERSION?=$(shell git describe --always --tags || echo "dev" )
@@ -89,14 +91,19 @@ ifeq ($(BUILD_TYPE),openblas)
export WHISPER_OPENBLAS=1
endif
ifeq ($(BUILD_TYPE),cublas)
CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH)
CGO_LDFLAGS+=-lcublas -lcudart -lculibos -lcublasLt -L$(CUDA_LIBPATH)
export LLAMA_CUBLAS=1
# required by whisper.cpp
export WHISPER_CUBLAS=1
CGO_LDFLAGS+=-L$(CUDA_PATH)/stubs -lcuda
endif
ifeq ($(BUILD_TYPE),hipblas)
ROCM_HOME ?= /opt/rocm
ROCM_PATH ?= /opt/rocm
LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib
export CXX=$(ROCM_HOME)/llvm/bin/clang++
export CC=$(ROCM_HOME)/llvm/bin/clang
# llama-ggml has no hipblas support, so override it here.
@@ -105,7 +112,7 @@ ifeq ($(BUILD_TYPE),hipblas)
GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100
AMDGPU_TARGETS ?= "$(GPU_TARGETS)"
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)"
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link
CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib
endif
ifeq ($(BUILD_TYPE),metal)
@@ -153,6 +160,7 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
GRPC_BACKENDS?=$(ALL_GRPC_BACKENDS) $(OPTIONAL_GRPC)
TEST_PATHS?=./api/... ./pkg/... ./core/...
# If empty, then we build all
ifeq ($(GRPC_BACKENDS),)
@@ -248,7 +256,7 @@ sources/go-piper/libpiper_binding.a: sources/go-piper
$(MAKE) -C sources/go-piper libpiper_binding.a example/main
backend/cpp/llama/llama.cpp:
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp
get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream
touch $@
@@ -326,7 +334,7 @@ test-models/testmodel:
cp tests/models_fixtures/* test-models
prepare-test: grpcs
cp -rf backend-assets api
cp -rf backend-assets core/http
cp tests/models_fixtures/* test-models
test: prepare test-models/testmodel grpcs
@@ -334,7 +342,7 @@ test: prepare test-models/testmodel grpcs
export GO_TAGS="tts stablediffusion"
$(MAKE) prepare-test
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
$(MAKE) test-gpt4all
$(MAKE) test-llama
$(MAKE) test-llama-gguf
@@ -363,23 +371,23 @@ teardown-e2e:
test-gpt4all: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS)
test-llama: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS)
test-llama-gguf: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS)
test-tts: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r $(TEST_PATHS)
test-stablediffusion: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r $(TEST_PATHS)
test-container:
docker build --target requirements -t local-ai-test-container .
@@ -457,9 +465,6 @@ backend-assets/grpc/llama: backend-assets/grpc sources/go-llama/libbinding.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama LIBRARY_PATH=$(CURDIR)/sources/go-llama \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./backend/go/llm/llama/
# TODO: every binary should have its own folder instead, so can have different implementations
ifeq ($(BUILD_TYPE),metal)
cp backend/cpp/llama/llama.cpp/ggml-metal.metal backend-assets/grpc/
endif
## BACKEND CPP LLAMA START
# Sets the variables in case it has to build the gRPC locally.
@@ -480,7 +485,7 @@ ifdef BUILD_GRPC_FOR_BACKEND_LLAMA
CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
else
echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined."
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
endif
## BACKEND CPP LLAMA END
@@ -489,7 +494,7 @@ backend-assets/grpc/llama-cpp: backend-assets/grpc backend/cpp/llama/grpc-server
cp -rfv backend/cpp/llama/grpc-server backend-assets/grpc/llama-cpp
# TODO: every binary should have its own folder instead, so can have different metal implementations
ifeq ($(BUILD_TYPE),metal)
cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/
cp backend/cpp/llama/llama.cpp/build/bin/default.metallib backend-assets/grpc/
endif
backend-assets/grpc/llama-ggml: backend-assets/grpc sources/go-llama-ggml/libbinding.a
@@ -514,6 +519,7 @@ backend-assets/grpc/langchain-huggingface: backend-assets/grpc
backend-assets/grpc/stablediffusion: backend-assets/grpc
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
$(MAKE) sources/go-stable-diffusion; \
$(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-stable-diffusion/ LIBRARY_PATH=$(CURDIR)/sources/go-stable-diffusion/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \
@@ -551,3 +557,10 @@ docker-image-intel:
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
--build-arg GO_TAGS="none" \
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .
docker-image-intel-xpu:
docker build \
--build-arg BASE_IMAGE=intel/oneapi-basekit:2024.0.1-devel-ubuntu22.04 \
--build-arg IMAGE_TYPE=$(IMAGE_TYPE) \
--build-arg GO_TAGS="none" \
--build-arg BUILD_TYPE=sycl_f32 -t $(DOCKER_IMAGE) .

View File

@@ -43,19 +43,24 @@
[Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
- Intel GPU support (sycl): https://github.com/mudler/LocalAI/issues/1653
- Parallel function calling: https://github.com/mudler/LocalAI/pull/1726
- Upload file API: https://github.com/mudler/LocalAI/pull/1703
- Tools API support: https://github.com/mudler/LocalAI/pull/1715
- LLaVa 1.6: https://github.com/mudler/LocalAI/pull/1714
- ROCm container images: https://github.com/mudler/LocalAI/pull/1595
- Intel GPU support (sycl, transformers, diffusers): https://github.com/mudler/LocalAI/issues/1653
- Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651
- Mamba support: https://github.com/mudler/LocalAI/pull/1589
- Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522
- 🐸 Coqui: https://github.com/mudler/LocalAI/pull/1489
- Inline templates: https://github.com/mudler/LocalAI/pull/1452
- Mixtral: https://github.com/mudler/LocalAI/pull/1449
- Img2vid https://github.com/mudler/LocalAI/pull/1442
- Musicgen https://github.com/mudler/LocalAI/pull/1387
Hot topics (looking for contributors):
- Backends v2: https://github.com/mudler/LocalAI/issues/1126
- Improving UX v2: https://github.com/mudler/LocalAI/issues/1373
- Assistant API: https://github.com/mudler/LocalAI/issues/1273
- Moderation endpoint: https://github.com/mudler/LocalAI/issues/999
- Vulkan: https://github.com/mudler/LocalAI/issues/1647
If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22
@@ -64,7 +69,7 @@ If you want to help and contribute, issues up for grabs: https://github.com/mudl
For a detailed step-by-step introduction, refer to the [Getting Started](https://localai.io/basics/getting_started/index.html) guide. For those in a hurry, here's a straightforward one-liner to launch a LocalAI instance with [phi-2](https://huggingface.co/microsoft/phi-2) using `docker`:
```
docker run -ti -p 8080:8080 localai/localai:v2.7.0-ffmpeg-core phi-2
docker run -ti -p 8080:8080 localai/localai:v2.9.0-ffmpeg-core phi-2
```
## 🚀 [Features](https://localai.io/features/)
@@ -94,10 +99,6 @@ WebUIs:
Model galleries
- https://github.com/go-skynet/model-gallery
Auto Docker / Model setup
- https://io.midori-ai.xyz/howtos/easy-localai-installer/
- https://io.midori-ai.xyz/howtos/easy-model-installer/
Other:
- Helm chart https://github.com/go-skynet/helm-charts
@@ -108,6 +109,7 @@ Other:
- Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack
- Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot
- Examples: https://github.com/mudler/LocalAI/tree/master/examples/
### 🔗 Resources
@@ -119,6 +121,8 @@ Other:
## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social)
- [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/ai/answers/tiZMDoZzZV6TLxgDXNBnFE/deploying-helm-charts-on-aws-eks)
- [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance)
- [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/)
- [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE)
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/)

42
SECURITY.md Normal file
View File

@@ -0,0 +1,42 @@
# Security Policy
## Introduction
At LocalAI, we take the security of our software seriously. We understand the importance of protecting our community from vulnerabilities and are committed to ensuring the safety and security of our users.
## Supported Versions
We provide support and updates for certain versions of our software. The following table outlines which versions are currently supported with security updates:
| Version | Supported |
| ------- | ------------------ |
| > 2.0 | :white_check_mark: |
| < 2.0 | :x: |
Please ensure that you are using a supported version to receive the latest security updates.
## Reporting a Vulnerability
We encourage the responsible disclosure of any security vulnerabilities. If you believe you've found a security issue in our software, we kindly ask you to follow the steps below to report it to us:
1. **Email Us:** Send an email to [security@localai.io](mailto:security@localai.io) with a detailed report. Please do not disclose the vulnerability publicly or to any third parties before it has been addressed by us.
2. **Expect a Response:** We aim to acknowledge receipt of vulnerability reports within 48 hours. Our security team will review your report and work closely with you to understand the impact and ensure a thorough investigation.
3. **Collaboration:** If the vulnerability is accepted, we will work with you and our community to address the issue promptly. We'll keep you informed throughout the resolution process and may request additional information or collaboration.
4. **Disclosure:** Once the vulnerability has been resolved, we encourage a coordinated disclosure. We believe in transparency and will work with you to ensure that our community is informed in a responsible manner.
## Use of Third-Party Platforms
As a Free and Open Source Software (FOSS) organization, we do not offer monetary bounties. However, researchers who wish to report vulnerabilities can also do so via [Huntr](https://huntr.dev/bounties), a platform that recognizes contributions to open source security.
## Contact
For any security-related inquiries beyond vulnerability reporting, please contact us at [security@localai.io](mailto:security@localai.io).
## Acknowledgments
We appreciate the efforts of those who contribute to the security of our project. Your responsible disclosure is invaluable to the safety and integrity of LocalAI.
Thank you for helping us keep LocalAI secure.

View File

@@ -1,288 +0,0 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/localai"
"github.com/go-skynet/LocalAI/api/openai"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/startup"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if options.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
app.Use(recover.New())
if options.Metrics != nil {
app.Use(metrics.APIMiddleware(options.Metrics))
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 {
return c.Next()
}
// Check for api_keys.json file
fileContent, err := os.ReadFile("api_keys.json")
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
}
// Add file keys to options.ApiKeys
options.ApiKeys = append(options.ApiKeys, fileKeys...)
}
if len(options.ApiKeys) == 0 {
return c.Next()
}
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
// Make sure directories exists
os.MkdirAll(options.ImageDir, 0755)
os.MkdirAll(options.AudioDir, 0755)
os.MkdirAll(options.Loader.ModelPath, 0755)
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/metrics", metrics.MetricsHandler())
return app, nil
}

View File

@@ -1,61 +0,0 @@
package backend
import (
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
}),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}

View File

@@ -1,129 +0,0 @@
package backend
import (
"os"
"path/filepath"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
)
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
if o.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
if o.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
return opts
}
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
return &pb.ModelOptions{
ContextSize: int32(c.ContextSize),
Seed: int32(c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
CUDA: c.CUDA, // diffusers, transformers
DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath,
Quantization: c.Quantization,
MMProj: c.MMProj,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
YarnBetaFast: c.YarnBetaFast,
YarnBetaSlow: c.YarnBetaSlow,
LoraAdapter: c.LoraAdapter,
LoraBase: c.LoraBase,
LoraScale: c.LoraScale,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
F16Memory: c.F16,
MLock: c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: c.Embeddings,
LowVRAM: c.LowVRAM,
NGPULayers: int32(c.NGPULayers),
MMap: c.MMap,
MainGPU: c.MainGPU,
Threads: int32(c.Threads),
TensorSplit: c.TensorSplit,
// AutoGPTQ
ModelBaseName: c.AutoGPTQ.ModelBaseName,
Device: c.AutoGPTQ.Device,
UseTriton: c.AutoGPTQ.Triton,
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
// RWKV
Tokenizer: c.Tokenizer,
}
}
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}
return &pb.PredictOptions{
Temperature: float32(c.Temperature),
TopP: float32(c.TopP),
NDraft: c.NDraft,
TopK: int32(c.TopK),
Tokens: int32(c.Maxtokens),
Threads: int32(c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: c.F16,
DebugMode: c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(c.LLMConfig.Mirostat),
MirostatETA: float32(c.LLMConfig.MirostatETA),
MirostatTAU: float32(c.LLMConfig.MirostatTAU),
Debug: c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(c.Seed),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: c.MMlock,
MMap: c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(c.TFZ),
TypicalP: float32(c.TypicalP),
}
}

View File

@@ -1,39 +0,0 @@
package backend
import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
opts := modelOpts(c, o, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
})
whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}

View File

@@ -1,84 +0,0 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
api_config "github.com/go-skynet/LocalAI/api/config"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(c)
opts := modelOpts(api_config.Config{}, o, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
piperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return "", nil, err
}
} else {
modelPath = modelFile
}
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Dst: filePath,
})
return filePath, res, err
}

View File

@@ -1,162 +0,0 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
return bm.options.Loader.ShutdownModel(backendId)
}
}

View File

@@ -1,326 +0,0 @@
package localai
import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func NewGalleryService(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cm.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoint Service
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@@ -1,53 +0,0 @@
package localai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
)
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Input
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View File

@@ -1,399 +0,0 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream && !processFunctions
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
MessageIndex: messageIndex,
}
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
if toStream {
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
}
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)
ss["arguments"] = string(d)
ss["name"] = func_name
// if do nothing, reply with a message
if func_name == noActionName {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
}
return
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View File

@@ -126,6 +126,11 @@ message ModelOptions {
// vllm
string Quantization = 40;
float GPUMemoryUtilization = 50;
bool TrustRemoteCode = 51;
bool EnforceEager = 52;
int32 SwapSpace = 53;
int32 MaxModelLen = 54;
string MMProj = 41;
@@ -186,6 +191,7 @@ message TTSRequest {
string text = 1;
string model = 2;
string dst = 3;
string voice = 4;
}
message TokenizationResponse {

View File

@@ -2,16 +2,20 @@
## XXX: In some versions of CMake clip wasn't being built before llama.
## This is an hack for now, but it should be fixed in the future.
set(TARGET myclip)
add_library(${TARGET} clip.cpp clip.h)
add_library(${TARGET} clip.cpp clip.h llava.cpp llava.h)
install(TARGETS ${TARGET} LIBRARY)
target_link_libraries(${TARGET} PRIVATE common ggml ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(myclip PUBLIC .)
target_include_directories(myclip PUBLIC ../..)
target_include_directories(myclip PUBLIC ../../common)
target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if (NOT MSVC)
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
endif()
# END CLIP hack
set(TARGET grpc-server)
# END CLIP hack
set(CMAKE_CXX_STANDARD 17)
cmake_minimum_required(VERSION 3.15)
set(TARGET grpc-server)

View File

@@ -12,12 +12,13 @@ ifeq ($(BUILD_TYPE),cublas)
# to CMAKE_ARGS automatically
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
# If build type is clblast (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
else ifeq ($(BUILD_TYPE),clblast)
# If build type is clblas (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
else ifeq ($(BUILD_TYPE),clblas)
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON
# If it's OSX, DO NOT embed the metal library - -DLLAMA_METAL_EMBED_LIBRARY=ON requires further investigation
endif
ifeq ($(BUILD_TYPE),sycl_f16)
@@ -45,6 +46,9 @@ llama.cpp/examples/grpc-server:
## XXX: In some versions of CMake clip wasn't being built before llama.
## This is an hack for now, but it should be fixed in the future.
cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h
cp -rfv llama.cpp/examples/llava/llava.cpp llama.cpp/examples/grpc-server/llava.cpp
echo '#include "llama.h"' > llama.cpp/examples/grpc-server/llava.h
cat llama.cpp/examples/llava/llava.h >> llama.cpp/examples/grpc-server/llava.h
cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp
rebuild:

View File

@@ -11,7 +11,8 @@
#include <memory>
#include <string>
#include <getopt.h>
#include "../llava/clip.h"
#include "clip.h"
#include "llava.h"
#include "stb_image.h"
#include "common.h"
#include "json.hpp"
@@ -32,6 +33,7 @@
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <atomic>
#include <signal.h>
using grpc::Server;
using grpc::ServerBuilder;
@@ -51,12 +53,16 @@ struct server_params
std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public";
std::string chat_template = "";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
bool slots_endpoint = true;
bool metrics_endpoint = false;
};
bool server_verbose = false;
bool server_log_json = true;
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
{
@@ -172,6 +178,7 @@ struct llama_client_slot
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t i_batch = -1;
int32_t n_predict = -1;
int32_t num_prompt_tokens = 0;
int32_t num_prompt_tokens_processed = 0;
@@ -311,12 +318,76 @@ struct llama_client_slot
}
void print_timings() const {
LOG_TEE("\n");
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_token_generation, n_decoded,t_token_generation / n_decoded, 1e3 / t_token_generation * n_decoded);
LOG_TEE("%s: total time = %10.2f ms\n", __func__, t_prompt_processing + t_token_generation);
char buffer[512];
double t_token = t_prompt_processing / num_prompt_tokens_processed;
double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed;
sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
t_prompt_processing, num_prompt_tokens_processed,
t_token, n_tokens_second);
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_prompt_processing", t_prompt_processing},
{"num_prompt_tokens_processed", num_prompt_tokens_processed},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});
t_token = t_token_generation / n_decoded;
n_tokens_second = 1e3 / t_token_generation * n_decoded;
sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
t_token_generation, n_decoded,
t_token, n_tokens_second);
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_token_generation", t_token_generation},
{"n_decoded", n_decoded},
{"t_token", t_token},
{"n_tokens_second", n_tokens_second},
});
sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
LOG_INFO(buffer, {
{"slot_id", id},
{"task_id", task_id},
{"t_prompt_processing", t_prompt_processing},
{"t_token_generation", t_token_generation},
{"t_total", t_prompt_processing + t_token_generation},
});
}
};
struct llama_metrics {
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
void on_prompt_eval(const llama_client_slot &slot) {
n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed;
n_prompt_tokens_processed += slot.num_prompt_tokens_processed;
t_prompt_processing += slot.t_prompt_processing;
}
void on_prediction(const llama_client_slot &slot) {
n_tokens_predicted_total += slot.n_decoded;
n_tokens_predicted += slot.n_decoded;
t_tokens_generation += slot.t_token_generation;
}
void reset_bucket() {
n_prompt_tokens_processed = 0;
t_prompt_processing = 0;
n_tokens_predicted = 0;
t_tokens_generation = 0;
}
};
@@ -349,10 +420,13 @@ struct llama_server_context
// slots / clients
std::vector<llama_client_slot> slots;
json default_generation_settings_for_props;
llama_server_queue queue_tasks;
llama_server_response queue_results;
llama_metrics metrics;
~llama_server_context()
{
if (ctx)
@@ -372,7 +446,7 @@ struct llama_server_context
params = params_;
if (!params.mmproj.empty()) {
multimodal = true;
LOG_TEE("Multi Modal Mode Enabled");
LOG_INFO("Multi Modal Mode Enabled", {});
clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) {
LOG_ERROR("unable to load clip model", {{"model", params.mmproj}});
@@ -409,21 +483,35 @@ struct llama_server_context
return true;
}
void validate_model_chat_template(server_params & sparams) {
llama_chat_message chat[] = {{"user", "test"}};
std::vector<char> buf(1);
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
if (res < 0) {
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
}
}
void initialize() {
// create slots
all_slots_are_idle = true;
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
LOG_TEE("Available slots:\n");
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
for (int i = 0; i < params.n_parallel; i++)
{
llama_client_slot slot;
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
LOG_INFO("new slot", {
{"slot_id", slot.id},
{"n_ctx_slot", slot.n_ctx}
});
const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;
@@ -433,7 +521,12 @@ struct llama_server_context
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
LOG_INFO("slot self-extend", {
{"slot_id", slot.id},
{"ga_n", ga_n},
{"ga_w", ga_w}
});
}
slot.ga_i = 0;
@@ -445,11 +538,10 @@ struct llama_server_context
slots.push_back(slot);
}
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1;
// empty system prompt
system_prompt = "";
system_tokens.clear();
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
}
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
@@ -526,28 +618,40 @@ struct llama_server_context
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
slot_params default_params;
llama_sampling_params default_sparams;
slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
slot->params.seed = json_value(data, "seed", default_params.seed);
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
slot->params.seed = json_value(data, "seed", default_params.seed);
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
{"params.n_predict", slot->params.n_predict},
{"slot.n_predict", slot->n_predict},
});
slot->params.n_predict = slot->n_predict;
}
// infill
if (data.count("input_prefix") != 0)
@@ -626,18 +730,36 @@ struct llama_server_context
const int n_vocab = llama_n_vocab(model);
for (const auto &el : *logit_bias)
{
if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
if (el.is_array() && el.size() == 2)
{
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab)
float bias;
if (el[1].is_number())
{
if (el[1].is_number())
bias = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
bias = -INFINITY;
}
else
{
continue;
}
if (el[0].is_number_integer())
{
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{
slot->sparams.logit_bias[tok] = el[1].get<float>();
slot->sparams.logit_bias[tok] = bias;
}
else if (el[1].is_boolean() && !el[1].get<bool>())
}
else if (el[0].is_string())
{
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks)
{
slot->sparams.logit_bias[tok] = -INFINITY;
slot->sparams.logit_bias[tok] = bias;
}
}
}
@@ -658,6 +780,24 @@ struct llama_server_context
}
}
const auto &samplers_sequence = data.find("samplers");
if (samplers_sequence != data.end() && samplers_sequence->is_array())
{
std::vector<std::string> sampler_names;
for (const auto &sampler_name : *samplers_sequence)
{
if (sampler_name.is_string())
{
sampler_names.emplace_back(sampler_name);
}
}
slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
}
else
{
slot->sparams.samplers_sequence = default_sparams.samplers_sequence;
}
if (multimodal)
{
const auto &images_data = data.find("image_data");
@@ -672,10 +812,16 @@ struct llama_server_context
img_sl.img_data = clip_image_u8_init();
if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data))
{
LOG_TEE("slot %i - failed to load image [id: %i]\n", slot->id, img_sl.id);
LOG_ERROR("failed to load image", {
{"slot_id", slot->id},
{"img_sl_id", img_sl.id}
});
return false;
}
LOG_TEE("slot %i - loaded image\n", slot->id);
LOG_VERBOSE("image loaded", {
{"slot_id", slot->id},
{"img_sl_id", img_sl.id}
});
img_sl.request_encode_image = true;
slot->images.push_back(img_sl);
}
@@ -735,7 +881,10 @@ struct llama_server_context
all_slots_are_idle = false;
LOG_TEE("slot %i is processing [task id: %i]\n", slot->id, slot->task_id);
LOG_INFO("slot is processing task", {
{"slot_id", slot->id},
{"task_id", slot->task_id},
});
return true;
}
@@ -747,27 +896,44 @@ struct llama_server_context
}
void update_system_prompt() {
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
llama_batch_clear(batch);
kv_cache_clear();
system_tokens.clear();
for (int i = 0; i < (int) system_tokens.size(); ++i)
{
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
if (!system_prompt.empty()) {
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
if (llama_decode(ctx, batch) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
llama_batch_clear(batch);
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
for (int i = 0; i < (int)system_tokens.size(); ++i)
{
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
{
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
}
}
LOG_TEE("system prompt updated\n");
@@ -789,10 +955,8 @@ struct llama_server_context
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
if (slots.size() > 0)
{
notify_system_prompt_changed();
}
notify_system_prompt_changed();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
@@ -920,7 +1084,7 @@ struct llama_server_context
slot.has_next_token = false;
}
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
if (result.tok == llama_token_eos(model))
{
slot.stopped_eos = true;
slot.has_next_token = false;
@@ -950,28 +1114,12 @@ struct llama_server_context
{
continue;
}
clip_image_f32 * img_res = clip_image_f32_init();
if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true))
{
if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) {
LOG_TEE("Error processing the given image");
clip_free(clp_ctx);
return false;
}
img.image_tokens = clip_n_patches(clp_ctx);
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
if (!img.image_embedding)
{
LOG_TEE("Unable to allocate memory for image embeddings\n");
clip_free(clp_ctx);
return false;
}
LOG_TEE("slot %i - encoding image [id: %i]\n", slot.id, img.id);
if (!clip_image_encode(clp_ctx, params.n_threads, img_res, img.image_embedding))
{
LOG_TEE("Unable to encode image\n");
return false;
}
clip_image_f32_free(img_res);
img.request_encode_image = false;
}
@@ -990,21 +1138,25 @@ struct llama_server_context
queue_results.send(res);
}
json get_model_props()
{
return get_formated_generation(slots[0]);
}
json get_formated_generation(llama_client_slot &slot)
{
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;
for (const auto &sampler_type : slot.sparams.samplers_sequence)
{
samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
}
return json {
{"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict},
{"model", params.model_alias},
{"seed", slot.params.seed},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
@@ -1027,7 +1179,9 @@ struct llama_server_context
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence}
};
}
@@ -1166,13 +1320,30 @@ struct llama_server_context
task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{
split_multiprompt_task(task_id, task);
}
// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.post(task);
// if there's numbers in the prompt array it will be treated as an array of tokens
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
bool numbers = false;
for (const auto& e : task.data.at("prompt")) {
if (e.is_number()) {
numbers = true;
break;
}
}
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
// it will completely stall the server. I don't know where the bug for this is.
//
// if there are numbers, it needs to be treated like a single prompt,
// queue_tasks handles a mix of strings and numbers just fine.
if (numbers) {
queue_tasks.post(task);
} else {
split_multiprompt_task(task_id, task);
}
} else {
queue_tasks.post(task);
}
}
// for multiple images processing
@@ -1254,7 +1425,10 @@ struct llama_server_context
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{
int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
if (prompt_count <= 1) {
send_error(multiprompt_task, "error while handling multiple prompts");
return;
}
// generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count);
@@ -1286,7 +1460,7 @@ struct llama_server_context
if (slot == nullptr)
{
// if no slot is available, we defer this task for processing later
LOG_VERBOSE("no slot is available", {});
LOG_VERBOSE("no slot is available", {{"task_id", task.id}});
queue_tasks.defer(task);
break;
}
@@ -1360,7 +1534,7 @@ struct llama_server_context
bool update_slots() {
if (system_need_update)
{
LOG_TEE("updating system prompt\n");
LOG_INFO("updating system prompt", {});
update_system_prompt();
}
@@ -1370,12 +1544,13 @@ struct llama_server_context
{
if (system_prompt.empty() && clean_kv_cache)
{
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
kv_cache_clear();
}
return true;
}
LOG_VERBOSE("posting NEXT_RESPONSE", {});
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
@@ -1406,6 +1581,7 @@ struct llama_server_context
}
// decode any currently ongoing sequences
LOG_VERBOSE("decoding ongoing sequences", {});
for (auto & slot : slots)
{
// release the slot
@@ -1415,7 +1591,15 @@ struct llama_server_context
slot.command = NONE;
slot.t_last_used = ggml_time_us();
LOG_TEE("slot %d released (%d tokens in cache)\n", slot.id, (int) slot.cache_tokens.size());
LOG_INFO("slot released", {
{"slot_id", slot.id},
{"task_id", slot.task_id},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
queue_tasks.notify_slot_changed();
continue;
@@ -1542,6 +1726,14 @@ struct llama_server_context
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
// the last token of the cache is not in the KV cache until the next call to llama_decode
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
{
slot.n_past -= 1;
}
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
if (slot.ga_n != 1)
@@ -1563,19 +1755,23 @@ struct llama_server_context
slot.ga_i = ga_i;
}
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
LOG_INFO("slot progression", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "num_prompt_tokens_processed", slot.num_prompt_tokens_processed }
});
}
LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
slot.cache_tokens = prompt_tokens;
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
{
// we have to evaluate at least 1 token to generate logits.
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id }
});
slot.n_past--;
if (slot.ga_i > 0)
{
@@ -1583,6 +1779,14 @@ struct llama_server_context
}
}
int p0 = (int) system_tokens.size() + slot.n_past;
LOG_INFO("kv cache rm [p0, end)", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "p0", p0 }
});
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
@@ -1616,7 +1820,13 @@ struct llama_server_context
if (has_images && !ingest_images(slot, n_batch))
{
LOG_TEE("failed processing images\n");
LOG_ERROR("failed processing images", {
"slot_id", slot.id,
"task_id", slot.task_id,
});
// FIXME @phymbert: to be properly tested
// early returning without changing the slot state will block the slot for ever
// no one at the moment is checking the return value
return false;
}
@@ -1658,9 +1868,9 @@ struct llama_server_context
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;
@@ -1716,7 +1926,7 @@ struct llama_server_context
send_embedding(slot);
slot.release();
slot.i_batch = -1;
return true;
continue;
}
completion_token_output result;
@@ -1729,6 +1939,7 @@ struct llama_server_context
{
slot.t_start_genereration = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
@@ -1751,11 +1962,14 @@ struct llama_server_context
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
}
slot.i_batch = -1;
}
}
LOG_VERBOSE("slots updated", {});
return true;
}
@@ -1784,18 +1998,6 @@ static json format_partial_response(
return res;
}
static json format_tokenizer_response(const std::vector<llama_token> &tokens)
{
return json{
{"tokens", tokens}};
}
static json format_detokenized_response(std::string content)
{
return json{
{"content", content}};
}
struct token_translator
{
llama_context * ctx;
@@ -1819,6 +2021,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
}
}
std::function<void(int)> shutdown_handler;
inline void signal_handler(int signal) { shutdown_handler(signal); }
/////////////////////////////////
////////////////////////////////
//////// LOCALAI code starts below here
@@ -2051,9 +2256,9 @@ static void params_parse(const backend::ModelOptions* request,
params.use_mmap = request->mmap();
params.embedding = request->embeddings();
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
if ( request->yarnextfactor() != 0.0f ) {
params.yarn_ext_factor = request->yarnextfactor();
}
@@ -2089,7 +2294,8 @@ public:
gpt_params params;
params_parse(request, params);
llama_backend_init(params.numa);
llama_backend_init();
llama_numa_init(params.numa);
// load the model
if (!llama.load_model(params))

View File

@@ -8,24 +8,24 @@ import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/schema"
)
func sh(c string) (string, error) {
cmd := exec.Command("/bin/sh", "-c", c)
func runCommand(command []string) (string, error) {
cmd := exec.Command(command[0], command[1:]...)
cmd.Env = os.Environ()
o, err := cmd.CombinedOutput()
return string(o), err
out, err := cmd.CombinedOutput()
return string(out), err
}
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
// AudioToWav converts audio to wav for transcribe.
// TODO: use https://github.com/mccoyst/ogg?
func audioToWav(src, dst string) error {
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst}
out, err := runCommand(command)
if err != nil {
return fmt.Errorf("error: %w out: %s", err, out)
}
return nil
}

View File

@@ -4,7 +4,7 @@ package main
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
)

View File

@@ -33,7 +33,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model = AutoGPTQForCausalLM.from_quantized(request.Model,
model_basename=request.ModelBaseName,
use_safetensors=True,
trust_remote_code=True,
trust_remote_code=request.TrustRemoteCode,
device=device,
use_triton=request.UseTriton,
quantize_config=None)

View File

@@ -71,7 +71,7 @@ dependencies:
- regex==2023.10.3
- requests==2.31.0
- rouge==1.0.1
- safetensors==0.3.3
- safetensors>=0.3.3
- six==1.16.0
- sympy==1.12
- tokenizers==0.14.0

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,17 @@ ifeq ($(BUILD_TYPE), cublas)
CONDA_ENV_PATH = "transformers-nvidia.yml"
endif
ifeq ($(BUILD_TYPE), hipblas)
CONDA_ENV_PATH = "transformers-rocm.yml"
endif
# Intel GPU are supposed to have dependencies installed in the main python
# environment, so we skip conda installation for SYCL builds.
# https://github.com/intel/intel-extension-for-pytorch/issues/538
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
export SKIP_CONDA=1
endif
.PHONY: transformers
transformers:
@echo "Installing $(CONDA_ENV_PATH)..."

View File

@@ -1,24 +1,38 @@
#!/bin/bash
set -ex
SKIP_CONDA=${SKIP_CONDA:-0}
# Check if environment exist
conda_env_exists(){
! conda list --name "${@}" >/dev/null 2>/dev/null
}
if conda_env_exists "transformers" ; then
echo "Creating virtual environment..."
conda env create --name transformers --file $1
echo "Virtual environment created."
else
echo "Virtual environment already exists."
if [ $SKIP_CONDA -eq 1 ]; then
echo "Skipping conda environment installation"
else
export PATH=$PATH:/opt/conda/bin
if conda_env_exists "transformers" ; then
echo "Creating virtual environment..."
conda env create --name transformers --file $1
echo "Virtual environment created."
else
echo "Virtual environment already exists."
fi
fi
if [ -d "/opt/intel" ]; then
# Intel GPU: If the directory exists, we assume we are using the intel image
# (no conda env)
# https://github.com/intel/intel-extension-for-pytorch/issues/538
pip install intel-extension-for-transformers datasets sentencepiece tiktoken neural_speed
fi
if [ "$PIP_CACHE_PURGE" = true ] ; then
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
if [ $SKIP_CONDA -eq 0 ]; then
# Activate conda environment
source activate transformers
fi
pip cache purge
fi

View File

@@ -30,12 +30,14 @@ dependencies:
- async-timeout==4.0.3
- attrs==23.1.0
- bark==0.1.5
- bitsandbytes==0.43.0
- boto3==1.28.61
- botocore==1.31.61
- certifi==2023.7.22
- TTS==0.22.0
- charset-normalizer==3.3.0
- datasets==2.14.5
- sentence-transformers==2.2.2
- sentence-transformers==2.5.1 # Updated Version
- sentencepiece==0.1.99
- dill==0.3.7
- einops==0.7.0
@@ -80,8 +82,8 @@ dependencies:
- requests==2.31.0
- rouge==1.0.1
- s3transfer==0.7.0
- safetensors==0.3.3
- scipy==1.11.3
- safetensors>=0.4.1
- scipy==1.12.0 # Updated Version
- six==1.16.0
- sympy==1.12
- tokenizers
@@ -112,7 +114,7 @@ dependencies:
- sudachipy
- sudachidict_core
- vocos
- vllm==0.2.7
- transformers>=4.36.0 # Required for Mixtral.
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers

View File

@@ -0,0 +1,109 @@
name: transformers
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.08.22=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.11=h7f8727e_2
- pip=23.2.1=py311h06a4308_0
- python=3.11.5=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.0.0=py311h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.41.2=py311h06a4308_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- --pre
- --extra-index-url https://download.pytorch.org/whl/nightly/
- accelerate==0.23.0
- aiohttp==3.8.5
- aiosignal==1.3.1
- async-timeout==4.0.3
- attrs==23.1.0
- bark==0.1.5
- boto3==1.28.61
- botocore==1.31.61
- certifi==2023.7.22
- TTS==0.22.0
- charset-normalizer==3.3.0
- datasets==2.14.5
- sentence-transformers==2.5.1 # Updated Version
- sentencepiece==0.1.99
- dill==0.3.7
- einops==0.7.0
- encodec==0.1.1
- filelock==3.12.4
- frozenlist==1.4.0
- fsspec==2023.6.0
- funcy==2.0
- grpcio==1.59.0
- huggingface-hub
- idna==3.4
- jinja2==3.1.2
- jmespath==1.0.1
- markupsafe==2.1.3
- mpmath==1.3.0
- multidict==6.0.4
- multiprocess==0.70.15
- networkx
- numpy==1.26.0
- packaging==23.2
- pandas
- peft==0.5.0
- protobuf==4.24.4
- psutil==5.9.5
- pyarrow==13.0.0
- python-dateutil==2.8.2
- pytz==2023.3.post1
- pyyaml==6.0.1
- regex==2023.10.3
- requests==2.31.0
- rouge==1.0.1
- s3transfer==0.7.0
- safetensors>=0.4.1
- scipy==1.12.0 # Updated Version
- six==1.16.0
- sympy==1.12
- tokenizers
- torch
- torchaudio
- tqdm==4.66.1
- triton==2.1.0
- typing-extensions==4.8.0
- tzdata==2023.3
- auto-gptq==0.6.0
- urllib3==1.26.17
- xxhash==3.4.1
- yarl==1.9.2
- soundfile
- langid
- wget
- unidecode
- pyopenjtalk-prebuilt
- pypinyin
- inflect
- cn2an
- jieba
- eng_to_ipa
- openai-whisper
- matplotlib
- gradio==3.41.2
- nltk
- sudachipy
- sudachidict_core
- vocos
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers

View File

@@ -36,7 +36,7 @@ dependencies:
- TTS==0.22.0
- charset-normalizer==3.3.0
- datasets==2.14.5
- sentence-transformers==2.2.2
- sentence-transformers==2.5.1 # Updated Version
- sentencepiece==0.1.99
- dill==0.3.7
- einops==0.7.0
@@ -69,8 +69,8 @@ dependencies:
- requests==2.31.0
- rouge==1.0.1
- s3transfer==0.7.0
- safetensors==0.3.3
- scipy==1.11.3
- safetensors>=0.4.1
- scipy==1.12.0 # Updated Version
- six==1.16.0
- sympy==1.12
- tokenizers
@@ -101,7 +101,7 @@ dependencies:
- sudachipy
- sudachidict_core
- vocos
- vllm==0.2.7
- transformers>=4.36.0 # Required for Mixtral.
- vllm==0.3.2
- transformers>=4.38.2 # Updated Version
- xformers==0.0.23.post1
prefix: /opt/conda/envs/transformers
prefix: /opt/conda/envs/transformers

View File

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,15 @@
CONDA_ENV_PATH = "diffusers.yml"
export CONDA_ENV_PATH = "diffusers.yml"
ifeq ($(BUILD_TYPE), hipblas)
export CONDA_ENV_PATH = "diffusers-rocm.yml"
endif
# Intel GPU are supposed to have dependencies installed in the main python
# environment, so we skip conda installation for SYCL builds.
# https://github.com/intel/intel-extension-for-pytorch/issues/538
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
export SKIP_CONDA=1
endif
.PHONY: diffusers
diffusers:
@@ -12,4 +23,4 @@ run:
@echo "Diffusers run."
test:
bash test.sh
bash test.sh

View File

@@ -21,14 +21,15 @@ from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipelin
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image,export_to_video
from compel import Compel
from compel import Compel, ReturnedEmbeddingsType
from transformers import CLIPTextModel
from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL=os.environ.get("COMPEL", "1") == "1"
COMPEL=os.environ.get("COMPEL", "0") == "1"
XPU=os.environ.get("XPU", "0") == "1"
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
@@ -36,6 +37,10 @@ FPS=os.environ.get("FPS", "7")
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
FRAMES=os.environ.get("FRAMES", "64")
if XPU:
import intel_extension_for_pytorch as ipex
print(ipex.xpu.get_device_name(0))
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
@@ -231,8 +236,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.SchedulerType != "":
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
if not self.img2vid:
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
if COMPEL:
self.compel = Compel(
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2 ],
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
if request.ControlNet:
@@ -247,6 +257,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.pipe.to('cuda')
if self.controlnet:
self.controlnet.to('cuda')
if XPU:
self.pipe = self.pipe.to("xpu")
# Assume directory from request.ModelFile.
# Only if request.LoraAdapter it's not an absolute path
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
@@ -386,8 +398,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
image = {}
if COMPEL:
conditioning = self.compel.build_conditioning_tensor(prompt)
kwargs["prompt_embeds"]= conditioning
conditioning, pooled = self.compel.build_conditioning_tensor(prompt)
kwargs["prompt_embeds"] = conditioning
kwargs["pooled_prompt_embeds"] = pooled
# pass the kwargs dictionary to the self.pipe method
image = self.pipe(
guidance_scale=self.cfg_scale,

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,64 @@
name: diffusers
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.08.22=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.11=h7f8727e_2
- pip=23.2.1=py311h06a4308_0
- python=3.11.5=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.0.0=py311h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- tzdata=2023c=h04d1e81_0
- wheel=0.41.2=py311h06a4308_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- --pre
- --extra-index-url https://download.pytorch.org/whl/nightly/
- accelerate>=0.11.0
- certifi==2023.7.22
- charset-normalizer==3.3.0
- compel==2.0.2
- diffusers==0.24.0
- filelock==3.12.4
- fsspec==2023.9.2
- grpcio==1.59.0
- huggingface-hub>=0.19.4
- idna==3.4
- importlib-metadata==6.8.0
- jinja2==3.1.2
- markupsafe==2.1.3
- mpmath==1.3.0
- networkx==3.1
- numpy==1.26.0
- omegaconf
- packaging==23.2
- pillow==10.0.1
- protobuf==4.24.4
- psutil==5.9.5
- pyparsing==3.1.1
- pyyaml==6.0.1
- regex==2023.10.3
- requests==2.31.0
- safetensors==0.4.0
- sympy==1.12
- tqdm==4.66.1
- transformers>=4.25.1
- triton==2.1.0
- typing-extensions==4.8.0
- urllib3==2.0.6
- zipp==3.17.0
- torch
prefix: /opt/conda/envs/diffusers

View File

@@ -71,4 +71,4 @@ dependencies:
- typing-extensions==4.8.0
- urllib3==2.0.6
- zipp==3.17.0
prefix: /opt/conda/envs/diffusers
prefix: /opt/conda/envs/diffusers

View File

@@ -1,24 +1,50 @@
#!/bin/bash
set -ex
SKIP_CONDA=${SKIP_CONDA:-0}
# Check if environment exist
conda_env_exists(){
! conda list --name "${@}" >/dev/null 2>/dev/null
}
if conda_env_exists "diffusers" ; then
echo "Creating virtual environment..."
conda env create --name diffusers --file $1
echo "Virtual environment created."
else
echo "Virtual environment already exists."
if [ $SKIP_CONDA -eq 1 ]; then
echo "Skipping conda environment installation"
else
export PATH=$PATH:/opt/conda/bin
if conda_env_exists "diffusers" ; then
echo "Creating virtual environment..."
conda env create --name diffusers --file $1
echo "Virtual environment created."
else
echo "Virtual environment already exists."
fi
fi
if [ -d "/opt/intel" ]; then
# Intel GPU: If the directory exists, we assume we are using the Intel image
# https://github.com/intel/intel-extension-for-pytorch/issues/538
pip install torch==2.1.0a0 \
torchvision==0.16.0a0 \
torchaudio==2.1.0a0 \
intel-extension-for-pytorch==2.1.10+xpu \
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install google-api-python-client \
grpcio \
grpcio-tools \
diffusers==0.24.0 \
transformers>=4.25.1 \
accelerate \
compel==2.0.2 \
Pillow
fi
if [ "$PIP_CACHE_PURGE" = true ] ; then
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate diffusers
if [ $SKIP_CONDA -ne 1 ]; then
# Activate conda environment
source activate diffusers
fi
pip cache purge
fi

View File

@@ -3,10 +3,15 @@
##
## A bash script wrapper that runs the diffusers server with conda
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate diffusers
if [ -d "/opt/intel" ]; then
# Assumes we are using the Intel oneAPI container image
# https://github.com/intel/intel-extension-for-pytorch/issues/538
export XPU=1
else
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate diffusers
fi
# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

View File

@@ -1,7 +1,8 @@
export CONDA_ENV_PATH = "exllama.yml"
.PHONY: exllama
exllama:
$(MAKE) -C ../common-env/transformers
bash install.sh
bash install.sh ${CONDA_ENV_PATH}
.PHONY: run
run:

View File

File diff suppressed because one or more lines are too long

View File

@@ -1,14 +1,27 @@
#!/bin/bash
set -ex
##
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
if [ "$BUILD_TYPE" != "cublas" ]; then
echo "[exllama] Attention!!! Nvidia GPU is required - skipping installation"
exit 0
fi
echo $CONDA_PREFIX
# Check if environment exist
conda_env_exists(){
! conda list --name "${@}" >/dev/null 2>/dev/null
}
if conda_env_exists "exllama" ; then
echo "Creating virtual environment..."
conda env create --name exllama --file $1
echo "Virtual environment created."
else
echo "Virtual environment already exists."
fi
source activate exllama
git clone https://github.com/turboderp/exllama $CONDA_PREFIX/exllama && pushd $CONDA_PREFIX/exllama && pip install -r requirements.txt && popd

View File

@@ -2,11 +2,10 @@
##
## A bash script wrapper that runs the exllama server with conda
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
source activate exllama
# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

View File

File diff suppressed because one or more lines are too long

View File

@@ -2,10 +2,14 @@
set -e
##
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
export PATH=$PATH:/opt/conda/bin
export SHA=c0ddebaaaf8ffd1b3529c2bb654e650bce2f790f
# Activate conda environment
if [ "$BUILD_TYPE" != "cublas" ]; then
echo "[exllamav2] Attention!!! Nvidia GPU is required - skipping installation"
exit 0
fi
export PATH=$PATH:/opt/conda/bin
source activate transformers
echo $CONDA_PREFIX

View File

File diff suppressed because one or more lines are too long

View File

@@ -2,13 +2,14 @@
set -e
##
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
export PATH=$PATH:/opt/conda/bin
if [ "$BUILD_TYPE" != "cublas" ]; then
echo "[mamba] Attention!!! nvcc is required - skipping installation"
exit 0
fi
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers

View File

@@ -1,7 +1,7 @@
.PHONY: petals
petals:
@echo "Creating virtual environment..."
@conda env create --name petals --file petals.yml
bash install.sh "petals.yml"
@echo "Virtual environment created."
.PHONY: run

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,5 @@
#!/bin/bash
export PATH=$PATH:/opt/conda/bin
conda env create --name petals --file $1

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

@@ -3,10 +3,16 @@
##
## A bash script wrapper that runs the transformers server with conda
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
if [ -d "/opt/intel" ]; then
# Assumes we are using the Intel oneAPI container image
# https://github.com/intel/intel-extension-for-pytorch/issues/538
export XPU=1
else
export PATH=$PATH:/opt/conda/bin
# Activate conda environment
source activate transformers
fi
# get the directory where the bash script is located
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

View File

@@ -16,7 +16,15 @@ import backend_pb2_grpc
import grpc
import torch
import torch.cuda
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
XPU=os.environ.get("XPU", "0") == "1"
if XPU:
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel, set_seed
else:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -67,22 +75,60 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
A Result object that contains the result of the LoadModel operation.
"""
model_name = request.Model
compute = "auto"
if request.F16Memory == True:
compute=torch.bfloat16
self.CUDA = request.CUDA
device_map="cpu"
quantization = None
if self.CUDA:
if request.Device:
device_map=request.Device
else:
device_map="cuda:0"
if request.Quantization == "bnb_4bit":
quantization = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_compute_dtype = compute,
bnb_4bit_quant_type = "nf4",
bnb_4bit_use_double_quant = True,
load_in_8bit = False,
)
elif request.Quantization == "bnb_8bit":
quantization = BitsAndBytesConfig(
load_in_4bit=False,
bnb_4bit_compute_dtype = None,
load_in_8bit=True,
)
try:
if request.Type == "AutoModelForCausalLM":
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
if XPU:
if quantization == "xpu_4bit":
xpu_4bit = True
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode,
device_map="xpu", load_in_4bit=xpu_4bit)
else:
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute)
else:
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
self.XPU = False
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.CUDA = False
if request.CUDA or torch.cuda.is_available():
if XPU:
self.XPU = True
try:
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
self.model = self.model.to("cuda")
self.CUDA = True
print("Optimizing model", model_name, "to XPU.", file=sys.stderr)
self.model = ipex.optimize_transformers(self.model, inplace=True, dtype=torch.float16, device="xpu")
except Exception as err:
print("Not using CUDA:", err, file=sys.stderr)
print("Not using XPU:", err, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
@@ -109,13 +155,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
# Create word embeddings
model_output = self.model(**encoded_input)
if self.CUDA:
encoded_input = encoded_input.to("cuda")
with torch.no_grad():
model_output = self.model(**encoded_input)
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
def Predict(self, request, context):
"""
@@ -139,13 +189,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
inputs = self.tokenizer(request.Prompt, return_tensors="pt").input_ids
if self.CUDA:
inputs = inputs.to("cuda")
if XPU:
inputs = inputs.to("xpu")
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))

View File

@@ -1,3 +1,7 @@
ifneq (,$(findstring sycl,$(BUILD_TYPE)))
export SKIP_CONDA=1
endif
.PHONY: ttsvalle
ttsvalle:
$(MAKE) -C ../common-env/transformers

View File

File diff suppressed because one or more lines are too long

View File

@@ -2,13 +2,16 @@
##
## A bash script installs the required dependencies of VALL-E-X and prepares the environment
export PATH=$PATH:/opt/conda/bin
export SHA=3faaf8ccadb154d63b38070caf518ce9309ea0f4
# Activate conda environment
source activate transformers
SKIP_CONDA=${SKIP_CONDA:-0}
echo $CONDA_PREFIX
if [ $SKIP_CONDA -ne 1 ]; then
source activate transformers
else
export PATH=$PATH:/opt/conda/bin
CONDA_PREFIX=$PWD
fi
git clone https://github.com/Plachtaa/VALL-E-X.git $CONDA_PREFIX/vall-e-x && pushd $CONDA_PREFIX/vall-e-x && git checkout -b build $SHA && popd

View File

@@ -79,7 +79,7 @@ dependencies:
- pypinyin==0.49.0
- python-multipart==0.0.6
- regex==2023.10.3
- safetensors==0.4.0
- safetensors>=0.4.0
- semantic-version==2.10.0
- soundfile==0.12.1
- starlette==0.27.0

View File

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
import asyncio
from concurrent import futures
import time
import argparse
import signal
import sys
@@ -10,7 +10,10 @@ import backend_pb2
import backend_pb2_grpc
import grpc
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
@@ -79,16 +82,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The load model result.
"""
engine_args = AsyncEngineArgs(
model=request.Model,
)
if request.Quantization != "":
engine_args.quantization = request.Quantization
if request.GPUMemoryUtilization != 0:
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
if request.TrustRemoteCode:
engine_args.trust_remote_code = request.TrustRemoteCode
if request.EnforceEager:
engine_args.enforce_eager = request.EnforceEager
if request.SwapSpace != 0:
engine_args.swap_space = request.SwapSpace
if request.MaxModelLen != 0:
engine_args.max_model_len = request.MaxModelLen
try:
if request.Quantization != "":
self.llm = LLM(model=request.Model, quantization=request.Quantization)
else:
self.llm = LLM(model=request.Model)
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
def Predict(self, request, context):
async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
@@ -99,24 +116,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Reply: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9
gen = self._predict(request, context, streaming=False)
res = await gen.__anext__()
return res
max_tokens = 200
if request.Tokens > 0:
max_tokens = request.Tokens
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
outputs = self.llm.generate([request.Prompt], sampling_params)
generated_text = outputs[0].outputs[0].text
# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
def PredictStream(self, request, context):
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
@@ -127,30 +131,84 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
Returns:
backend_pb2.Result: The predict stream result.
"""
yield self.Predict(request, context)
iterations = self._predict(request, context, streaming=True)
try:
async for iteration in iterations:
yield iteration
finally:
await iterations.aclose()
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
async def _predict(self, request, context, streaming=False):
# Build sampling parameters
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
if request.TopP != 0:
sampling_params.top_p = request.TopP
if request.Tokens > 0:
sampling_params.max_tokens = request.Tokens
if request.Temperature != 0:
sampling_params.temperature = request.Temperature
if request.TopK != 0:
sampling_params.top_k = request.TopK
if request.PresencePenalty != 0:
sampling_params.presence_penalty = request.PresencePenalty
if request.FrequencyPenalty != 0:
sampling_params.frequency_penalty = request.FrequencyPenalty
if request.StopPrompts:
sampling_params.stop = request.StopPrompts
if request.IgnoreEOS:
sampling_params.ignore_eos = request.IgnoreEOS
if request.Seed != 0:
sampling_params.seed = request.Seed
# Generate text
request_id = random_uuid()
outputs = self.llm.generate(request.Prompt, sampling_params, request_id)
# Stream the results
generated_text = ""
try:
async for request_output in outputs:
iteration_text = request_output.outputs[0].text
if streaming:
# Remove text already sent as vllm concatenates the text from previous yields
delta_iteration_text = iteration_text.removeprefix(generated_text)
# Send the partial result
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
# Keep track of text generated
generated_text = iteration_text
finally:
await outputs.aclose()
# If streaming, we already sent everything
if streaming:
return
# Sending the final generated text
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
server.start()
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
@@ -159,4 +217,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
serve(args.addr)
asyncio.run(serve(args.addr))

0
configuration/.keep Normal file
View File

View File

@@ -3,36 +3,32 @@ package backend
import (
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
grpcOpts := gRPCModelOpts(backendConfig)
var inferenceModel interface{}
var err error
opts := modelOpts(c, o, []model.Option{
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithContext(appConfig.Context),
})
if c.Backend == "" {
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(c.Backend))
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
@@ -43,7 +39,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
@@ -52,7 +48,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
@@ -61,7 +57,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.Embeddings = s
res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}

52
core/backend/image.go Normal file
View File

@@ -0,0 +1,52 @@
package backend
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
threads := backendConfig.Threads
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(*threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(gRPCOpts),
})
inferenceModel, err := loader.BackendLoader(
opts...,
)
if err != nil {
return nil, err
}
fn := func() error {
_, err := inferenceModel.GenerateImage(
appConfig.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: backendConfig.Diffusers.EnableParameters,
})
return err
}
return fn, nil
}

View File

@@ -8,8 +8,8 @@ import (
"sync"
"unicode/utf8"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
@@ -26,9 +26,12 @@ type TokenUsage struct {
Completion int
}
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
threads := c.Threads
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := gRPCModelOpts(c)
var inferenceModel grpc.Backend
@@ -36,7 +39,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
opts := modelOpts(c, o, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
@@ -140,7 +143,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.Config, input, prediction string) string {
func Finetune(config config.BackendConfig, input, prediction string) string {
if config.Echo {
prediction = input + prediction
}

141
core/backend/options.go Normal file
View File

@@ -0,0 +1,141 @@
package backend
import (
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/core/config"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend())
}
if so.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests)
}
if c.GRPC.Attempts != 0 {
opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts))
}
if c.GRPC.AttemptsSleepTime != 0 {
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
}
for k, v := range so.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
return opts
}
func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}
return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
F16Memory: *c.F16,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(*c.ContextSize),
Seed: int32(*c.Seed),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
DraftModel: c.DraftModel,
AudioPath: c.VallE.AudioPath,
Quantization: c.Quantization,
GPUMemoryUtilization: c.GPUMemoryUtilization,
TrustRemoteCode: c.TrustRemoteCode,
EnforceEager: c.EnforceEager,
SwapSpace: int32(c.SwapSpace),
MaxModelLen: int32(c.MaxModelLen),
MMProj: c.MMProj,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
YarnBetaFast: c.YarnBetaFast,
YarnBetaSlow: c.YarnBetaSlow,
NGQA: c.NGQA,
RMSNormEps: c.RMSNormEps,
MLock: *c.MMlock,
RopeFreqBase: c.RopeFreqBase,
RopeScaling: c.RopeScaling,
Type: c.ModelType,
RopeFreqScale: c.RopeFreqScale,
NUMA: c.NUMA,
Embeddings: c.Embeddings,
LowVRAM: *c.LowVRAM,
NGPULayers: int32(*c.NGPULayers),
MMap: *c.MMap,
MainGPU: c.MainGPU,
Threads: int32(*c.Threads),
TensorSplit: c.TensorSplit,
// AutoGPTQ
ModelBaseName: c.AutoGPTQ.ModelBaseName,
Device: c.AutoGPTQ.Device,
UseTriton: c.AutoGPTQ.Triton,
UseFastTokenizer: c.AutoGPTQ.UseFastTokenizer,
// RWKV
Tokenizer: c.Tokenizer,
}
}
func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
promptCachePath := ""
if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
promptCachePath = p
}
return &pb.PredictOptions{
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
TopK: int32(*c.TopK),
Tokens: int32(*c.Maxtokens),
Threads: int32(*c.Threads),
PromptCacheAll: c.PromptCacheAll,
PromptCacheRO: c.PromptCacheRO,
PromptCachePath: promptCachePath,
F16KV: *c.F16,
DebugMode: *c.Debug,
Grammar: c.Grammar,
NegativePromptScale: c.NegativePromptScale,
RopeFreqBase: c.RopeFreqBase,
RopeFreqScale: c.RopeFreqScale,
NegativePrompt: c.NegativePrompt,
Mirostat: int32(*c.LLMConfig.Mirostat),
MirostatETA: float32(*c.LLMConfig.MirostatETA),
MirostatTAU: float32(*c.LLMConfig.MirostatTAU),
Debug: *c.Debug,
StopPrompts: c.StopWords,
Repeat: int32(c.RepeatPenalty),
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(*c.Seed),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: *c.MMlock,
MMap: *c.MMap,
MainGPU: c.MainGPU,
TensorSplit: c.TensorSplit,
TailFreeSamplingZ: float32(c.TFZ),
TypicalP: float32(c.TypicalP),
}
}

View File

@@ -0,0 +1,38 @@
package backend
import (
"context"
"fmt"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})
whisperModel, err := ml.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(*backendConfig.Threads),
})
}

89
core/backend/tts.go Normal file
View File

@@ -0,0 +1,89 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
}
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
})
return filePath, res, err
}

View File

@@ -1,4 +1,4 @@
package options
package config
import (
"context"
@@ -6,27 +6,25 @@ import (
"encoding/json"
"time"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
type Option struct {
type ApplicationConfig struct {
Context context.Context
ConfigFile string
Loader *model.ModelLoader
ModelPath string
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug, DisableMessage bool
ImageDir string
AudioDir string
UploadDir string
CORS bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
Metrics *metrics.Metrics
ModelLibraryURL string
@@ -51,10 +49,10 @@ type Option struct {
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
}
type AppOption func(*Option)
type AppOption func(*ApplicationConfig)
func NewOptions(o ...AppOption) *Option {
opt := &Option{
func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
opt := &ApplicationConfig{
Context: context.Background(),
UploadLimitMB: 15,
Threads: 1,
@@ -69,63 +67,69 @@ func NewOptions(o ...AppOption) *Option {
}
func WithModelsURL(urls ...string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ModelsURL = urls
}
}
func WithModelPath(path string) AppOption {
return func(o *ApplicationConfig) {
o.ModelPath = path
}
}
func WithCors(b bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.CORS = b
}
}
func WithModelLibraryURL(url string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ModelLibraryURL = url
}
}
var EnableWatchDog = func(o *Option) {
var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}
var EnableWatchDogIdleCheck = func(o *Option) {
var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogIdle = true
}
var EnableWatchDogBusyCheck = func(o *Option) {
var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogBusy = true
}
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.WatchDogBusyTimeout = t
}
}
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.WatchDogIdleTimeout = t
}
}
var EnableSingleBackend = func(o *Option) {
var EnableSingleBackend = func(o *ApplicationConfig) {
o.SingleBackend = true
}
var EnableParallelBackendRequests = func(o *Option) {
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
o.ParallelBackendRequests = true
}
var EnableGalleriesAutoload = func(o *Option) {
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
@@ -134,27 +138,26 @@ func WithExternalBackend(name string, uri string) AppOption {
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.CORSAllowOrigins = b
}
}
func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.AssetsDestination = out
}
}
func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.BackendAssets = f
}
}
func WithStringGalleries(galls string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{}
return
}
@@ -167,96 +170,96 @@ func WithStringGalleries(galls string) AppOption {
}
func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Galleries = append(o.Galleries, galleries...)
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Context = ctx
}
}
func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.PreloadModelsFromPath = configFile
}
}
func WithJSONStringPreload(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.PreloadJSONModels = configFile
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ConfigFile = configFile
}
}
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.Loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.UploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ContextSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.F16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.Debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.DisableMessage = disableMessage
}
}
func WithAudioDir(audioDir string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.AudioDir = audioDir
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ImageDir = imageDir
}
}
func WithUploadDir(uploadDir string) AppOption {
return func(o *ApplicationConfig) {
o.UploadDir = uploadDir
}
}
func WithApiKeys(apiKeys []string) AppOption {
return func(o *Option) {
return func(o *ApplicationConfig) {
o.ApiKeys = apiKeys
}
}
func WithMetrics(meter *metrics.Metrics) AppOption {
return func(o *Option) {
o.Metrics = meter
}
}
// func WithMetrics(meter *metrics.Metrics) AppOption {
// return func(o *StartupOptions) {
// o.Metrics = meter
// }
// }

View File

@@ -1,27 +1,31 @@
package api_config
package config
import (
"errors"
"fmt"
"io/fs"
"math/rand"
"os"
"path/filepath"
"strings"
"sync"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
"github.com/charmbracelet/glamour"
)
type Config struct {
PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
F16 *bool `yaml:"f16"`
Threads *int `yaml:"threads"`
Debug *bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
@@ -104,29 +108,34 @@ type LLMConfig struct {
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
MMap bool `yaml:"mmap"`
MMlock bool `yaml:"mmlock"`
LowVRAM bool `yaml:"low_vram"`
MirostatETA *float64 `yaml:"mirostat_eta"`
MirostatTAU *float64 `yaml:"mirostat_tau"`
Mirostat *int `yaml:"mirostat"`
NGPULayers *int `yaml:"gpu_layers"`
MMap *bool `yaml:"mmap"`
MMlock *bool `yaml:"mmlock"`
LowVRAM *bool `yaml:"low_vram"`
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
ContextSize int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
MMProj string `yaml:"mmproj"`
ContextSize *int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len"` // vLLM
MMProj string `yaml:"mmproj"`
RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`
@@ -148,6 +157,7 @@ type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}
type TemplateConfig struct {
@@ -158,108 +168,209 @@ type TemplateConfig struct {
Functions string `yaml:"function"`
}
type ConfigLoader struct {
configs map[string]Config
sync.Mutex
}
func (c *Config) SetFunctionCallString(s string) {
func (c *BackendConfig) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *Config) SetFunctionCallNameString(s string) {
func (c *BackendConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *Config) ShouldUseFunctions() bool {
func (c *BackendConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *Config) ShouldCallSpecificFunction() bool {
func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
func (c *Config) FunctionToCall() string {
func (c *BackendConfig) FunctionToCall() string {
return c.functionCallNameString
}
// Load a config file for a model
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml")
func (cfg *BackendConfig) SetDefaults(debug bool, threads, ctx int, f16 bool) {
defaultTopP := 0.7
defaultTopK := 80
defaultTemp := 0.9
defaultMaxTokens := 2048
defaultMirostat := 2
defaultMirostatTAU := 5.0
defaultMirostatETA := 0.1
var cfg *Config
// Try to offload all GPU layers (if GPU is found)
defaultNGPULayers := 99999999
defaults := func() {
cfg = DefaultConfig(modelName)
cfg.ContextSize = ctx
cfg.Threads = threads
cfg.F16 = f16
cfg.Debug = debug
trueV := true
falseV := false
if cfg.Seed == nil {
// random number generator seed
defaultSeed := int(rand.Int31())
cfg.Seed = &defaultSeed
}
cfgExisting, exists := cm.GetConfig(modelName)
if !exists {
if cfg.TopK == nil {
cfg.TopK = &defaultTopK
}
if cfg.MMap == nil {
// MMap is enabled by default
cfg.MMap = &trueV
}
if cfg.MMlock == nil {
// MMlock is disabled by default
cfg.MMlock = &falseV
}
if cfg.TopP == nil {
cfg.TopP = &defaultTopP
}
if cfg.Temperature == nil {
cfg.Temperature = &defaultTemp
}
if cfg.Maxtokens == nil {
cfg.Maxtokens = &defaultMaxTokens
}
if cfg.Mirostat == nil {
cfg.Mirostat = &defaultMirostat
}
if cfg.MirostatETA == nil {
cfg.MirostatETA = &defaultMirostatETA
}
if cfg.MirostatTAU == nil {
cfg.MirostatTAU = &defaultMirostatTAU
}
if cfg.NGPULayers == nil {
cfg.NGPULayers = &defaultNGPULayers
}
if cfg.LowVRAM == nil {
cfg.LowVRAM = &falseV
}
// Value passed by the top level are treated as default (no implicit defaults)
// defaults are set by the user
if ctx == 0 {
ctx = 1024
}
if cfg.ContextSize == nil {
cfg.ContextSize = &ctx
}
if threads == 0 {
// Threads can't be 0
threads = 4
}
if cfg.Threads == nil {
cfg.Threads = &threads
}
if cfg.F16 == nil {
cfg.F16 = &f16
}
if debug {
cfg.Debug = &debug
}
}
////// Config Loader ////////
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
type LoadOptions struct {
debug bool
threads, ctxSize int
f16 bool
}
func LoadOptionDebug(debug bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.debug = debug
}
}
func LoadOptionThreads(threads int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.threads = threads
}
}
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.ctxSize = ctxSize
}
}
func LoadOptionF16(f16 bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.f16 = f16
}
}
type ConfigLoaderOption func(*LoadOptions)
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
for _, l := range options {
l(lo)
}
}
// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
if err := cl.LoadBackendConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cm.GetConfig(modelName)
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
defaults()
}
} else {
defaults()
}
} else {
cfg = &cfgExisting
}
// Set the parameters for the language model prediction
//updateConfig(cfg, input)
// Don't allow 0 as setting
if cfg.Threads == 0 {
if threads != 0 {
cfg.Threads = threads
} else {
cfg.Threads = 4
}
}
// Enforce debug flag if passed from CLI
if debug {
cfg.Debug = true
}
cfg.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
return cfg, nil
}
func defaultPredictOptions(modelFile string) PredictionOptions {
return PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
func DefaultConfig(modelFile string) *Config {
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
@@ -268,11 +379,18 @@ func ReadConfigFile(file string) ([]*Config, error) {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
}
return *c, nil
}
func ReadConfig(file string) (*Config, error) {
c := &Config{}
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
@@ -281,13 +399,14 @@ func ReadConfig(file string) (*Config, error) {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(lo.debug, lo.threads, lo.ctxSize, lo.f16)
return c, nil
}
func (cm *ConfigLoader) LoadConfigFile(file string) error {
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
c, err := ReadBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
@@ -298,49 +417,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error {
return nil
}
func (cm *ConfigLoader) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm.configs[c.Name] = *c
cl.configs[c.Name] = *c
return nil
}
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
return v, exists
}
func (cm *ConfigLoader) GetAllConfigs() []Config {
cm.Lock()
defer cm.Unlock()
var res []Config
for _, v := range cm.configs {
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cl.Lock()
defer cl.Unlock()
var res []string
for k := range cm.configs {
for k := range cl.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
@@ -348,7 +467,21 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
log.Info().Msgf("Preloading models from %s", modelPath)
for i, config := range cm.configs {
renderMode := "dark"
if os.Getenv("COLOR") != "" {
renderMode = os.Getenv("COLOR")
}
glamText := func(t string) {
out, err := glamour.Render(t, renderMode)
if err == nil && os.Getenv("NO_COLOR") == "" {
fmt.Println(out)
} else {
fmt.Println(t)
}
}
for i, config := range cl.configs {
// Download files and verify their SHA
for _, file := range config.DownloadFiles {
@@ -380,25 +513,29 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
}
}
cc := cm.configs[i]
cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
cm.configs[i] = *c
cl.configs[i] = *c
}
if cm.configs[i].Name != "" {
log.Info().Msgf("Model name: %s", cm.configs[i].Name)
if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
}
if cm.configs[i].Description != "" {
log.Info().Msgf("Model description: %s", cm.configs[i].Description)
if cl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
}
if cm.configs[i].Usage != "" {
log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage)
if cl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
}
}
return nil
}
func (cm *ConfigLoader) LoadConfigs(path string) error {
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
@@ -418,7 +555,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
if err == nil {
cm.configs[c.Name] = *c
}

View File

@@ -1,11 +1,10 @@
package api_config_test
package config_test
import (
"os"
. "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/model"
. "github.com/go-skynet/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() {
config, err := ReadConfigFile(configFile)
config, err := ReadBackendConfigFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@@ -28,29 +27,26 @@ var _ = Describe("Test cases for config related functions", func() {
})
It("Test LoadConfigs", func() {
cm := NewConfigLoader()
opts := options.NewOptions()
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
options.WithModelLoader(modelLoader)(opts)
err := cm.LoadConfigs(opts.Loader.ModelPath)
cm := NewBackendConfigLoader()
opts := NewApplicationConfig()
err := cm.LoadBackendConfigsFromPath(opts.ModelPath)
Expect(err).To(BeNil())
Expect(cm.ListConfigs()).ToNot(BeNil())
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
// config should includes gpt4all models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all"))
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all"))
// config should includes gpt2 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2"))
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2"))
// config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002"))
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
// config should includes rwkv_test models's api.config
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test"))
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
// config should includes whisper-1 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("whisper-1"))
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
})
})
})

242
core/http/api.go Normal file
View File

@@ -0,0 +1,242 @@
package http
import (
"encoding/json"
"errors"
"os"
"strings"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
)
func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")
// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}
return authHeader
}
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: appConfig.DisableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
},
})
if appConfig.Debug {
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
// Default middleware config
if !appConfig.Debug {
app.Use(recover.New())
}
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
// Check for api_keys.json file
fileContent, err := os.ReadFile("api_keys.json")
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
}
// Add file keys to options.ApiKeys
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
}
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}
authHeader := readAuthHeader(c)
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
}
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if appConfig.CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
}
app.Use(c)
}
// LocalAI API endpoints
galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(appConfig.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
// Load upload json
openai.LoadUploadConfig(appConfig.UploadDir)
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
}
if appConfig.AudioDir != "" {
app.Static("/generated-audio", appConfig.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// Experimental Backend Statistics Module
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/metrics", localai.LocalAIMetricsEndpoint())
return app, nil
}

View File

@@ -1,4 +1,4 @@
package api_test
package http_test
import (
"bytes"
@@ -13,9 +13,10 @@ import (
"path/filepath"
"runtime"
. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
@@ -127,25 +128,33 @@ var backendAssets embed.FS
var _ = Describe("API test", func() {
var app *fiber.App
var modelLoader *model.ModelLoader
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
var cancel context.CancelFunc
var tmpdir string
var modelDir string
var bcl *config.BackendConfigLoader
var ml *model.ModelLoader
var applicationConfig *config.ApplicationConfig
commonOpts := []options.AppOption{
options.WithDebug(true),
options.WithDisableMessage(true),
commonOpts := []config.AppOption{
config.WithDebug(true),
config.WithDisableMessage(true),
}
Context("API with ephemeral models", func() {
BeforeEach(func() {
BeforeEach(func(sc SpecContext) {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{
@@ -172,16 +181,18 @@ var _ = Describe("API test", func() {
},
}
metricsService, err := metrics.SetupMetrics()
bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts,
options.WithMetrics(metricsService),
options.WithContext(c),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -198,15 +209,21 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
AfterEach(func(sc SpecContext) {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadDir(tmpdir)
Expect(err).To(HaveOccurred())
})
Context("Applying models", func() {
It("applies models from a gallery", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
@@ -228,10 +245,10 @@ var _ = Describe("API test", func() {
}, "360s", "10s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml"))
_, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
@@ -253,6 +270,7 @@ var _ = Describe("API test", func() {
}
})
It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert",
@@ -270,7 +288,7 @@ var _ = Describe("API test", func() {
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
@@ -294,7 +312,7 @@ var _ = Describe("API test", func() {
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{}
@@ -368,7 +386,7 @@ var _ = Describe("API test", func() {
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
@@ -483,8 +501,11 @@ var _ = Describe("API test", func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{
@@ -494,21 +515,20 @@ var _ = Describe("API test", func() {
},
}
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts,
options.WithContext(c),
options.WithMetrics(metricsService),
options.WithAudioDir(tmpdir),
options.WithImageDir(tmpdir),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader),
options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(tmpdir))...,
config.WithContext(c),
config.WithAudioDir(tmpdir),
config.WithImageDir(tmpdir),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -527,8 +547,14 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadDir(tmpdir)
Expect(err).To(HaveOccurred())
})
It("installs and is capable to run tts", Label("tts"), func() {
if runtime.GOOS != "linux" {
@@ -599,20 +625,20 @@ var _ = Describe("API test", func() {
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
modelPath := os.Getenv("MODELS_PATH")
c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
var err error
app, err = App(
bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts,
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c),
options.WithModelLoader(modelLoader),
options.WithMetrics(metricsService),
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
config.WithContext(c),
config.WithModelPath(modelPath),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -630,7 +656,10 @@ var _ = Describe("API test", func() {
})
AfterEach(func() {
cancel()
app.Shutdown()
if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
})
It("returns the models list", func() {
models, err := client.ListModels(context.TODO())
@@ -811,20 +840,20 @@ var _ = Describe("API test", func() {
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
modelPath := os.Getenv("MODELS_PATH")
c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics()
Expect(err).ToNot(HaveOccurred())
app, err = App(
var err error
bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts,
options.WithContext(c),
options.WithMetrics(metricsService),
options.WithModelLoader(modelLoader),
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
config.WithContext(c),
config.WithModelPath(modelPath),
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -840,7 +869,10 @@ var _ = Describe("API test", func() {
})
AfterEach(func() {
cancel()
app.Shutdown()
if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
})
It("can generate chat completions from config file (list1)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})

View File

@@ -1,4 +1,4 @@
package api_test
package http_test
import (
"testing"

View File

View File

@@ -0,0 +1,55 @@
package elevenlabs
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
voiceID := c.Params("voice-id")
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View File

@@ -0,0 +1,36 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2"
)
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@@ -0,0 +1,146 @@
package localai
import (
"encoding/json"
"fmt"
"slices"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- gallery.GalleryOp{
Req: input.GalleryModel,
Id: uuid.String(),
GalleryName: input.ID,
Galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@@ -0,0 +1,43 @@
package localai
import (
"time"
"github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func LocalAIMetricsEndpoint() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}

View File

@@ -0,0 +1,55 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
if err != nil {
return err
}
return c.Download(filePath)
}
}

View File

@@ -0,0 +1,609 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
emptyMessage := ""
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
CompletionTokens: usage.Completion,
TotalTokens: usage.Prompt + usage.Completion,
},
}
responses <- resp
return true
})
close(responses)
}
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
responses <- resp
default:
for i, ss := range results {
name, args := ss.name, ss.arguments
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
}
}
close(responses)
}
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, ml, startupOptions, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Configuration read: %+v", config)
// Allow the user to set custom actions via config file
// to be "embedded" in each model
noActionName := "answer"
noActionDescription := "use this action to answer without performing any action"
if config.FunctionsConfig.NoActionFunctionName != "" {
noActionName = config.FunctionsConfig.NoActionFunctionName
}
if config.FunctionsConfig.NoActionDescriptionName != "" {
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
if input.ResponseFormat.Type == "json_object" {
input.Grammar = grammar.JSONBNF
}
// process functions if we have any defined or if we have a function call string
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
log.Debug().Msgf("Response needs to process functions")
processFunctions = true
noActionGrammar := grammar.Function{
Name: noActionName,
Description: noActionDescription,
Parameters: map[string]interface{}{
"properties": map[string]interface{}{
"message": map[string]interface{}{
"type": "string",
"description": "The message to reply the user with",
}},
},
}
// Append the no action function
funcs = append(funcs, input.Functions...)
if !config.FunctionsConfig.DisableNoAction {
funcs = append(funcs, noActionGrammar)
}
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
}
// functions are not supported in stream mode (yet?)
toStream := input.Stream
log.Debug().Msgf("Parameters: %+v", config)
var predInput string
suppressConfigSystemPrompt := false
mess := []string{}
for messageIndex, i := range input.Messages {
var content string
role := i.Role
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
if i.FunctionCall != nil && i.Role == "assistant" {
roleFn := "assistant_function_call"
r := config.Roles[roleFn]
if r != "" {
role = roleFn
}
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""
// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionName: i.Name,
MessageIndex: messageIndex,
}
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else {
if templatedChatMessage == "" {
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
}
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
content = templatedChatMessage
}
}
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
if content == "" {
if r != "" {
if contentExists {
content = fmt.Sprint(r, i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + fmt.Sprint(r, " ", string(j))
} else {
content = fmt.Sprint(r, " ", string(j))
}
}
}
} else {
if contentExists {
content = fmt.Sprint(i.StringContent)
}
if i.FunctionCall != nil {
j, err := json.Marshal(i.FunctionCall)
if err == nil {
if contentExists {
content += "\n" + string(j)
} else {
content = string(j)
}
}
}
}
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
if contentExists && role == "system" {
suppressConfigSystemPrompt = true
}
}
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
log.Debug().Msgf("Prompt (before templating): %s", predInput)
if toStream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Functions != "" && processFunctions {
templateFile = config.TemplateConfig.Functions
}
if templateFile != "" {
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}
log.Debug().Msgf("Prompt (after templating): %s", predInput)
if processFunctions {
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}
switch {
case toStream:
responses := make(chan schema.OpenAIResponse)
if !processFunctions {
go process(predInput, input, config, ml, responses)
} else {
go processTools(noActionName, predInput, input, config, ml, responses)
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
finishReason := "stop"
if toolsCalled {
finishReason = "tool_calls"
} else if toolsCalled && len(input.Tools) == 0 {
finishReason = "function_call"
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return nil
// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !processFunctions {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName
switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}
if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
}, nil)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(resp)
}
}
}
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, prompt, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
return message, nil
}
}
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU/GPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return "", err
}
return backend.Finetune(*config, prompt, prediction.Response), nil
}
type funcCallResults struct {
name string
arguments string
}
func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}
// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
for _, s := range ss {
func_name, ok := s["function"]
if !ok {
continue
}
args, ok := s["arguments"]
if !ok {
continue
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
continue
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)
// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss["function"]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, funcCallResults{name: funcName, arguments: string(d)})
}
return results
}

View File

@@ -8,10 +8,10 @@ import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@@ -21,12 +21,12 @@ import (
)
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String()
created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{
ID: id,
Created: created,
@@ -53,14 +53,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true)
modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -84,7 +84,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
@@ -100,7 +100,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
predInput := config.PromptStrings[0]
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
@@ -111,7 +111,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses)
go process(predInput, input, config, ml, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@@ -153,7 +153,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for k, i := range config.PromptStrings {
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
@@ -164,7 +164,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
}
r, tokenUsage, err := ComputeChoices(
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {

View File

@@ -5,10 +5,10 @@ import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
@@ -16,14 +16,14 @@ import (
"github.com/rs/zerolog/log"
)
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true)
modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -33,7 +33,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}
@@ -46,7 +46,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
for _, i := range config.InputStrings {
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
@@ -57,7 +57,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
}
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {

View File

@@ -5,25 +5,26 @@ import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/api/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, o, true)
model, input, err := readRequest(c, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -33,7 +34,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
if err != nil {
return err
}
@@ -47,7 +48,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
if err != nil {
return err
}

View File

@@ -0,0 +1,218 @@
package openai
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
var uploadedFiles []File
const uploadedFilesFile = "uploadedFiles.json"
// File represents the structure of a file object from the OpenAI API.
type File struct {
ID string `json:"id"` // Unique identifier for the file
Object string `json:"object"` // Type of the object (e.g., "file")
Bytes int `json:"bytes"` // Size of the file in bytes
CreatedAt time.Time `json:"created_at"` // The time at which the file was created
Filename string `json:"filename"` // The name of the file
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
}
func saveUploadConfig(uploadDir string) {
file, err := json.MarshalIndent(uploadedFiles, "", " ")
if err != nil {
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
}
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
if err != nil {
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
}
}
func LoadUploadConfig(uploadPath string) {
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
_, err := os.Stat(uploadFilePath)
if os.IsNotExist(err) {
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
return
}
file, err := os.ReadFile(uploadFilePath)
if err != nil {
log.Error().Msgf("Failed to read file: %s", err)
} else {
err = json.Unmarshal(file, &uploadedFiles)
if err != nil {
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
}
}
}
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := c.FormFile("file")
if err != nil {
return err
}
// Check the file size
if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
}
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
if purpose == "" {
return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined")
}
// Sanitize the filename to prevent directory traversal
filename := utils.SanitizeFileName(file.Filename)
savePath := filepath.Join(appConfig.UploadDir, filename)
// Check if file already exists
if _, err := os.Stat(savePath); !os.IsNotExist(err) {
return c.Status(fiber.StatusBadRequest).SendString("File already exists")
}
err = c.SaveFile(file, savePath)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error())
}
f := File{
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
Object: "file",
Bytes: int(file.Size),
CreatedAt: time.Now(),
Filename: file.Filename,
Purpose: purpose,
}
uploadedFiles = append(uploadedFiles, f)
saveUploadConfig(appConfig.UploadDir)
return c.Status(fiber.StatusOK).JSON(f)
}
}
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type ListFiles struct {
Data []File
Object string
}
return func(c *fiber.Ctx) error {
var listFiles ListFiles
purpose := c.Query("purpose")
if purpose == "" {
listFiles.Data = uploadedFiles
} else {
for _, f := range uploadedFiles {
if purpose == f.Purpose {
listFiles.Data = append(listFiles.Data, f)
}
}
}
listFiles.Object = "list"
return c.Status(fiber.StatusOK).JSON(listFiles)
}
}
func getFileFromRequest(c *fiber.Ctx) (*File, error) {
id := c.Params("file_id")
if id == "" {
return nil, fmt.Errorf("file_id parameter is required")
}
for _, f := range uploadedFiles {
if id == f.ID {
return &f, nil
}
}
return nil, fmt.Errorf("unable to find file id %s", id)
}
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(file)
}
}
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type DeleteStatus struct {
Id string
Object string
Deleted bool
}
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
// If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) {
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err))
}
}
// Remove upload from list
for i, f := range uploadedFiles {
if f.ID == file.ID {
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
break
}
}
saveUploadConfig(appConfig.UploadDir)
return c.JSON(DeleteStatus{
Id: file.ID,
Object: "file",
Deleted: true,
})
}
}
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.Send(fileContents)
}
}

View File

@@ -0,0 +1,287 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"github.com/go-skynet/LocalAI/core/config"
utils2 "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"testing"
)
type ListFiles struct {
Data []File
Object string
}
func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
// Preparing the mocked objects
loader = &config.BackendConfigLoader{}
option = &config.ApplicationConfig{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app = fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
return
}
func TestUploadFileExceedSizeLimit(t *testing.T) {
// Preparing the mocked objects
loader := &config.BackendConfigLoader{}
option := &config.ApplicationConfig{
UploadLimitMB: 10,
UploadDir: "test_dir",
}
_ = os.RemoveAll(option.UploadDir)
app := fiber.New(fiber.Config{
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
})
// Create a Test Server
app.Post("/files", UploadFilesEndpoint(loader, option))
app.Get("/files", ListFilesEndpoint(loader, option))
app.Get("/files/:file_id", GetFilesEndpoint(loader, option))
app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option))
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
})
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
})
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
fmt.Println(f1)
fmt.Printf("ERror: %v", err)
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
assert.Contains(t, bodyToString(resp, t), "File already exists")
})
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
// Check if file exists in the disk
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
_, err := os.Stat(filePath)
assert.False(t, os.IsNotExist(err))
assert.Equal(t, file.Bytes, 5242880)
assert.NotEmpty(t, file.CreatedAt)
assert.Equal(t, file.Filename, "test.txt")
assert.Equal(t, file.Purpose, "fine-tune")
})
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != len(uploadedFiles) {
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
assert.NoError(t, err)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 1 {
t.Errorf("Expected 1 file, got %v files", len(listFiles.Data))
}
})
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
listFiles := responseToListFile(t, resp)
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
req := httptest.NewRequest("GET", "/files", nil)
resp, _ := app.Test(req)
assert.Equal(t, 200, resp.StatusCode)
var listFiles ListFiles
if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil {
t.Errorf("Failed to decode response: %v", err)
return
}
if len(listFiles.Data) != 0 {
t.Errorf("Expected 0 file, got %v files", len(listFiles.Data))
}
})
}
func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) {
var target string
if purpose != "" {
target = fmt.Sprintf("/files?purpose=%s", purpose)
} else {
target = "/files"
}
req := httptest.NewRequest("GET", target, nil)
return app.Test(req)
}
func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil)
return app.Test(request)
}
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
return app.Test(req)
}
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
// Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose)
req := httptest.NewRequest(http.MethodPost, "/files", body)
req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType())
resp, err := app.Test(req)
assert.NoError(t, err)
f := responseToFile(t, resp)
id := f.ID
t.Cleanup(func() {
_, err := CallFilesDeleteEndpoint(t, app, id)
assert.NoError(t, err)
})
return f
}
func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) {
target := fmt.Sprintf("/files/%s", fileId)
req := httptest.NewRequest(http.MethodDelete, target, nil)
return app.Test(req)
}
// Helper to create multi-part file
func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) {
body := new(strings.Builder)
writer := multipart.NewWriter(body)
file, _ := os.Open(filePath)
defer file.Close()
part, _ := writer.CreateFormFile(tag, filepath.Base(filePath))
io.Copy(part, file)
if purpose != "" {
_ = writer.WriteField("purpose", purpose)
}
writer.Close()
return strings.NewReader(body.String()), writer
}
// Helper to create test files
func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
err := os.MkdirAll(option.UploadDir, 0755)
if err != nil {
t.Fatalf("Error MKDIR: %v", err)
}
file, _ := os.Create(name)
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
t.Cleanup(func() {
os.Remove(name)
os.RemoveAll(option.UploadDir)
})
return file
}
func bodyToString(resp *http.Response, t *testing.T) string {
return string(bodyToByteArray(resp, t))
}
func bodyToByteArray(resp *http.Response, t *testing.T) []byte {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
return bodyBytes
}
func responseToFile(t *testing.T, resp *http.Response) File {
var file File
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file)
if err != nil {
t.Errorf("Failed to decode response: %s", err)
}
return file
}
func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
var listFiles ListFiles
responseToString := bodyToString(resp, t)
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
if err != nil {
fmt.Printf("Failed to decode response: %s", err)
}
return listFiles
}

View File

@@ -13,12 +13,12 @@ import (
"strings"
"time"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/core/backend"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) {
*
*/
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false)
m, input, err := readRequest(c, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -71,7 +71,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
}
log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -104,7 +104,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
}
// Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
if err != nil {
return err
}
@@ -133,15 +133,15 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 {
return fmt.Errorf("Invalid value for 'size'")
return fmt.Errorf("invalid value for 'size'")
}
width, err := strconv.Atoi(sizeParts[0])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
return fmt.Errorf("invalid value for 'size'")
}
height, err := strconv.Atoi(sizeParts[1])
if err != nil {
return fmt.Errorf("Invalid value for 'size'")
return fmt.Errorf("invalid value for 'size'")
}
b64JSON := false
@@ -179,7 +179,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
tempDir := ""
if !b64JSON {
tempDir = o.ImageDir
tempDir = appConfig.ImageDir
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
@@ -196,7 +196,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
if err != nil {
return err
}

View File

@@ -1,18 +1,18 @@
package openai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.Config,
o *options.Option,
config *config.BackendConfig,
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {

View File

@@ -3,15 +3,15 @@ package openai
import (
"regexp"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
models, err := ml.ListModels()
if err != nil {
return err
}
@@ -40,7 +40,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations
for _, c := range cm.GetAllConfigs() {
for _, c := range cl.GetAllBackendConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}

View File

@@ -5,24 +5,22 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"strings"
config "github.com/go-skynet/LocalAI/api/config"
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
options "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) {
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
@@ -30,9 +28,13 @@ func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *sch
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel)
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
return modelFile, input, err
}
@@ -49,7 +51,7 @@ func getBase64Image(s string) (string, error) {
defer resp.Body.Close()
// read the image data into memory
data, err := ioutil.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
@@ -68,14 +70,14 @@ func getBase64Image(s string) (string, error) {
return "", fmt.Errorf("not valid string")
}
func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != 0 {
if input.TopP != nil {
config.TopP = input.TopP
}
@@ -115,11 +117,11 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
@@ -136,6 +138,20 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice grammar.Tool
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
@@ -177,30 +193,14 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
if input.Seed != nil {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
@@ -255,8 +255,13 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
config.LoadOptionContextSize(ctx),
config.LoadOptionF16(f16),
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)

View File

@@ -8,23 +8,23 @@ import (
"path"
"path/filepath"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false)
m, input, err := readRequest(c, ml, appConfig, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@@ -59,7 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
if err != nil {
return err
}

View File

@@ -0,0 +1,6 @@
package schema
type ElevenLabsTTSRequest struct {
Text string `json:"text" yaml:"text"`
ModelID string `json:"model_id" yaml:"model_id"`
}

22
core/schema/localai.go Normal file
View File

@@ -0,0 +1,22 @@
package schema
import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Voice string `json:"voice" yaml:"voice"`
Backend string `json:"backend" yaml:"backend"`
}

View File

@@ -3,8 +3,6 @@ package schema
import (
"context"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar"
)
@@ -49,7 +47,7 @@ type OpenAIResponse struct {
type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason,omitempty"`
FinishReason string `json:"finish_reason"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
@@ -68,6 +66,10 @@ type ContentURL struct {
type Message struct {
// The message role
Role string `json:"role,omitempty" yaml:"role"`
// The message name (used for tools calls)
Name string `json:"name,omitempty" yaml:"name"`
// The message content
Content interface{} `json:"content" yaml:"content"`
@@ -76,6 +78,20 @@ type Message struct {
// A result of a function call
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"`
}
type ToolCall struct {
Index int `json:"index"`
ID string `json:"id"`
Type string `json:"type"`
FunctionCall FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments"`
}
type OpenAIModel struct {
@@ -90,10 +106,10 @@ type ChatCompletionResponseFormat struct {
}
type OpenAIRequest struct {
config.PredictionOptions
PredictionOptions
Context context.Context
Cancel context.CancelFunc
Context context.Context `json:"-"`
Cancel context.CancelFunc `json:"-"`
// whisper
File string `json:"file" validate:"required"`
@@ -117,6 +133,9 @@ type OpenAIRequest struct {
Functions []grammar.Function `json:"functions" yaml:"functions"`
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"`
ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"`
Stream bool `json:"stream"`
// Image (not supported by OpenAI)

View File

@@ -1,4 +1,4 @@
package api_config
package schema
type PredictionOptions struct {
@@ -12,28 +12,23 @@ type PredictionOptions struct {
N int `json:"n"`
// Common options between all the API calls, part of the OpenAI spec
TopP float64 `json:"top_p" yaml:"top_p"`
TopK int `json:"top_k" yaml:"top_k"`
Temperature float64 `json:"temperature" yaml:"temperature"`
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
Echo bool `json:"echo"`
TopP *float64 `json:"top_p" yaml:"top_p"`
TopK *int `json:"top_k" yaml:"top_k"`
Temperature *float64 `json:"temperature" yaml:"temperature"`
Maxtokens *int `json:"max_tokens" yaml:"max_tokens"`
Echo bool `json:"echo"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch" yaml:"batch"`
F16 bool `json:"f16" yaml:"f16"`
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
Keep int `json:"n_keep" yaml:"n_keep"`
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
Mirostat int `json:"mirostat" yaml:"mirostat"`
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"`
TFZ float64 `json:"tfz" yaml:"tfz"`
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
Seed int `json:"seed" yaml:"seed"`
Seed *int `json:"seed" yaml:"seed"`
NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"`
RopeFreqBase float32 `json:"rope_freq_base" yaml:"rope_freq_base"`

View File

@@ -0,0 +1,140 @@
package services
import (
"context"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitor struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: appConfig,
}
}
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetBackendConfig(modelName)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = modelName
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetBackendConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.modelLoader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &schema.BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return nil, err
}
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return &proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
}, nil
}
return status, nil
}
func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return err
}
return bm.modelLoader.ShutdownModel(backendId)
}

Some files were not shown because too many files have changed in this diff Show More