Compare commits

...

10 Commits
v0.5 ... v0.6

Author SHA1 Message Date
mudler
93d8977ba2 Return model list 2023-04-10 12:02:40 +02:00
mudler
f43aeeb4a1 Add both API endpoints (completion, chat) 2023-04-09 12:30:55 +02:00
mudler
c17dcc5e9d Allow to inject prompt as part of the call 2023-04-09 09:36:19 +02:00
mudler
4a932483e1 Small fixup to template loading 2023-04-08 11:59:40 +02:00
mudler
b710147b95 Add mutex on same models (parallel isn't supported yet) 2023-04-08 11:45:36 +02:00
mudler
ba70363330 Use template input 2023-04-08 11:24:25 +02:00
mudler
9fb581739b Allow to template model prompts inputs 2023-04-08 10:46:51 +02:00
mudler
48aca246e3 Drop unused interactive mode 2023-04-07 11:31:14 +02:00
mudler
12eee097b7 Make it compatible with openAI api, support multiple models
Signed-off-by: mudler <mudler@c3os.io>
2023-04-07 11:30:59 +02:00
mudler
b33d015b8c Use go-llama.cpp 2023-04-07 10:08:15 +02:00
7 changed files with 403 additions and 203 deletions

View File

@@ -14,10 +14,11 @@ go-deps:
build:
FROM +go-deps
WORKDIR /build
RUN git clone https://github.com/go-skynet/llama
RUN cd llama && make libllama.a
RUN git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp
RUN cd go-llama.cpp && make libbinding.a
COPY . .
RUN C_INCLUDE_PATH=/build/llama LIBRARY_PATH=/build/llama go build -o llama-cli ./
RUN go mod edit -replace github.com/go-skynet/go-llama.cpp=/build/go-llama.cpp
RUN C_INCLUDE_PATH=$GOPATH/src/github.com/go-skynet/go-llama.cpp LIBRARY_PATH=$GOPATH/src/github.com/go-skynet/go-llama.cpp go build -o llama-cli ./
SAVE ARTIFACT llama-cli AS LOCAL llama-cli
image:

268
api.go
View File

@@ -2,24 +2,282 @@ package main
import (
"embed"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
llama "github.com/go-skynet/llama/go"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
)
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"chat.completion,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type Choice struct {
Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message Message `json:"message,omitempty"`
Text string `json:"text,omitempty"`
}
type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
//go:embed index.html
var indexHTML embed.FS
func api(l *llama.LLama, listenAddr string, threads int) error {
func completionEndpoint(defaultModel *llama.LLama, loader *ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
// Get input data from the request body
input := new(struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
})
if err := c.BodyParser(input); err != nil {
return err
}
if input.Model == "" {
if defaultModel == nil {
return fmt.Errorf("no default model loaded, and no model specified")
}
model = defaultModel
} else {
model, err = loader.LoadModel(input.Model)
if err != nil {
return err
}
}
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
if input.Model != "" {
mutexMap.Lock()
l, ok := mutexes[input.Model]
if !ok {
m := &sync.Mutex{}
mutexes[input.Model] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
} else {
defaultMutex.Lock()
defer defaultMutex.Unlock()
}
// Set the parameters for the language model prediction
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
if err != nil {
return err
}
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
if err != nil {
return err
}
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
if err != nil {
return err
}
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
if err != nil {
return err
}
predInput := input.Prompt
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(input.Model, struct {
Input string
}{Input: input.Prompt})
if err == nil {
predInput = templatedInput
}
// Generate the prediction using the language model
prediction, err := model.Predict(
predInput,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
)
if err != nil {
return err
}
// Return the prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model,
Choices: []Choice{{Text: prediction}},
})
}
}
func chatEndpoint(defaultModel *llama.LLama, loader *ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
// Get input data from the request body
input := new(struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
})
if err := c.BodyParser(input); err != nil {
return err
}
if input.Model == "" {
if defaultModel == nil {
return fmt.Errorf("no default model loaded, and no model specified")
}
model = defaultModel
} else {
model, err = loader.LoadModel(input.Model)
if err != nil {
return err
}
}
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
if input.Model != "" {
mutexMap.Lock()
l, ok := mutexes[input.Model]
if !ok {
m := &sync.Mutex{}
mutexes[input.Model] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
} else {
defaultMutex.Lock()
defer defaultMutex.Unlock()
}
// Set the parameters for the language model prediction
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
if err != nil {
return err
}
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
if err != nil {
return err
}
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
if err != nil {
return err
}
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
if err != nil {
return err
}
mess := []string{}
for _, i := range input.Messages {
mess = append(mess, i.Content)
}
predInput := strings.Join(mess, "\n")
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(input.Model, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
}
// Generate the prediction using the language model
prediction, err := model.Predict(
predInput,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
)
if err != nil {
return err
}
// Return the prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model,
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
})
}
}
func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error {
app := fiber.New()
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutex = &sync.Mutex{}
mu := map[string]*sync.Mutex{}
var mumutex = &sync.Mutex{}
// openAI compatible API endpoint
app.Post("/v1/chat/completions", chatEndpoint(defaultModel, loader, threads, mutex, mumutex, mu))
app.Post("/v1/completions", completionEndpoint(defaultModel, loader, threads, mutex, mumutex, mu))
app.Get("/v1/models", func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
return err
}
dataModels := []OpenAIModel{}
for _, m := range models {
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
})
app.Use("/", filesystem.New(filesystem.Config{
Root: http.FS(indexHTML),
NotFoundFile: "index.html",
}))
/*
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
"text": "What is an alpaca?",
@@ -29,8 +287,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
"tokens": 100
}'
*/
var mutex = &sync.Mutex{}
// Endpoint to generate the prediction
app.Post("/predict", func(c *fiber.Ctx) error {
mutex.Lock()
@@ -65,7 +321,7 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
}
// Generate the prediction using the language model
prediction, err := l.Predict(
prediction, err := defaultModel.Predict(
input.Text,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
@@ -86,6 +342,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
})
// Start the server
app.Listen(":8080")
app.Listen(listenAddr)
return nil
}

