Compare commits

..

4 Commits

Author SHA1 Message Date
Ettore Di Giacinto
ed954d66c3 Do not take all CPU by default (#50) 2023-04-21 00:55:19 +02:00
Ettore Di Giacinto
f816dfae65 Add support for stablelm (#48)
Signed-off-by: mudler <mudler@mocaccino.org>
2023-04-21 00:06:55 +02:00
Ettore Di Giacinto
142bcd66ca Cleanup makefile, fix dep versions (#46)
Signed-off-by: mudler <mudler@c3os.io>
2023-04-20 19:49:06 +02:00
Ettore Di Giacinto
1c4fbaae20 Add support for cerebras (#45)
Signed-off-by: mudler <mudler@c3os.io>
2023-04-20 19:33:36 +02:00
7 changed files with 265 additions and 26 deletions

View File

@@ -3,6 +3,8 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai
GOLLAMA_VERSION?=llama.cpp-5ecff35
GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0
GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06
GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3)
@@ -17,19 +19,23 @@ all: help
## Build:
build: prepare ## Build the project
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./
buildgeneric: prepare-generic ## Build the project
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./
C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./
## GPT4ALL-J
go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j && cd go-gpt4all-j && git checkout -b build $(GOGPT4ALLJ_VERSION)
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./go-gpt4all-j -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j libgptj.a
@@ -37,6 +43,23 @@ go-gpt4all-j/libgptj.a: go-gpt4all-j
go-gpt4all-j/libgptj.a-generic: go-gpt4all-j
$(MAKE) -C go-gpt4all-j generic-libgptj.a
# CEREBRAS GPT
go-gpt2.cpp:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2.cpp && cd go-gpt2.cpp && git checkout -b build $(GOGPT2_VERSION)
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./go-gpt2.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} +
@find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} +
go-gpt2.cpp/libgpt2.a: go-gpt2.cpp
$(MAKE) -C go-gpt2.cpp libgpt2.a
go-gpt2.cpp/libgpt2.a-generic: go-gpt2.cpp
$(MAKE) -C go-gpt2.cpp generic-libgpt2.a
go-llama:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
$(MAKE) -C go-llama libbinding.a
@@ -45,17 +68,19 @@ go-llama-generic:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
$(MAKE) -C go-llama generic-libbinding.a
prepare: go-llama go-gpt4all-j/libgptj.a
replace:
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2.cpp
prepare: go-llama go-gpt4all-j/libgptj.a go-gpt2.cpp/libgpt2.a replace
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic go-gpt2.cpp/libgpt2.a-generic replace
prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
clean: ## Remove build related file
rm -fr ./go-llama
rm -rf ./go-gpt4all-j
rm -rf ./go-gpt2.cpp
rm -rf $(BINARY_NAME)
## Run:

View File

@@ -12,13 +12,22 @@ LocalAI is a straightforward, drop-in replacement API compatible with OpenAI for
- OpenAI compatible API
- Supports multiple-models
- Once loaded the first time, it keep models loaded in memory for faster inference
- Provides a simple command line interface that allows text generation directly from the terminal
- Support for prompt templates
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
## Model compatibility
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) and also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all).
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml).
Tested with:
- Vicuna
- Alpaca
- [GPT4ALL](https://github.com/nomic-ai/gpt4all)
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
- Koala
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
It should also be compatible with StableLM and GPTNeoX ggml models (untested)
Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
@@ -97,8 +106,6 @@ And you'll see:
└───────────────────────────────────────────────────┘
```
Note: Models have to end up with `.bin` so can be listed by the `/models` endpoint.
You can control the API server options with command line arguments:
```
@@ -110,7 +117,7 @@ The API takes takes the following parameters:
| Parameter | Environment Variable | Default Value | Description |
| ------------ | -------------------- | ------------- | -------------------------------------- |
| models-path | MODELS_PATH | | The path where you have models (ending with `.bin`). |
| threads | THREADS | CPU cores | The number of threads to use for text generation. |
| threads | THREADS | Number of Physical cores | The number of threads to use for text generation. |
| address | ADDRESS | :8080 | The address and port to listen on. |
| context-size | CONTEXT_SIZE | 512 | Default token context size. |

View File

@@ -8,6 +8,7 @@ import (
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
@@ -73,6 +74,8 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var err error
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
input := new(OpenAIRequest)
// Get input data from the request body
@@ -97,7 +100,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
}
// Try to load the model with both
var llamaerr error
var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
@@ -106,11 +109,18 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, err = loader.LoadGPTJModel(modelFile)
if err != nil {
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
}
}
}
@@ -176,6 +186,54 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var predFunc func() (string, error)
switch {
case stableLMModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return stableLMModel.Predict(
predInput,
predictOptions...,
)
}
case gpt2Model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return gpt2Model.Predict(
predInput,
predictOptions...,
)
}
case gptModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model

