Compare commits

..

5 Commits

Author SHA1 Message Date
Matt Williams
9dd88dc040 Fixes for Bruces comments
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-11-06 14:16:24 -08:00
Matt Williams
3d8872bbbd update as per jmorganca comments
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-11-06 08:51:15 -08:00
Matt Williams
a1c8974975 also try other models
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-11-05 16:13:48 -08:00
Matt Williams
1aaaaa76a0 add examples
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-11-05 15:56:00 -08:00
Matt Williams
9411399cb4 Add new example for self querying retrieval
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-11-05 15:53:24 -08:00
47 changed files with 3254 additions and 1679 deletions

View File

@@ -29,7 +29,8 @@ curl https://ollama.ai/install.sh | sh
### Docker
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
The official [Ollama Docker image `ollama/ollama`](https://hub.docker.com/r/ollama/ollama)
is available on Docker Hub.
## Quickstart
@@ -159,7 +160,7 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
### Pass in prompt as arguments
```
$ ollama run llama2 "Summarize this file: $(cat README.md)"
$ ollama run llama2 "summarize this file:" "$(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
```
@@ -216,44 +217,21 @@ See the [API documentation](./docs/api.md) for all endpoints.
## Community Integrations
### Web & Desktop
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Web UI](https://github.com/ollama-webui/ollama-webui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md)
### Terminal
- [oterm](https://github.com/ggozad/oterm)
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
- [Emacs client](https://github.com/zweifisch/ollama)
- [gen.nvim](https://github.com/David-Kunz/gen.nvim)
- [ollama.nvim](https://github.com/nomnivore/ollama.nvim)
- [gptel Emacs client](https://github.com/karthink/gptel)
### Libraries
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
### Extensions & Plugins
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama)
- [Discollama](https://github.com/mxyng/discollama) (Discord bot inside the Ollama discord channel)
- [Continue](https://github.com/continuedev/continue)
- [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama)
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
- [LiteLLM](https://github.com/BerriAI/litellm)
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Dumbar](https://github.com/JerrySievert/Dumbar)
- [Emacs client](https://github.com/zweifisch/ollama)
- [oterm](https://github.com/ggozad/oterm)
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)

View File

@@ -72,7 +72,7 @@ func ClientFromEnvironment() (*Client, error) {
},
}
mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
if err != nil {
return nil, err
}

View File

@@ -7,7 +7,7 @@ BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
# The final response object will include statistics and additional data from the request. Use the callback function to override
# the default handler.
def generate(model_name, prompt, system=None, template=None, format="", context=None, options=None, callback=None):
def generate(model_name, prompt, system=None, template=None, context=None, options=None, callback=None):
try:
url = f"{BASE_URL}/api/generate"
payload = {
@@ -16,8 +16,7 @@ def generate(model_name, prompt, system=None, template=None, format="", context=
"system": system,
"template": template,
"context": context,
"options": options,
"format": format,
"options": options
}
# Remove keys with None values

View File

@@ -37,56 +37,10 @@ type GenerateRequest struct {
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct {
Runner
// Predict options used at runtime
NumKeep int `json:"num_keep,omitempty"`
Seed int `json:"seed,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
TFSZ float32 `json:"tfs_z,omitempty"`
TypicalP float32 `json:"typical_p,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
Stop []string `json:"stop,omitempty"`
}
// Runner options which must be set when the model is loaded into memory
type Runner struct {
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
EmbeddingOnly bool `json:"embedding_only,omitempty"`
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
@@ -207,6 +161,49 @@ func (r *GenerateResponse) Summary() {
}
}
// Runner options which must be set when the model is loaded into memory
type Runner struct {
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
EmbeddingOnly bool `json:"embedding_only,omitempty"`
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
type Options struct {
Runner
// Predict options used at runtime
NumKeep int `json:"num_keep,omitempty"`
Seed int `json:"seed,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
TFSZ float32 `json:"tfs_z,omitempty"`
TypicalP float32 `json:"typical_p,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
Stop []string `json:"stop,omitempty"`
}
var ErrInvalidOpts = fmt.Errorf("invalid options")
func (opts *Options) FromMap(m map[string]interface{}) error {
@@ -296,7 +293,7 @@ func DefaultOptions() Options {
return Options{
// options set on request to runner
NumPredict: -1,
NumKeep: 0,
NumKeep: -1,
Temperature: 0.8,
TopK: 40,
TopP: 0.9,

View File

@@ -1,6 +1,7 @@
package cmd
import (
"bufio"
"context"
"crypto/ed25519"
"crypto/rand"
@@ -27,9 +28,9 @@ import (
"golang.org/x/term"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/editor"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/readline"
"github.com/jmorganca/ollama/server"
"github.com/jmorganca/ollama/version"
)
@@ -101,14 +102,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// check if the model exists on the server
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, args); err != nil {
if err != nil {
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, args); err != nil {
return err
}
case err != nil:
return err
}
case err != nil:
return err
}
return RunGenerate(cmd, args)
@@ -349,49 +352,34 @@ func pull(model string, insecure bool) error {
}
func RunGenerate(cmd *cobra.Command, args []string) error {
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
if len(args) > 1 {
// join all args into a single prompt
wordWrap := false
if term.IsTerminal(int(os.Stdout.Fd())) {
wordWrap = true
}
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wordWrap = false
}
prompts = append([]string{string(in)}, prompts...)
return generate(cmd, args[0], strings.Join(args[1:], " "), wordWrap)
}
// output is being piped
if !term.IsTerminal(int(os.Stdout.Fd())) {
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
if readline.IsTerminal(int(os.Stdin.Fd())) {
return generateInteractive(cmd, args[0])
}
wordWrap := os.Getenv("TERM") == "xterm-256color"
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wordWrap = false
}
// prompts are provided via stdin or args so don't enter interactive mode
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
}
return generateInteractive(cmd, args[0], wordWrap, format)
return generateBatch(cmd, args[0])
}
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -407,7 +395,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
generateContext = []int{}
}
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
termWidth, _, err := term.GetSize(int(0))
if err != nil {
wordWrap = false
}
@@ -428,7 +416,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
fn := func(response api.GenerateResponse) error {
if !spinner.IsFinished() {
spinner.Finish()
@@ -499,9 +487,9 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
return nil
}
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
func generateInteractive(cmd *cobra.Command, model string) error {
// load the model
if err := generate(cmd, model, "", false, ""); err != nil {
if err := generate(cmd, model, "", false); err != nil {
return err
}
@@ -522,8 +510,6 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "")
@@ -539,24 +525,45 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, "")
}
prompt := editor.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
prompt := readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
}
ed, err := editor.New(prompt)
scanner, err := readline.New(prompt)
if err != nil {
return err
}
var wordWrap bool
termType := os.Getenv("TERM")
if termType == "xterm-256color" {
wordWrap = true
}
// override wrapping if the user turned it off
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
if nowrap {
wordWrap = false
}
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
var multiLineBuffer string
for {
line, err := ed.HandleInput()
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, editor.ErrInterrupt):
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl-D or /bye to exit.")
}
@@ -569,6 +576,20 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
line = strings.TrimSpace(line)
switch {
case scanner.Prompt.UseAlt:
if strings.HasSuffix(line, `"""`) {
scanner.Prompt.UseAlt = false
multiLineBuffer += strings.TrimSuffix(line, `"""`)
line = multiLineBuffer
multiLineBuffer = ""
} else {
multiLineBuffer += line + " "
continue
}
case strings.HasPrefix(line, `"""`):
scanner.Prompt.UseAlt = true
multiLineBuffer = strings.TrimPrefix(line, `"""`) + " "
continue
case strings.HasPrefix(line, "/list"):
args := strings.Fields(line)
if err := ListHandler(cmd, args[1:]); err != nil {
@@ -579,9 +600,9 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if len(args) > 1 {
switch args[1] {
case "history":
//scanner.HistoryEnable()
scanner.HistoryEnable()
case "nohistory":
//scanner.HistoryDisable()
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
@@ -594,16 +615,6 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
case "quiet":
cmd.Flags().Set("verbose", "false")
fmt.Println("Set 'quiet' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
@@ -677,13 +688,26 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
}
if len(line) > 0 && line[0] != '/' {
if err := generate(cmd, model, line, wordWrap, format); err != nil {
if err := generate(cmd, model, line, wordWrap); err != nil {
return err
}
}
}
}
func generateBatch(cmd *cobra.Command, model string) error {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
prompt := scanner.Text()
fmt.Printf(">>> %s\n", prompt)
if err := generate(cmd, model, prompt, false); err != nil {
return err
}
}
return nil
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
if err != nil {
@@ -861,7 +885,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
serveCmd := &cobra.Command{
Use: "serve",

View File

@@ -41,36 +41,28 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
- `stream`: if `false` the response will be be returned as a single response object, rather than a stream of objects
### JSON mode
Enable JSON mode by setting the `format` parameter to `json` and specifying the model should use JSON in the `prompt`. This will structure the response as valid JSON. See the JSON mode [example](#request-json-mode) below.
### Examples
#### Request
### Request
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama2:7b",
"prompt": "Why is the sky blue?"
}'
```
#### Response
### Response
A stream of JSON objects is returned:
A stream of JSON objects:
```json
{
"model": "llama2",
"model": "llama2:7b",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"response": "The",
"done": false
@@ -94,7 +86,7 @@ To calculate how fast the response is generated in tokens per second (token/s),
```json
{
"model": "llama2",
"model": "llama2:7b",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "",
"context": [1, 2, 3],
@@ -110,182 +102,6 @@ To calculate how fast the response is generated in tokens per second (token/s),
}
```
#### Request (No streaming)
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2:7b",
"prompt": "Why is the sky blue?",
"stream": false
}'
```
#### Response
If `stream` is set to `false`, the response will be a single JSON object:
```json
{
"model": "llama2:7b",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 13,
"eval_duration": 1325948000
}
```
#### Request (Raw mode)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "mistral",
"prompt": "[INST] why is the sky blue? [/INST]",
"raw": true,
"stream": false
}'
```
#### Response
```json
{
"model": "mistral",
"created_at": "2023-11-03T15:36:02.583064Z",
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
"done": true,
"total_duration": 14648695333,
"load_duration": 3302671417,
"prompt_eval_count": 14,
"prompt_eval_duration": 286243000,
"eval_count": 129,
"eval_duration": 10931424000
}
```
#### Request (JSON mode)
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2",
"prompt": "What color is the sky at different times of the day? Respond using JSON",
"format": "json",
"stream": false
}'
```
#### Response
```json
{
"model": "llama2",
"created_at": "2023-11-09T21:07:55.186497Z",
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
"done": true,
"total_duration": 4661289125,
"load_duration": 1714434500,
"prompt_eval_count": 36,
"prompt_eval_duration": 264132000,
"eval_count": 75,
"eval_duration": 2112149000
}
```
The value of `response` will be a string containing JSON similar to:
```json
{
"morning": {
"color": "blue"
},
"noon": {
"color": "blue-gray"
},
"afternoon": {
"color": "warm gray"
},
"evening": {
"color": "orange"
}
}
```
#### Request (With options)
If you want to set custom options for the model at runtime rather than in the Modelfile, you can do so with the `options` parameter. This example sets every available option, but you can set any of them individually and omit the ones you do not want to override.
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2:7b",
"prompt": "Why is the sky blue?",
"stream": false,
"options": {
"num_keep": 5,
"seed": 42,
"num_predict": 100,
"top_k": 20,
"top_p": 0.9,
"tfs_z": 0.5,
"typical_p": 0.7,
"repeat_last_n": 33,
"temperature": 0.8,
"repeat_penalty": 1.2,
"presence_penalty": 1.5,
"frequency_penalty": 1.0,
"mirostat": 1,
"mirostat_tau": 0.8,
"mirostat_eta": 0.6,
"penalize_newline": true,
"stop": ["\n", "user:"],
"numa": false,
"num_ctx": 4,
"num_batch": 2,
"num_gqa": 1,
"num_gpu": 1,
"main_gpu": 0,
"low_vram": false,
"f16_kv": true,
"logits_all": false,
"vocab_only": false,
"use_mmap": true,
"use_mlock": false,
"embedding_only": false,
"rope_frequency_base": 1.1,
"rope_frequency_scale": 0.8,
"num_thread": 8
}
}'
```
#### Response
```json
{
"model": "llama2:7b",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 13,
"eval_duration": 1325948000
}
```
## Create a Model
```shell
@@ -298,11 +114,9 @@ Create a model from a [`Modelfile`](./modelfile.md)
- `name`: name of the model to create
- `path`: path to the Modelfile
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
### Examples
#### Request
### Request
```shell
curl -X POST http://localhost:11434/api/create -d '{
@@ -311,7 +125,7 @@ curl -X POST http://localhost:11434/api/create -d '{
}'
```
#### Response
### Response
A stream of JSON objects. When finished, `status` is `success`.
@@ -329,17 +143,13 @@ GET /api/tags
List models that are available locally.
### Examples
#### Request
### Request
```shell
curl http://localhost:11434/api/tags
```
#### Response
A single JSON object will be returned.
### Response
```json
{
@@ -370,9 +180,7 @@ Show details about a model including modelfile, template, parameters, license, a
- `name`: name of the model to show
### Examples
#### Request
### Request
```shell
curl http://localhost:11434/api/show -d '{
@@ -380,7 +188,7 @@ curl http://localhost:11434/api/show -d '{
}'
```
#### Response
### Response
```json
{
@@ -399,9 +207,7 @@ POST /api/copy
Copy a model. Creates a model with another name from an existing model.
### Examples
#### Request
### Request
```shell
curl http://localhost:11434/api/copy -d '{
@@ -410,10 +216,6 @@ curl http://localhost:11434/api/copy -d '{
}'
```
#### Response
The only response is a 200 OK if successful.
## Delete a Model
```shell
@@ -424,11 +226,9 @@ Delete a model and its data.
### Parameters
- `name`: model name to delete
- `model`: model name to delete
### Examples
#### Request
### Request
```shell
curl -X DELETE http://localhost:11434/api/delete -d '{
@@ -436,10 +236,6 @@ curl -X DELETE http://localhost:11434/api/delete -d '{
}'
```
#### Response
If successful, the only response is a 200 OK.
## Pull a Model
```shell
@@ -452,11 +248,9 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
- `name`: name of the model to pull
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
### Examples
#### Request
### Request
```shell
curl -X POST http://localhost:11434/api/pull -d '{
@@ -464,51 +258,13 @@ curl -X POST http://localhost:11434/api/pull -d '{
}'
```
#### Response
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
The first object is the manifest:
```json
{
"status": "pulling manifest"
}
```
Then there is a series of downloading responses. Until any of the download is completed, the `completed` key may not be included. The number of files to be downloaded depends on the number of layers specified in the manifest.
### Response
```json
{
"status": "downloading digestname",
"digest": "digestname",
"total": 2142590208,
"completed": 241970
}
```
After all the files are downloaded, the final responses are:
```json
{
"status": "verifying sha256 digest"
}
{
"status": "writing manifest"
}
{
"status": "removing any unused layers"
}
{
"status": "success"
}
```
if `stream` is set to false, then the response is a single JSON object:
```json
{
"status": "success"
"total": 2142590208
}
```
@@ -524,11 +280,9 @@ Upload a model to a model library. Requires registering for ollama.ai and adding
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects
### Examples
#### Request
### Request
```shell
curl -X POST http://localhost:11434/api/push -d '{
@@ -536,9 +290,9 @@ curl -X POST http://localhost:11434/api/push -d '{
}'
```
#### Response
### Response
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
Streaming response that starts with:
```json
{ "status": "retrieving manifest" }
@@ -571,12 +325,6 @@ Finally, when the upload is complete:
{"status":"success"}
```
If `stream` is set to `false`, then the response is a single JSON object:
```json
{ "status": "success" }
```
## Generate Embeddings
```shell
@@ -594,9 +342,7 @@ Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
### Examples
#### Request
### Request
```shell
curl -X POST http://localhost:11434/api/embeddings -d '{
@@ -605,7 +351,7 @@ curl -X POST http://localhost:11434/api/embeddings -d '{
}'
```
#### Response
### Response
```json
{

View File

@@ -74,25 +74,6 @@ systemctl restart ollama
- macOS: Raw model data is stored under `~/.ollama/models`.
- Linux: Raw model data is stored under `/usr/share/ollama/.ollama/models`
Below the models directory you will find a structure similar to the following:
```shell
.
├── blobs
└── manifests
└── registry.ollama.ai
├── f0rodo
├── library
├── mattw
└── saikatkumardey
```
There is a `manifests/registry.ollama.ai/namespace` path. In example above, the user has downloaded models from the official `library`, `f0rodo`, `mattw`, and `saikatkumardey` namespaces. Within each of those directories, you will find directories for each of the models downloaded. And in there you will find a file name representing each tag. Each tag file is the manifest for the model.
The manifest lists all the layers used in this model. You will see a `media type` for each layer, along with a digest. That digest corresponds with a file in the `models/blobs directory`.
### How can I change where Ollama stores models?
To modify where models are stored, you can use the `OLLAMA_MODELS` environment variable. Note that on Linux this means defining `OLLAMA_MODELS` in a drop-in `/etc/systemd/system/ollama.service.d` service file, reloading systemd, and restarting the ollama service.

View File

@@ -112,8 +112,8 @@ PARAMETER <parameter> <parametervalue>
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. | string | stop "AI assistant:" |
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |

View File

@@ -23,17 +23,13 @@ const answer = await ollama.call(`why is the sky blue?`);
console.log(answer);
```
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
```bash
npm install cheerio
```
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's build that part of the app.
```javascript
import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio";
const loader = new CheerioWebBaseLoader("https://en.wikipedia.org/wiki/2023_Hawaii_wildfires");
const data = await loader.load();
const data = loader.load();
```
That will load the document. Although this page is smaller than the Odyssey, it is certainly bigger than the context size for most LLMs. So we are going to need to split into smaller pieces, and then select just the pieces relevant to our question. This is a great use for a vector datastore. In this example, we will use the **MemoryVectorStore** that is part of **LangChain**. But there is one more thing we need to get the content into the datastore. We have to run an embeddings process that converts the tokens in the text into a series of vectors. And for that, we are going to use **Tensorflow**. There is a lot of stuff going on in this one. First, install the **Tensorflow** components that we need.

View File

@@ -1,488 +0,0 @@
package editor
import (
"fmt"
"strings"
"github.com/emirpasic/gods/lists/arraylist"
"golang.org/x/term"
)
type Buffer struct {
PosX int
PosY int
Buf []*arraylist.List
Prompt *Prompt
WordWrap int
ScreenWidth int
ScreenHeight int
}
func NewBuffer(prompt *Prompt) (*Buffer, error) {
width, height, err := term.GetSize(0)
if err != nil {
fmt.Println("Error getting size:", err)
return nil, err
}
b := &Buffer{
PosX: 0,
PosY: 0,
Buf: []*arraylist.List{arraylist.New()},
Prompt: prompt,
ScreenWidth: width,
ScreenHeight: height,
}
return b, nil
}
func (b *Buffer) LineWidth() int {
return b.ScreenWidth - len(b.Prompt.Prompt)
}
func (b *Buffer) findWordAtPos(line string, pos int) string {
return ""
}
func (b *Buffer) addLine(row int) {
if row+1 == len(b.Buf) {
b.Buf = append(b.Buf, arraylist.New())
} else {
b.Buf = append(b.Buf, nil)
copy(b.Buf[row+2:], b.Buf[row+1:])
b.Buf[row+1] = arraylist.New()
}
}
func (b *Buffer) Add(r rune) {
switch r {
case CharCtrlJ, CharEnter:
b.addLine(b.PosY)
// handle Ctrl-J in the middle of a line
var remainingText string
if b.PosX < b.Buf[b.PosY].Size() {
fmt.Print(ClearToEOL)
remainingText = b.StringLine(b.PosX, b.PosY)
for cnt := 0; cnt < len(remainingText); cnt++ {
b.Buf[b.PosY].Remove(b.Buf[b.PosY].Size() - 1)
b.Buf[b.PosY+1].Add(rune(remainingText[cnt]))
}
}
b.PosY++
b.PosX = 0
fmt.Printf("\n... " + ClearToEOL)
b.drawRemaining()
default:
if b.PosX == b.Buf[b.PosY].Size() {
fmt.Printf("%c", r)
b.PosX++
b.Buf[b.PosY].Add(r)
wrap, prefix, offset := b.splitLineInsert(b.PosY, b.PosX)
if wrap {
fmt.Print(CursorHide + cursorLeftN(len(prefix)+1) + ClearToEOL)
fmt.Printf("\n%s... %s%c", ClearToEOL, prefix, r)
b.PosY++
b.PosX = offset
b.ResetCursor()
b.drawRemaining()
fmt.Print(CursorShow)
}
} else {
fmt.Printf("%c", r)
b.Buf[b.PosY].Insert(b.PosX, r)
b.PosX++
_, prefix, offset := b.splitLineInsert(b.PosY, b.PosX)
fmt.Print(CursorHide)
if b.PosX > b.Buf[b.PosY].Size() {
if offset > 0 {
fmt.Print(cursorLeftN(offset))
}
fmt.Print(ClearToEOL + CursorDown + CursorBOL + ClearToEOL)
fmt.Printf("... %s", prefix[:offset])
b.PosY++
b.PosX = offset
b.ResetCursor()
}
b.drawRemaining()
fmt.Print(CursorShow)
}
}
}
func (b *Buffer) ResetCursor() {
fmt.Print(CursorHide + CursorBOL)
fmt.Print(cursorRightN(b.PosX + len(b.Prompt.Prompt)))
fmt.Print(CursorShow)
}
func (b *Buffer) splitLineInsert(posY, posX int) (bool, string, int) {
line := b.StringLine(0, posY)
screenEdge := b.LineWidth() - 5
// if the current line doesn't need to be reflowed, none of the other
// lines will either
if len(line) <= screenEdge {
return false, "", 0
}
// we know we're going to have to insert onto the next line, so
// add another line if there isn't one already
if posY == len(b.Buf)-1 {
b.Buf = append(b.Buf, arraylist.New())
}
// make a truncated version of the current line
currLine := line[:screenEdge]
// figure out where the last space in the line is
idx := strings.LastIndex(currLine, " ")
// deal with strings that don't have spaces in them
if idx == -1 {
idx = len(currLine) - 1
}
// if the next line already has text on it, we need
// to add a space to insert our new word
if b.Buf[posY+1].Size() > 0 {
b.Buf[posY+1].Insert(0, ' ')
}
// calculate the number of characters we need to remove
// from the current line to add to the next one
totalChars := len(line) - idx - 1
for cnt := 0; cnt < totalChars; cnt++ {
b.Buf[posY].Remove(b.Buf[posY].Size() - 1)
b.Buf[posY+1].Insert(0, rune(line[len(line)-1-cnt]))
}
// remove the trailing space
b.Buf[posY].Remove(b.Buf[posY].Size() - 1)
// wrap any further lines
if b.Buf[posY+1].Size() > b.LineWidth()-5 {
b.splitLineInsert(posY+1, 0)
}
return true, currLine[idx+1:], posX - idx - 1
}
func (b *Buffer) drawRemaining() {
remainingText := b.StringFromRow(b.PosY)
remainingText = remainingText[b.PosX:]
fmt.Print(CursorHide + ClearToEOL)
var rowCount int
for _, c := range remainingText {
fmt.Print(string(c))
if c == '\n' {
fmt.Print("... " + ClearToEOL)
rowCount++
}
}
if rowCount > 0 {
fmt.Print(cursorUpN(rowCount))
}
b.ResetCursor()
}
func (b *Buffer) findWordBeginning(posX int) int {
for {
if posX < 0 {
return -1
}
r, ok := b.Buf[b.PosY].Get(posX)
if !ok {
return -1
} else if r.(rune) == ' ' {
return posX
}
posX--
}
}
func (b *Buffer) Delete() {
if b.PosX < b.Buf[b.PosY].Size()-1 {
b.Buf[b.PosY].Remove(b.PosX)
b.drawRemaining()
} else {
b.joinLines()
}
}
func (b *Buffer) joinLines() {
lineLen := b.Buf[b.PosY].Size()
for cnt := 0; cnt < lineLen; cnt++ {
r, _ := b.Buf[b.PosY].Get(0)
b.Buf[b.PosY].Remove(0)
b.Buf[b.PosY-1].Add(r)
}
}
func (b *Buffer) Remove() {
if b.PosX > 0 {
fmt.Print(CursorLeft + " " + CursorLeft)
b.PosX--
b.Buf[b.PosY].Remove(b.PosX)
if b.PosX < b.Buf[b.PosY].Size() {
fmt.Print(ClearToEOL)
b.drawRemaining()
}
} else if b.PosX == 0 && b.PosY > 0 {
b.joinLines()
lastPos := b.Buf[b.PosY-1].Size()
var cnt int
b.PosX = lastPos
b.PosY--
fmt.Print(CursorHide)
for {
if b.PosX+cnt > b.LineWidth()-5 {
// the concatenated line won't fit, so find the beginning of the word
// and copy the rest of the string from there
idx := b.findWordBeginning(b.PosX)
lineLen := b.Buf[b.PosY].Size()
for offset := idx + 1; offset < lineLen; offset++ {
r, _ := b.Buf[b.PosY].Get(idx + 1)
b.Buf[b.PosY].Remove(idx + 1)
b.Buf[b.PosY+1].Add(r)
}
// remove the trailing space
b.Buf[b.PosY].Remove(idx)
fmt.Print(CursorUp + ClearToEOL)
b.PosX = 0
b.drawRemaining()
fmt.Print(CursorDown)
if idx > 0 {
if lastPos-idx-1 > 0 {
b.PosX = lastPos - idx - 1
b.ResetCursor()
}
}
b.PosY++
break
}
r, ok := b.Buf[b.PosY].Get(b.PosX + cnt)
if !ok {
// found the end of the string
fmt.Print(CursorUp + cursorRightN(b.PosX) + ClearToEOL)
b.drawRemaining()
break
}
if r == ' ' {
// found the end of the word
lineLen := b.Buf[b.PosY].Size()
for offset := b.PosX + cnt + 1; offset < lineLen; offset++ {
r, _ := b.Buf[b.PosY].Get(b.PosX + cnt + 1)
b.Buf[b.PosY].Remove(b.PosX + cnt + 1)
b.Buf[b.PosY+1].Add(r)
}
fmt.Print(CursorUp + cursorRightN(b.PosX) + ClearToEOL)
b.drawRemaining()
break
}
cnt++
}
fmt.Print(CursorShow)
}
}
func (b *Buffer) RemoveBefore() {
for {
if b.PosX == 0 && b.PosY == 0 {
break
}
b.Remove()
}
}
func (b *Buffer) RemoveWordBefore() {
if b.PosX > 0 || b.PosY > 0 {
var foundNonspace bool
for {
xPos := b.PosX
yPos := b.PosY
v, _ := b.Buf[yPos].Get(xPos - 1)
if v == ' ' {
if !foundNonspace {
b.Remove()
} else {
break
}
} else {
foundNonspace = true
b.Remove()
}
if xPos == 0 && yPos == 0 {
break
}
}
}
}
func (b *Buffer) StringLine(x, y int) string {
if y >= len(b.Buf) {
return ""
}
var output string
for cnt := x; cnt < b.Buf[y].Size(); cnt++ {
r, _ := b.Buf[y].Get(cnt)
output += string(r.(rune))
}
return output
}
func (b *Buffer) String() string {
return b.StringFromRow(0)
}
func (b *Buffer) StringFromRow(n int) string {
var output []string
for _, row := range b.Buf[n:] {
var currLine string
for cnt := 0; cnt < row.Size(); cnt++ {
r, _ := row.Get(cnt)
currLine += string(r.(rune))
}
currLine = strings.TrimRight(currLine, " ")
output = append(output, currLine)
}
return strings.Join(output, "\n")
}
func (b *Buffer) cursorUp() {
fmt.Print(CursorUp)
b.ResetCursor()
}
func (b *Buffer) cursorDown() {
fmt.Print(CursorDown)
b.ResetCursor()
}
func (b *Buffer) MoveUp() {
if b.PosY > 0 {
b.PosY--
if b.Buf[b.PosY].Size() < b.PosX {
b.PosX = b.Buf[b.PosY].Size()
}
b.cursorUp()
} else {
fmt.Print("\a")
}
}
func (b *Buffer) MoveDown() {
if b.PosY < len(b.Buf)-1 {
b.PosY++
if b.Buf[b.PosY].Size() < b.PosX {
b.PosX = b.Buf[b.PosY].Size()
}
b.cursorDown()
} else {
fmt.Print("\a")
}
}
func (b *Buffer) MoveLeft() {
if b.PosX > 0 {
b.PosX--
fmt.Print(CursorLeft)
} else if b.PosY > 0 {
b.PosX = b.Buf[b.PosY-1].Size()
b.PosY--
b.cursorUp()
} else if b.PosX == 0 && b.PosY == 0 {
fmt.Print("\a")
}
}
func (b *Buffer) MoveRight() {
if b.PosX < b.Buf[b.PosY].Size() {
b.PosX++
fmt.Print(CursorRight)
} else if b.PosY < len(b.Buf)-1 {
b.PosY++
b.PosX = 0
b.cursorDown()
} else {
fmt.Print("\a")
}
}
func (b *Buffer) MoveToBOL() {
if b.PosX > 0 {
b.PosX = 0
b.ResetCursor()
}
}
func (b *Buffer) MoveToEOL() {
if b.PosX < b.Buf[b.PosY].Size() {
b.PosX = b.Buf[b.PosY].Size()
b.ResetCursor()
}
}
func (b *Buffer) MoveToEnd() {
fmt.Print(CursorHide)
yDiff := len(b.Buf)-1 - b.PosY
if yDiff > 0 {
fmt.Print(cursorDownN(yDiff))
}
b.PosY = len(b.Buf)-1
b.MoveToEOL()
fmt.Print(CursorShow)
}
func cursorLeftN(n int) string {
return fmt.Sprintf(CursorLeftN, n)
}
func cursorRightN(n int) string {
return fmt.Sprintf(CursorRightN, n)
}
func cursorUpN(n int) string {
return fmt.Sprintf(CursorUpN, n)
}
func cursorDownN(n int) string {
return fmt.Sprintf(CursorDownN, n)
}
func (b *Buffer) ClearScreen() {
fmt.Printf(CursorHide + ClearScreen + CursorReset + b.Prompt.Prompt)
if b.IsEmpty() {
ph := b.Prompt.Placeholder
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
} else {
currPosX := b.PosX
currPosY := b.PosY
b.PosX = 0
b.PosY = 0
b.drawRemaining()
b.PosX = currPosX
b.PosY = currPosY
fmt.Print(CursorReset + cursorRightN(len(b.Prompt.Prompt)))
if b.PosY > 0 {
fmt.Print(cursorDownN(b.PosY))
}
if b.PosX > 0 {
fmt.Print(cursorRightN(b.PosX))
}
}
fmt.Print(CursorShow)
}
func (b *Buffer) IsEmpty() bool {
return len(b.Buf) == 1 && b.Buf[0].Empty()
}

View File

@@ -1,10 +0,0 @@
# Bash Shell examples
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
`ollama run llama2 < sourcequestions.txt`
This concept is used in the following example.
## Compare Models
`comparemodels.sh` is a script that runs all the questions in `sourcequestions.txt` using any 4 models you choose that you have already pulled from the Ollama library or have created locally.

View File

@@ -1,64 +0,0 @@
#! /usr/bin/env bash
# Compare multiple models by running them with the same questions
NUMBEROFCHOICES=4
SELECTIONS=()
declare -a SUMS=()
# Get the list of models
CHOICES=$(ollama list | awk '{print $1}')
# Select which models to run as a comparison
echo "Select $NUMBEROFCHOICES models to compare:"
select ITEM in $CHOICES; do
if [[ -n $ITEM ]]; then
echo "You have selected $ITEM"
SELECTIONS+=("$ITEM")
((COUNT++))
if [[ $COUNT -eq $NUMBEROFCHOICES ]]; then
break
fi
else
echo "Invalid selection"
fi
done
# Loop through each of the selected models
for ITEM in "${SELECTIONS[@]}"; do
echo "--------------------------------------------------------------"
echo "Loading the model $ITEM into memory"
ollama run "$ITEM" ""
echo "--------------------------------------------------------------"
echo "Running the questions through the model $ITEM"
COMMAND_OUTPUT=$(ollama run "$ITEM" --verbose < sourcequestions.txt 2>&1| tee /dev/stderr)
# eval duration is sometimes listed in seconds and sometimes in milliseconds.
# Add up the values for each model
SUM=$(echo "$COMMAND_OUTPUT" | awk '
/eval duration:/ {
value = $3
if (index(value, "ms") > 0) {
gsub("ms", "", value)
value /= 1000
} else {
gsub("s", "", value)
}
sum += value
}
END { print sum }')
SUMS+=("All questions for $ITEM completed in $SUM seconds")
done
echo ""
echo "--------------------------------------------------------------"
echo -e "Sums of eval durations for each run:"
for val in "${SUMS[@]}"; do
echo "$val"
done
echo "--------------------------------------------------------------"
echo "Comparison complete. Now you can decide"
echo "which model is best."
echo "--------------------------------------------------------------"

View File

@@ -1,7 +0,0 @@
Why is the sky blue
What is a black hole
Explain the big bang theory like I am 5?
What is the quickest way to win a game of Monopoly with 3 others?
Why does a vacuum bottle keep my coffee hot and my milkshake cold?
What is the difference between a meteor, a meteorite, and a meteoroid?
Create an array with 5 items and print to the console. Do this in Python, C#, Typescript, and Rust.

View File

@@ -1,36 +0,0 @@
# Deploy Ollama to Kubernetes
## Prerequisites
- Ollama: https://ollama.ai/download
- Kubernetes cluster. This example will use Google Kubernetes Engine.
## Steps
1. Create the Ollama namespace, daemon set, and service
```bash
kubectl apply -f cpu.yaml
```
1. Port forward the Ollama service to connect and use it locally
```bash
kubectl -n ollama port-forward service/ollama 11434:80
```
1. Pull and run a model, for example `orca-mini:3b`
```bash
ollama run orca-mini:3b
```
## (Optional) Hardware Acceleration
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
Once configured, create a GPU enabled Ollama deployment.
```bash
kubectl apply -f gpu.yaml
```

View File

@@ -1,42 +0,0 @@
---
apiVersion: v1
kind: Namespace
metadata:
name: ollama
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: ollama
namespace: ollama
spec:
selector:
matchLabels:
name: ollama
template:
metadata:
labels:
name: ollama
spec:
containers:
- name: ollama
image: ollama/ollama:latest
ports:
- name: http
containerPort: 11434
protocol: TCP
---
apiVersion: v1
kind: Service
metadata:
name: ollama
namespace: ollama
spec:
type: ClusterIP
selector:
name: ollama
ports:
- port: 80
name: http
targetPort: http
protocol: TCP

View File

@@ -1,56 +0,0 @@
---
apiVersion: v1
kind: Namespace
metadata:
name: ollama
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: ollama
namespace: ollama
spec:
strategy:
type: Recreate
selector:
matchLabels:
name: ollama
template:
metadata:
labels:
name: ollama
spec:
containers:
- name: ollama
image: ollama/ollama:latest
env:
- name: PATH
value: /usr/local/nvidia/bin:/usr/local/nvidia/lib64:/usr/bin:/usr/sbin:/bin:/sbin
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64
ports:
- name: http
containerPort: 11434
protocol: TCP
resources:
limits:
nvidia.com/gpu: 1
tolerations:
- key: nvidia.com/gpu
operator: Exists
effect: NoSchedule
---
apiVersion: v1
kind: Service
metadata:
name: ollama
namespace: ollama
spec:
type: ClusterIP
selector:
name: ollama
ports:
- port: 80
name: http
targetPort: http
protocol: TCP

View File

@@ -0,0 +1,2 @@
node_modules
artcollection

View File

@@ -0,0 +1,73 @@
import { Chroma } from "langchain/vectorstores/chroma";
import { ChromaTranslator } from "langchain/retrievers/self_query/chroma";
import { Ollama } from "langchain/llms/ollama"
import { AttributeInfo } from "langchain/schema/query_constructor";
import { HuggingFaceTransformersEmbeddings } from "langchain/embeddings/hf_transformers";
import { SelfQueryRetriever } from "langchain/retrievers/self_query";
const modelName = "codellama";
// Define the attributes of the schema so that the model will know what to look for
const attributeInfo: AttributeInfo[] = [
{
name: "title",
type: "string",
description: "The title of the painting"
},
{
name: "date",
type: "integer",
description: "The four digit year when the painting was created"
},
{
name: "artistName",
type: "string",
description: "The first name and last name of the artist who created the painting. Always use the full name in the filter, even if it isn't included. If the query is 'van Gogh', the filter should be 'Vincent van Gogh'. Use Pierre-Auguste Renoir instead of just Renoir."
}
]
// Define the model used to generate embeddings, these capture the context of the input data
const embeddings = new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/all-MiniLM-L6-v2",
});
// Run the model using Ollama
const llm = new Ollama({
model: modelName
})
const documentContents = "Description of the art";
const findArt = async () => {
// Load the saved vector store
const vectorStore = await Chroma.fromExistingCollection(embeddings, {
collectionName: "artcollection",
});
const retriever = SelfQueryRetriever.fromLLM({
llm, vectorStore, documentContents, attributeInfo, verbose: false, useOriginalQuery: true, structuredQueryTranslator: new ChromaTranslator()
});
// Get the query from the command line
const query = process.argv[2];
try {
const newquery = await retriever.getRelevantDocuments(query, [
// You can add callbacks to the retriever to get information about the process. In this case, show the output
// query from the LLM used to retrieve the documents
{
handleLLMEnd(output) {
console.log("This is the output from the LLM after it has come up with a filter")
const llmEndOutput = output.generations[0][0].text.replace(/\\"/gm, "'").replace(/\n/gm, "")
console.log(`output - ${JSON.stringify(llmEndOutput, null, 2)}`)
}
},
]);
console.log(newquery);
} catch (error) {
console.log(`There was an error getting the values: ${error}`);
}
}
findArt();

View File

@@ -0,0 +1,128 @@
import { Artwork, RawArtwork } from './types';
import { HuggingFaceTransformersEmbeddings } from 'langchain/embeddings/hf_transformers';
import { Chroma } from "langchain/vectorstores/chroma";
import { Document } from "langchain/document";
import { ChromaClient } from "chromadb";
const numberOfArtworks = 10;
// list of artists we are going to pull from the API
const artists = ["van Gogh", "Renoir", "Monet", "Picasso"]
const generateSource = async () => {
// Delete the existing vector store so that we don't get duplicate documents
await new ChromaClient().deleteCollection({
name: "artcollection",
});
const allartworkdocs = await getArt(artists);
// Create the vector store
const vectorStore = await Chroma.fromDocuments(allartworkdocs, embedding, { collectionName: "artcollection" });
console.log(`Created vector store with ${await vectorStore.collection?.count()} documents`);
}
const getArt = async (artists: string[]) => {
const artworks: Artwork[] = [];
const artistsWorkIds: number[] = []
for (const artist of artists) {
// First get the ids of the works by each artist
const thisIds = await fetchArtistWorkIds(artist);
console.log(`Fetching ${artist}`);
await (new Promise(r => setTimeout(r, 1000)));
artistsWorkIds.push(...thisIds);
};
// now get the actual artwork
const artwork = await fetchArtwork(artistsWorkIds);
return artwork
}
const fetchArtistWorkIds = async (artist: string): Promise<number[]> => {
const artistURL = `https://api.artic.edu/api/v1/artworks/search?q=${artist}&limit=${numberOfArtworks}`;
const response = await fetch(artistURL);
const json = await response.json();
const artistWorks: { id: number }[] = json.data;
return artistWorks.map((work) => work.id);
}
const embedding = new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/all-MiniLM-L6-v2",
});
//Turns out there are some weird characters in the descriptions
const sanitize = (badstring: string): string => {
let goodstring = " ";
if (badstring !== null) {
goodstring = badstring
.replace(/<\s*a\s+[^>]*href\s*=\s*[\"']?([^\"' >]+)[\"' >]>/gm, "")
.replace(/<\/a>/gm, "")
.replace(/<\/?em>/gm, "")
.replace(/[\u2018\u2019]/gm, "")
.replace(/[\u201C\u201D]/gm, "")
.replace(/[\u2013\u2014]/gm, "-")
.replace(/[\u2026]/gm, "...")
.replace(/[\u00A0]/gm, " ")
.replace(/[\u00AD]/gm, "-")
.replace(/[\u00B0]/gm, " degrees ")
.replace(/[\u00B1]/gm, " plus or minus ")
.replace(/[\u00B2]/gm, " squared ")
.replace(/[\u00B3]/gm, " cubed ")
.replace(/[\u00B4]/gm, "'")
.replace(/[\u00B5]/gm, " micro ")
.replace(/[\u00B6]/gm, " paragraph ")
.replace(/[\u00B7]/gm, " dot ")
.replace(/[\u00B8]/gm, ",")
.replace(/[\u00B9]/gm, " first ")
.replace(/[\u00BA]/gm, " degrees ")
.replace(/[\u00BB]/gm, ">>")
.replace(/[\u00BC]/gm, " 1/4 ")
.replace(/[\u00BD]/gm, " 1/2 ")
.replace(/[\uFB01]/gm, "fi")
.replace(/[\uFB02]/gm, "fl")
.replace(/[\uFB03]/gm, "ffi")
.replace(/[\uFB04]/gm, "ffl")
.replace(/[\uFB05]/gm, "ft")
.replace(/[\uFB06\uFB07\uFB08]/gm, "st")
.replace(/[\u00D7]/gm, "x")
.replace(/[\u00E8\u00E9]/gm, "e")
.replace(/[\u00F1]/gm, "n")
.replace(/[\u00F6]/gm, "o")
.replace(/[\u00F8]/gm, "o")
.replace(/[\u00FC]/gm, "u")
.replace(/[\u00FF]/gm, "y")
.replace(/[\u0101\u0103\u00E0]/gm, "a")
.replace(/[\u00C9]/gm, "E")
.replace(/<p>/gm, "")
.replace(/<\/p>/gm, "")
.replace(/\n/gm, "");
};
return goodstring;
}
const fetchArtwork = async (workids: number[]) => {
const docsarray = [];
const artworks: Artwork[] = [];
for await (const workid of workids) {
const artworkURL = `https://api.artic.edu/api/v1/artworks/${workid}`;
const response = await fetch(artworkURL);
const json = await response.json();
const artworkraw: RawArtwork = await json.data as RawArtwork;
const description = sanitize(artworkraw.description)
if (description !== " ") {
const doc = new Document({
pageContent: description,
metadata: {
title: sanitize(artworkraw.title),
date: artworkraw.date_end,
artistName: artworkraw.artist_title,
}
});
docsarray.push(doc);
console.log("------------------")
console.log(`${artworkraw.title} - ${artworkraw.artist_title}`);
}
}
return docsarray;
}
generateSource();

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
{
"name": "typescript-selfqueryingretreival",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"dependencies": {
"@xenova/transformers": "^2.7.0",
"chromadb": "^1.5.11",
"langchain": "^0.0.177",
"ollama-node": "^0.1.24",
"peggy": "^3.0.2",
"sharp": "^0.32.6"
}
}

View File

@@ -0,0 +1,111 @@
# Self Query Retrieval
Filtering your vector database results to get better answers from your LLM.
![sqr 2023-11-05 14_30_50](https://github.com/jmorganca/ollama/assets/633681/55afb7f5-ebd8-4c58-86ba-284594fd1ec8)
## TLDR
1. Install and run ChromaDB
1. Run `git clone https://github.com/chroma-core/chroma.git`
2. `cd chroma`
3. `docker-compose up -d --build`
2. Navigate to this example's directory
3. `npm install`
4. `tsx ./GenerateSource.ts`
5. `tsx ./FindArt.ts "are there any paintings from the artist Pablo Picasso"`
Other questions to try:
- Are there any paintings painted in 1881
- Are there any paintings painted by Vincent van Gogh
Note: If you haven't used `tsx`, it's a more modern alternate to `ts-node` and works especially well when you have libraries that use different module types. You can find it at [https://github.com/esbuild-kit/tsx](https://github.com/esbuild-kit/tsx).
## Introduction
Retrieval Augmented Generation (RAG) is what developers usually reach for when they want to ask questions to all of their notes. But often it doesn't give the results you need. And that's because there is still too much information. And frequently it's the wrong information. When you ask a question, RAG will retrieve a set of documents that it thinks are relevant to the question and then hand them off to the LLM. If you ask "what is a transformer", it may grab excerpts from the Transformers paper you read recently, along with sections of your Intro to Electronics book. Even if you ask a better question, such as "what is a transformer in the context of electrical engineering", it may still grab excerpts from the Transformers paper. And that's because the Transformers paper is a very good match for the question. It's just not the right match.
Ideally, the Transformers paper and the Electronics book would be added to the database with some metadata, such as the topics or keywords. But RAG typically doesn't look at those metadata fields. And that's where Self Query Retrieval comes in. It's a way to use traditional database queries to narrow down the set of documents that RAG will use and thus get better results.
## How it works
There are a few things you need to do to enable Self Query Retrieval. First, there needs to be additional metadata about your content in the database. The examples in the Langchain documentation are based on movies, and the metadata includes the year, the director's name, the genre, etc. And then you need to pass the schema to the query to help it get the right documents.
## The code
There are two main parts to the code. First there is a `GenerateSource.ts` file and then there is a `FindArt.ts` file. Let's look at GenerateSource first.
### GenerateSource
The purpose of Generate Source is to create our data source. For this example, we are using the [Chicago Institute of Art API,](https://api.artic.edu/docs/#introduction) which is incredible. This will be loaded into a vector database, which for this example is ChromaDB.
This could be any CSV file or other data source you have access to. The file would have a single descriptive column and then metadata columns. All the relevant columns from our dataset are being added to a Document object. Then that array of Documents is being loaded into ChromaDB. Finally, at the end, I verify that documents were created by outputting a count to the screen.
```typescript
await new ChromaClient().deleteCollection({
name: "artcollection",
});
const vectorStore = await Chroma.fromDocuments(allartworkdocs,
embedding, { collectionName: "artcollection" });
console.log(`Created vector store with
${await vectorStore.collection?.count()} documents`);
```
### FindArt
To actually find the art, we need to start by loading the database:
```typescript
const vectorStore = await Chroma.fromExistingCollection(embeddings, {
collectionName: "artcollection",
});
```
Now we can create our Self Query Retriever. This needs to be created referring to the LLM, the database, the description of the document and the description of all the attributes in the metadata, and finally a structured query translator which will take the query generated by the LLM and turn it into something useable by the database.
```typescript
const llm = new Ollama({
model: modelName
})
const documentContents = "Description of the art";
const attributeInfo: AttributeInfo[] = [
{
name: "title",
type: "string",
description: "The title of the painting"
},
{
name: "date",
type: "integer",
description: "The four digit year when the painting was created"
},
{
name: "artistName",
type: "string",
description: "The first name and last name of the artist who created the painting. Always use the full name in the filter, even if it isn't included. If the query is 'van Gogh', the filter should be 'Vincent van Gogh'. Use Pierre-Auguste Renoir instead of just Renoir."
}
]
const retriever = SelfQueryRetriever.fromLLM({
llm, vectorStore, documentContents, attributeInfo, verbose: false, useOriginalQuery: true, structuredQueryTranslator: new ChromaTranslator()
});
```
Now we can ask a question and get the results:
```typescript
const newquery = await retriever.getRelevantDocuments(query)
```
## Next Steps
When you run this example, you will get a set of documents from the database that may be a bit more relevant to your question. Now you could feed those to the LLM and get the actual answer to the question based on these documents.
To take this further, you could work on getting more out of the dataset. It turns out that this works best if there is only a single possible value for any given field. Our artists are often referred to by their last name, but sometimes using their full name. It may be Vincent van Gogh, or just van Gogh. Another way to get around this is to build a better query translator that knows that the search could be for a substring of the full name. But that also requires looking into the metadata searching capabilities of the database.
Maybe it makes more sense to move the artist name and title of the work into the document itself. Then add some more metadata (there are at least 100 other attributes in the raw API that aren't used in this example.)
Also try different models. In testing so far, it seems that `codellama` produces more reliably useable filters. It's not perfect and can still create a filter that won't find anything. When a new code model comes out, you might try that to see if it performs better.

View File

@@ -0,0 +1,10 @@
{
"compilerOptions": {
"target": "es2016",
"module": "commonjs", /* Specify what module code is generated. */
"esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
"forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */
"strict": true, /* Enable all strict type-checking options. */
"skipLibCheck": true /* Skip type checking all .d.ts files. */
}
}

View File

@@ -0,0 +1,26 @@
export type RawArtwork = {
id: number;
title: string;
artist_display: string;
place_of_origin: string;
date_start: number;
date_end: number;
duration: number;
dimensions: string;
medium_display: string;
credit_line: string;
artwork_type_title: string;
department_title: string;
artist_title: string;
classification_title: string;
description: string;
}
export type Artwork = {
id: number;
title: string;
country: string;
date: number;
artist: string;
description: string;
}

View File

@@ -17,7 +17,7 @@ def generate(prompt, context):
for line in r.iter_lines():
body = json.loads(line)
response_part = body.get('response', '')
# the response streams one token at a time, print that as we receive it
# the response streams one token at a time, print that as we recieve it
print(response_part, end='', flush=True)
if 'error' in body:
@@ -35,4 +35,4 @@ def main():
print()
if __name__ == "__main__":
main()
main()

View File

@@ -1,25 +0,0 @@
package format
import (
"fmt"
"math"
)
const (
Thousand = 1000
Million = Thousand * 1000
Billion = Million * 1000
)
func HumanNumber(b uint64) string {
switch {
case b > Billion:
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
case b > Million:
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
case b > Thousand:
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
default:
return fmt.Sprintf("%d", b)
}
}

1
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
golang.org/x/sync v0.3.0
gonum.org/v1/gonum v0.14.0
)
require github.com/rivo/uniseg v0.2.0 // indirect

2
go.sum
View File

@@ -140,6 +140,8 @@ golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View File

@@ -5,8 +5,6 @@ import (
"encoding/binary"
"fmt"
"io"
"github.com/jmorganca/ollama/format"
)
type containerGGUF struct {
@@ -23,8 +21,6 @@ type containerGGUF struct {
NumTensor uint64
NumKV uint64
}
parameters uint64
}
func (c *containerGGUF) Name() string {
@@ -79,14 +75,6 @@ func newGGUFModel(container *containerGGUF) *ggufModel {
}
}
func (llm *ggufModel) NumTensor() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumTensor)
}
return llm.V2.NumTensor
}
func (llm *ggufModel) NumKV() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumKV)
@@ -105,10 +93,6 @@ func (llm *ggufModel) ModelFamily() string {
}
func (llm *ggufModel) ModelType() string {
if llm.parameters > 0 {
return format.HumanNumber(llm.parameters)
}
switch llm.ModelFamily() {
case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
@@ -143,9 +127,13 @@ func (llm *ggufModel) FileType() string {
}
func (llm *ggufModel) Decode(r io.Reader) error {
// decode key-values
read := llm.readString
if llm.Version == 1 {
read = llm.readStringV1
}
for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := llm.readString(r)
k, err := read(r)
if err != nil {
return err
}
@@ -177,14 +165,24 @@ func (llm *ggufModel) Decode(r io.Reader) error {
case ggufTypeBool:
v = llm.readBool(r)
case ggufTypeString:
s, err := llm.readString(r)
fn := llm.readString
if llm.Version == 1 {
fn = llm.readStringV1
}
s, err := fn(r)
if err != nil {
return err
}
v = s
case ggufTypeArray:
a, err := llm.readArray(r)
fn := llm.readArray
if llm.Version == 1 {
fn = llm.readArrayV1
}
a, err := fn(r)
if err != nil {
return err
}
@@ -197,25 +195,6 @@ func (llm *ggufModel) Decode(r io.Reader) error {
llm.kv[k] = v
}
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
return err
}
dimensions := llm.readU32(r)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
}
llm.readU32(r) // type
llm.readU64(r) // offset
llm.parameters += elements
}
return nil
}
@@ -311,10 +290,6 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
}
func (llm ggufModel) readString(r io.Reader) (string, error) {
if llm.Version == 1 {
return llm.readStringV1(r)
}
var nameLength uint64
binary.Read(r, llm.bo, &nameLength)
@@ -364,10 +339,6 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
}
func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
if llm.Version == 1 {
return llm.readArrayV1(r)
}
atype := llm.readU32(r)
n := llm.readU64(r)

View File

@@ -27,34 +27,6 @@ import (
"github.com/jmorganca/ollama/format"
)
const jsonGrammar = `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
`
//go:embed llama.cpp/*/build/*/bin/*
var llamaCppEmbed embed.FS
@@ -224,10 +196,7 @@ type llama struct {
Running
}
var (
errNvidiaSMI = errors.New("nvidia-smi command failed")
errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
)
var errNoGPU = errors.New("nvidia-smi command failed")
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) {
@@ -236,7 +205,7 @@ func CheckVRAM() (int64, error) {
cmd.Stdout = &stdout
err := cmd.Run()
if err != nil {
return 0, errNvidiaSMI
return 0, errNoGPU
}
var freeMiB int64
@@ -257,8 +226,8 @@ func CheckVRAM() (int64, error) {
freeBytes := freeMiB * 1024 * 1024
if freeBytes < 2*format.GigaByte {
log.Printf("less than 2 GB VRAM available")
return 0, errAvailableVRAM
log.Printf("less than 2 GB VRAM available, falling back to CPU only")
freeMiB = 0
}
return freeBytes, nil
@@ -271,7 +240,7 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
if runtime.GOOS == "linux" {
freeBytes, err := CheckVRAM()
if err != nil {
if !errors.Is(err, errNvidiaSMI) {
if err.Error() != "nvidia-smi command failed" {
log.Print(err.Error())
}
// nvidia driver not installed or no nvidia GPU found
@@ -337,19 +306,13 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params := []string{
"--model", model,
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--n-gpu-layers", fmt.Sprintf("%d", numGPU),
"--embedding",
}
if opts.RopeFrequencyBase > 0 {
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
}
if opts.RopeFrequencyScale > 0 {
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
}
if opts.NumGQA > 0 {
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
}
@@ -397,15 +360,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
runner.Path,
append(params, "--port", strconv.Itoa(port))...,
)
var libraryPaths []string
if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, libraryPath)
}
libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
cmd.Stdout = os.Stderr
statusWriter := NewStatusWriter()
cmd.Stderr = statusWriter
@@ -525,7 +480,7 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
@@ -560,10 +515,6 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
request["grammar"] = jsonGrammar
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, []int, string, func(api.GenerateResponse)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
@@ -85,10 +85,7 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
switch ggml.Name() {
case "gguf":
// TODO: gguf will load these options automatically from the model binary
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)

View File

@@ -291,7 +291,7 @@ func OptionShowDescriptionAtLineEnd() Option {
}
}
var defaultTheme = Theme{Saucer: "█", SaucerPadding: " ", BarStart: "", BarEnd: ""}
var defaultTheme = Theme{Saucer: "█", SaucerPadding: " ", BarStart: "|", BarEnd: "|"}
// NewOptions constructs a new instance of ProgressBar, with any options you specify
func NewOptions(max int, options ...Option) *ProgressBar {

372
readline/buffer.go Normal file
View File

@@ -0,0 +1,372 @@
package readline
import (
"fmt"
"os"
"github.com/emirpasic/gods/lists/arraylist"
"golang.org/x/term"
)
type Buffer struct {
Pos int
Buf *arraylist.List
Prompt *Prompt
LineWidth int
Width int
Height int
}
func NewBuffer(prompt *Prompt) (*Buffer, error) {
fd := int(os.Stdout.Fd())
width, height, err := term.GetSize(fd)
if err != nil {
fmt.Println("Error getting size:", err)
return nil, err
}
lwidth := width - len(prompt.Prompt)
if prompt.UseAlt {
lwidth = width - len(prompt.AltPrompt)
}
b := &Buffer{
Pos: 0,
Buf: arraylist.New(),
Prompt: prompt,
Width: width,
Height: height,
LineWidth: lwidth,
}
return b, nil
}
func (b *Buffer) MoveLeft() {
if b.Pos > 0 {
if b.Pos%b.LineWidth == 0 {
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
} else {
fmt.Print(CursorLeft)
}
b.Pos -= 1
}
}
func (b *Buffer) MoveLeftWord() {
if b.Pos > 0 {
var foundNonspace bool
for {
v, _ := b.Buf.Get(b.Pos - 1)
if v == ' ' {
if foundNonspace {
break
}
} else {
foundNonspace = true
}
b.MoveLeft()
if b.Pos == 0 {
break
}
}
}
}
func (b *Buffer) MoveRight() {
if b.Pos < b.Size() {
b.Pos += 1
if b.Pos%b.LineWidth == 0 {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(b.PromptSize()))
} else {
fmt.Print(CursorRight)
}
}
}
func (b *Buffer) MoveRightWord() {
if b.Pos < b.Size() {
for {
b.MoveRight()
v, _ := b.Buf.Get(b.Pos)
if v == ' ' {
break
}
if b.Pos == b.Size() {
break
}
}
}
}
func (b *Buffer) MoveToStart() {
if b.Pos > 0 {
currLine := b.Pos / b.LineWidth
if currLine > 0 {
for cnt := 0; cnt < currLine; cnt++ {
fmt.Print(CursorUp)
}
}
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()))
b.Pos = 0
}
}
func (b *Buffer) MoveToEnd() {
if b.Pos < b.Size() {
currLine := b.Pos / b.LineWidth
totalLines := b.Size() / b.LineWidth
if currLine < totalLines {
for cnt := 0; cnt < totalLines-currLine; cnt++ {
fmt.Print(CursorDown)
}
remainder := b.Size() % b.LineWidth
fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()+remainder))
} else {
fmt.Print(cursorRightN(b.Size() - b.Pos))
}
b.Pos = b.Size()
}
}
func (b *Buffer) Size() int {
return b.Buf.Size()
}
func min(n, m int) int {
if n > m {
return m
}
return n
}
func (b *Buffer) PromptSize() int {
if b.Prompt.UseAlt {
return len(b.Prompt.AltPrompt)
}
return len(b.Prompt.Prompt)
}
func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() {
fmt.Printf("%c", r)
b.Buf.Add(r)
b.Pos += 1
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
}
} else {
fmt.Printf("%c", r)
b.Buf.Insert(b.Pos, r)
b.Pos += 1
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
}
b.drawRemaining()
}
}
func (b *Buffer) drawRemaining() {
var place int
remainingText := b.StringN(b.Pos)
if b.Pos > 0 {
place = b.Pos % b.LineWidth
}
fmt.Print(CursorHide)
// render the rest of the current line
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))]
if len(currLine) > 0 {
fmt.Printf(ClearToEOL + currLine)
fmt.Print(cursorLeftN(len(currLine)))
} else {
fmt.Print(ClearToEOL)
}
// render the other lines
if len(remainingText) > len(currLine) {
remaining := []rune(remainingText[len(currLine):])
var totalLines int
for i, c := range remaining {
if i%b.LineWidth == 0 {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
totalLines += 1
}
fmt.Printf("%c", c)
}
fmt.Print(ClearToEOL)
fmt.Print(cursorUpN(totalLines))
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine)))
}
fmt.Print(CursorShow)
}
func (b *Buffer) Remove() {
if b.Buf.Size() > 0 && b.Pos > 0 {
if b.Pos%b.LineWidth == 0 {
// if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
} else {
fmt.Printf(CursorLeft + " " + CursorLeft)
}
var eraseExtraLine bool
if (b.Size()-1)%b.LineWidth == 0 {
eraseExtraLine = true
}
b.Pos -= 1
b.Buf.Remove(b.Pos)
if b.Pos < b.Size() {
b.drawRemaining()
// this erases a line which is left over when backspacing in the middle of a line and there
// are trailing characters which go over the line width boundary
if eraseExtraLine {
remainingLines := (b.Size() - b.Pos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
place := b.Pos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.Prompt)))
}
}
}
}
func (b *Buffer) Delete() {
if b.Size() > 0 && b.Pos < b.Size() {
b.Buf.Remove(b.Pos)
b.drawRemaining()
if b.Size()%b.LineWidth == 0 {
if b.Pos != b.Size() {
remainingLines := (b.Size() - b.Pos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
place := b.Pos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.Prompt)))
}
}
}
}
func (b *Buffer) DeleteBefore() {
if b.Pos > 0 {
for cnt := b.Pos - 1; cnt >= 0; cnt-- {
b.Remove()
}
}
}
func (b *Buffer) DeleteRemaining() {
if b.Size() > 0 && b.Pos < b.Size() {
charsToDel := b.Size() - b.Pos
for cnt := 0; cnt < charsToDel; cnt++ {
b.Delete()
}
}
}
func (b *Buffer) DeleteWord() {
if b.Buf.Size() > 0 && b.Pos > 0 {
var foundNonspace bool
for {
v, _ := b.Buf.Get(b.Pos - 1)
if v == ' ' {
if !foundNonspace {
b.Remove()
} else {
break
}
} else {
foundNonspace = true
b.Remove()
}
if b.Pos == 0 {
break
}
}
}
}
func (b *Buffer) ClearScreen() {
fmt.Printf(ClearScreen + CursorReset + b.Prompt.Prompt)
if b.IsEmpty() {
ph := b.Prompt.Placeholder
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
} else {
currPos := b.Pos
b.Pos = 0
b.drawRemaining()
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.Prompt)))
if currPos > 0 {
targetLine := currPos / b.LineWidth
if targetLine > 0 {
for cnt := 0; cnt < targetLine; cnt++ {
fmt.Print(CursorDown)
}
}
remainder := currPos % b.LineWidth
if remainder > 0 {
fmt.Print(cursorRightN(remainder))
}
if currPos%b.LineWidth == 0 {
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
}
}
b.Pos = currPos
}
}
func (b *Buffer) IsEmpty() bool {
return b.Buf.Empty()
}
func (b *Buffer) Replace(r []rune) {
b.Pos = 0
b.Buf.Clear()
fmt.Printf(ClearLine + CursorBOL + b.Prompt.Prompt)
for _, c := range r {
b.Add(c)
}
}
func (b *Buffer) String() string {
return b.StringN(0)
}
func (b *Buffer) StringN(n int) string {
return b.StringNM(n, 0)
}
func (b *Buffer) StringNM(n, m int) string {
var s string
if m == 0 {
m = b.Size()
}
for cnt := n; cnt < m; cnt++ {
c, _ := b.Buf.Get(cnt)
s += string(c.(rune))
}
return s
}
func cursorLeftN(n int) string {
return fmt.Sprintf(CursorLeftN, n)
}
func cursorRightN(n int) string {
return fmt.Sprintf(CursorRightN, n)
}
func cursorUpN(n int) string {
return fmt.Sprintf(CursorUpN, n)
}
func cursorDownN(n int) string {
return fmt.Sprintf(CursorDownN, n)
}