1
go.mod
View File

@@ -17,6 +17,7 @@ require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/containerd/console v1.0.3 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/klauspost/compress v1.15.9 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect

4
go.sum
View File

@@ -19,6 +19,10 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/go-skynet/go-llama.cpp v0.0.0-20230404185816-24b85a924f09 h1:WPUWvw7DOv3WUuhtNfv+xJVE2CCTGa1op1PKGcNk2Bk=
github.com/go-skynet/go-llama.cpp v0.0.0-20230404185816-24b85a924f09/go.mod h1:yD5HHNAHPReBlvWGWUr9OcMeE5BJH3xOUDtKCwjxdEQ=
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021 h1:SsUkTjdCCAJjULfspizf99Sfw8Fx9OAHF30kp3i6cxc=
github.com/go-skynet/go-llama.cpp v0.0.0-20230405204601-5429d2339021/go.mod h1:yD5HHNAHPReBlvWGWUr9OcMeE5BJH3xOUDtKCwjxdEQ=
github.com/go-skynet/llama v0.0.0-20230321172246-7be5326e18cc h1:NcmO8mA7iRZIX0Qy2SjcsSaV14+g87MiTey1neUJaFQ=
github.com/go-skynet/llama v0.0.0-20230321172246-7be5326e18cc/go.mod h1:ZtYsAIud4cvP9VTTI9uhdgR1uCwaO/gGKnZZ95h9i7w=
github.com/go-skynet/llama v0.0.0-20230325223742-a3563a2690ba h1:u6OhAqlWFHsTjfWKePdK2kP4/mTyXX5vsmKwrK5QX6o=

View File

