mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-04 03:32:40 -05:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bed9570e48 | ||
|
|
c6bf67f446 | ||
|
|
5ee186b8e5 | ||
|
|
94817b557c | ||
|
|
26e1496075 | ||
|
|
92fca8ae74 | ||
|
|
7fa5b8401d | ||
|
|
0eac0402e1 | ||
|
|
c71c729bc2 | ||
|
|
e459f114cd | ||
|
|
982a7e86a8 | ||
|
|
94916749c5 | ||
|
|
5ce5f87a26 | ||
|
|
1d2ae46ddc | ||
|
|
71ac331f90 | ||
|
|
47cc95fc9f | ||
|
|
3feb632eb4 | ||
|
|
236497e331 | ||
|
|
a38dc497b2 | ||
|
|
28ed52fa94 | ||
|
|
e995b95c94 | ||
|
|
8379cce209 | ||
|
|
3c6b798522 | ||
|
|
c18770a61a | ||
|
|
6352448b72 |
3
.github/workflows/bump_deps.yaml
vendored
3
.github/workflows/bump_deps.yaml
vendored
@@ -12,6 +12,9 @@ jobs:
|
||||
- repository: "go-skynet/go-llama.cpp"
|
||||
variable: "GOLLAMA_VERSION"
|
||||
branch: "master"
|
||||
- repository: "go-skynet/go-llama.cpp"
|
||||
variable: "GOLLAMA_GRAMMAR_VERSION"
|
||||
branch: "master"
|
||||
- repository: "go-skynet/go-ggml-transformers.cpp"
|
||||
variable: "GOGGMLTRANSFORMERS_VERSION"
|
||||
branch: "master"
|
||||
|
||||
32
.github/workflows/image.yml
vendored
32
.github/workflows/image.yml
vendored
@@ -59,6 +59,38 @@ jobs:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Release space from worker
|
||||
run: |
|
||||
echo "Listing top largest packages"
|
||||
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||
head -n 30 <<< "${pkgs}"
|
||||
echo
|
||||
df -h
|
||||
echo
|
||||
sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true
|
||||
sudo apt-get remove --auto-remove android-sdk-platform-tools || true
|
||||
sudo apt-get purge --auto-remove android-sdk-platform-tools || true
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo apt-get remove -y '^mono-.*' || true
|
||||
sudo apt-get remove -y '^ghc-.*' || true
|
||||
sudo apt-get remove -y '.*jdk.*|.*jre.*' || true
|
||||
sudo apt-get remove -y 'php.*' || true
|
||||
sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true
|
||||
sudo apt-get remove -y '^google-.*' || true
|
||||
sudo apt-get remove -y azure-cli || true
|
||||
sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true
|
||||
sudo apt-get remove -y '^gfortran-.*' || true
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get clean
|
||||
echo
|
||||
echo "Listing top largest packages"
|
||||
pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr)
|
||||
head -n 30 <<< "${pkgs}"
|
||||
echo
|
||||
sudo rm -rfv build || true
|
||||
df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
|
||||
sudo apt-get install -y ca-certificates cmake curl patch
|
||||
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
|
||||
sudo pip install -r extra/requirements.txt
|
||||
|
||||
sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \
|
||||
curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \
|
||||
@@ -45,7 +46,6 @@ jobs:
|
||||
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \
|
||||
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \
|
||||
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,9 +3,10 @@ go-llama
|
||||
/gpt4all
|
||||
go-stable-diffusion
|
||||
go-piper
|
||||
/go-bert
|
||||
go-ggllm
|
||||
/piper
|
||||
|
||||
__pycache__/
|
||||
*.a
|
||||
get-sources
|
||||
|
||||
|
||||
@@ -11,10 +11,15 @@ ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
|
||||
ENV BUILD_TYPE=${BUILD_TYPE}
|
||||
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/huggingface.py"
|
||||
ARG GO_TAGS="stablediffusion tts"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y ca-certificates cmake curl patch
|
||||
apt-get install -y ca-certificates cmake curl patch pip
|
||||
|
||||
# Extras requirements
|
||||
COPY extra/requirements.txt /build/extra/requirements.txt
|
||||
RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt
|
||||
|
||||
# CuBLAS requirements
|
||||
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \
|
||||
|
||||
54
Makefile
54
Makefile
@@ -5,19 +5,20 @@ BINARY_NAME=local-ai
|
||||
|
||||
# llama.cpp versions
|
||||
# Temporarly pinned to https://github.com/go-skynet/go-llama.cpp/pull/124
|
||||
GOLLAMA_VERSION?=cb8d7cd4cb95725a04504a9e3a26dd72a12b69ac
|
||||
GOLLAMA_VERSION?=f3a6ee0ef53d667f110d28fcf9b808bdca741c07
|
||||
|
||||
GOLLAMA_GRAMMAR_VERSION?=cb8d7cd4cb95725a04504a9e3a26dd72a12b69ac
|
||||
# Temporary set a specific version of llama.cpp
|
||||
# containing: https://github.com/ggerganov/llama.cpp/pull/1773 and
|
||||
# rebased on top of master.
|
||||
# This pin can be dropped when the PR above is merged, and go-llama has merged changes as well
|
||||
# Set empty to use the version pinned by go-llama
|
||||
LLAMA_CPP_REPO?=https://github.com/mudler/llama.cpp
|
||||
LLAMA_CPP_VERSION?=48ce8722a05a018681634af801fd0fd45b3a87cc
|
||||
LLAMA_CPP_GRAMMAR_REPO?=https://github.com/mudler/llama.cpp
|
||||
LLAMA_CPP_GRAMMAR_VERSION?=48ce8722a05a018681634af801fd0fd45b3a87cc
|
||||
|
||||
# gpt4all version
|
||||
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
|
||||
GPT4ALL_VERSION?=cfd70b69fcf5e587b8e0e3e9b9aaa90e19cbbc51
|
||||
GPT4ALL_VERSION?=8aba2c9009fb6bc723f623c614e265b41722e4e3
|
||||
|
||||
# go-ggml-transformers version
|
||||
GOGGMLTRANSFORMERS_VERSION?=ffb09d7dd71e2cbc6c5d7d05357d230eea6f369a
|
||||
@@ -30,7 +31,7 @@ RWKV_VERSION?=c898cd0f62df8f2a7830e53d1d513bef4f6f792b
|
||||
WHISPER_CPP_VERSION?=85ed71aaec8e0612a84c0b67804bde75aa75a273
|
||||
|
||||
# bert.cpp version
|
||||
BERT_VERSION?=6069103f54b9969c02e789d0fb12a23bd614285f
|
||||
BERT_VERSION?=6abe312cded14042f6b7c3cd8edf082713334a4d
|
||||
|
||||
# go-piper version
|
||||
PIPER_VERSION?=56b8a81b4760a6fbee1a82e62f007ae7e8f010a7
|
||||
@@ -188,7 +189,7 @@ go-ggml-transformers:
|
||||
cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
|
||||
|
||||
go-ggml-transformers/libtransformers.a: go-ggml-transformers
|
||||
$(MAKE) -C go-ggml-transformers libtransformers.a
|
||||
$(MAKE) -C go-ggml-transformers BUILD_TYPE=$(BUILD_TYPE) libtransformers.a
|
||||
|
||||
whisper.cpp:
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
@@ -200,21 +201,29 @@ whisper.cpp/libwhisper.a: whisper.cpp
|
||||
go-llama:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
|
||||
cd go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1
|
||||
ifneq ($(LLAMA_CPP_REPO),)
|
||||
cd go-llama && rm -rf llama.cpp && git clone $(LLAMA_CPP_REPO) llama.cpp && cd llama.cpp && git checkout -b build $(LLAMA_CPP_VERSION) && git submodule update --init --recursive --depth 1
|
||||
|
||||
go-llama-grammar:
|
||||
git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama-grammar
|
||||
cd go-llama-grammar && git checkout -b build $(GOLLAMA_GRAMMAR_VERSION) && git submodule update --init --recursive --depth 1
|
||||
ifneq ($(LLAMA_CPP_GRAMMAR_REPO),)
|
||||
cd go-llama-grammar && rm -rf llama.cpp && git clone $(LLAMA_CPP_GRAMMAR_REPO) llama.cpp && cd llama.cpp && git checkout -b build $(LLAMA_CPP_GRAMMAR_VERSION) && git submodule update --init --recursive --depth 1
|
||||
endif
|
||||
|
||||
go-llama/libbinding.a: go-llama
|
||||
$(MAKE) -C go-llama BUILD_TYPE=$(BUILD_TYPE) libbinding.a
|
||||
|
||||
go-llama-grammar/libbinding.a: go-llama-grammar
|
||||
$(MAKE) -C go-llama-grammar BUILD_TYPE=$(BUILD_TYPE) libbinding.a
|
||||
|
||||
go-piper/libpiper_binding.a:
|
||||
$(MAKE) -C go-piper libpiper_binding.a example/main
|
||||
|
||||
get-sources: go-llama go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion
|
||||
get-sources: go-llama go-ggllm go-llama-grammar go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion
|
||||
touch $@
|
||||
|
||||
replace:
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp-grammar=$(shell pwd)/go-llama-grammar
|
||||
$(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(shell pwd)/gpt4all/gpt4all-bindings/golang
|
||||
$(GOCMD) mod edit -replace github.com/go-skynet/go-ggml-transformers.cpp=$(shell pwd)/go-ggml-transformers
|
||||
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv
|
||||
@@ -232,6 +241,7 @@ prepare-sources: get-sources replace
|
||||
rebuild: ## Rebuilds the project
|
||||
$(GOCMD) clean -cache
|
||||
$(MAKE) -C go-llama clean
|
||||
$(MAKE) -C go-llama-grammar clean
|
||||
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean
|
||||
$(MAKE) -C go-ggml-transformers clean
|
||||
$(MAKE) -C go-rwkv clean
|
||||
@@ -272,9 +282,6 @@ build: grpcs prepare ## Build the project
|
||||
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
|
||||
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
|
||||
ifeq ($(BUILD_TYPE),metal)
|
||||
cp go-llama/build/bin/ggml-metal.metal .
|
||||
endif
|
||||
|
||||
dist: build
|
||||
mkdir -p release
|
||||
@@ -303,7 +310,7 @@ test: prepare test-models/testmodel grpcs
|
||||
@echo 'Running tests'
|
||||
export GO_TAGS="tts stablediffusion"
|
||||
$(MAKE) prepare-test
|
||||
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/huggingface.py TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
|
||||
$(MAKE) test-gpt4all
|
||||
$(MAKE) test-llama
|
||||
@@ -328,9 +335,7 @@ test-stablediffusion: prepare-test
|
||||
|
||||
test-container:
|
||||
docker build --target requirements -t local-ai-test-container .
|
||||
docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test
|
||||
docker rm localai-tests
|
||||
docker rmi local-ai-test-container
|
||||
docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
|
||||
|
||||
## Help:
|
||||
help: ## Show this help.
|
||||
@@ -344,10 +349,15 @@ help: ## Show this help.
|
||||
else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \
|
||||
}' $(MAKEFILE_LIST)
|
||||
|
||||
protogen:
|
||||
protogen: protogen-go protogen-python
|
||||
|
||||
protogen-go:
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
pkg/grpc/proto/backend.proto
|
||||
|
||||
protogen-python:
|
||||
python -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto
|
||||
|
||||
## GRPC
|
||||
|
||||
backend-assets/grpc:
|
||||
@@ -360,6 +370,14 @@ backend-assets/grpc/falcon: backend-assets/grpc go-ggllm/libggllm.a
|
||||
backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/
|
||||
# TODO: every binary should have its own folder instead, so can have different metal implementations
|
||||
ifeq ($(BUILD_TYPE),metal)
|
||||
cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/
|
||||
endif
|
||||
|
||||
backend-assets/grpc/llama-grammar: backend-assets/grpc go-llama-grammar/libbinding.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama-grammar LIBRARY_PATH=$(shell pwd)/go-llama-grammar \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-grammar ./cmd/grpc/llama-grammar/
|
||||
|
||||
backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \
|
||||
@@ -424,4 +442,4 @@ backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \
|
||||
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/
|
||||
|
||||
grpcs: prepare backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
|
||||
grpcs: prepare backend-assets/grpc/langchain-huggingface backend-assets/grpc/llama-grammar backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
|
||||
@@ -185,7 +185,7 @@ LocalAI can be installed inside Kubernetes with helm. See [installation instruct
|
||||
|
||||
## Supported API endpoints
|
||||
|
||||
See the [list of the supported API endpoints](https://localai.io/api-endpoints/index.html) and how to configure image generation and audio transcription.
|
||||
See the [list of the LocalAI features](https://localai.io/features/index.html) for a full tour of the available API endpoints.
|
||||
|
||||
## Frequently asked questions
|
||||
|
||||
|
||||
@@ -125,6 +125,11 @@ var _ = Describe("API test", func() {
|
||||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
|
||||
commonOpts := []options.AppOption{
|
||||
options.WithDebug(true),
|
||||
options.WithDisableMessage(true),
|
||||
}
|
||||
|
||||
Context("API with ephemeral models", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
@@ -143,7 +148,7 @@ var _ = Describe("API test", func() {
|
||||
Name: "bert2",
|
||||
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
|
||||
Overrides: map[string]interface{}{"foo": "bar"},
|
||||
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
|
||||
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
|
||||
},
|
||||
}
|
||||
out, err := yaml.Marshal(g)
|
||||
@@ -159,9 +164,10 @@ var _ = Describe("API test", func() {
|
||||
}
|
||||
|
||||
app, err = App(
|
||||
options.WithContext(c),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -291,7 +297,7 @@ var _ = Describe("API test", func() {
|
||||
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
|
||||
URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
|
||||
Name: "openllama_3b",
|
||||
Overrides: map[string]string{},
|
||||
Overrides: map[string]string{"backend": "llama-grammar"},
|
||||
})
|
||||
|
||||
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
|
||||
@@ -400,13 +406,14 @@ var _ = Describe("API test", func() {
|
||||
}
|
||||
|
||||
app, err = App(
|
||||
options.WithContext(c),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir),
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
@@ -500,7 +507,12 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader))
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -524,7 +536,7 @@ var _ = Describe("API test", func() {
|
||||
It("returns the models list", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(10))
|
||||
Expect(len(models.Models)).To(Equal(11))
|
||||
})
|
||||
It("can generate completions", func() {
|
||||
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
|
||||
@@ -555,9 +567,10 @@ var _ = Describe("API test", func() {
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
|
||||
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
|
||||
})
|
||||
It("transcribes audio", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
@@ -601,6 +614,36 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||
})
|
||||
|
||||
Context("External gRPC calls", func() {
|
||||
It("calculate embeddings with huggingface", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("test supported only on linux")
|
||||
}
|
||||
resp, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaCodeSearchCode,
|
||||
Input: []string{"sun", "cat"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
|
||||
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
|
||||
|
||||
sunEmbedding := resp.Data[0].Embedding
|
||||
resp2, err := client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
openai.EmbeddingRequest{
|
||||
Model: openai.AdaCodeSearchCode,
|
||||
Input: []string{"sun"},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
|
||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
|
||||
})
|
||||
})
|
||||
|
||||
Context("backends", func() {
|
||||
It("runs rwkv completion", func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
@@ -673,7 +716,12 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE")))
|
||||
app, err = App(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@@ -695,7 +743,7 @@ var _ = Describe("API test", func() {
|
||||
It("can generate chat completions from config file", func() {
|
||||
models, err := client.ListModels(context.TODO())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(models.Models)).To(Equal(12))
|
||||
Expect(len(models.Models)).To(Equal(13))
|
||||
})
|
||||
It("can generate chat completions from config file", func() {
|
||||
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
|
||||
|
||||
@@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
|
||||
@@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||
}
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModelFile(c.ImageGenerationAssets),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
@@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
||||
return ss, err
|
||||
} else {
|
||||
reply, err := inferenceModel.Predict(o.Context, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return reply.Message, err
|
||||
}
|
||||
}
|
||||
|
||||
42
api/backend/transcript.go
Normal file
42
api/backend/transcript.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModelFile(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return nil, fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Threads: uint32(c.Threads),
|
||||
})
|
||||
}
|
||||
72
api/backend/tts.go
Normal file
72
api/backend/tts.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(model.PiperBackend),
|
||||
model.WithModelFile(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return "", nil, fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.AudioDir, fileName)
|
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, modelFile)
|
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Dst: filePath,
|
||||
})
|
||||
|
||||
return filePath, res, err
|
||||
}
|
||||
@@ -49,6 +49,8 @@ type Config struct {
|
||||
functionCallString, functionCallNameString string
|
||||
|
||||
FunctionsConfig Functions `yaml:"function"`
|
||||
|
||||
SystemPrompt string `yaml:"system_prompt"`
|
||||
}
|
||||
|
||||
type Functions struct {
|
||||
@@ -58,10 +60,11 @@ type Functions struct {
|
||||
}
|
||||
|
||||
type TemplateConfig struct {
|
||||
Completion string `yaml:"completion"`
|
||||
Functions string `yaml:"function"`
|
||||
Chat string `yaml:"chat"`
|
||||
Edit string `yaml:"edit"`
|
||||
Chat string `yaml:"chat"`
|
||||
ChatMessage string `yaml:"chat_message"`
|
||||
Completion string `yaml:"completion"`
|
||||
Edit string `yaml:"edit"`
|
||||
Functions string `yaml:"function"`
|
||||
}
|
||||
|
||||
type ConfigLoader struct {
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.C:
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// updates the status with an error
|
||||
@@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
displayDownload(fileName, current, total, percentage)
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
var err error
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.galleryName != "" {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
if strings.Contains(op.galleryName, "@") {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
}
|
||||
} else {
|
||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
||||
}
|
||||
@@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
}()
|
||||
}
|
||||
|
||||
var lastProgress time.Time = time.Now()
|
||||
var startTime time.Time = time.Now()
|
||||
|
||||
func displayDownload(fileName string, current string, total string, percentage float64) {
|
||||
currentTime := time.Now()
|
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||
|
||||
lastProgress = currentTime
|
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration
|
||||
if percentage > 0 {
|
||||
elapsed := currentTime.Sub(startTime)
|
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||
}
|
||||
|
||||
if total != "" {
|
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||
} else {
|
||||
log.Debug().Msgf("Downloading: %s", current)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel
|
||||
ID string `json:"id"`
|
||||
@@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -20,22 +13,6 @@ type TTSRequest struct {
|
||||
Input string `json:"input" yaml:"input"`
|
||||
}
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
@@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
piperModel, err := o.Loader.BackendLoader(
|
||||
model.WithBackendString(model.PiperBackend),
|
||||
model.WithModelFile(input.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination))
|
||||
filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.AudioDir, fileName)
|
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
|
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: input.Input,
|
||||
Model: modelPath,
|
||||
Dst: filePath,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type OpenAIResponse struct {
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
|
||||
@@ -43,12 +43,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
model, input, err := readInput(c, o.Loader, true)
|
||||
modelFile, input, err := readInput(c, o.Loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
@@ -110,9 +110,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
var predInput string
|
||||
|
||||
mess := []string{}
|
||||
for _, i := range input.Messages {
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||
@@ -124,31 +125,55 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && *i.Content != ""
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, " ", *i.Content)
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: *i.Content,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, " ", *i.Content)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(*i.Content)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(*i.Content)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -181,10 +206,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
}
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
Functions []grammar.Function
|
||||
}{
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
@@ -302,7 +324,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
|
||||
return
|
||||
}
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}})
|
||||
*c = append(*c, Choice{FinishReason: "stop", Index: 0, Message: &Message{Role: "assistant", Content: &s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -38,14 +38,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.Loader, true)
|
||||
modelFile, input, err := readInput(c, o.Loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
@@ -76,9 +76,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
})
|
||||
if err == nil {
|
||||
@@ -122,11 +120,9 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
}
|
||||
|
||||
var result []Choice
|
||||
for _, i := range config.PromptStrings {
|
||||
for k, i := range config.PromptStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
}{
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
@@ -135,7 +131,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
}
|
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||
*c = append(*c, Choice{Text: s})
|
||||
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -6,18 +6,19 @@ import (
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o.Loader, true)
|
||||
modelFile, input, err := readInput(c, o.Loader, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
@@ -33,10 +34,10 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
var result []Choice
|
||||
for _, i := range config.InputStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct {
|
||||
Input string
|
||||
Instruction string
|
||||
}{Input: i})
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -9,10 +8,9 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModelFile(config.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(config.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: dst,
|
||||
Language: input.Language,
|
||||
Threads: uint32(config.Threads),
|
||||
})
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,6 +28,10 @@ type Option struct {
|
||||
|
||||
BackendAssets embed.FS
|
||||
AssetsDestination string
|
||||
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries bool
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
@@ -53,6 +57,19 @@ func WithCors(b bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *Option) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
o.ExternalGRPCBackends = make(map[string]string)
|
||||
}
|
||||
o.ExternalGRPCBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.CORSAllowOrigins = b
|
||||
|
||||
25
cmd/grpc/llama-grammar/main.go
Normal file
25
cmd/grpc/llama-grammar/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama-grammar"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,14 @@ A ready to use example to show e2e how to integrate LocalAI with langchain
|
||||
|
||||
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-python/)
|
||||
|
||||
### LocalAI functions
|
||||
|
||||
_by [@mudler](https://github.com/mudler)_
|
||||
|
||||
A ready to use example to show how to use OpenAI functions with LocalAI
|
||||
|
||||
[Check it out here](https://github.com/go-skynet/LocalAI/tree/master/examples/functions/)
|
||||
|
||||
### LocalAI WebUI
|
||||
|
||||
_by [@dhruvgera](https://github.com/dhruvgera)_
|
||||
|
||||
9
examples/functions/.env
Normal file
9
examples/functions/.env
Normal file
@@ -0,0 +1,9 @@
|
||||
OPENAI_API_KEY=sk---anystringhere
|
||||
OPENAI_API_BASE=http://api:8080/v1
|
||||
# Models to preload at start
|
||||
# Here we configure gpt4all as gpt-3.5-turbo and bert as embeddings
|
||||
PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/openllama-7b-open-instruct.yaml", "name": "gpt-3.5-turbo"}]
|
||||
|
||||
## Change the default number of threads
|
||||
#THREADS=14
|
||||
|
||||
5
examples/functions/Dockerfile
Normal file
5
examples/functions/Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
||||
FROM python:3.10-bullseye
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
ENTRYPOINT [ "python", "./functions-openai.py" ];
|
||||
18
examples/functions/README.md
Normal file
18
examples/functions/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# LocalAI functions
|
||||
|
||||
Example of using LocalAI functions, see the [OpenAI](https://openai.com/blog/function-calling-and-other-api-updates) blog post.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
# Clone LocalAI
|
||||
git clone https://github.com/go-skynet/LocalAI
|
||||
|
||||
cd LocalAI/examples/functions
|
||||
|
||||
docker-compose run --rm functions
|
||||
```
|
||||
|
||||
Note: The example automatically downloads the `openllama` model as it is under a permissive license.
|
||||
|
||||
See the `.env` configuration file to set a different model with the [model-gallery](https://github.com/go-skynet/model-gallery) by editing `PRELOAD_MODELS`.
|
||||
23
examples/functions/docker-compose.yaml
Normal file
23
examples/functions/docker-compose.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
api:
|
||||
image: quay.io/go-skynet/local-ai:master
|
||||
ports:
|
||||
- 8080:8080
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- DEBUG=true
|
||||
- MODELS_PATH=/models
|
||||
volumes:
|
||||
- ./models:/models:cached
|
||||
command: ["/usr/bin/local-ai" ]
|
||||
functions:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
depends_on:
|
||||
api:
|
||||
condition: service_healthy
|
||||
env_file:
|
||||
- .env
|
||||
76
examples/functions/functions-openai.py
Normal file
76
examples/functions/functions-openai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import openai
|
||||
import json
|
||||
|
||||
# Example dummy function hard coded to return the same weather
|
||||
# In production, this could be your backend API or an external API
|
||||
def get_current_weather(location, unit="fahrenheit"):
|
||||
"""Get the current weather in a given location"""
|
||||
weather_info = {
|
||||
"location": location,
|
||||
"temperature": "72",
|
||||
"unit": unit,
|
||||
"forecast": ["sunny", "windy"],
|
||||
}
|
||||
return json.dumps(weather_info)
|
||||
|
||||
|
||||
def run_conversation():
|
||||
# Step 1: send the conversation and available functions to GPT
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
functions = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call="auto", # auto is default, but we'll be explicit
|
||||
)
|
||||
response_message = response["choices"][0]["message"]
|
||||
|
||||
# Step 2: check if GPT wanted to call a function
|
||||
if response_message.get("function_call"):
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
available_functions = {
|
||||
"get_current_weather": get_current_weather,
|
||||
} # only one function in this example, but you can have multiple
|
||||
function_name = response_message["function_call"]["name"]
|
||||
fuction_to_call = available_functions[function_name]
|
||||
function_args = json.loads(response_message["function_call"]["arguments"])
|
||||
function_response = fuction_to_call(
|
||||
location=function_args.get("location"),
|
||||
unit=function_args.get("unit"),
|
||||
)
|
||||
|
||||
# Step 4: send the info on the function call and function response to GPT
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
}
|
||||
) # extend conversation with function response
|
||||
second_response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
) # get a new response from GPT where it can see the function response
|
||||
return second_response
|
||||
|
||||
|
||||
print(run_conversation())
|
||||
2
examples/functions/requirements.txt
Normal file
2
examples/functions/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
langchain==0.0.234
|
||||
openai==0.27.8
|
||||
@@ -22,7 +22,7 @@ services:
|
||||
- 'PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}, {"url": "github:go-skynet/model-gallery/stablediffusion.yaml"}, {"url": "github:go-skynet/model-gallery/whisper-base.yaml", "name": "whisper-1"}]'
|
||||
volumes:
|
||||
- ./models:/models:cached
|
||||
command: ["/usr/bin/local-ai" ]
|
||||
command: ["/usr/bin/local-ai"]
|
||||
chatgpt_telegram_bot:
|
||||
container_name: chatgpt_telegram_bot
|
||||
command: python3 bot/bot.py
|
||||
|
||||
49
extra/grpc/huggingface/backend_pb2.py
Normal file
49
extra/grpc/huggingface/backend_pb2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: backend.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
|
||||
_globals['_HEALTHMESSAGE']._serialized_start=26
|
||||
_globals['_HEALTHMESSAGE']._serialized_end=41
|
||||
_globals['_PREDICTOPTIONS']._serialized_start=44
|
||||
_globals['_PREDICTOPTIONS']._serialized_end=720
|
||||
_globals['_REPLY']._serialized_start=722
|
||||
_globals['_REPLY']._serialized_end=746
|
||||
_globals['_MODELOPTIONS']._serialized_start=749
|
||||
_globals['_MODELOPTIONS']._serialized_end=1049
|
||||
_globals['_RESULT']._serialized_start=1051
|
||||
_globals['_RESULT']._serialized_end=1093
|
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1095
|
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1132
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1134
|
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1201
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1203
|
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1281
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1283
|
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1372
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1375
|
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1533
|
||||
_globals['_TTSREQUEST']._serialized_start=1535
|
||||
_globals['_TTSREQUEST']._serialized_end=1589
|
||||
_globals['_BACKEND']._serialized_start=1592
|
||||
_globals['_BACKEND']._serialized_end=2083
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
297
extra/grpc/huggingface/backend_pb2_grpc.py
Normal file
297
extra/grpc/huggingface/backend_pb2_grpc.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import backend_pb2 as backend__pb2
|
||||
|
||||
|
||||
class BackendStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Health = channel.unary_unary(
|
||||
'/backend.Backend/Health',
|
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Predict = channel.unary_unary(
|
||||
'/backend.Backend/Predict',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.LoadModel = channel.unary_unary(
|
||||
'/backend.Backend/LoadModel',
|
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.PredictStream = channel.unary_stream(
|
||||
'/backend.Backend/PredictStream',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.Reply.FromString,
|
||||
)
|
||||
self.Embedding = channel.unary_unary(
|
||||
'/backend.Backend/Embedding',
|
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString,
|
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString,
|
||||
)
|
||||
self.GenerateImage = channel.unary_unary(
|
||||
'/backend.Backend/GenerateImage',
|
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
self.AudioTranscription = channel.unary_unary(
|
||||
'/backend.Backend/AudioTranscription',
|
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.TranscriptResult.FromString,
|
||||
)
|
||||
self.TTS = channel.unary_unary(
|
||||
'/backend.Backend/TTS',
|
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString,
|
||||
response_deserializer=backend__pb2.Result.FromString,
|
||||
)
|
||||
|
||||
|
||||
class BackendServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Health(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def AudioTranscription(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Health': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Health,
|
||||
request_deserializer=backend__pb2.HealthMessage.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'LoadModel': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.LoadModel,
|
||||
request_deserializer=backend__pb2.ModelOptions.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'PredictStream': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.PredictStream,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.Reply.SerializeToString,
|
||||
),
|
||||
'Embedding': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embedding,
|
||||
request_deserializer=backend__pb2.PredictOptions.FromString,
|
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
|
||||
),
|
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GenerateImage,
|
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AudioTranscription,
|
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString,
|
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
|
||||
),
|
||||
'TTS': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.TTS,
|
||||
request_deserializer=backend__pb2.TTSRequest.FromString,
|
||||
response_serializer=backend__pb2.Result.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'backend.Backend', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Backend(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Health(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
|
||||
backend__pb2.HealthMessage.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def LoadModel(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
|
||||
backend__pb2.ModelOptions.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def PredictStream(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.Reply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Embedding(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
|
||||
backend__pb2.PredictOptions.SerializeToString,
|
||||
backend__pb2.EmbeddingResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GenerateImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
|
||||
backend__pb2.GenerateImageRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def AudioTranscription(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
|
||||
backend__pb2.TranscriptRequest.SerializeToString,
|
||||
backend__pb2.TranscriptResult.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def TTS(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
|
||||
backend__pb2.TTSRequest.SerializeToString,
|
||||
backend__pb2.Result.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
67
extra/grpc/huggingface/huggingface.py
Executable file
67
extra/grpc/huggingface/huggingface.py
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message="OK")
|
||||
def LoadModel(self, request, context):
|
||||
model_name = request.Model
|
||||
model_name = os.path.basename(model_name)
|
||||
try:
|
||||
self.model = SentenceTransformer(model_name)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
def Embedding(self, request, context):
|
||||
# Implement your logic here for the Embedding service
|
||||
# Replace this with your desired response
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
sentence_embeddings = self.model.encode(request.Embeddings)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
4
extra/requirements.txt
Normal file
4
extra/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
sentence_transformers
|
||||
grpcio
|
||||
google
|
||||
protobuf
|
||||
1
go.mod
1
go.mod
@@ -39,6 +39,7 @@ require (
|
||||
require (
|
||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
|
||||
github.com/go-skynet/go-llama.cpp-grammar v0.0.0-20230703203849-ffa57fbc3a12 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golang/snappy v0.0.2 // indirect
|
||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||
|
||||
30
main.go
30
main.go
@@ -4,6 +4,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
api "github.com/go-skynet/LocalAI/api"
|
||||
@@ -40,6 +41,10 @@ func main() {
|
||||
Name: "f16",
|
||||
EnvVars: []string{"F16"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "autoload-galleries",
|
||||
EnvVars: []string{"AUTOLOAD_GALLERIES"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "debug",
|
||||
EnvVars: []string{"DEBUG"},
|
||||
@@ -108,6 +113,11 @@ func main() {
|
||||
EnvVars: []string{"BACKEND_ASSETS_PATH"},
|
||||
Value: "/tmp/localai/backend_data",
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "external-grpc-backends",
|
||||
Usage: "A list of external grpc backends",
|
||||
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
Usage: "Default context size of the model",
|
||||
@@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
UsageText: `local-ai [options]`,
|
||||
Copyright: "Ettore Di Giacinto",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
app, err := api.App(
|
||||
|
||||
opts := []options.AppOption{
|
||||
options.WithConfigFile(ctx.String("config-file")),
|
||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
@@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
options.WithThreads(ctx.Int("threads")),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")))
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range externalgRPC {
|
||||
backend := v[:strings.IndexByte(v, ':')]
|
||||
uri := v[strings.IndexByte(v, ':')+1:]
|
||||
opts = append(opts, options.WithExternalBackend(backend, uri))
|
||||
}
|
||||
|
||||
if ctx.Bool("autoload-galleries") {
|
||||
opts = append(opts, options.EnableGalleriesAutoload)
|
||||
}
|
||||
|
||||
app, err := api.App(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,23 +18,15 @@ type Gallery struct {
|
||||
|
||||
// Installs a model from the gallery (galleryname@modelname)
|
||||
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
installName := model.Name
|
||||
if req.Name != "" {
|
||||
model.Name = req.Name
|
||||
installName = req.Name
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
@@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
return err
|
||||
}
|
||||
|
||||
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
|
||||
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model, err := FindGallery(models, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
for _, model := range models {
|
||||
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
||||
return applyModel(model)
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no gallery found with name %q", name)
|
||||
}
|
||||
|
||||
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
|
||||
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
var model *GalleryModel
|
||||
for _, m := range models {
|
||||
if name == m.Name {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
if model == nil {
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
}
|
||||
|
||||
return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
|
||||
}
|
||||
|
||||
// List available models
|
||||
|
||||
@@ -18,9 +18,17 @@ func (f Functions) ToJSONStructure() JSONFunctionStructure {
|
||||
//tt := t.(string)
|
||||
|
||||
properties := function.Parameters["properties"]
|
||||
defs := function.Parameters["$defs"]
|
||||
dat, _ := json.Marshal(properties)
|
||||
dat2, _ := json.Marshal(defs)
|
||||
prop := map[string]interface{}{}
|
||||
defsD := map[string]interface{}{}
|
||||
|
||||
json.Unmarshal(dat, &prop)
|
||||
json.Unmarshal(dat2, &defsD)
|
||||
if js.Defs == nil {
|
||||
js.Defs = defsD
|
||||
}
|
||||
js.OneOf = append(js.OneOf, Item{
|
||||
Type: "object",
|
||||
Properties: Properties{
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
PRIMITIVE_RULES = map[string]string{
|
||||
"boolean": `("true" | "false") space`,
|
||||
"number": `[0-9]+ space`, // TODO complete
|
||||
"integer": `[0-9]+ space`, // TODO complete
|
||||
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
|
||||
"null": `"null" space`,
|
||||
}
|
||||
@@ -82,7 +83,7 @@ func (sc *JSONSchemaConverter) formatGrammar() string {
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string {
|
||||
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) string {
|
||||
st, existType := schema["type"]
|
||||
var schemaType string
|
||||
if existType {
|
||||
@@ -101,18 +102,21 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
|
||||
|
||||
if oneOfExists {
|
||||
for i, altSchema := range oneOfSchemas {
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
} else if anyOfExists {
|
||||
for i, altSchema := range anyOfSchemas {
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
|
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
|
||||
alternatives = append(alternatives, alternative)
|
||||
}
|
||||
}
|
||||
|
||||
rule := strings.Join(alternatives, " | ")
|
||||
return sc.addRule(ruleName, rule)
|
||||
} else if ref, exists := schema["$ref"].(string); exists {
|
||||
referencedSchema := sc.resolveReference(ref, rootSchema)
|
||||
return sc.visit(referencedSchema, name, rootSchema)
|
||||
} else if constVal, exists := schema["const"]; exists {
|
||||
return sc.addRule(ruleName, sc.formatLiteral(constVal))
|
||||
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||||
@@ -152,7 +156,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
|
||||
for i, propPair := range propPairs {
|
||||
propName := propPair.propName
|
||||
propSchema := propPair.propSchema
|
||||
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName))
|
||||
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
|
||||
|
||||
if i > 0 {
|
||||
rule.WriteString(` "," space`)
|
||||
@@ -164,7 +168,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
|
||||
rule.WriteString(` "}" space`)
|
||||
return sc.addRule(ruleName, rule.String())
|
||||
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||||
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName))
|
||||
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
|
||||
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||||
return sc.addRule(ruleName, rule)
|
||||
} else {
|
||||
@@ -175,9 +179,30 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
|
||||
return sc.addRule(schemaType, primitiveRule)
|
||||
}
|
||||
}
|
||||
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
|
||||
if !strings.HasPrefix(ref, "#/$defs/") {
|
||||
panic(fmt.Sprintf("Invalid reference format: %s", ref))
|
||||
}
|
||||
|
||||
defKey := strings.TrimPrefix(ref, "#/$defs/")
|
||||
definitions, exists := rootSchema["$defs"].(map[string]interface{})
|
||||
if !exists {
|
||||
fmt.Println(rootSchema)
|
||||
|
||||
panic("No definitions found in the schema")
|
||||
}
|
||||
|
||||
def, exists := definitions[defKey].(map[string]interface{})
|
||||
if !exists {
|
||||
fmt.Println(definitions)
|
||||
|
||||
panic(fmt.Sprintf("Definition not found: %s", defKey))
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
|
||||
sc.visit(schema, "")
|
||||
sc.visit(schema, "", schema)
|
||||
return sc.formatGrammar()
|
||||
}
|
||||
|
||||
@@ -212,8 +237,9 @@ type Item struct {
|
||||
}
|
||||
|
||||
type JSONFunctionStructure struct {
|
||||
OneOf []Item `json:"oneOf,omitempty"`
|
||||
AnyOf []Item `json:"anyOf,omitempty"`
|
||||
OneOf []Item `json:"oneOf,omitempty"`
|
||||
AnyOf []Item `json:"anyOf,omitempty"`
|
||||
Defs map[string]interface{} `json:"$defs,omitempty"`
|
||||
}
|
||||
|
||||
func (j JSONFunctionStructure) Grammar(propOrder string) string {
|
||||
|
||||
170
pkg/grpc/llm/llama-grammar/llama.go
Normal file
170
pkg/grpc/llm/llama-grammar/llama.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package llama
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/go-llama.cpp-grammar"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
base.Base
|
||||
|
||||
llama *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
|
||||
if opts.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
if opts.F16Memory {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if opts.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if opts.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.Model, llamaOpts...)
|
||||
llm.llama = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(float64(opts.Temperature)),
|
||||
llama.SetTopP(float64(opts.TopP)),
|
||||
llama.SetTopK(int(opts.TopK)),
|
||||
llama.SetTokens(int(opts.Tokens)),
|
||||
llama.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar))
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA)))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU)))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty)))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.F16KV {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty)))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ)))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP)))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
|
||||
}
|
||||
@@ -71,8 +71,6 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar))
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
|
||||
@@ -3,7 +3,9 @@ package tts
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
@@ -16,6 +18,9 @@ type Piper struct {
|
||||
}
|
||||
|
||||
func (sd *Piper) Load(opts *pb.ModelOptions) error {
|
||||
if filepath.Ext(opts.Model) != ".onnx" {
|
||||
return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model)
|
||||
}
|
||||
var err error
|
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.piper, err = New(opts.LibrarySearchPath)
|
||||
|
||||
@@ -19,8 +19,6 @@ import (
|
||||
process "github.com/mudler/go-processmanager"
|
||||
)
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
const (
|
||||
LlamaBackend = "llama"
|
||||
BloomzBackend = "bloomz"
|
||||
@@ -37,6 +35,7 @@ const (
|
||||
Gpt4All = "gpt4all"
|
||||
FalconBackend = "falcon"
|
||||
FalconGGMLBackend = "falcon-ggml"
|
||||
LlamaGrammarBackend = "llama-grammar"
|
||||
|
||||
BertEmbeddingsBackend = "bert-embeddings"
|
||||
RwkvBackend = "rwkv"
|
||||
@@ -44,17 +43,15 @@ const (
|
||||
StableDiffusionBackend = "stablediffusion"
|
||||
PiperBackend = "piper"
|
||||
LCHuggingFaceBackend = "langchain-huggingface"
|
||||
//GGLLMFalconBackend = "falcon"
|
||||
)
|
||||
|
||||
var autoLoadBackends []string = []string{
|
||||
var AutoLoadBackends []string = []string{
|
||||
LlamaBackend,
|
||||
Gpt4All,
|
||||
RwkvBackend,
|
||||
FalconBackend,
|
||||
WhisperBackend,
|
||||
GPTNeoXBackend,
|
||||
BertEmbeddingsBackend,
|
||||
LlamaGrammarBackend,
|
||||
FalconGGMLBackend,
|
||||
GPTJBackend,
|
||||
Gpt2Backend,
|
||||
@@ -63,6 +60,10 @@ var autoLoadBackends []string = []string{
|
||||
ReplitBackend,
|
||||
StarcoderBackend,
|
||||
BloomzBackend,
|
||||
RwkvBackend,
|
||||
WhisperBackend,
|
||||
StableDiffusionBackend,
|
||||
PiperBackend,
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) StopGRPC() {
|
||||
@@ -71,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() {
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
|
||||
// Make sure the process is executable
|
||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
||||
|
||||
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(grpcProcess),
|
||||
process.WithArgs("--addr", serverAddress))
|
||||
|
||||
ml.grpcProcesses[id] = grpcControlProcess
|
||||
|
||||
if err := grpcControlProcess.Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
grpcControlProcess.Stop()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stderr")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stdout")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
||||
return func(s string) (*grpc.Client, error) {
|
||||
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
||||
log.Debug().Msgf("Loading GRPC Model %s: %+v", backend, *o)
|
||||
|
||||
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
||||
var client *grpc.Client
|
||||
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||
}
|
||||
|
||||
// Make sure the process is executable
|
||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
||||
port, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverAddress := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(grpcProcess),
|
||||
process.WithArgs("--addr", serverAddress))
|
||||
|
||||
ml.grpcProcesses[o.modelFile] = grpcControlProcess
|
||||
|
||||
if err := grpcControlProcess.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
grpcControlProcess.Stop()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||
getFreeAddress := func() (string, error) {
|
||||
port, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stderr")
|
||||
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
||||
return fmt.Sprintf("127.0.0.1:%d", port), nil
|
||||
}
|
||||
|
||||
// Check if the backend is provided as external
|
||||
if uri, ok := o.externalBackends[backend]; ok {
|
||||
log.Debug().Msgf("Loading external backend: %s", uri)
|
||||
// check if uri is a file or a address
|
||||
if _, err := os.Stat(uri); err == nil {
|
||||
serverAddress, err := getFreeAddress()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
// Make sure the process is executable
|
||||
if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service Started")
|
||||
|
||||
client = grpc.NewClient(serverAddress)
|
||||
} else {
|
||||
// address
|
||||
client = grpc.NewClient(uri)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
||||
} else {
|
||||
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||
}
|
||||
|
||||
serverAddress, err := getFreeAddress()
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Could not tail stdout")
|
||||
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
||||
|
||||
// Make sure the process is executable
|
||||
if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug().Msgf("GRPC Service Started")
|
||||
log.Debug().Msgf("GRPC Service Started")
|
||||
|
||||
client := grpc.NewClient(serverAddress)
|
||||
client = grpc.NewClient(serverAddress)
|
||||
}
|
||||
|
||||
// Wait for the service to start up
|
||||
ready := false
|
||||
@@ -154,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
||||
|
||||
if !ready {
|
||||
log.Debug().Msgf("GRPC Service NOT ready")
|
||||
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
|
||||
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
|
||||
|
||||
log.Debug().Msgf(grpcControlProcess.ExitCode())
|
||||
|
||||
return nil, fmt.Errorf("grpc service not ready")
|
||||
}
|
||||
|
||||
@@ -169,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
||||
|
||||
res, err := client.LoadModel(o.context, &options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("could not load model: %w", err)
|
||||
}
|
||||
if !res.Success {
|
||||
return nil, fmt.Errorf("could not load model: %s", res.Message)
|
||||
return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
@@ -185,8 +222,15 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
|
||||
|
||||
backend := strings.ToLower(o.backendString)
|
||||
|
||||
// if an external backend is provided, use it
|
||||
_, externalBackendExists := o.externalBackends[backend]
|
||||
if externalBackendExists {
|
||||
return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
|
||||
}
|
||||
|
||||
switch backend {
|
||||
case LlamaBackend, GPTJBackend, DollyBackend,
|
||||
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
|
||||
MPTBackend, Gpt2Backend, FalconBackend,
|
||||
GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend,
|
||||
RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend:
|
||||
@@ -205,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
|
||||
|
||||
// Is this really needed? BackendLoader already does this
|
||||
ml.mu.Lock()
|
||||
if m := ml.checkIsLoaded(o.modelFile); m != nil {
|
||||
@@ -217,25 +259,38 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
ml.mu.Unlock()
|
||||
var err error
|
||||
|
||||
for _, b := range autoLoadBackends {
|
||||
if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("[%s] Attempting to load", b)
|
||||
// autoload also external backends
|
||||
allBackendsToAutoLoad := []string{}
|
||||
allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...)
|
||||
for _, b := range o.externalBackends {
|
||||
allBackendsToAutoLoad = append(allBackendsToAutoLoad, b)
|
||||
}
|
||||
log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", "))
|
||||
|
||||
model, modelerr := ml.BackendLoader(
|
||||
for _, b := range allBackendsToAutoLoad {
|
||||
log.Debug().Msgf("[%s] Attempting to load", b)
|
||||
options := []Option{
|
||||
WithBackendString(b),
|
||||
WithModelFile(o.modelFile),
|
||||
WithLoadGRPCLLMModelOpts(o.gRPCOptions),
|
||||
WithThreads(o.threads),
|
||||
WithAssetDir(o.assetDir),
|
||||
)
|
||||
}
|
||||
|
||||
for k, v := range o.externalBackends {
|
||||
options = append(options, WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
model, modelerr := ml.BackendLoader(options...)
|
||||
if modelerr == nil && model != nil {
|
||||
log.Debug().Msgf("[%s] Loads OK", b)
|
||||
return model, nil
|
||||
} else if modelerr != nil {
|
||||
err = multierror.Append(err, modelerr)
|
||||
log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error())
|
||||
} else if model == nil {
|
||||
err = multierror.Append(err, fmt.Errorf("backend returned no usable model"))
|
||||
log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,43 +4,81 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Rather than pass an interface{} to the prompt template:
|
||||
// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file
|
||||
// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values.
|
||||
type PromptTemplateData struct {
|
||||
Input string
|
||||
Instruction string
|
||||
Functions []grammar.Function
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
// TODO: Ask mudler about FunctionCall stuff being useful at the message level?
|
||||
type ChatMessageTemplateData struct {
|
||||
SystemPrompt string
|
||||
Role string
|
||||
RoleName string
|
||||
Content string
|
||||
MessageIndex int
|
||||
}
|
||||
|
||||
// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
|
||||
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
|
||||
type TemplateType int
|
||||
|
||||
const (
|
||||
ChatPromptTemplate TemplateType = iota
|
||||
ChatMessageTemplate
|
||||
CompletionPromptTemplate
|
||||
EditPromptTemplate
|
||||
FunctionsPromptTemplate
|
||||
|
||||
// The following TemplateType is **NOT** a valid value and MUST be last. It exists to make the sanity integration tests simpler!
|
||||
IntegrationTestTemplate
|
||||
)
|
||||
|
||||
// new idea: what if we declare a struct of these here, and use a loop to check?
|
||||
|
||||
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
|
||||
type ModelLoader struct {
|
||||
ModelPath string
|
||||
mu sync.Mutex
|
||||
// TODO: this needs generics
|
||||
models map[string]*grpc.Client
|
||||
grpcProcesses map[string]*process.Process
|
||||
promptsTemplates map[string]*template.Template
|
||||
models map[string]*grpc.Client
|
||||
grpcProcesses map[string]*process.Process
|
||||
templates map[TemplateType]map[string]*template.Template
|
||||
}
|
||||
|
||||
func NewModelLoader(modelPath string) *ModelLoader {
|
||||
return &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
models: make(map[string]*grpc.Client),
|
||||
promptsTemplates: make(map[string]*template.Template),
|
||||
grpcProcesses: make(map[string]*process.Process),
|
||||
nml := &ModelLoader{
|
||||
ModelPath: modelPath,
|
||||
models: make(map[string]*grpc.Client),
|
||||
templates: make(map[TemplateType]map[string]*template.Template),
|
||||
grpcProcesses: make(map[string]*process.Process),
|
||||
}
|
||||
nml.initializeTemplateMap()
|
||||
return nml
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ExistsInModelPath(s string) bool {
|
||||
_, err := os.Stat(filepath.Join(ml.ModelPath, s))
|
||||
return err == nil
|
||||
return existsInPath(ml.ModelPath, s)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||
files, err := ioutil.ReadDir(ml.ModelPath)
|
||||
files, err := os.ReadDir(ml.ModelPath)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
@@ -58,63 +96,6 @@ func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
m, ok := ml.promptsTemplates[modelName]
|
||||
if !ok {
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
t, exists := ml.promptsTemplates[modelName]
|
||||
if exists {
|
||||
m = t
|
||||
}
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading any template")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := ml.promptsTemplates[modelName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
// skip any error here - we run anyway if a template does not exist
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName)
|
||||
|
||||
if !ml.ExistsInModelPath(modelTemplateFile) {
|
||||
return nil
|
||||
}
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ml.promptsTemplates[modelName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
@@ -134,10 +115,13 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there is a prompt template, load it
|
||||
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Add a helper method to iterate all prompt templates associated with a config if and only if it's YAML?
|
||||
// Minor perf loss here until this is fixed, but we initialize on first request
|
||||
|
||||
// // If there is a prompt template, load it
|
||||
// if err := ml.loadTemplateIfExists(modelName); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
ml.models[modelName] = model
|
||||
return model, nil
|
||||
@@ -148,9 +132,9 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
|
||||
log.Debug().Msgf("Model already loaded in memory: %s", s)
|
||||
|
||||
if !m.HealthCheck(context.Background()) {
|
||||
log.Debug().Msgf("GRPC Model not responding", s)
|
||||
log.Debug().Msgf("GRPC Model not responding: %s", s)
|
||||
if !ml.grpcProcesses[s].IsAlive() {
|
||||
log.Debug().Msgf("GRPC Process is not responding", s)
|
||||
log.Debug().Msgf("GRPC Process is not responding: %s", s)
|
||||
// stop and delete the process, this forces to re-load the model and re-create again the service
|
||||
ml.grpcProcesses[s].Stop()
|
||||
delete(ml.grpcProcesses, s)
|
||||
@@ -164,3 +148,81 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) {
|
||||
// TODO: should this check be improved?
|
||||
if templateType == ChatMessageTemplate {
|
||||
return "", fmt.Errorf("invalid templateType: ChatMessage")
|
||||
}
|
||||
return ml.evaluateTemplate(templateType, templateName, in)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) {
|
||||
return ml.evaluateTemplate(ChatMessageTemplate, templateName, messageData)
|
||||
}
|
||||
|
||||
func existsInPath(path string, s string) bool {
|
||||
_, err := os.Stat(filepath.Join(path, s))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) initializeTemplateMap() {
|
||||
// This also seems somewhat clunky as we reference the Test / End of valid data value slug, but it works?
|
||||
for tt := TemplateType(0); tt < IntegrationTestTemplate; tt++ {
|
||||
ml.templates[tt] = make(map[string]*template.Template)
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) evaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
m, ok := ml.templates[templateType][templateName]
|
||||
if !ok {
|
||||
// return "", fmt.Errorf("template not loaded: %s", templateName)
|
||||
loadErr := ml.loadTemplateIfExists(templateType, templateName)
|
||||
if loadErr != nil {
|
||||
return "", loadErr
|
||||
}
|
||||
m = ml.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
|
||||
}
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("failed loading a template for %s", templateName)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateName string) error {
|
||||
// Check if the template was already loaded
|
||||
if _, ok := ml.templates[templateType][templateName]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
// skip any error here - we run anyway if a template does not exist
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)
|
||||
|
||||
if !ml.ExistsInModelPath(modelTemplateFile) {
|
||||
return nil
|
||||
}
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ml.templates[templateType][templateName] = tmpl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,10 +14,21 @@ type Options struct {
|
||||
context context.Context
|
||||
|
||||
gRPCOptions *pb.ModelOptions
|
||||
|
||||
externalBackends map[string]string
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
|
||||
func WithExternalBackend(name string, uri string) Option {
|
||||
return func(o *Options) {
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = make(map[string]string)
|
||||
}
|
||||
o.externalBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendString(backend string) Option {
|
||||
return func(o *Options) {
|
||||
o.backendString = backend
|
||||
|
||||
37
pkg/utils/logging.go
Normal file
37
pkg/utils/logging.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var lastProgress time.Time = time.Now()
|
||||
var startTime time.Time = time.Now()
|
||||
|
||||
func ResetDownloadTimers() {
|
||||
lastProgress = time.Now()
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) {
|
||||
currentTime := time.Now()
|
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||
|
||||
lastProgress = currentTime
|
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration
|
||||
if percentage > 0 {
|
||||
elapsed := currentTime.Sub(startTime)
|
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||
}
|
||||
|
||||
if total != "" {
|
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||
} else {
|
||||
log.Debug().Msgf("Downloading: %s", current)
|
||||
}
|
||||
}
|
||||
}
|
||||
7
prompt-templates/llama2-chat-message.tmpl
Normal file
7
prompt-templates/llama2-chat-message.tmpl
Normal file
@@ -0,0 +1,7 @@
|
||||
{{if eq .RoleName "assistant"}}{{.Content}}{{else}}
|
||||
[INST]
|
||||
{{if .SystemPrompt}}{{.SystemPrompt}}{{else if eq .RoleName "system"}}<<SYS>>{{.Content}}<</SYS>>
|
||||
|
||||
{{else if .Content}}{{.Content}}{{end}}
|
||||
[/INST]
|
||||
{{end}}
|
||||
23
tests/integration/reflect_test.go
Normal file
23
tests/integration/reflect_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
||||
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() {
|
||||
|
||||
ttc := reflect.TypeOf(config.TemplateConfig{})
|
||||
|
||||
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
||||
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
||||
Expect(lastValidTemplateType).To(Equal(ttc.NumField()))
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
5
tests/models_fixtures/grpc.yaml
Normal file
5
tests/models_fixtures/grpc.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
name: code-search-ada-code-001
|
||||
backend: huggingface
|
||||
embeddings: true
|
||||
parameters:
|
||||
model: all-MiniLM-L6-v2
|
||||
Reference in New Issue
Block a user