View File

@@ -1,4 +1,4 @@
package editor
package readline
import (
"errors"

152
readline/history.go Normal file
View File

@@ -0,0 +1,152 @@
package readline
import (
"bufio"
"errors"
"io"
"os"
"path/filepath"
"strings"
"github.com/emirpasic/gods/lists/arraylist"
)
type History struct {
Buf *arraylist.List
Autosave bool
Pos int
Limit int
Filename string
Enabled bool
}
func NewHistory() (*History, error) {
h := &History{
Buf: arraylist.New(),
Limit: 100, //resizeme
Autosave: true,
Enabled: true,
}
err := h.Init()
if err != nil {
return nil, err
}
return h, nil
}
func (h *History) Init() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
path := filepath.Join(home, ".ollama", "history")
h.Filename = path
//todo check if the file exists
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
defer f.Close()
r := bufio.NewReader(f)
for {
line, err := r.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
h.Add([]rune(line))
}
return nil
}
func (h *History) Add(l []rune) {
h.Buf.Add(l)
h.Compact()
h.Pos = h.Size()
if h.Autosave {
h.Save()
}
}
func (h *History) Compact() {
s := h.Buf.Size()
if s > h.Limit {
for cnt := 0; cnt < s-h.Limit; cnt++ {
h.Buf.Remove(0)
}
}
}
func (h *History) Clear() {
h.Buf.Clear()
}
func (h *History) Prev() []rune {
var line []rune
if h.Pos > 0 {
h.Pos -= 1
}
v, _ := h.Buf.Get(h.Pos)
line, _ = v.([]rune)
return line
}
func (h *History) Next() []rune {
var line []rune
if h.Pos < h.Buf.Size() {
h.Pos += 1
v, _ := h.Buf.Get(h.Pos)
line, _ = v.([]rune)
}
return line
}
func (h *History) Size() int {
return h.Buf.Size()
}
func (h *History) Save() error {
if !h.Enabled {
return nil
}
tmpFile := h.Filename + ".tmp"
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666)
if err != nil {
return err
}
defer f.Close()
buf := bufio.NewWriter(f)
for cnt := 0; cnt < h.Size(); cnt++ {
v, _ := h.Buf.Get(cnt)
line, _ := v.([]rune)
buf.WriteString(string(line) + "\n")
}
buf.Flush()
f.Close()
if err = os.Rename(tmpFile, h.Filename); err != nil {
return err
}
return nil
}