11
go.mod
View File

@@ -3,22 +3,31 @@ module github.com/go-skynet/LocalAI
go 1.19
require (
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640
github.com/gofiber/fiber/v2 v2.42.0
github.com/jaypipes/ghw v0.10.0
github.com/rs/zerolog v1.29.1
github.com/urfave/cli/v2 v2.25.0
)
require (
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jaypipes/pcidb v1.0.0 // indirect
github.com/klauspost/compress v1.15.9 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/philhofer/fwd v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect
@@ -29,4 +38,6 @@ require (
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/sys v0.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
howett.net/plist v1.0.0 // indirect
)

29
go.sum
View File

@@ -1,9 +1,19 @@
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 h1:GkGuqnhDFKlCsT6Bo8sdY00A7rFXCzfU1nBOSS4ZnYM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=
@@ -16,8 +26,16 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jaypipes/ghw v0.10.0 h1:UHu9UX08Py315iPojADFPOkmjTsNzHj4g4adsNKKteY=
github.com/jaypipes/ghw v0.10.0/go.mod h1:jeJGbkRB2lL3/gxYzNYzEDETV1ZJ56OKr+CSeSEym+g=
github.com/jaypipes/pcidb v1.0.0 h1:vtZIfkiCUE42oYbJS0TAq9XSfSmcsgo9IdxSm9qzYU8=
github.com/jaypipes/pcidb v1.0.0/go.mod h1:TnYUvqhPBzCKnH34KrIX22kAeEbDCSRJ9cqLRCuNDfk=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
@@ -27,10 +45,13 @@ github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPn
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ=
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@@ -71,6 +92,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -95,4 +117,11 @@ golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=

15
main.go
View File

@@ -2,13 +2,12 @@ package main
import (
"os"
"runtime"
api "github.com/go-skynet/LocalAI/api"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/jaypipes/ghw"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
)
@@ -21,6 +20,12 @@ func main() {
os.Exit(1)
}
threads := 4
cpu, err := ghw.CPU()
if err == nil {
threads = int(cpu.TotalCores)
}
app := &cli.App{
Name: "LocalAI",
Usage: "OpenAI compatible API for running LLaMA/GPT models locally on CPU with consumer grade hardware.",
@@ -37,7 +42,7 @@ func main() {
Name: "threads",
DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.",
EnvVars: []string{"THREADS"},
Value: runtime.NumCPU(),
Value: threads,
},
&cli.StringFlag{
Name: "models-path",
@@ -66,9 +71,11 @@ Some of the models compatible are:
- Koala
- GPT4ALL
- GPT4ALL-J
- Cerebras
- Alpaca
- StableLM (ggml quantized)
It uses llama.cpp and gpt4all as backend, supporting all the models supported by both.
It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
`,
UsageText: `local-ai [options]`,
Copyright: "go-skynet authors",

View File

@@ -12,20 +12,32 @@ import (
"github.com/rs/zerolog/log"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
)
type ModelLoader struct {
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM
promptsTemplates map[string]*template.Template
}
func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
return &ModelLoader{
modelPath: modelPath,
gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
models: make(map[string]*llama.LLama),
promptsTemplates: make(map[string]*template.Template),
}
}
func (ml *ModelLoader) ExistsInModelPath(s string) bool {
@@ -98,6 +110,77 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil
}
func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.NewStableLM(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gptstablelmmodels[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.New(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gpt2models[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
@@ -112,6 +195,17 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
return m, nil
}
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPT2 model loaded instead - if we do we return an error so the API tries with GPT2
if _, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one")
}
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
@@ -152,6 +246,14 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
log.Debug().Msgf("Model is GPTJ: %s", modelName)
return nil, fmt.Errorf("this model is a GPTJ one")
}
if _, ok := ml.gpt2models[modelName]; ok {
log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one")
}
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)