mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 11:13:31 -05:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
624092cb99 | ||
|
|
a422a883ac | ||
|
|
7858a97254 | ||
|
|
5556aa46dd | ||
|
|
eb4257f946 | ||
|
|
ae30bd346d | ||
|
|
93d8977ba2 | ||
|
|
f43aeeb4a1 | ||
|
|
c17dcc5e9d | ||
|
|
4a932483e1 | ||
|
|
b710147b95 | ||
|
|
ba70363330 | ||
|
|
9fb581739b | ||
|
|
48aca246e3 | ||
|
|
12eee097b7 | ||
|
|
b33d015b8c | ||
|
|
b7c0a108f5 | ||
|
|
f694a89c28 | ||
|
|
be682e6c2f | ||
|
|
bf85a31f9e | ||
|
|
d69048e0b0 | ||
|
|
827f189163 | ||
|
|
a23deb5ec7 | ||
|
|
999676b106 | ||
|
|
c61b023bc8 | ||
|
|
650a22aef1 | ||
|
|
17b1724f7c | ||
|
|
e860e62036 | ||
|
|
1f45ff8cd6 | ||
|
|
abee34f60a | ||
|
|
dbc70dc13c | ||
|
|
55142065eb | ||
|
|
d83d2293b5 | ||
|
|
467ce5a7aa |
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
llama-cli
|
||||
25
Earthfile
25
Earthfile
@@ -11,37 +11,22 @@ go-deps:
|
||||
SAVE ARTIFACT go.mod AS LOCAL go.mod
|
||||
SAVE ARTIFACT go.sum AS LOCAL go.sum
|
||||
|
||||
model-image:
|
||||
ARG MODEL_IMAGE=quay.io/go-skynet/models:ggml2-alpaca-7b-v0.2
|
||||
FROM $MODEL_IMAGE
|
||||
SAVE ARTIFACT /models/model.bin
|
||||
|
||||
build:
|
||||
FROM +go-deps
|
||||
WORKDIR /build
|
||||
RUN git clone https://github.com/go-skynet/llama
|
||||
RUN cd llama && make libllama.a
|
||||
RUN git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp
|
||||
RUN cd go-llama.cpp && make libbinding.a
|
||||
COPY . .
|
||||
RUN C_INCLUDE_PATH=/build/llama LIBRARY_PATH=/build/llama go build -o llama-cli ./
|
||||
RUN go mod edit -replace github.com/go-skynet/go-llama.cpp=/build/go-llama.cpp
|
||||
RUN C_INCLUDE_PATH=$GOPATH/src/github.com/go-skynet/go-llama.cpp LIBRARY_PATH=$GOPATH/src/github.com/go-skynet/go-llama.cpp go build -o llama-cli ./
|
||||
SAVE ARTIFACT llama-cli AS LOCAL llama-cli
|
||||
|
||||
image:
|
||||
FROM +go-deps
|
||||
ARG IMAGE=alpaca-cli
|
||||
COPY +model-image/model.bin /model.bin
|
||||
ARG IMAGE=alpaca-cli-nomodel
|
||||
COPY +build/llama-cli /llama-cli
|
||||
ENV MODEL_PATH=/model.bin
|
||||
ENTRYPOINT [ "/llama-cli" ]
|
||||
SAVE IMAGE --push $IMAGE
|
||||
|
||||
lite-image:
|
||||
FROM +go-deps
|
||||
ARG IMAGE=alpaca-cli-nomodel
|
||||
COPY +build/llama-cli /llama-cli
|
||||
ENV MODEL_PATH=/model.bin
|
||||
ENTRYPOINT [ "/llama-cli" ]
|
||||
SAVE IMAGE --push $IMAGE-lite
|
||||
|
||||
image-all:
|
||||
BUILD --platform=linux/amd64 --platform=linux/arm64 +image
|
||||
BUILD --platform=linux/amd64 --platform=linux/arm64 +lite-image
|
||||
177
README.md
177
README.md
@@ -1,14 +1,29 @@
|
||||
## :camel: llama-cli
|
||||
|
||||
|
||||
llama-cli is a straightforward golang CLI interface for [llama.cpp](https://github.com/ggerganov/llama.cpp), providing a simple API and a command line interface that allows text generation using a GPT-based model like llama directly from the terminal.
|
||||
llama-cli is a straightforward golang CLI interface and API compatible with OpenAI for [llama.cpp](https://github.com/ggerganov/llama.cpp), it supports multiple-models and also provides a simple command line interface that allows text generation using a GPT-based model like llama directly from the terminal.
|
||||
|
||||
It is compatible with the models supported by `llama.cpp`. You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
|
||||
|
||||
`llama-cli` doesn't shell-out, it uses https://github.com/go-skynet/go-llama.cpp, which is a golang binding of [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
|
||||
## Container images
|
||||
|
||||
The `llama-cli` [container images](https://quay.io/repository/go-skynet/llama-cli?tab=tags&tag=latest) come preloaded with the [alpaca.cpp 7B](https://github.com/antimatter15/alpaca.cpp) model, enabling you to start making predictions immediately! To begin, run:
|
||||
`llama-cli` comes by default as a container image. You can check out all the available images with corresponding tags [here](https://quay.io/repository/go-skynet/llama-cli?tab=tags&tag=latest)
|
||||
|
||||
To begin, run:
|
||||
|
||||
```
|
||||
docker run -ti --rm quay.io/go-skynet/llama-cli:v0.2 --instruction "What's an alpaca?" --topk 10000
|
||||
docker run -ti --rm quay.io/go-skynet/llama-cli:v0.6 --instruction "What's an alpaca?" --topk 10000 --model ...
|
||||
```
|
||||
|
||||
Where `--model` is the path of the model you want to use.
|
||||
|
||||
Note: you need to mount a volume to the docker container in order to load a model, for instance:
|
||||
|
||||
```
|
||||
# assuming your model is in /path/to/your/models/foo.bin
|
||||
docker run -v /path/to/your/models:/models -ti --rm quay.io/go-skynet/llama-cli:v0.6 --instruction "What's an alpaca?" --topk 10000 --model /models/foo.bin
|
||||
```
|
||||
|
||||
You will receive a response like the following:
|
||||
@@ -37,7 +52,6 @@ llama-cli --model <model_path> --instruction <instruction> [--input <input>] [--
|
||||
| top_p | TOP_P | 0.85 | The cumulative probability for top-p sampling. |
|
||||
| top_k | TOP_K | 20 | The number of top-k tokens to consider for text generation. |
|
||||
| context-size | CONTEXT_SIZE | 512 | Default token context size. |
|
||||
| alpaca | ALPACA | true | Set to true for alpaca models. |
|
||||
|
||||
Here's an example of using `llama-cli`:
|
||||
|
||||
@@ -47,14 +61,14 @@ llama-cli --model ~/ggml-alpaca-7b-q4.bin --instruction "What's an alpaca?"
|
||||
|
||||
This will generate text based on the given model and instruction.
|
||||
|
||||
## Advanced usage
|
||||
## API
|
||||
|
||||
`llama-cli` also provides an API for running text generation as a service.
|
||||
`llama-cli` also provides an API for running text generation as a service. The models once loaded the first time will be kept in memory.
|
||||
|
||||
Example of starting the API with `docker`:
|
||||
|
||||
```bash
|
||||
docker run -p 8080:8080 -ti --rm quay.io/go-skynet/llama-cli:v0.2 api
|
||||
docker run -p 8080:8080 -ti --rm quay.io/go-skynet/llama-cli:v0.6 api --models-path /path/to/models --context-size 700 --threads 4
|
||||
```
|
||||
|
||||
And you'll see:
|
||||
@@ -69,35 +83,70 @@ And you'll see:
|
||||
└───────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Note: Models have to end up with `.bin`.
|
||||
|
||||
You can control the API server options with command line arguments:
|
||||
|
||||
```
|
||||
llama-cli api --model <model_path> [--address <address>] [--threads <num_threads>]
|
||||
llama-cli api --models-path <model_path> [--address <address>] [--threads <num_threads>]
|
||||
```
|
||||
|
||||
The API takes takes the following:
|
||||
|
||||
| Parameter | Environment Variable | Default Value | Description |
|
||||
| ------------ | -------------------- | ------------- | -------------------------------------- |
|
||||
| model | MODEL_PATH | | The path to the pre-trained GPT-based model. |
|
||||
| models-path | MODELS_PATH | | The path where you have models (ending with `.bin`). |
|
||||
| threads | THREADS | CPU cores | The number of threads to use for text generation. |
|
||||
| address | ADDRESS | :8080 | The address and port to listen on. |
|
||||
| context-size | CONTEXT_SIZE | 512 | Default token context size. |
|
||||
| alpaca | ALPACA | true | Set to true for alpaca models. |
|
||||
|
||||
Once the server is running, you can start making requests to it using HTTP, using the OpenAI API.
|
||||
|
||||
Once the server is running, you can make requests to it using HTTP. For example, to generate text based on an instruction, you can send a POST request to the `/predict` endpoint with the instruction as the request body:
|
||||
### Supported OpenAI API endpoints
|
||||
|
||||
You can check out the [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
|
||||
Following the list of endpoints/parameters supported.
|
||||
|
||||
#### Chat completions
|
||||
|
||||
For example, to generate a chat completion, you can send a POST request to the `/v1/chat/completions` endpoint with the instruction as the request body:
|
||||
|
||||
```
|
||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
||||
"text": "What is an alpaca?",
|
||||
"topP": 0.8,
|
||||
"topK": 50,
|
||||
"temperature": 0.7,
|
||||
"tokens": 100
|
||||
}'
|
||||
curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "ggml-koala-7b-model-q4_0-r2.bin",
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
Available additional parameters: `top_p`, `top_k`, `max_tokens`
|
||||
|
||||
#### Completions
|
||||
|
||||
For example, to generate a comletion, you can send a POST request to the `/v1/completions` endpoint with the instruction as the request body:
|
||||
```
|
||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "ggml-koala-7b-model-q4_0-r2.bin",
|
||||
"prompt": "A long time ago in a galaxy far, far away",
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
Available additional parameters: `top_p`, `top_k`, `max_tokens`
|
||||
|
||||
#### List models
|
||||
|
||||
You can list all the models available with:
|
||||
|
||||
```
|
||||
curl http://localhost:8080/v1/models
|
||||
```
|
||||
|
||||
## Web interface
|
||||
|
||||
There is also available a simple web interface (for instance, http://localhost:8080/) which can be used as a playground.
|
||||
|
||||
Note: The API doesn't inject a template for talking to the instance, while the CLI does. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release, for instance:
|
||||
|
||||
```
|
||||
@@ -109,32 +158,29 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||
### Response:
|
||||
```
|
||||
|
||||
Note: You can use a use a default template for every model in your model path, by creating a corresponding file with the `.tmpl` suffix. For instance, if the model is called `foo.bin`, you can create a sibiling file, `foo.bin.tmpl` which will be used as a default prompt, for instance:
|
||||
|
||||
```
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{{.Input}}
|
||||
|
||||
### Response:
|
||||
```
|
||||
|
||||
## Using other models
|
||||
|
||||
You can use the lite images ( for example `quay.io/go-skynet/llama-cli:v0.2-lite`) that don't ship any model, and specify a model binary to be used for inference with `--model`.
|
||||
gpt4all (https://github.com/nomic-ai/gpt4all) works as well, however the original model needs to be converted (same applies for old alpaca models, too):
|
||||
|
||||
13B and 30B models are known to work:
|
||||
|
||||
### 13B
|
||||
|
||||
```
|
||||
# Download the model image, extract the model
|
||||
docker run --name model --entrypoint /models quay.io/go-skynet/models:ggml2-alpaca-13b-v0.2
|
||||
docker cp model:/models/model.bin ./
|
||||
|
||||
# Use the model with llama-cli
|
||||
docker run -v $PWD:/models -p 8080:8080 -ti --rm quay.io/go-skynet/llama-cli:v0.2-lite api --model /models/model.bin
|
||||
```
|
||||
|
||||
### 30B
|
||||
|
||||
```
|
||||
# Download the model image, extract the model
|
||||
docker run --name model --entrypoint /models quay.io/go-skynet/models:ggml2-alpaca-30b-v0.2
|
||||
docker cp model:/models/model.bin ./
|
||||
|
||||
# Use the model with llama-cli
|
||||
docker run -v $PWD:/models -p 8080:8080 -ti --rm quay.io/go-skynet/llama-cli:v0.2-lite api --model /models/model.bin
|
||||
```bash
|
||||
wget -O tokenizer.model https://huggingface.co/decapoda-research/llama-30b-hf/resolve/main/tokenizer.model
|
||||
mkdir models
|
||||
cp gpt4all.. models/
|
||||
git clone https://gist.github.com/eiz/828bddec6162a023114ce19146cb2b82
|
||||
pip install sentencepiece
|
||||
python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model
|
||||
# There will be a new model with the ".tmp" extension, you have to use that one!
|
||||
```
|
||||
|
||||
### Golang client API
|
||||
@@ -152,7 +198,7 @@ import (
|
||||
|
||||
func main() {
|
||||
|
||||
cli := client.NewClient("http://ip:30007")
|
||||
cli := client.NewClient("http://ip:port")
|
||||
|
||||
out, err := cli.Predict("What's an alpaca?")
|
||||
if err != nil {
|
||||
@@ -163,10 +209,55 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
### Windows compatibility
|
||||
|
||||
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2
|
||||
|
||||
### Kubernetes
|
||||
|
||||
You can run the API directly in Kubernetes:
|
||||
|
||||
```bash
|
||||
kubectl apply -f https://raw.githubusercontent.com/go-skynet/llama-cli/master/kubernetes/deployment.yaml
|
||||
```
|
||||
```
|
||||
|
||||
### Build locally
|
||||
|
||||
Pre-built images might fit well for most of the modern hardware, however you can and might need to build the images manually.
|
||||
|
||||
In order to build the `llama-cli` container image locally you can use `docker`:
|
||||
|
||||
```
|
||||
# build the image as "alpaca-image"
|
||||
docker run --privileged -v /var/run/docker.sock:/var/run/docker.sock --rm -t -v "$(pwd)":/workspace -v earthly-tmp:/tmp/earthly:rw earthly/earthly:v0.7.2 +image --IMAGE=alpaca-image
|
||||
# run the image
|
||||
docker run alpaca-image --instruction "What's an alpaca?"
|
||||
```
|
||||
|
||||
Or build the binary with:
|
||||
|
||||
```
|
||||
# build the image as "alpaca-image"
|
||||
docker run --privileged -v /var/run/docker.sock:/var/run/docker.sock --rm -t -v "$(pwd)":/workspace -v earthly-tmp:/tmp/earthly:rw earthly/earthly:v0.7.2 +build
|
||||
# run the binary
|
||||
./llama-cli --instruction "What's an alpaca?"
|
||||
```
|
||||
|
||||
## Short-term roadmap
|
||||
|
||||
- [x] Mimic OpenAI API (https://github.com/go-skynet/llama-cli/issues/10)
|
||||
- Binary releases (https://github.com/go-skynet/llama-cli/issues/6)
|
||||
- Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351)
|
||||
- [x] Multi-model support
|
||||
- Have a webUI!
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
- https://github.com/tatsu-lab/stanford_alpaca
|
||||
- https://github.com/cornelk/llama-go for the initial ideas
|
||||
- https://github.com/antimatter15/alpaca.cpp for the light model version (this is compatible and tested only with that checkpoint model!)
|
||||
|
||||
78
api.go
78
api.go
@@ -1,78 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
llama "github.com/go-skynet/llama/go"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func api(l *llama.LLama, listenAddr string, threads int) error {
|
||||
app := fiber.New()
|
||||
|
||||
/*
|
||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
||||
"text": "What is an alpaca?",
|
||||
"topP": 0.8,
|
||||
"topK": 50,
|
||||
"temperature": 0.7,
|
||||
"tokens": 100
|
||||
}'
|
||||
*/
|
||||
|
||||
// Endpoint to generate the prediction
|
||||
app.Post("/predict", func(c *fiber.Ctx) error {
|
||||
// Get input data from the request body
|
||||
input := new(struct {
|
||||
Text string `json:"text"`
|
||||
})
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
prediction, err := l.Predict(
|
||||
input.Text,
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(struct {
|
||||
Prediction string `json:"prediction"`
|
||||
}{
|
||||
Prediction: prediction,
|
||||
})
|
||||
})
|
||||
|
||||
// Start the server
|
||||
app.Listen(":8080")
|
||||
return nil
|
||||
}
|
||||
275
api/api.go
Normal file
275
api/api.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
model "github.com/go-skynet/llama-cli/pkg/model"
|
||||
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
)
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Created int `json:"created,omitempty"`
|
||||
Object string `json:"chat.completion,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message Message `json:"message,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
// Prompt is read only by completion API calls
|
||||
Prompt string `json:"prompt"`
|
||||
// Messages is readh only by chat/completion API calls
|
||||
Messages []Message `json:"messages"`
|
||||
|
||||
// Common options between all the API calls
|
||||
TopP float64 `json:"top_p"`
|
||||
TopK int `json:"top_k"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Maxtokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
//go:embed index.html
|
||||
var indexHTML embed.FS
|
||||
|
||||
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var err error
|
||||
var model *llama.LLama
|
||||
|
||||
input := new(OpenAIRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if input.Model == "" {
|
||||
if defaultModel == nil {
|
||||
return fmt.Errorf("no default model loaded, and no model specified")
|
||||
}
|
||||
model = defaultModel
|
||||
} else {
|
||||
model, err = loader.LoadModel(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
if input.Model != "" {
|
||||
mutexMap.Lock()
|
||||
l, ok := mutexes[input.Model]
|
||||
if !ok {
|
||||
m := &sync.Mutex{}
|
||||
mutexes[input.Model] = m
|
||||
l = m
|
||||
}
|
||||
mutexMap.Unlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
} else {
|
||||
defaultMutex.Lock()
|
||||
defer defaultMutex.Unlock()
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
topP := input.TopP
|
||||
if topP == 0 {
|
||||
topP = 0.7
|
||||
}
|
||||
topK := input.TopK
|
||||
if topK == 0 {
|
||||
topK = 80
|
||||
}
|
||||
|
||||
temperature := input.Temperature
|
||||
if temperature == 0 {
|
||||
temperature = 0.9
|
||||
}
|
||||
|
||||
tokens := input.Maxtokens
|
||||
if tokens == 0 {
|
||||
tokens = 512
|
||||
}
|
||||
|
||||
predInput := input.Prompt
|
||||
if chat {
|
||||
mess := []string{}
|
||||
for _, i := range input.Messages {
|
||||
mess = append(mess, i.Content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
}
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := loader.TemplatePrefix(input.Model, struct {
|
||||
Input string
|
||||
}{Input: predInput})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
prediction, err := model.Predict(
|
||||
predInput,
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if chat {
|
||||
// Return the chat prediction in the response body
|
||||
return c.JSON(OpenAIResponse{
|
||||
Model: input.Model,
|
||||
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
|
||||
})
|
||||
}
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(OpenAIResponse{
|
||||
Model: input.Model,
|
||||
Choices: []Choice{{Text: prediction}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr string, threads int) error {
|
||||
app := fiber.New()
|
||||
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
app.Use(cors.New())
|
||||
|
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutex = &sync.Mutex{}
|
||||
mu := map[string]*sync.Mutex{}
|
||||
var mumutex = &sync.Mutex{}
|
||||
|
||||
// openAI compatible API endpoint
|
||||
app.Post("/v1/chat/completions", openAIEndpoint(true, defaultModel, loader, threads, mutex, mumutex, mu))
|
||||
app.Post("/v1/completions", openAIEndpoint(false, defaultModel, loader, threads, mutex, mumutex, mu))
|
||||
app.Get("/v1/models", func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataModels := []OpenAIModel{}
|
||||
for _, m := range models {
|
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
return c.JSON(struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIModel `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
})
|
||||
|
||||
app.Use("/", filesystem.New(filesystem.Config{
|
||||
Root: http.FS(indexHTML),
|
||||
NotFoundFile: "index.html",
|
||||
}))
|
||||
|
||||
/*
|
||||
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
|
||||
"text": "What is an alpaca?",
|
||||
"topP": 0.8,
|
||||
"topK": 50,
|
||||
"temperature": 0.7,
|
||||
"tokens": 100
|
||||
}'
|
||||
*/
|
||||
// Endpoint to generate the prediction
|
||||
app.Post("/predict", func(c *fiber.Ctx) error {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
// Get input data from the request body
|
||||
input := new(struct {
|
||||
Text string `json:"text"`
|
||||
})
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate the prediction using the language model
|
||||
prediction, err := defaultModel.Predict(
|
||||
input.Text,
|
||||
llama.SetTemperature(temperature),
|
||||
llama.SetTopP(topP),
|
||||
llama.SetTopK(topK),
|
||||
llama.SetTokens(tokens),
|
||||
llama.SetThreads(threads),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(struct {
|
||||
Prediction string `json:"prediction"`
|
||||
}{
|
||||
Prediction: prediction,
|
||||
})
|
||||
})
|
||||
|
||||
// Start the server
|
||||
app.Listen(listenAddr)
|
||||
return nil
|
||||
}
|
||||
120
api/index.html
Normal file
120
api/index.html
Normal file
@@ -0,0 +1,120 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>llama-cli</title>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
||||
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css">
|
||||
</head>
|
||||
<style>
|
||||
@keyframes rotating {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.waiting {
|
||||
animation: rotating 1s linear infinite;
|
||||
}
|
||||
|
||||
</style>
|
||||
<body>
|
||||
|
||||
<div class="container mt-5" x-data="{ templates:[
|
||||
{
|
||||
name: 'Alpaca: Instruction without input',
|
||||
text: `Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{{.Instruction}}
|
||||
|
||||
### Response:`,
|
||||
},
|
||||
{
|
||||
name: 'Alpaca: Instruction with input',
|
||||
text: `Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{{.Instruction}}
|
||||
|
||||
### Input:
|
||||
{{.Input}}
|
||||
|
||||
### Response:`,
|
||||
}
|
||||
], selectedTemplate: '', selectedTemplateText: '' }">
|
||||
<h1>llama-cli API</h1>
|
||||
<div class="form-group">
|
||||
<label for="inputText">Input Text:</label>
|
||||
<textarea class="form-control" id="inputText" rows="6" placeholder="Your text input here..." x-text="selectedTemplateText"></textarea>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="templateSelect">Select Template:</label>
|
||||
<select class="form-control" id="templateSelect" x-model="selectedTemplateText">
|
||||
<option value="">None</option>
|
||||
<template x-for="(template, index) in templates" :key="index">
|
||||
<option :value="template.text" x-text="template.name"></option>
|
||||
</template>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="topP">Top P:</label>
|
||||
<input type="range" step="0.01" min="0" max="1" class="form-control" id="topP" value="0.20" name="topP" onchange="this.nextElementSibling.value = this.value" required>
|
||||
<output>0.20</output>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="topK">Top K:</label>
|
||||
<input type="number" class="form-control" id="topK" value="10000" name="topK" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="temperature">Temperature:</label>
|
||||
<input type="range" step="0.01" min="0" max="1" value="0.9" class="form-control" id="temperature" name="temperature" onchange="this.nextElementSibling.value = this.value" required>
|
||||
<output>0.9</output>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="tokens">Tokens:</label>
|
||||
<input type="number" class="form-control" id="tokens" name="tokens" value="128" required>
|
||||
</div>
|
||||
<button class="btn btn-primary" x-on:click="submitRequest()">Submit <i class="fas fa-paper-plane"></i></button>
|
||||
<hr>
|
||||
<div class="form-group">
|
||||
<label for="outputText">Output Text:</label>
|
||||
<textarea class="form-control" id="outputText" rows="5" readonly></textarea>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
<script>
|
||||
function submitRequest() {
|
||||
var button = document.querySelector("i.fa-paper-plane");
|
||||
button.classList.add("waiting");
|
||||
var text = document.getElementById("inputText").value;
|
||||
var url = "/predict";
|
||||
var data = {
|
||||
"text": text,
|
||||
"topP": document.getElementById("topP").value,
|
||||
"topK": document.getElementById("topK").value,
|
||||
"temperature": document.getElementById("temperature").value,
|
||||
"tokens": document.getElementById("tokens").value
|
||||
};
|
||||
fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify(data)
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
document.getElementById("outputText").value = data.prediction;
|
||||
button.classList.remove("waiting");
|
||||
})
|
||||
.catch(error => { console.error(error); button.classList.remove("waiting"); });
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
7
go.mod
7
go.mod
@@ -6,7 +6,7 @@ require (
|
||||
github.com/charmbracelet/bubbles v0.15.0
|
||||
github.com/charmbracelet/bubbletea v0.23.2
|
||||
github.com/charmbracelet/lipgloss v0.7.1
|
||||
github.com/go-skynet/llama v0.0.0-20230321172246-7be5326e18cc
|
||||
github.com/go-skynet/llama v0.0.0-20230329165201-84efc8db3647
|
||||
github.com/gofiber/fiber/v2 v2.42.0
|
||||
github.com/urfave/cli/v2 v2.25.0
|
||||
)
|
||||
@@ -17,6 +17,7 @@ require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/containerd/console v1.0.3 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/klauspost/compress v1.15.9 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
@@ -40,6 +41,6 @@ require (
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.6.0 // indirect
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/term v0.5.0 // indirect
|
||||
golang.org/x/text v0.7.0 // indirect
|
||||
)
|
||||
|
||||
14
go.sum
14
go.sum
@@ -19,8 +19,16 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230404185816-24b85a924f09 h1:WPUWvw7DOv3WUuhtNfv+xJVE2CCTGa1op1PKGcNk2Bk=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230404185816-24b85a924f09/go.mod h1:yD5HHNAHPReBlvWGWUr9OcMeE5BJH3xOUDtKCwjxdEQ=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021 h1:SsUkTjdCCAJjULfspizf99Sfw8Fx9OAHF30kp3i6cxc=
|
||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021/go.mod h1:yD5HHNAHPReBlvWGWUr9OcMeE5BJH3xOUDtKCwjxdEQ=
|
||||
github.com/go-skynet/llama v0.0.0-20230321172246-7be5326e18cc h1:NcmO8mA7iRZIX0Qy2SjcsSaV14+g87MiTey1neUJaFQ=
|
||||
github.com/go-skynet/llama v0.0.0-20230321172246-7be5326e18cc/go.mod h1:ZtYsAIud4cvP9VTTI9uhdgR1uCwaO/gGKnZZ95h9i7w=
|
||||
github.com/go-skynet/llama v0.0.0-20230325223742-a3563a2690ba h1:u6OhAqlWFHsTjfWKePdK2kP4/mTyXX5vsmKwrK5QX6o=
|
||||
github.com/go-skynet/llama v0.0.0-20230325223742-a3563a2690ba/go.mod h1:ZtYsAIud4cvP9VTTI9uhdgR1uCwaO/gGKnZZ95h9i7w=
|
||||
github.com/go-skynet/llama v0.0.0-20230329165201-84efc8db3647 h1:W6qHHD/Bv6wRXSzdv38gWMAXgw3fklHyEblfw88uEUU=
|
||||
github.com/go-skynet/llama v0.0.0-20230329165201-84efc8db3647/go.mod h1:ZtYsAIud4cvP9VTTI9uhdgR1uCwaO/gGKnZZ95h9i7w=
|
||||
github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8=
|
||||
github.com/gofiber/fiber/v2 v2.42.0/go.mod h1:3+SGNjqMh5VQH5Vz2Wdi43zTIV16ktlFd3x3R6O1Zlc=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
@@ -108,13 +116,15 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
|
||||
142
interactive.go
142
interactive.go
@@ -1,142 +0,0 @@
|
||||
package main
|
||||
|
||||
// A simple program demonstrating the text area component from the Bubbles
|
||||
// component library.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textarea"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
llama "github.com/go-skynet/llama/go"
|
||||
)
|
||||
|
||||
func startInteractive(l *llama.LLama, opts ...llama.PredictOption) error {
|
||||
p := tea.NewProgram(initialModel(l, opts...))
|
||||
|
||||
_, err := p.Run()
|
||||
return err
|
||||
}
|
||||
|
||||
type (
|
||||
errMsg error
|
||||
)
|
||||
|
||||
type model struct {
|
||||
viewport viewport.Model
|
||||
messages *[]string
|
||||
textarea textarea.Model
|
||||
senderStyle lipgloss.Style
|
||||
err error
|
||||
l *llama.LLama
|
||||
opts []llama.PredictOption
|
||||
|
||||
predictC chan string
|
||||
}
|
||||
|
||||
func initialModel(l *llama.LLama, opts ...llama.PredictOption) model {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = "Send a message..."
|
||||
ta.Focus()
|
||||
|
||||
ta.Prompt = "┃ "
|
||||
ta.CharLimit = 280
|
||||
|
||||
ta.SetWidth(200)
|
||||
ta.SetHeight(3)
|
||||
|
||||
// Remove cursor line styling
|
||||
ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
||||
|
||||
ta.ShowLineNumbers = false
|
||||
|
||||
vp := viewport.New(200, 5)
|
||||
vp.SetContent(`Welcome to llama-cli. Type a message and press Enter to send. Alpaca doesn't keep context of the whole chat (yet).`)
|
||||
|
||||
ta.KeyMap.InsertNewline.SetEnabled(false)
|
||||
|
||||
predictChannel := make(chan string)
|
||||
messages := []string{}
|
||||
m := model{
|
||||
textarea: ta,
|
||||
messages: &messages,
|
||||
viewport: vp,
|
||||
senderStyle: lipgloss.NewStyle().Foreground(lipgloss.Color("5")),
|
||||
err: nil,
|
||||
l: l,
|
||||
opts: opts,
|
||||
predictC: predictChannel,
|
||||
}
|
||||
go func() {
|
||||
for p := range predictChannel {
|
||||
str, _ := templateString(emptyInput, struct {
|
||||
Instruction string
|
||||
Input string
|
||||
}{Instruction: p})
|
||||
res, _ := l.Predict(
|
||||
str,
|
||||
opts...,
|
||||
)
|
||||
|
||||
mm := *m.messages
|
||||
*m.messages = mm[:len(mm)-1]
|
||||
*m.messages = append(*m.messages, m.senderStyle.Render("llama: ")+res)
|
||||
m.viewport.SetContent(strings.Join(*m.messages, "\n"))
|
||||
ta.Reset()
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
}()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var (
|
||||
tiCmd tea.Cmd
|
||||
vpCmd tea.Cmd
|
||||
)
|
||||
|
||||
m.textarea, tiCmd = m.textarea.Update(msg)
|
||||
m.viewport, vpCmd = m.viewport.Update(msg)
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
|
||||
// m.viewport.Width = msg.Width
|
||||
// m.viewport.Height = msg.Height
|
||||
case tea.KeyMsg:
|
||||
switch msg.Type {
|
||||
case tea.KeyCtrlC, tea.KeyEsc:
|
||||
fmt.Println(m.textarea.Value())
|
||||
return m, tea.Quit
|
||||
case tea.KeyEnter:
|
||||
*m.messages = append(*m.messages, m.senderStyle.Render("You: ")+m.textarea.Value(), m.senderStyle.Render("Loading response..."))
|
||||
m.predictC <- m.textarea.Value()
|
||||
m.viewport.SetContent(strings.Join(*m.messages, "\n"))
|
||||
m.textarea.Reset()
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
|
||||
// We handle errors just like any other message
|
||||
case errMsg:
|
||||
m.err = msg
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, tea.Batch(tiCmd, vpCmd)
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
return fmt.Sprintf(
|
||||
"%s\n\n%s",
|
||||
m.viewport.View(),
|
||||
m.textarea.View(),
|
||||
) + "\n\n"
|
||||
}
|
||||
@@ -25,7 +25,7 @@ spec:
|
||||
- name: llama
|
||||
args:
|
||||
- api
|
||||
image: quay.io/go-skynet/llama-cli:v0.1
|
||||
image: quay.io/go-skynet/llama-cli:v0.3
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
|
||||
61
main.go
61
main.go
@@ -8,7 +8,10 @@ import (
|
||||
"runtime"
|
||||
"text/template"
|
||||
|
||||
llama "github.com/go-skynet/llama/go"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
api "github.com/go-skynet/llama-cli/api"
|
||||
model "github.com/go-skynet/llama-cli/pkg/model"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
@@ -33,10 +36,6 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
|
||||
|
||||
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
if ctx.Bool("alpaca") {
|
||||
opts = append(opts, llama.EnableAlpaca)
|
||||
}
|
||||
|
||||
return llama.New(ctx.String("model"), opts...)
|
||||
}
|
||||
|
||||
@@ -90,11 +89,6 @@ var modelFlags = []cli.Flag{
|
||||
EnvVars: []string{"TOP_K"},
|
||||
Value: 20,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "alpaca",
|
||||
EnvVars: []string{"ALPACA"},
|
||||
Value: true,
|
||||
},
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -127,24 +121,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
`,
|
||||
Copyright: "go-skynet authors",
|
||||
Commands: []*cli.Command{
|
||||
{
|
||||
Flags: modelFlags,
|
||||
Name: "interactive",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
|
||||
l, err := llamaFromOptions(ctx)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return startInteractive(l, llama.SetTemperature(ctx.Float64("temperature")),
|
||||
llama.SetTopP(ctx.Float64("topp")),
|
||||
llama.SetTopK(ctx.Int("topk")),
|
||||
llama.SetTokens(ctx.Int("tokens")),
|
||||
llama.SetThreads(ctx.Int("threads")))
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
Name: "api",
|
||||
@@ -155,19 +131,18 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
Value: runtime.NumCPU(),
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "model",
|
||||
EnvVars: []string{"MODEL_PATH"},
|
||||
Name: "models-path",
|
||||
EnvVars: []string{"MODELS_PATH"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "default-model",
|
||||
EnvVars: []string{"default-model"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "address",
|
||||
EnvVars: []string{"ADDRESS"},
|
||||
Value: ":8080",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "alpaca",
|
||||
EnvVars: []string{"ALPACA"},
|
||||
Value: true,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
EnvVars: []string{"CONTEXT_SIZE"},
|
||||
@@ -175,13 +150,19 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
|
||||
},
|
||||
},
|
||||
Action: func(ctx *cli.Context) error {
|
||||
l, err := llamaFromOptions(ctx)
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
|
||||
var defaultModel *llama.LLama
|
||||
defModel := ctx.String("default-model")
|
||||
if defModel != "" {
|
||||
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
|
||||
var err error
|
||||
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return api(l, ctx.String("address"), ctx.Int("threads"))
|
||||
return api.Start(defaultModel, model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
114
pkg/model/loader.go
Normal file
114
pkg/model/loader.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type ModelLoader struct {
|
||||
modelPath string
|
||||
mu sync.Mutex
|
||||
models map[string]*llama.LLama
|
||||
promptsTemplates map[string]*template.Template
|
||||
}
|
||||
|
||||
func NewModelLoader(modelPath string) *ModelLoader {
|
||||
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) ListModels() ([]string, error) {
|
||||
files, err := ioutil.ReadDir(ml.modelPath)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
models := []string{}
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(file.Name(), ".bin") {
|
||||
models = append(models, strings.TrimRight(file.Name(), ".bin"))
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
m, ok := ml.promptsTemplates[modelName]
|
||||
if !ok {
|
||||
// try to find a s.bin
|
||||
modelBin := fmt.Sprintf("%s.bin", modelName)
|
||||
m, ok = ml.promptsTemplates[modelBin]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no prompt template available")
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := m.Execute(&buf, in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
// Check if we already have a loaded model
|
||||
modelFile := filepath.Join(ml.modelPath, modelName)
|
||||
|
||||
if m, ok := ml.models[modelFile]; ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
// try to find a s.bin
|
||||
modelBin := fmt.Sprintf("%s.bin", modelFile)
|
||||
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
|
||||
return nil, err
|
||||
} else {
|
||||
modelName = fmt.Sprintf("%s.bin", modelName)
|
||||
modelFile = modelBin
|
||||
}
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
model, err := llama.New(modelFile, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there is a prompt template, load it
|
||||
|
||||
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
|
||||
// Check if the model path exists
|
||||
if _, err := os.Stat(modelTemplateFile); err == nil {
|
||||
dat, err := os.ReadFile(modelTemplateFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("prompt").Parse(string(dat))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ml.promptsTemplates[modelName] = tmpl
|
||||
}
|
||||
|
||||
ml.models[modelFile] = model
|
||||
return model, err
|
||||
}
|
||||
Reference in New Issue
Block a user