Compare commits

...

41 Commits

Author SHA1 Message Date
Ettore Di Giacinto
de36a48861 Update gpt4all to fix thread counts (#249) 2023-05-13 09:37:46 +02:00
mudler
961ca93219 docs: Update README 2023-05-13 00:46:48 +02:00
Ettore Di Giacinto
557ccc5ad8 examples: add langchain-chroma example (#248) 2023-05-12 22:20:07 +02:00
Ettore Di Giacinto
2488c445b6 feat: bert.cpp token embeddings (#241) 2023-05-12 17:16:49 +02:00
Ettore Di Giacinto
b4241d0a0d tests: enable whisper (#239) 2023-05-12 14:10:18 +02:00
Ettore Di Giacinto
8250391e49 Add support for gptneox/replit (#238) 2023-05-12 11:36:35 +02:00
Ettore Di Giacinto
fd1df4e971 whisper: add tests and allow to set upload size (#237) 2023-05-12 10:04:20 +02:00
ci-robbot [bot]
5115b2faa3 ⬆️ Update go-skynet/go-llama.cpp (#219)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2023-05-11 23:43:55 +02:00
ci-robbot [bot]
93e82a8bf4 ⬆️ Update go-skynet/go-gpt2.cpp (#220)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2023-05-11 23:43:44 +02:00
Ettore Di Giacinto
4413defca5 feat: add starcoder (#236) 2023-05-11 20:20:07 +02:00
Ettore Di Giacinto
f359e1c6c4 fix: dolly/rp (#235) 2023-05-11 19:38:27 +02:00
renovate[bot]
1bc87d582d fix(deps): update module github.com/sashabaranov/go-openai to v1.9.4 (#230)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-11 18:02:19 +02:00
renovate[bot]
a86a383357 fix(deps): update github.com/donomii/go-rwkv.cpp digest to 07166da (#224)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-11 18:01:52 +02:00
renovate[bot]
16f02c7b30 fix(deps): update github.com/go-skynet/go-bert.cpp digest to ec771ec (#223)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-11 18:01:35 +02:00
Ettore Di Giacinto
fe2706890c Update README.md 2023-05-11 17:32:13 +02:00
Ettore Di Giacinto
85f0f8227d refactor: drop code dups (#234) 2023-05-11 16:34:16 +02:00
Ettore Di Giacinto
59e3c02002 make use of new bindings for gpt4all (#232) 2023-05-11 14:31:19 +02:00
Matthew Campbell
032dee256f Keep whisper models in memory (#233) 2023-05-11 14:05:07 +02:00
Matthew Campbell
6b5e2b2bf5 Upload transcription API wasn't reading the data from the post (#229) 2023-05-11 10:43:05 +02:00
renovate[bot]
6fc303de87 fix(deps): update github.com/go-skynet/go-llama.cpp digest to 70593fc (#221)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-11 10:34:34 +02:00
renovate[bot]
6ad6e4873d fix(deps): update github.com/ggerganov/whisper.cpp/bindings/go digest to 1d17cd5 (#216)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-11 01:14:34 +02:00
ci-robbot [bot]
d6d7391da8 ⬆️ Update donomii/go-rwkv.cpp (#225)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2023-05-11 01:13:28 +02:00
Ettore Di Giacinto
11675932ac feat: add dolly/redpajama/bloomz models support (#214) 2023-05-11 01:12:58 +02:00
Ettore Di Giacinto
f02202e1e1 update README 2023-05-10 15:51:16 +02:00
Ettore Di Giacinto
f8ee20991c feat: add bert.cpp embeddings (#222) 2023-05-10 15:20:21 +02:00
Cedrik Boudreau
e6db14e2f1 Added spark in projects (#215) 2023-05-10 14:05:44 +02:00
Dave
d00886abea Tiny .gitignore suggestion (#213) 2023-05-09 20:03:29 +02:00
renovate[bot]
4873d2bfa1 fix(deps): update github.com/go-skynet/go-llama.cpp digest to f4d26f4 (#212)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-09 12:17:15 +02:00
Ettore Di Giacinto
9f426578cf feat: add transcript endpoint (#211) 2023-05-09 11:43:50 +02:00
ci-robbot [bot]
9d01b695a8 ⬆️ Update go-skynet/go-llama.cpp (#209)
Signed-off-by: GitHub <noreply@github.com>
Co-authored-by: mudler <mudler@users.noreply.github.com>
2023-05-08 22:37:16 +02:00
Ettore Di Giacinto
93829ab228 docs: update news 2023-05-08 22:34:12 +02:00
renovate[bot]
dd234f86d5 fix(deps): update github.com/go-skynet/go-llama.cpp digest to c03e8ad (#208)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-08 20:30:52 +02:00
mudler
3daff6f1aa doc: update README with embeddings docs 2023-05-08 20:02:59 +02:00
Ettore Di Giacinto
89dfa0f5fc feat: add experimental support for embeddings as arrays (#207) 2023-05-08 19:31:18 +02:00
renovate[bot]
bc03c492a0 fix(deps): update module github.com/gofiber/fiber/v2 to v2.45.0 (#206)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-07 17:46:31 +02:00
Fabian Hachenberg
f50a4c1454 Added runpod.io template for LocalAI to examples (#203) 2023-05-07 10:58:15 +02:00
mudler
d13d4d95ce examples: fix default parameter 2023-05-07 10:13:57 +02:00
renovate[bot]
428790ec06 fix(deps): update github.com/go-skynet/go-llama.cpp digest to cf9b522 (#202)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
2023-05-07 09:14:00 +02:00
mudler
4f551ce414 examples: add update index example, update README 2023-05-07 09:05:24 +02:00
mudler
6ed7b10273 examples: add langchain agent example 2023-05-07 08:14:01 +02:00
mudler
02979566ee examples: better defaults 2023-05-07 00:58:30 +02:00
36 changed files with 1144 additions and 352 deletions

View File

@@ -21,6 +21,18 @@ jobs:
- repository: "donomii/go-rwkv.cpp"
variable: "RWKV_VERSION"
branch: "main"
- repository: "ggerganov/whisper.cpp"
variable: "WHISPER_CPP_VERSION"
branch: "master"
- repository: "go-skynet/go-bert.cpp"
variable: "BERT_VERSION"
branch: "master"
- repository: "go-skynet/bloomz.cpp"
variable: "BLOOMZ_VERSION"
branch: "main"
- repository: "go-skynet/gpt4all"
variable: "GPT4ALL_VERSION"
branch: "main"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

View File

@@ -21,7 +21,7 @@ jobs:
- name: Dependencies
run: |
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install build-essential ffmpeg
- name: Test
run: |
make test
@@ -38,7 +38,7 @@ jobs:
- name: Dependencies
run: |
brew update
brew install sdl2
brew install sdl2 ffmpeg
- name: Test
run: |
make test

6
.gitignore vendored
View File

@@ -3,6 +3,7 @@ go-llama
go-gpt4all-j
go-gpt2
go-rwkv
whisper.cpp
# LocalAI build binary
LocalAI
@@ -12,4 +13,7 @@ local-ai
# Ignore models
models/*
test-models/
test-models/
# just in case
.DS_Store

116
Makefile
View File

@@ -3,11 +3,16 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
GOLLAMA_VERSION?=cf9b522db63898dcc5eb86e37c979ab85cbd583e
GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
GOLLAMA_VERSION?=70593fccbe4b01dedaab805b0f25cb58192c7b38
GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all
GPT4ALL_VERSION?=a330bfe26e9e35ca402e16df18973a3b162fb4db
GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1
GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3)
@@ -15,8 +20,8 @@ WHITE := $(shell tput -Txterm setaf 7)
CYAN := $(shell tput -Txterm setaf 6)
RESET := $(shell tput -Txterm sgr0)
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
# Use this if you want to set the default behavior
ifndef BUILD_TYPE
@@ -33,19 +38,34 @@ endif
all: help
## GPT4ALL-J
go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
cd go-gpt4all-j && git checkout -b build $(GOGPT4ALLJ_VERSION) && git submodule update --init --recursive --depth 1
## GPT4ALL
gpt4all:
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./go-gpt4all-j -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.go" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h
## BERT embeddings
go-bert:
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert
cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
@find ./go-bert -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
@find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
@find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
## RWKV
go-rwkv:
@@ -58,8 +78,23 @@ go-rwkv:
go-rwkv/librwkv.a: go-rwkv
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a ..
go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a
## bloomz
bloomz:
git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz
@find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
bloomz/libbloomz.a: bloomz
cd bloomz && make libbloomz.a
go-bert/libgobert.a: go-bert
$(MAKE) -C go-bert libgobert.a
gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ $(GENERIC_PREFIX)libgpt4all.a
## CEREBRAS GPT
go-gpt2:
@@ -69,13 +104,24 @@ go-gpt2:
@find ./go-gpt2 -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2 -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2 -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} +
@find ./go-gpt2 -type f -name "*.h" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} +
@find ./go-gpt2 -type f -name "*.h" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} +
@find ./go-gpt2 -type f -name "*.h" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} +
@find ./go-gpt2 -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} +
go-gpt2/libgpt2.a: go-gpt2
$(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a
whisper.cpp:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
whisper.cpp/libwhisper.a: whisper.cpp
cd whisper.cpp && make libwhisper.a
go-llama:
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
cd go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1
@@ -85,28 +131,36 @@ go-llama/libbinding.a: go-llama
replace:
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
$(GOCMD) mod edit -replace github.com/nomic/gpt4all/gpt4all-bindings/golang=$(shell pwd)/gpt4all/gpt4all-bindings/golang
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/whisper.cpp
$(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/go-bert
$(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz
prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv
prepare-sources: go-llama go-gpt2 gpt4all go-rwkv whisper.cpp go-bert bloomz replace
$(GOCMD) mod download
## GENERIC
rebuild: ## Rebuilds the project
$(MAKE) -C go-llama clean
$(MAKE) -C go-gpt4all-j clean
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean
$(MAKE) -C go-gpt2 clean
$(MAKE) -C go-rwkv clean
$(MAKE) -C whisper.cpp clean
$(MAKE) -C go-bert clean
$(MAKE) -C bloomz clean
$(MAKE) build
prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a go-rwkv/librwkv.a replace ## Prepares for building
prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a go-llama/libbinding.a go-bert/libgobert.a go-gpt2/libgpt2.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
clean: ## Remove build related file
rm -fr ./go-llama
rm -rf ./go-gpt4all-j
rm -rf ./gpt4all
rm -rf ./go-gpt2
rm -rf ./go-rwkv
rm -rf ./go-bert
rm -rf ./bloomz
rm -rf $(BINARY_NAME)
## Build:
@@ -114,7 +168,7 @@ clean: ## Remove build related file
build: prepare ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -x -o $(BINARY_NAME) ./
generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build
@@ -125,12 +179,16 @@ run: prepare ## run local-ai
test-models/testmodel:
mkdir test-models
mkdir test-dir
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O test-models/bert
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
cp tests/fixtures/* test-models
test: prepare test-models/testmodel
cp tests/fixtures/* test-models
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./...
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
## Help:
help: ## Show this help.

View File

@@ -9,7 +9,7 @@
[![](https://dcbadge.vercel.app/api/server/uJAeKSAGDy?style=flat-square&theme=default-inverted)](https://discord.gg/uJAeKSAGDy)
**LocalAI** is a drop-in replacement REST API compatible with OpenAI for local CPU inferencing. It allows to run models locally or on-prem with consumer grade hardware. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp) and [ggml](https://github.com/ggerganov/ggml), including support GPT4ALL-J which is licensed under Apache 2.0.
**LocalAI** is a drop-in replacement REST API compatible with OpenAI for local CPU inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families. Supports also GPT4ALL-J which is licensed under Apache 2.0.
- OpenAI compatible API
- Supports multiple models
@@ -19,10 +19,17 @@
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https://github.com/ggerganov/llama.cpp), [gpt4all](https://github.com/nomic-ai/gpt4all), [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp), [ggml](https://github.com/ggerganov/ggml), [whisper.cpp](https://github.com/ggerganov/whisper.cpp) for audio transcriptions, and [bert.cpp](https://github.com/skeskinen/bert.cpp) for embedding.
See [examples on how to integrate LocalAI](https://github.com/go-skynet/LocalAI/tree/master/examples/).
## News
- 12-05-2023: __v1.10.0__ released! 🔥🔥 Updated `gpt4all` bindings. Added support for GPTNeox (experimental), RedPajama (experimental), Starcoder (experimental), Replit (experimental), MosaicML MPT. Also now `embeddings` endpoint supports tokens arrays. See the [langchain-chroma](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-chroma) example! Note - this update does NOT include https://github.com/ggerganov/llama.cpp/pull/1405 which makes models incompatible.
- 11-05-2023: __v1.9.0__ released! 🔥 Important whisper updates ( https://github.com/go-skynet/LocalAI/pull/233 https://github.com/go-skynet/LocalAI/pull/229 ) and extended gpt4all model families support ( https://github.com/go-skynet/LocalAI/pull/232 ). Redpajama/dolly experimental ( https://github.com/go-skynet/LocalAI/pull/214 )
- 10-05-2023: __v1.8.0__ released! 🔥 Added support for fast and accurate embeddings with `bert.cpp` ( https://github.com/go-skynet/LocalAI/pull/222 )
- 09-05-2023: Added experimental support for transcriptions endpoint ( https://github.com/go-skynet/LocalAI/pull/211 )
- 08-05-2023: Support for embeddings with models using the `llama.cpp` backend ( https://github.com/go-skynet/LocalAI/pull/207 )
- 02-05-2023: Support for `rwkv.cpp` models ( https://github.com/go-skynet/LocalAI/pull/158 ) and for `/edits` endpoint
- 01-05-2023: Support for SSE stream of tokens in `llama.cpp` backends ( https://github.com/go-skynet/LocalAI/pull/152 )
@@ -30,7 +37,8 @@ Twitter: [@LocalAI_API](https://twitter.com/LocalAI_API) and [@mudler_it](https:
### Blogs and articles
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65) - excellent usecase for localAI, using AI to analyse Kubernetes clusters.
- [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/) by Ettore Di Giacinto
- [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65) - excellent usecase for localAI, using AI to analyse Kubernetes clusters. by Tyller Gillson
## Contribute and help
@@ -68,7 +76,7 @@ Note: You might need to convert older models to the new format, see [here](https
A full example on how to run a rwkv model is in the [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/rwkv).
Note: rwkv models have an associated tokenizer along that needs to be provided with it:
Note: rwkv models needs to specify the backend `rwkv` in the YAML config files and have an associated tokenizer along that needs to be provided with it:
```
36464540 -rw-r--r-- 1 mudler mudler 1.2G May 3 10:51 rwkv_small
@@ -86,6 +94,30 @@ It should also be compatible with StableLM and GPTNeoX ggml models (untested).
Depending on the model you are attempting to run might need more RAM or CPU resources. Check out also [here](https://github.com/ggerganov/llama.cpp#memorydisk-requirements) for `ggml` based backends. `rwkv` is less expensive on resources.
### Model compatibility table
<details>
| Backend | Compatible models | Completion/Chat endpoint | Audio transcription | Embeddings support | Token stream support | Github | Bindings |
|-----------------|-----------------------|--------------------------|---------------------|-----------------------------------|----------------------|--------------------------------------------|-------------------------------------------|
| llama | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes | https://github.com/ggerganov/llama.cpp | https://github.com/go-skynet/go-llama.cpp |
| gpt4all-llama | Vicuna, Alpaca, LLaMa | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
| gpt4all-mpt | MPT | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
| gpt4all-j | GPT4ALL-J | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all |
| gpt2 | GPT/NeoX, Cerebras | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| dolly | Dolly | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| redpajama | RedPajama | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| stableLM | StableLM GPT/NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| replit | Replit | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| gptneox | GPT NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| starcoder | Starcoder | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp |
| bloomz | Bloom | yes | no | no | no | https://github.com/NouamaneTazi/bloomz.cpp | https://github.com/go-skynet/bloomz.cpp |
| rwkv | RWKV | yes | no | no | yes | https://github.com/saharNooby/rwkv.cpp | https://github.com/donomii/go-rwkv.cpp |
| bert-embeddings | bert | no | no | yes | no | https://github.com/skeskinen/bert.cpp | https://github.com/go-skynet/go-bert.cpp |
| whisper | whisper | no | yes | no | no | https://github.com/ggerganov/whisper.cpp | https://github.com/ggerganov/whisper.cpp |
</details>
## Usage
> `LocalAI` comes by default as a container image. You can check out all the available images with corresponding tags [here](https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest).
@@ -529,6 +561,52 @@ curl http://localhost:8080/v1/models
</details>
### Embeddings
<details>
The embedding endpoint is experimental and enabled only if the model is configured with `embeddings: true` in its `yaml` file, for example:
```yaml
name: text-embedding-ada-002
parameters:
model: bert
embeddings: true
backend: "bert-embeddings"
```
There is an example available [here](https://github.com/go-skynet/LocalAI/tree/master/examples/query_data/).
Note: embeddings is supported only with `llama.cpp` compatible models and `bert` models. bert is more performant and available independently of the LLM model.
</details>
### Transcriptions endpoint
<details>
Note: requires ffmpeg in the container image, which is currently not shipped due to licensing issues. We will prepare separated images with ffmpeg. (stay tuned!)
Download one of the models from https://huggingface.co/ggerganov/whisper.cpp/tree/main in the `models` folder, and create a YAML file for your model:
```yaml
name: whisper-1
backend: whisper
parameters:
model: whisper-en
```
The transcriptions endpoint then can be tested like so:
```
wget --quiet --show-progress -O gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
curl http://localhost:8080/v1/audio/transcriptions -H "Content-Type: multipart/form-data" -F file="@$PWD/gb1.ogg" -F model="whisper-1"
{"text":"My fellow Americans, this day has brought terrible news and great sadness to our country.At nine o'clock this morning, Mission Control in Houston lost contact with our Space ShuttleColumbia.A short time later, debris was seen falling from the skies above Texas.The Columbia's lost.There are no survivors.One board was a crew of seven.Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain DavidBrown, Commander William McCool, Dr. Kultna Shavla, and Elon Ramon, a colonel in the IsraeliAir Force.These men and women assumed great risk in the service to all humanity.In an age when spaceflight has come to seem almost routine, it is easy to overlook thedangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere ofthe Earth.These astronauts knew the dangers, and they faced them willingly, knowing they had a highand noble purpose in life.Because of their courage and daring and idealism, we will miss them all the more.All Americans today are thinking as well of the families of these men and women who havebeen given this sudden shock and grief.You're not alone.Our entire nation agrees with you, and those you loved will always have the respect andgratitude of this country.The cause in which they died will continue.Mankind has led into the darkness beyond our world by the inspiration of discovery andthe longing to understand.Our journey into space will go on.In the skies today, we saw destruction and tragedy.As farther than we can see, there is comfort and hope.In the words of the prophet Isaiah, \"Lift your eyes and look to the heavens who createdall these, he who brings out the starry hosts one by one and calls them each by name.\"Because of his great power and mighty strength, not one of them is missing.The same creator who names the stars also knows the names of the seven souls we mourntoday.The crew of the shuttle Columbia did not return safely to Earth yet we can pray that all aresafely home.May God bless the grieving families and may God continue to bless America.[BLANK_AUDIO]"}
```
</details>
## Frequently asked questions
Here are answers to some of the most common questions.
@@ -590,6 +668,7 @@ Feel free to open up a PR to get your project listed!
- [Kairos](https://github.com/kairos-io/kairos)
- [k8sgpt](https://github.com/k8sgpt-ai/k8sgpt#running-local-models)
- [Spark](https://github.com/cedriking/spark)
## Blog posts and other articles
@@ -621,6 +700,7 @@ MIT
- [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- [go-skynet/go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp)
- [go-skynet/go-gpt2.cpp](https://github.com/go-skynet/go-gpt2.cpp)
- [go-skynet/go-bert.cpp](https://github.com/go-skynet/go-bert.cpp)
- [donomii/go-rwkv.cpp](https://github.com/donomii/go-rwkv.cpp)
## Acknowledgements

View File

@@ -12,7 +12,7 @@ import (
"github.com/rs/zerolog/log"
)
func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@@ -20,6 +20,7 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: disableMessage,
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
@@ -84,6 +85,8 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader, cm))

View File

@@ -3,6 +3,8 @@ package api_test
import (
"context"
"os"
"path/filepath"
"runtime"
. "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/model"
@@ -23,7 +25,7 @@ var _ = Describe("API test", func() {
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App("", modelLoader, 1, 512, false, true, true)
app = App("", modelLoader, 15, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -45,8 +47,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(3))
Expect(models.Models[0].ID).To(Equal("testmodel"))
Expect(len(models.Models)).To(Equal(7))
})
It("can generate completions", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
@@ -79,15 +80,55 @@ var _ = Describe("API test", func() {
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
})
It("transcribes audio", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateTranscription(
context.Background(),
openai.AudioRequest{
Model: openai.Whisper1,
FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"),
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting"))
})
It("calculate embeddings", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaEmbeddingV2,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
})
})
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true)
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@@ -108,8 +149,7 @@ var _ = Describe("API test", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(5))
Expect(models.Models[0].ID).To(Equal("testmodel"))
Expect(len(models.Models)).To(Equal(9))
})
It("can generate chat completions from config file", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
@@ -134,5 +174,6 @@ var _ = Describe("API test", func() {
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
})
})

View File

@@ -33,6 +33,7 @@ type Config struct {
Mirostat int `yaml:"mirostat"`
PromptStrings, InputStrings []string
InputToken [][]int
}
type TemplateConfig struct {
@@ -186,8 +187,15 @@ func updateConfig(config *Config, input *OpenAIRequest) {
}
case []interface{}:
for _, pp := range inputs {
if s, ok := pp.(string); ok {
config.InputStrings = append(config.InputStrings, s)
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
@@ -277,5 +285,10 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug
}
}
// Enforce debug flag if passed from CLI
if debug {
config.Debug = true
}
return config, input, nil
}

View File

@@ -5,9 +5,17 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
model "github.com/go-skynet/LocalAI/pkg/model"
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
@@ -69,6 +77,11 @@ type OpenAIModel struct {
type OpenAIRequest struct {
Model string `json:"model" yaml:"model"`
// whisper
File string `json:"file" validate:"required"`
ResponseFormat string `json:"response_format"`
Language string `json:"language"`
// Prompt is read only by completion API calls
Prompt interface{} `json:"prompt" yaml:"prompt"`
@@ -177,10 +190,23 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{}
for i, s := range config.InputStrings {
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, loader, *config)
embedFn, err := ModelEmbedding("", s, loader, *config)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
if err != nil {
return err
}
@@ -372,6 +398,70 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
}
}
// https://platform.openai.com/docs/api-reference/audio/create
func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer f.Close()
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return err
}
defer os.RemoveAll(dir)
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
if err != nil {
return err
}
if whisperModel == nil {
return fmt.Errorf("could not load whisper model")
}
w, ok := whisperModel.(whisper.Model)
if !ok {
return fmt.Errorf("loader returned non-whisper object")
}
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
if err != nil {
return err
}
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr})
}
}
func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()

View File

@@ -8,9 +8,11 @@ import (
"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang"
)
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
@@ -32,7 +34,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
return llamaOpts
}
func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
@@ -57,8 +59,19 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl
case *llama.LLama:
fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c)
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...)
}
return model.Embeddings(s, predictOptions...)
}
// bert embeddings
case *bert.Bert:
fn = func() ([]float32, error) {
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
}
return model.Embeddings(s, bert.SetThreads(c.Threads))
}
default:
fn = func() ([]float32, error) {
return nil, fmt.Errorf("embeddings not supported by the backend")
@@ -186,6 +199,122 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
return response, nil
}
case *gpt2.GPTNeoX:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.Replit:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.Starcoder:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.RedPajama:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *bloomz.Bloomz:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []bloomz.PredictOption{
bloomz.SetTemperature(c.Temperature),
bloomz.SetTopP(c.TopP),
bloomz.SetTopK(c.TopK),
bloomz.SetTokens(c.Maxtokens),
bloomz.SetThreads(c.Threads),
}
if c.Seed != 0 {
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.StableLM:
fn = func() (string, error) {
// Generate the prediction using the language model
@@ -210,6 +339,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions...,
)
}
case *gpt2.Dolly:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.GPT2:
fn = func() (string, error) {
// Generate the prediction using the language model
@@ -234,29 +387,35 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions...,
)
}
case *gptj.GPTJ:
case *gpt4all.Model:
supportStreams = true
fn = func() (string, error) {
if tokenCallback != nil {
model.SetTokenCallback(tokenCallback)
}
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
gptj.SetTemperature(c.Temperature),
gptj.SetTopP(c.TopP),
gptj.SetTopK(c.TopK),
gptj.SetTokens(c.Maxtokens),
gptj.SetThreads(c.Threads),
predictOptions := []gpt4all.PredictOption{
gpt4all.SetTemperature(c.Temperature),
gpt4all.SetTopP(c.TopP),
gpt4all.SetTopK(c.TopK),
gpt4all.SetTokens(c.Maxtokens),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(c.Batch))
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed))
}
return model.Predict(
str, er := model.Predict(
s,
predictOptions...,
)
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
// after a stream event has occurred
model.SetTokenCallback(nil)
return str, er
}
case *llama.LLama:
supportStreams = true

View File

@@ -65,7 +65,7 @@ Run a slack bot which lets you talk directly with a model
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/slack-bot/)
### Question answering on documents
### Question answering on documents with llama-index
_by [@mudler](https://github.com/mudler)_
@@ -73,6 +73,22 @@ Shows how to integrate with [Llama-Index](https://gpt-index.readthedocs.io/en/st
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/query_data/)
### Question answering on documents with langchain and chroma
_by [@mudler](https://github.com/mudler)_
Shows how to integrate with `Langchain` and `Chroma` to enable question answering on a set of documents.
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-chroma/)
### Template for Runpod.io
_by [@fHachenberg](https://github.com/fHachenberg)_
Allows to run any LocalAI-compatible model as a backend on the servers of https://runpod.io
[Check it out here](https://runpod.io/gsc?template=uv9mtqnrd0&ref=984wlcra)
## Want to contribute?
Create an issue, and put `Example: <description>` in the title! We will post your examples here.

View File

@@ -0,0 +1,54 @@
# Data query example
This example makes use of [langchain and chroma](https://blog.langchain.dev/langchain-chroma/) to enable question answering on a set of documents.
## Setup
Download the models and start the API:
```bash
# Clone LocalAI
git clone https://github.com/go-skynet/LocalAI
cd LocalAI/examples/query_data
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O models/bert
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
# start with docker-compose
docker-compose up -d --build
```
### Python requirements
```
pip install -r requirements.txt
```
### Create a storage
In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM.
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
wget https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt
python store.py
```
After it finishes, a directory "storage" will be created with the vector index database.
## Query
We can now query the dataset.
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
python query.py
# President Trump recently stated during a press conference regarding tax reform legislation that "we're getting rid of all these loopholes." He also mentioned that he wants to simplify the system further through changes such as increasing the standard deduction amount and making other adjustments aimed at reducing taxpayers' overall burden.
```
Keep in mind now things are hit or miss!

View File

@@ -0,0 +1 @@
{{.Input}}

View File

@@ -0,0 +1,5 @@
name: text-embedding-ada-002
parameters:
model: bert
backend: bert-embeddings
embeddings: true

View File

@@ -0,0 +1,16 @@
name: gpt-3.5-turbo
parameters:
model: ggml-gpt4all-j
top_k: 80
temperature: 0.2
top_p: 0.7
context_size: 1024
stopwords:
- "HUMAN:"
- "GPT:"
roles:
user: " "
system: " "
template:
completion: completion
chat: gpt4all

View File

@@ -0,0 +1,4 @@
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
{{.Input}}
### Response:

View File

@@ -0,0 +1,31 @@
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter,CharacterTextSplitter
from langchain.llms import OpenAI
from langchain.chains import VectorDBQA
from langchain.document_loaders import TextLoader
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
# Load and process the text
loader = TextLoader('state_of_the_union.txt')
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=70)
texts = text_splitter.split_documents(documents)
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'db'
embedding = OpenAIEmbeddings()
# Now we can load the persisted database from disk, and use it as normal.
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
qa = VectorDBQA.from_chain_type(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path), chain_type="stuff", vectorstore=vectordb)
query = "What the president said about taxes ?"
print(qa.run(query))

View File

@@ -0,0 +1,4 @@
langchain==0.0.160
openai==0.27.6
chromadb==0.3.21
llama-index==0.6.2

View File

@@ -0,0 +1,28 @@
import os
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter,CharacterTextSplitter
from langchain.llms import OpenAI
from langchain.chains import VectorDBQA
from langchain.document_loaders import TextLoader
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
# Load and process the text
loader = TextLoader('state_of_the_union.txt')
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=70)
#text_splitter = TokenTextSplitter()
texts = text_splitter.split_documents(documents)
# Embed and store the texts
# Supplying a persist_directory will store the embeddings on disk
persist_directory = 'db'
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
vectordb.persist()
vectordb = None

View File

@@ -30,4 +30,6 @@ export OPENAI_API_KEY=sk-
python test.py
# A good company name for a company that makes colorful socks would be "Colorsocks".
python agent.py
```

View File

@@ -0,0 +1,44 @@
## This is a fork/based from https://gist.github.com/wiseman/4a706428eaabf4af1002a07a114f61d6
from io import StringIO
import sys
import os
from typing import Dict, Optional
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents.tools import Tool
from langchain.llms import OpenAI
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
model_name = os.environ.get('MODEL_NAME', 'gpt-3.5-turbo')
class PythonREPL:
"""Simulates a standalone Python REPL."""
def __init__(self):
pass
def run(self, command: str) -> str:
"""Run command and returns anything printed."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
try:
exec(command, globals())
sys.stdout = old_stdout
output = mystdout.getvalue()
except Exception as e:
sys.stdout = old_stdout
output = str(e)
return output
llm = OpenAI(temperature=0.0, openai_api_base=base_path, model_name=model_name)
python_repl = Tool(
"Python REPL",
PythonREPL().run,
"""A Python shell. Use this to execute python commands. Input should be a valid python command.
If you expect output it should be printed out.""",
)
tools = [python_repl]
agent = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True)
agent.run("What is the 10th fibonacci number?")

View File

@@ -4,13 +4,15 @@ This example makes use of [Llama-Index](https://gpt-index.readthedocs.io/en/stab
It loosely follows [the quickstart](https://gpt-index.readthedocs.io/en/stable/guides/primer/usage_pattern.html).
Summary of the steps:
- prepare the dataset (and store it into `data`)
- prepare a vector index database to run queries on
- run queries
## Requirements
For this in order to work, you will need a model compatible with the `llama.cpp` backend. This is will not work with gpt4all.
The example uses `WizardLM`. Edit the config files in `models/` accordingly to specify the model you use (change `HERE`).
You will also need a training data set. Copy that over `data`.
You will need a training data set. Copy that over `data`.
## Setup
@@ -22,13 +24,16 @@ git clone https://github.com/go-skynet/LocalAI
cd LocalAI/examples/query_data
# Copy your models, edit config files accordingly
wget https://huggingface.co/skeskinen/ggml/resolve/main/all-MiniLM-L6-v2/ggml-model-q4_0.bin -O models/bert
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
# start with docker-compose
docker-compose up -d --build
```
### Create a storage:
### Create a storage
In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM.
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
@@ -41,9 +46,22 @@ After it finishes, a directory "storage" will be created with the vector index d
## Query
We can now query the dataset.
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
python query.py
```
## Update
To update our vector database, run `update.py`
```bash
export OPENAI_API_BASE=http://localhost:8080/v1
export OPENAI_API_KEY=sk-
python update.py
```

View File

@@ -1,18 +1,6 @@
name: text-embedding-ada-002
parameters:
model: HERE
top_k: 80
temperature: 0.2
top_p: 0.7
context_size: 1024
model: bert
threads: 14
stopwords:
- "HUMAN:"
- "GPT:"
roles:
user: " "
system: " "
backend: bert-embeddings
embeddings: true
template:
completion: completion
chat: gpt4all

View File

@@ -1,12 +1,11 @@
name: gpt-3.5-turbo
parameters:
model: HERE
model: ggml-gpt4all-j
top_k: 80
temperature: 0.2
top_p: 0.7
context_size: 1024
threads: 14
embeddings: true
stopwords:
- "HUMAN:"
- "GPT:"
@@ -15,4 +14,4 @@ roles:
system: " "
template:
completion: completion
chat: wizardlm
chat: gpt4all

View File

@@ -1,3 +0,0 @@
{{.Input}}
### Response:

View File

@@ -10,10 +10,10 @@ from llama_index import StorageContext, load_index_from_storage
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
# This example uses text-davinci-003 by default; feel free to change if desired
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo",openai_api_base=base_path))
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path))
# Configure prompt parameters and initialise helper
max_input_size = 1024
max_input_size = 500
num_output = 256
max_chunk_overlap = 20
@@ -29,5 +29,7 @@ storage_context = StorageContext.from_defaults(persist_dir='./storage')
index = load_index_from_storage(storage_context, service_context=service_context, )
query_engine = index.as_query_engine()
response = query_engine.query("XXXXXX your question here XXXXX")
print(response)
data = input("Question: ")
response = query_engine.query(data)
print(response)

View File

@@ -13,15 +13,15 @@ base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path))
# Configure prompt parameters and initialise helper
max_input_size = 256
num_output = 256
max_chunk_overlap = 10
max_input_size = 400
num_output = 400
max_chunk_overlap = 30
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
# Load documents from the 'data' directory
documents = SimpleDirectoryReader('data').load_data()
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit = 257)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit = 400)
index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context)
index.storage_context.persist(persist_dir="./storage")

View File

@@ -0,0 +1,32 @@
import os
# Uncomment to specify your OpenAI API key here (local testing only, not in production!), or add corresponding environment variable (recommended)
# os.environ['OPENAI_API_KEY']= ""
from llama_index import LLMPredictor, PromptHelper, SimpleDirectoryReader, ServiceContext
from langchain.llms.openai import OpenAI
from llama_index import StorageContext, load_index_from_storage
base_path = os.environ.get('OPENAI_API_BASE', 'http://localhost:8080/v1')
# This example uses text-davinci-003 by default; feel free to change if desired
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo", openai_api_base=base_path))
# Configure prompt parameters and initialise helper
max_input_size = 512
num_output = 256
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)
# Load documents from the 'data' directory
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir='./storage')
# load index
index = load_index_from_storage(storage_context, service_context=service_context, )
documents = SimpleDirectoryReader('data').load_data()
index.refresh(documents)
index.storage_context.persist(persist_dir="./storage")

20
go.mod
View File

@@ -3,17 +3,22 @@ module github.com/go-skynet/LocalAI
go 1.19
require (
github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708
github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a
github.com/go-audio/wav v1.1.0
github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f
github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c
github.com/go-skynet/go-llama.cpp v0.0.0-20230505100647-691d479d3675
github.com/gofiber/fiber/v2 v2.44.0
github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b
github.com/gofiber/fiber/v2 v2.45.0
github.com/hashicorp/go-multierror v1.1.1
github.com/onsi/ginkgo/v2 v2.9.4
github.com/onsi/gomega v1.27.6
github.com/otiai10/copy v1.11.0
github.com/otiai10/openaigo v1.1.0
github.com/rs/zerolog v1.29.1
github.com/sashabaranov/go-openai v1.9.3
github.com/sashabaranov/go-openai v1.9.4
github.com/swaggo/swag v1.16.1
github.com/urfave/cli/v2 v2.25.3
github.com/valyala/fasthttp v1.47.0
@@ -26,6 +31,8 @@ require (
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/andybalholm/brotli v1.0.5 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.6 // indirect
@@ -42,6 +49,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/nomic/gpt4all/gpt4all-bindings/golang v0.0.0-00010101000000-000000000000 // indirect
github.com/philhofer/fwd v1.1.2 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
@@ -52,7 +60,7 @@ require (
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/tools v0.8.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect

43
go.sum
View File

@@ -18,6 +18,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be h1:3Hic97PY6hcw/SY44RuR7kyONkxd744RFeRrqckzwNQ=
github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2 h1:YNbUAyIRtaLODitigJU1EM5ubmMu5FmHtYAayJD6Vbg=
github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 h1:sMg/SgnMPS/HNUO/2kGm72vl8R9TmNIwgLFr2TNwR3g=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a h1:MlyiDLNCM/wjbv8U5Elj18NvaAgl61SGiRUpqQz5dfs=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -30,18 +42,23 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 h1:cfOi4TWvQ6JsAm9Q1A8I8j9YfNy10bmIfwOiyGyU5wQ=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f h1:GW8RQa1RVeDF1dOuAP/y6xWVC+BRtf9tJOuEza6Asbg=
github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA=
github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea h1:8Isk9D+Auth5OuXVAQPC3MO+5zF/2S7mvs2JZLw6a+8=
github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ=
github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557 h1:LD66fKtvP2lmyuuKL8pBat/pVTKUbLs3L5fM/5lyi4w=
github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230504223241-67ff6a4db244/go.mod h1:LvSQx5QAYBAMpWkbyVFFDiM1Tzj8LP55DvmUM3hbRMY=
github.com/go-skynet/go-llama.cpp v0.0.0-20230505100647-691d479d3675 h1:plXywr95RghidIHPHl+O/zpcNXenEeS6w/6WftFNr9E=
github.com/go-skynet/go-llama.cpp v0.0.0-20230505100647-691d479d3675/go.mod h1:LvSQx5QAYBAMpWkbyVFFDiM1Tzj8LP55DvmUM3hbRMY=
github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b h1:qqxrjY8fYDXQahmCMTCACahm1tbiqHLPUHALkFLyBfo=
github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofiber/fiber/v2 v2.44.0 h1:Z90bEvPcJM5GFJnu1py0E1ojoerkyew3iiNJ78MQCM8=
github.com/gofiber/fiber/v2 v2.44.0/go.mod h1:VTMtb/au8g01iqvHyaCzftuM/xmZgKOZCtFzz6CdV9w=
github.com/gofiber/fiber/v2 v2.45.0 h1:p4RpkJT9GAW6parBSbcNFH2ApnAuW3OzaQzbOCoDu+s=
github.com/gofiber/fiber/v2 v2.45.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -82,7 +99,9 @@ github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/otiai10/mint v1.4.1 h1:HOVBfKP1oXIc0wWo9hZ8JLdZtyCPWqjvmFDuVZ0yv2Y=
github.com/otiai10/copy v1.11.0 h1:OKBD80J/mLBrwnzXqGtFCzprFSGioo30JcmR4APsNwc=
github.com/otiai10/copy v1.11.0/go.mod h1:rSaLseMUsZFFbsFGc7wCJnnkTAvdc5L6VWxPE4308Ww=
github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks=
github.com/otiai10/openaigo v1.1.0 h1:zRvGBqZUW5PCMgdkJNsPVTBd8tOLCMTipXE5wD2pdTg=
github.com/otiai10/openaigo v1.1.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg=
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
@@ -100,6 +119,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM=
github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM=
github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4=
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8=
github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4=
@@ -108,7 +129,7 @@ github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJ
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/swaggo/swag v1.16.1 h1:fTNRhKstPKxcnoKsytm4sahr8FaYzUcT7i1/3nd/fBg=
github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKikGxto=
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
@@ -161,8 +182,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=

View File

@@ -62,6 +62,12 @@ func main() {
EnvVars: []string{"CONTEXT_SIZE"},
Value: 512,
},
&cli.IntFlag{
Name: "upload-limit",
DefaultText: "Default upload-limit. MB",
EnvVars: []string{"UPLOAD_LIMIT"},
Value: 15,
},
},
Description: `
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
@@ -81,7 +87,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address"))
return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address"))
},
}

182
pkg/model/initializers.go Normal file
View File

@@ -0,0 +1,182 @@
package model
import (
"fmt"
"strings"
rwkv "github.com/donomii/go-rwkv.cpp"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang"
"github.com/rs/zerolog/log"
)
const tokenizerSuffix = ".tokenizer.json"
const (
LlamaBackend = "llama"
BloomzBackend = "bloomz"
StarcoderBackend = "starcoder"
StableLMBackend = "stablelm"
DollyBackend = "dolly"
RedPajamaBackend = "redpajama"
GPTNeoXBackend = "gptneox"
ReplitBackend = "replit"
Gpt2Backend = "gpt2"
Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j"
BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
)
var backends []string = []string{
LlamaBackend,
Gpt4AllLlamaBackend,
Gpt4AllMptBackend,
Gpt4AllJBackend,
Gpt2Backend,
WhisperBackend,
RwkvBackend,
BloomzBackend,
StableLMBackend,
DollyBackend,
RedPajamaBackend,
GPTNeoXBackend,
ReplitBackend,
BertEmbeddingsBackend,
StarcoderBackend,
}
var starCoder = func(modelFile string) (interface{}, error) {
return gpt2.NewStarcoder(modelFile)
}
var redPajama = func(modelFile string) (interface{}, error) {
return gpt2.NewRedPajama(modelFile)
}
var dolly = func(modelFile string) (interface{}, error) {
return gpt2.NewDolly(modelFile)
}
var gptNeoX = func(modelFile string) (interface{}, error) {
return gpt2.NewGPTNeoX(modelFile)
}
var replit = func(modelFile string) (interface{}, error) {
return gpt2.NewReplit(modelFile)
}
var stableLM = func(modelFile string) (interface{}, error) {
return gpt2.NewStableLM(modelFile)
}
var bertEmbeddings = func(modelFile string) (interface{}, error) {
return bert.New(modelFile)
}
var bloomzLM = func(modelFile string) (interface{}, error) {
return bloomz.New(modelFile)
}
var gpt2LM = func(modelFile string) (interface{}, error) {
return gpt2.New(modelFile)
}
var whisperModel = func(modelFile string) (interface{}, error) {
return whisper.New(modelFile)
}
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
return llama.New(s, opts...)
}
}
func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
return gpt4all.New(s, opts...)
}
}
func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
model := rwkv.LoadFiles(s, tokenFile, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}
return model, nil
}
}
func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
switch strings.ToLower(backendString) {
case LlamaBackend:
return ml.LoadModel(modelFile, llamaLM(llamaOpts...))
case BloomzBackend:
return ml.LoadModel(modelFile, bloomzLM)
case StableLMBackend:
return ml.LoadModel(modelFile, stableLM)
case DollyBackend:
return ml.LoadModel(modelFile, dolly)
case RedPajamaBackend:
return ml.LoadModel(modelFile, redPajama)
case Gpt2Backend:
return ml.LoadModel(modelFile, gpt2LM)
case GPTNeoXBackend:
return ml.LoadModel(modelFile, gptNeoX)
case ReplitBackend:
return ml.LoadModel(modelFile, replit)
case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType)))
case Gpt4AllMptBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType)))
case Gpt4AllJBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType)))
case BertEmbeddingsBackend:
return ml.LoadModel(modelFile, bertEmbeddings)
case RwkvBackend:
return ml.LoadModel(modelFile, rwkvLM(modelFile+tokenizerSuffix, threads))
case WhisperBackend:
return ml.LoadModel(modelFile, whisperModel)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) {
log.Debug().Msgf("Loading models greedly")
ml.mu.Lock()
m, exists := ml.models[modelFile]
if exists {
ml.mu.Unlock()
return m, nil
}
ml.mu.Unlock()
var err error
for _, b := range backends {
if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv
continue
}
log.Debug().Msgf("[%s] Attempting to load", b)
model, modelerr := ml.BackendLoader(b, modelFile, llamaOpts, threads)
if modelerr == nil && model != nil {
log.Debug().Msgf("[%s] Loads OK", b)
return model, nil
} else if modelerr != nil {
err = multierror.Append(err, modelerr)
log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error())
}
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
}

View File

@@ -10,36 +10,22 @@ import (
"sync"
"text/template"
"github.com/hashicorp/go-multierror"
"github.com/rs/zerolog/log"
rwkv "github.com/donomii/go-rwkv.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
)
type ModelLoader struct {
ModelPath string
mu sync.Mutex
models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM
rwkv map[string]*rwkv.RwkvState
promptsTemplates map[string]*template.Template
// TODO: this needs generics
models map[string]interface{}
promptsTemplates map[string]*template.Template
}
func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{
ModelPath: modelPath,
gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
models: make(map[string]*llama.LLama),
rwkv: make(map[string]*rwkv.RwkvState),
promptsTemplates: make(map[string]*template.Template),
ModelPath: modelPath,
models: make(map[string]interface{}),
promptsTemplates: make(map[string]*template.Template),
}
}
@@ -124,143 +110,11 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil
}
func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) {
func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interface{}, error)) (interface{}, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.NewStableLM(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gptstablelmmodels[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.New(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gpt2models[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gptmodels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gptj.New(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gptmodels[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadRWKV(modelName, tokenFile string, threads uint32) (*rwkv.RwkvState, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
log.Debug().Msgf("Loading model name: %s", modelName)
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.rwkv[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
tokenPath := filepath.Join(ml.ModelPath, tokenFile)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model := rwkv.LoadFiles(modelFile, tokenPath, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}
ml.rwkv[modelName] = model
return model, nil
}
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
log.Debug().Msgf("Loading model name: %s", modelName)
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.models[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
@@ -270,7 +124,7 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := llama.New(modelFile, opts...)
model, err := loader(modelFile)
if err != nil {
return nil, err
}
@@ -281,85 +135,5 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
}
ml.models[modelName] = model
return model, err
}
const tokenizerSuffix = ".tokenizer.json"
var loadedModels map[string]interface{} = map[string]interface{}{}
var muModels sync.Mutex
func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
switch strings.ToLower(backendString) {
case "llama":
return ml.LoadLLaMAModel(modelFile, llamaOpts...)
case "stablelm":
return ml.LoadStableLMModel(modelFile)
case "gpt2":
return ml.LoadGPT2Model(modelFile)
case "gptj":
return ml.LoadGPTJModel(modelFile)
case "rwkv":
return ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
updateModels := func(model interface{}) {
muModels.Lock()
defer muModels.Unlock()
loadedModels[modelFile] = model
}
muModels.Lock()
m, exists := loadedModels[modelFile]
if exists {
muModels.Unlock()
return m, nil
}
muModels.Unlock()
model, modelerr := ml.LoadLLaMAModel(modelFile, llamaOpts...)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadGPTJModel(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadGPT2Model(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadStableLMModel(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
return model, nil
}

90
pkg/whisper/whisper.go Normal file
View File

@@ -0,0 +1,90 @@
package whisper
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
func sh(c string) (string, error) {
cmd := exec.Command("/bin/sh", "-c", c)
cmd.Env = os.Environ()
o, err := cmd.CombinedOutput()
return string(o), err
}
// AudioToWav converts audio to wav for transcribe. It bashes out to ffmpeg
// TODO: use https://github.com/mccoyst/ogg?
func audioToWav(src, dst string) error {
out, err := sh(fmt.Sprintf("ffmpeg -i %s -format s16le -ar 16000 -ac 1 -acodec pcm_s16le %s", src, dst))
if err != nil {
return fmt.Errorf("error: %w out: %s", err, out)
}
return nil
}
func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return "", err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := audioToWav(audiopath, convertedPath); err != nil {
return "", err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return "", err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return "", err
}
data := buf.AsFloat32Buffer().Data
// Process samples
context, err := model.NewContext()
if err != nil {
return "", err
}
context.SetThreads(threads)
if language != "" {
context.SetLanguage(language)
} else {
context.SetLanguage("auto")
}
if err := context.Process(data, nil); err != nil {
return "", err
}
text := ""
for {
segment, err := context.NextSegment()
if err != nil {
break
}
text += segment.Text
}
return text, nil
}

6
tests/fixtures/embeddings.yaml vendored Normal file
View File

@@ -0,0 +1,6 @@
name: text-embedding-ada-002
parameters:
model: bert
threads: 14
backend: bert-embeddings
embeddings: true

4
tests/fixtures/whisper.yaml vendored Normal file
View File

@@ -0,0 +1,4 @@
name: whisper-1
backend: whisper
parameters:
model: whisper-en