View File

@@ -1,4 +1,4 @@
package editor
package readline
import (
"bufio"
@@ -23,6 +23,7 @@ type Terminal struct {
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
}
func New(prompt Prompt) (*Instance, error) {
@@ -31,33 +32,40 @@ func New(prompt Prompt) (*Instance, error) {
return nil, err
}
history, err := NewHistory()
if err != nil {
return nil, err
}
return &Instance{
Prompt: &prompt,
Terminal: term,
History: history,
}, nil
}
func (i *Instance) HandleInput() (string, error) {
func (i *Instance) Readline() (string, error) {
prompt := i.Prompt.Prompt
if i.Prompt.UseAlt {
prompt = i.Prompt.AltPrompt
}
fmt.Print(prompt)
termios, err := SetRawMode(syscall.Stdin)
fd := int(syscall.Stdin)
termios, err := SetRawMode(fd)
if err != nil {
return "", err
}
defer UnsetRawMode(syscall.Stdin, termios)
defer UnsetRawMode(fd, termios)
buf, _ := NewBuffer(i.Prompt)
var esc bool
var escex bool
var metaDel bool
var pasteMode PasteMode
fmt.Print(StartBracketedPaste)
defer fmt.Printf(EndBracketedPaste)
var currentLineBuf []rune
for {
if buf.IsEmpty() {
@@ -69,22 +77,33 @@ func (i *Instance) HandleInput() (string, error) {
}
r, err := i.Terminal.Read()
if err != nil {
return "", io.EOF
}
if buf.IsEmpty() {
fmt.Print(ClearToEOL)
}
if err != nil {
return "", io.EOF
}
if escex {
escex = false
switch r {
case KeyUp:
buf.MoveUp()
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
currentLineBuf = []rune(buf.String())
}
buf.Replace(i.History.Prev())
}
case KeyDown:
buf.MoveDown()
if i.History.Pos < i.History.Size() {
buf.Replace(i.History.Next())
if i.History.Pos == i.History.Size() {
buf.Replace(currentLineBuf)
}
}
case KeyLeft:
buf.MoveLeft()
case KeyRight:
@@ -104,16 +123,28 @@ func (i *Instance) HandleInput() (string, error) {
} else if code == CharBracketedPasteEnd {
pasteMode = PasteModeEnd
}
case KeyDel:
if buf.Size() > 0 {
buf.Delete()
}
metaDel = true
case MetaStart:
buf.MoveToBOL()
buf.MoveToStart()
case MetaEnd:
buf.MoveToEOL()
buf.MoveToEnd()
default:
// skip any keys we don't know about
continue
}
continue
} else if esc {
esc = false
switch r {
case 'b':
buf.MoveLeftWord()
case 'f':
buf.MoveRightWord()
case CharEscapeEx:
escex = true
}
@@ -128,9 +159,9 @@ func (i *Instance) HandleInput() (string, error) {
case CharInterrupt:
return "", ErrInterrupt
case CharLineStart:
buf.MoveToBOL()
buf.MoveToStart()
case CharLineEnd:
buf.MoveToEOL()
buf.MoveToEnd()
case CharBackward:
buf.MoveLeft()
case CharForward:
@@ -138,38 +169,56 @@ func (i *Instance) HandleInput() (string, error) {
case CharBackspace, CharCtrlH:
buf.Remove()
case CharTab:
// todo: convert back to real tabs
for cnt := 0; cnt < 8; cnt++ {
buf.Add(' ')
}
case CharDelete:
if len(buf.Buf) > 0 && buf.Buf[0].Size() > 0 {
if buf.Size() > 0 {
buf.Delete()
} else {
return "", io.EOF
}
case CharKill:
buf.DeleteRemaining()
case CharCtrlU:
buf.RemoveBefore()
buf.DeleteBefore()
case CharCtrlL:
buf.ClearScreen()
case CharCtrlW:
buf.RemoveWordBefore()
case CharCtrlJ:
buf.Add(r)
buf.DeleteWord()
case CharEnter:
if pasteMode == PasteModeStart {
buf.Add(r)
continue
output := buf.String()
if output != "" {
i.History.Add([]rune(output))
}
buf.MoveToEnd()
fmt.Println()
return buf.String(), nil
switch pasteMode {
case PasteModeStart:
output = `"""` + output
case PasteModeEnd:
output = output + `"""`
}
return output, nil
default:
if metaDel {
metaDel = false
continue
}
if r >= CharSpace || r == CharEnter {
buf.Add(r)
}
}
}
}
func (i *Instance) HistoryEnable() {
i.History.Enabled = true
}
func (i *Instance) HistoryDisable() {
i.History.Enabled = false
}
func NewTerminal() (*Terminal, error) {

View File

@@ -1,6 +1,6 @@
//go:build aix || darwin || dragonfly || freebsd || (linux && !appengine) || netbsd || openbsd || os400 || solaris
package editor
package readline
import (
"syscall"

View File

@@ -1,5 +1,5 @@
//go:build darwin || freebsd || netbsd || openbsd
package editor
package readline
import (
"syscall"

View File

@@ -1,5 +1,5 @@
//go:build linux || solaris
package editor
package readline
import (
"syscall"

View File

@@ -1,4 +1,4 @@
package editor
package readline
const (
CharNull = 0

View File

@@ -63,10 +63,7 @@ status "Installing ollama to $BINDIR..."
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
install_success() {
status 'The Ollama API is now available at 0.0.0.0:11434.'
status 'Install complete. Run "ollama" from the command line.'
}
install_success() { status 'Install complete. Run "ollama" from the command line.'; }
trap install_success EXIT
# Everything from this point onwards is optional.
@@ -133,7 +130,6 @@ if check_gpu nvidia-smi; then
fi
if ! check_gpu lspci && ! check_gpu lshw; then
install_success
warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode."
exit 0
fi
@@ -180,7 +176,7 @@ install_cuda_driver_apt() {
case $1 in
debian)
status 'Enabling contrib sources...'
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | $SUDO tee /etc/apt/sources.list.d/contrib.list > /dev/null
$SUDO sed 's/main/contrib/' < /etc/apt/sources.list | sudo tee /etc/apt/sources.list.d/contrib.list > /dev/null
;;
esac

View File

@@ -91,7 +91,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
}
s := SignatureData{
Method: http.MethodGet,
Method: "GET",
Path: redirectURL.String(),
Data: nil,
}
@@ -103,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
headers := make(http.Header)
headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil)
if err != nil {
log.Printf("couldn't get token: %q", err)
return "", err

View File

@@ -89,12 +89,17 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
}
if len(b.Parts) == 0 {
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
}
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
var size = b.Total / numDownloadParts
@@ -129,6 +134,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
@@ -149,10 +155,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
i := i
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, requestURL, w, part, opts)
err := b.downloadChunk(inner, requestURL, w, part, opts)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
@@ -161,14 +166,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err)
continue
default:
if try > 0 {
log.Printf("%s part %d completed after %d retries", b.Digest[7:19], i, try)
}
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
return errors.New("max retries exceeded")
})
}
@@ -198,14 +200,14 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
if err != nil && !errors.Is(err, context.Canceled) {
// rollback progress
b.Completed.Add(-n)
return err
@@ -216,7 +218,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
return err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
// return nil or context.Canceled
return err
}
@@ -306,8 +308,6 @@ type downloadOpts struct {
const maxRetries = 3
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest)

View File

@@ -63,11 +63,15 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
First bool
System string
Prompt string
// deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
vars.Context = request.Context
if request.System != "" {
vars.System = request.System
@@ -397,7 +401,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err != nil {
return err
}
newLayer.From = mp.GetShortTagname()
newLayer.From = mp.GetNamespaceRepository()
layers = append(layers, newLayer)
}
}
@@ -977,7 +981,46 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
layers = append(layers, &manifest.Config)
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
if err != nil {
return err
}
if exists {
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
log.Printf("Layer %s already exists", layer.Digest)
continue
}
fn(api.ProgressResponse{
Status: "starting upload",
Digest: layer.Digest,
Total: layer.Size,
})
location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return err
}
if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
layer.Digest = filepath.Base(location.Path)
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
continue
}
if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err)
return err
}
@@ -994,7 +1037,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
@@ -1116,12 +1159,22 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't get manifest: %v", err)
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("model not found")
}
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
}
var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
@@ -1165,7 +1218,24 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
// Function to check if a blob already exists in the Docker registry
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't check for blob: %v", err)
return false, err
}
defer resp.Body.Close()
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
return resp.StatusCode < http.StatusBadRequest, nil
}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
var status string
for try := 0; try < maxRetries; try++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
@@ -1173,6 +1243,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, err
}
status = resp.Status
switch {
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
@@ -1184,25 +1256,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
regOpts.Token = token
if body != nil {
body.Seek(0, io.SeekStart)
if _, err := body.Seek(0, io.SeekStart); err != nil {
return nil, err
}
}
continue
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, body)
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
default:
return resp, nil
}
}
return nil, errMaxRetriesExceeded
return nil, fmt.Errorf("max retry exceeded: %v", status)
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {

View File

@@ -158,17 +158,9 @@ func GenerateHandler(c *gin.Context) {
return
}
// validate the request
switch {
case req.Model == "":
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
model, err := GetModel(req.Model)
@@ -197,13 +189,10 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
prompt, err = model.Prompt(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt, err := model.Prompt(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
@@ -226,15 +215,10 @@ func GenerateHandler(c *gin.Context) {
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if req.Raw {
// in raw mode the client must manage history on their own
r.Context = nil
}
ch <- r
}
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -381,9 +365,7 @@ func PushModelHandler(c *gin.Context) {
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
ctx := context.Background()
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -713,7 +695,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if runtime.GOOS == "linux" {
// check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil {
log.Printf("Warning: GPU support may not be enabled, check you have installed GPU drivers: %v", err)
log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err)
}
}

View File

@@ -2,369 +2,218 @@ package server
import (
"context"
"crypto/md5"
"errors"
"fmt"
"hash"
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
)
var blobUploadManager sync.Map
type blobUpload struct {
*Layer
Total int64
Completed atomic.Int64
Parts []blobUploadPart
nextURL chan *url.URL
context.CancelFunc
done bool
err error
references atomic.Int32
}
const (
numUploadParts = 64
minUploadPartSize int64 = 95 * 1000 * 1000
maxUploadPartSize int64 = 1000 * 1000 * 1000
redirectChunkSize int64 = 1024 * 1024 * 1024
regularChunkSize int64 = 95 * 1024 * 1024
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
}
if b.From != "" {
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if layer.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", b.From)
values.Add("mount", layer.Digest)
values.Add("from", layer.From)
requestURL.RawQuery = values.Encode()
}
resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
if err != nil {
return err
log.Printf("couldn't start upload: %v", err)
return nil, 0, err
}
defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location")
chunkSize := redirectChunkSize
if location == "" {
location = resp.Header.Get("Location")
chunkSize = regularChunkSize
}
fi, err := os.Stat(p)
locationURL, err := url.Parse(location)
if err != nil {
return err
return nil, 0, err
}
b.Total = fi.Size()
var size = b.Total / numUploadParts
switch {
case size < minUploadPartSize:
size = minUploadPartSize
case size > maxUploadPartSize:
size = maxUploadPartSize
}
var offset int64
for offset < fi.Size() {
if offset+size > fi.Size() {
size = fi.Size() - offset
}
// set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size})
offset += size
}
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
requestURL, err = url.Parse(location)
if err != nil {
return err
}
b.nextURL = make(chan *url.URL, 1)
b.nextURL <- requestURL
return nil
return locationURL, chunkSize, nil
}
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability
// TODO allow canceling uploads via DELETE
p, err := GetBlobsPath(b.Digest)
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
b.err = err
return
return err
}
f, err := os.Open(p)
f, err := os.Open(fp)
if err != nil {
b.err = err
return
return err
}
defer f.Close()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numUploadParts)
for i := range b.Parts {
part := &b.Parts[i]
select {
case <-inner.Done():
case requestURL := <-b.nextURL:
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
part.ReadSeeker = io.NewSectionReader(f, part.Offset, part.Size)
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
continue
}
pw := ProgressWriter{
status: fmt.Sprintf("uploading %s", layer.Digest),
digest: layer.Digest,
total: layer.Size,
fn: fn,
}
return nil
}
for offset := int64(0); offset < layer.Size; {
chunk := layer.Size - offset
if chunk > chunkSize {
chunk = chunkSize
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
if err != nil {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: offset,
})
return err
}
offset += chunk
location := resp.Header.Get("Docker-Upload-Location")
if location == "" {
location = resp.Header.Get("Location")
}
requestURL, err = url.Parse(location)
if err != nil {
return err
}
}
if err := g.Wait(); err != nil {
b.err = err
return
}
requestURL := <-b.nextURL
var sb strings.Builder
for _, part := range b.Parts {
sb.Write(part.Sum(nil))
}
md5sum := md5.Sum([]byte(sb.String()))
values := requestURL.Query()
values.Add("digest", b.Digest)
values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
values.Add("digest", layer.Digest)
requestURL.RawQuery = values.Encode()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if err != nil {
b.err = err
return
}
defer resp.Body.Close()
b.done = true
}
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
part.Reset()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch {
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(part.ReadSeeker, io.MultiWriter(part, part.Hash)), opts)
// finish the upload
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err
}
defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location")
if location == "" {
location = resp.Header.Get("Location")
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
nextURL, err := url.Parse(location)
if err != nil {
return err
}
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
b.nextURL <- nextURL
redirectURL, err := resp.Location()
if err != nil {
return err
}
for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
continue
}
return nil
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return err
}
opts.Token = token
fallthrough
case resp.StatusCode >= http.StatusBadRequest:
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
}
if method == http.MethodPatch {
b.nextURL <- nextURL
}
return nil
}
func (b *blobUpload) acquire() {
b.references.Add(1)
}
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
sectionReader := io.NewSectionReader(r, offset, limit)
func (b *blobUpload) release() {
if b.references.Add(-1) == 0 {
b.CancelFunc()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(limit)))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch {
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
}
}
func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
b.acquire()
defer b.release()
for try := 0; try < maxRetries; try++ {
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
defer resp.Body.Close()
ticker := time.NewTicker(60 * time.Millisecond)
for {
select {
case <-ticker.C:
case <-ctx.Done():
return ctx.Err()
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
location, err := resp.Location()
if err != nil {
return nil, err
}
pw.completed = offset
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
// retry
log.Printf("retrying redirected upload: %v", err)
continue
}
return resp, nil
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return nil, err
}
opts.Token = token
pw.completed = offset
sectionReader = io.NewSectionReader(r, offset, limit)
continue
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
}
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", b.Digest),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
return resp, nil
}
return nil, fmt.Errorf("max retries exceeded")
}
type ProgressWriter struct {
status string
digest string
bucket int64
completed int64
total int64
fn func(api.ProgressResponse)
mu sync.Mutex
}
func (pw *ProgressWriter) Write(b []byte) (int, error) {
pw.mu.Lock()
defer pw.mu.Unlock()
n := len(b)
pw.bucket += int64(n)
// throttle status updates to not spam the client
if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
pw.completed += pw.bucket
pw.fn(api.ProgressResponse{
Status: pw.status,
Digest: pw.digest,
Total: pw.total,
Completed: pw.completed,
})
if b.done || b.err != nil {
return b.err
}
pw.bucket = 0
}
}
type blobUploadPart struct {
// N is the part number
N int
Offset int64
Size int64
hash.Hash
written int64
io.ReadSeeker
*blobUpload
}
func (p *blobUploadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.written += int64(n)
p.Completed.Add(int64(n))
return n, nil
}
func (p *blobUploadPart) Reset() {
p.Seek(0, io.SeekStart)
p.Completed.Add(-int64(p.written))
p.written = 0
p.Hash = md5.New()
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
return nil
}
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err
}
go upload.Run(context.Background(), opts)
}
return upload.Wait(ctx, fn)
}