@@ -1,142 +0,0 @@
package main
// A simple program demonstrating the text area component from the Bubbles
// component library.
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
llama "github.com/go-skynet/llama/go"
)
func startInteractive(l *llama.LLama, opts ...llama.PredictOption) error {
p := tea.NewProgram(initialModel(l, opts...))
_, err := p.Run()
return err
}
type (
errMsg error
)
type model struct {
viewport viewport.Model
messages *[]string
textarea textarea.Model
senderStyle lipgloss.Style
err error
l *llama.LLama
opts []llama.PredictOption
predictC chan string
}
func initialModel(l *llama.LLama, opts ...llama.PredictOption) model {
ta := textarea.New()
ta.Placeholder = "Send a message..."
ta.Focus()
ta.Prompt = "┃ "
ta.CharLimit = 280
ta.SetWidth(200)
ta.SetHeight(3)
// Remove cursor line styling
ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
ta.ShowLineNumbers = false
vp := viewport.New(200, 5)
vp.SetContent(`Welcome to llama-cli. Type a message and press Enter to send. Alpaca doesn't keep context of the whole chat (yet).`)
ta.KeyMap.InsertNewline.SetEnabled(false)
predictChannel := make(chan string)
messages := []string{}
m := model{
textarea: ta,
messages: &messages,
viewport: vp,
senderStyle: lipgloss.NewStyle().Foreground(lipgloss.Color("5")),
err: nil,
l: l,
opts: opts,
predictC: predictChannel,
}
go func() {
for p := range predictChannel {
str, _ := templateString(emptyInput, struct {
Instruction string
Input string
}{Instruction: p})
res, _ := l.Predict(
str,
opts...,
)
mm := *m.messages
*m.messages = mm[:len(mm)-1]
*m.messages = append(*m.messages, m.senderStyle.Render("llama: ")+res)
m.viewport.SetContent(strings.Join(*m.messages, "\n"))
ta.Reset()
m.viewport.GotoBottom()
}
}()
return m
}
func (m model) Init() tea.Cmd {
return textarea.Blink
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var (
tiCmd tea.Cmd
vpCmd tea.Cmd
)
m.textarea, tiCmd = m.textarea.Update(msg)
m.viewport, vpCmd = m.viewport.Update(msg)
switch msg := msg.(type) {
case tea.WindowSizeMsg:
// m.viewport.Width = msg.Width
// m.viewport.Height = msg.Height
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc:
fmt.Println(m.textarea.Value())
return m, tea.Quit
case tea.KeyEnter:
*m.messages = append(*m.messages, m.senderStyle.Render("You: ")+m.textarea.Value(), m.senderStyle.Render("Loading response..."))
m.predictC <- m.textarea.Value()
m.viewport.SetContent(strings.Join(*m.messages, "\n"))
m.textarea.Reset()
m.viewport.GotoBottom()
}
// We handle errors just like any other message
case errMsg:
m.err = msg
return m, nil
}
return m, tea.Batch(tiCmd, vpCmd)
}
func (m model) View() string {
return fmt.Sprintf(
"%s\n\n%s",
m.viewport.View(),
m.textarea.View(),
) + "\n\n"
}

70
main.go
View File

