Compare commits

..

5 Commits

Author SHA1 Message Date
Roy Han
e210f8763f merge conflicts 2024-07-12 15:09:05 -07:00
royjhan
3971c2333f Merge branch 'main' into royh-precision 2024-07-12 15:07:36 -07:00
Roy Han
c71698426c Separate Rounding Functions 2024-06-24 11:09:08 -07:00
Roy Han
f93cdfdfae Standardize with ollama.com 2024-06-24 10:53:15 -07:00
Roy Han
af370ac178 Parameter Precision 2024-06-20 10:38:31 -07:00
10 changed files with 105 additions and 105 deletions

View File

@@ -293,7 +293,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
### Terminal

View File

@@ -657,7 +657,7 @@ func showInfo(resp *api.ShowResponse) {
modelData := [][]string{
{"arch", arch},
{"parameters", resp.Details.ParameterSize},
{"parameters", format.Parameters(uint64(resp.ModelInfo["general.parameter_count"].(float64)))},
{"quantization", resp.Details.QuantizationLevel},
{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
@@ -671,7 +671,7 @@ func showInfo(resp *api.ShowResponse) {
if resp.ProjectorInfo != nil {
projectorData := [][]string{
{"arch", "clip"},
{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
{"parameters", format.Parameters(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
}
if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {

View File

@@ -9,9 +9,10 @@ const (
Thousand = 1000
Million = Thousand * 1000
Billion = Million * 1000
Trillion = Billion * 1000
)
func HumanNumber(b uint64) string {
func RoundedParameter(b uint64) string {
switch {
case b >= Billion:
number := float64(b) / Billion
@@ -31,3 +32,33 @@ func HumanNumber(b uint64) string {
return fmt.Sprintf("%d", b)
}
}
func Parameters(b uint64) string {
switch {
case b >= Trillion:
number := float64(b) / Trillion
return fmt.Sprintf("%sT", decimalPlace(number))
case b >= Billion:
number := float64(b) / Billion
return fmt.Sprintf("%sB", decimalPlace(number))
case b >= Million:
number := float64(b) / Million
return fmt.Sprintf("%sM", decimalPlace(number))
case b >= Thousand:
number := float64(b) / Thousand
return fmt.Sprintf("%sK", decimalPlace(number))
default:
return fmt.Sprintf("%d", b)
}
}
func decimalPlace(number float64) string {
switch {
case number >= 100:
return fmt.Sprintf("%.0f", number)
case number >= 10:
return fmt.Sprintf("%.1f", number)
default:
return fmt.Sprintf("%.2f", number)
}
}

View File

@@ -4,7 +4,7 @@ import (
"testing"
)
func TestHumanNumber(t *testing.T) {
func TestRoundedParameter(t *testing.T) {
type testCase struct {
input uint64
expected string
@@ -24,7 +24,34 @@ func TestHumanNumber(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.expected, func(t *testing.T) {
result := HumanNumber(tc.input)
result := RoundedParameter(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}
func TestParameters(t *testing.T) {
type testCase struct {
input uint64
expected string
}
testCases := []testCase{
{26000000, "26.0M"},
{26000000000, "26.0B"},
{1000, "1.00K"},
{1000000, "1.00M"},
{1000000000, "1.00B"},
{1000000000000, "1.00T"},
{100, "100"},
{206000000, "206M"},
}
for _, tc := range testCases {
t.Run(tc.expected, func(t *testing.T) {
result := Parameters(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}

View File

@@ -127,7 +127,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory)
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))

View File

@@ -466,7 +466,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if baseLayer.GGML != nil {
config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.ModelType = cmp.Or(config.ModelType, format.RoundedParameter(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"log/slog"
"slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
@@ -16,18 +17,26 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err

View File

@@ -6,7 +6,6 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
@@ -165,19 +164,6 @@ func TestChatPrompt(t *testing.T) {
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
}
tmpl, err := template.Parse(`
@@ -201,10 +187,6 @@ func TestChatPrompt(t *testing.T) {
t.Errorf("expected %q, got %q", tt.prompt, prompt)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}

View File

@@ -102,7 +102,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -130,8 +129,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
checkpointLoaded := time.Now()
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@@ -194,48 +191,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
}, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
Response: r.Content,
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -1147,8 +1122,6 @@ func (s *Server) ProcessHandler(c *gin.Context) {
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1168,8 +1141,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
checkpointLoaded := time.Now()
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
@@ -1198,7 +1169,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
ch <- api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1211,13 +1182,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}

View File

@@ -149,19 +149,27 @@ type Values struct {
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
system, collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Messages": collated,
})
}
system = ""
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
execute := func () error {
for i, m := range collated {
switch m.Role {
case "system":
system = m.Content
case "user":
prompt = m.Content
case "assistant":
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
@@ -173,26 +181,6 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system = ""
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
}
}
@@ -211,7 +199,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"System": "",
"Prompt": prompt,
}); err != nil {
return err