@@ -8,7 +8,7 @@ import (
"runtime"
"text/template"
llama "github.com/go-skynet/llama/go"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/urfave/cli/v2"
)
@@ -33,12 +33,6 @@ var nonEmptyInput string = `Below is an instruction that describes a task, paire
func llamaFromOptions(ctx *cli.Context) (*llama.LLama, error) {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
if ctx.Bool("alpaca") {
opts = append(opts, llama.EnableAlpaca)
}
if ctx.Bool("gpt4all") {
opts = append(opts, llama.EnableGPT4All)
}
return llama.New(ctx.String("model"), opts...)
}
@@ -92,16 +86,6 @@ var modelFlags = []cli.Flag{
EnvVars: []string{"TOP_K"},
Value: 20,
},
&cli.BoolFlag{
Name: "alpaca",
EnvVars: []string{"ALPACA"},
Value: true,
},
&cli.BoolFlag{
Name: "gpt4all",
EnvVars: []string{"GPT4ALL"},
Value: false,
},
}
func main() {
@@ -134,24 +118,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
`,
Copyright: "go-skynet authors",
Commands: []*cli.Command{
{
Flags: modelFlags,
Name: "interactive",
Action: func(ctx *cli.Context) error {
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
return startInteractive(l, llama.SetTemperature(ctx.Float64("temperature")),
llama.SetTopP(ctx.Float64("topp")),
llama.SetTopK(ctx.Int("topk")),
llama.SetTokens(ctx.Int("tokens")),
llama.SetThreads(ctx.Int("threads")))
},
},
{
Name: "api",
@@ -162,24 +128,18 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
Value: runtime.NumCPU(),
},
&cli.StringFlag{
Name: "model",
EnvVars: []string{"MODEL_PATH"},
Name: "models-path",
EnvVars: []string{"MODELS_PATH"},
},
&cli.StringFlag{
Name: "default-model",
EnvVars: []string{"default-model"},
},
&cli.StringFlag{
Name: "address",
EnvVars: []string{"ADDRESS"},
Value: ":8080",
},
&cli.BoolFlag{
Name: "alpaca",
EnvVars: []string{"ALPACA"},
Value: true,
},
&cli.BoolFlag{
Name: "gpt4all",
EnvVars: []string{"GPT4ALL"},
Value: false,
},
&cli.IntFlag{
Name: "context-size",
EnvVars: []string{"CONTEXT_SIZE"},
@@ -187,13 +147,19 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
},
},
Action: func(ctx *cli.Context) error {
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
var defaultModel *llama.LLama
defModel := ctx.String("default-model")
if defModel != "" {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
var err error
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
if err != nil {
return err
}
}
return api(l, ctx.String("address"), ctx.Int("threads"))
return api(defaultModel, NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
},
},
},

114
model_loader.go Normal file
View File

@@ -0,0 +1,114 @@
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
"text/template"
llama "github.com/go-skynet/go-llama.cpp"
)
type ModelLoader struct {
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
promptsTemplates map[string]*template.Template
}
func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
}
func (ml *ModelLoader) ListModels() ([]string, error) {
files, err := ioutil.ReadDir(ml.modelPath)
if err != nil {
return []string{}, err
}
models := []string{}
for _, file := range files {
if strings.HasSuffix(file.Name(), ".bin") {
models = append(models, strings.TrimRight(file.Name(), ".bin"))
}
}
return models, nil
}
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
m, ok := ml.promptsTemplates[modelName]
if !ok {
// try to find a s.bin
modelBin := fmt.Sprintf("%s.bin", modelName)
m, ok = ml.promptsTemplates[modelBin]
if !ok {
return "", fmt.Errorf("no prompt template available")
}
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
modelFile := filepath.Join(ml.modelPath, modelName)
if m, ok := ml.models[modelFile]; ok {
return m, nil
}
// Check if the model path exists
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
// try to find a s.bin
modelBin := fmt.Sprintf("%s.bin", modelFile)
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
return nil, err
} else {
modelName = fmt.Sprintf("%s.bin", modelName)
modelFile = modelBin
}
}
// Load the model and keep it in memory for later use
model, err := llama.New(modelFile, opts...)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
// Check if the model path exists
if _, err := os.Stat(modelTemplateFile); err == nil {
dat, err := os.ReadFile(modelTemplateFile)
if err != nil {
return nil, err
}
// Parse the template
tmpl, err := template.New("prompt").Parse(string(dat))
if err != nil {
return nil, err
}
ml.promptsTemplates[modelName] = tmpl
}
ml.models[modelFile] = model
return model, err
}