mirror of
https://github.com/ollama/ollama.git
synced 2025-12-28 18:18:02 -05:00
Compare commits
8 Commits
brucemacd/
...
progress-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcfbb06f1b | ||
|
|
e8d35d0de0 | ||
|
|
e13e7c8d94 | ||
|
|
78f403ff45 | ||
|
|
08a299e1d0 | ||
|
|
f9c7ead160 | ||
|
|
5930aaeb1a | ||
|
|
faf67db089 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -14,6 +14,3 @@ test_data
|
||||
__debug_bin*
|
||||
llama/build
|
||||
llama/vendor
|
||||
model/testdata/models/*
|
||||
!model/testdata/models/*.md
|
||||
!model/testdata/models/*.json
|
||||
|
||||
@@ -24,7 +24,7 @@ set(GGML_LLAMAFILE ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
|
||||
if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
endif()
|
||||
|
||||
@@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
@@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
|
||||
38
cmd/cmd.go
38
cmd/cmd.go
@@ -15,13 +15,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/console"
|
||||
@@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
}
|
||||
return info, err
|
||||
@@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
@@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
if err := client.Chat(cmd.Context(), req, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
@@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
14
main.go
14
main.go
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -9,5 +11,15 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
|
||||
}
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
// Package model_test provides external tests for the model package.
|
||||
// This test file specifically tests the forward pass functionality on models.
|
||||
// It is in a separate package (model_test) to avoid import cycles while still
|
||||
// being able to test the public API of the model package.
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type modelTest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
OutputContainsOne []string `json:"output_contains_one"`
|
||||
}
|
||||
|
||||
func TestForwardSimple(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
// Read all JSON files from testdata/models
|
||||
files, err := os.ReadDir("testdata/models")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !strings.HasSuffix(file.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonPath := filepath.Join("testdata/models", file.Name())
|
||||
ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
|
||||
|
||||
// Skip if no corresponding .gguf file exists
|
||||
if _, err := os.Stat(ggufPath); err != nil {
|
||||
t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jsonPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var test modelTest
|
||||
if err := json.Unmarshal(data, &test); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
|
||||
m, err := model.New(ggufPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
|
||||
|
||||
inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result []string
|
||||
for len(result) < 100 { // Limit to 100 tokens max
|
||||
options := model.Options{
|
||||
Inputs: inputs,
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: []int32{int32(len(inputs) - 1)},
|
||||
}
|
||||
for i := range options.Positions {
|
||||
options.Positions[i] = int32(i)
|
||||
options.Sequences[i] = 0
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, m, options)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
t.Fatal(fmt.Errorf("forward pass failed: %v", err))
|
||||
}
|
||||
|
||||
f32s := modelOutput.Floats()
|
||||
logits := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
logits[i] = float64(f32)
|
||||
}
|
||||
|
||||
token, err := sample.Sample(logits, sample.Greedy())
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
t.Fatal(fmt.Errorf("sampling failed: %v", err))
|
||||
}
|
||||
|
||||
ctx.Close()
|
||||
|
||||
// Greedy sampling: take the token with the highest logit
|
||||
nextToken := int32(token[0])
|
||||
if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
|
||||
break
|
||||
}
|
||||
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result = append(result, piece)
|
||||
output := strings.Join(result, "")
|
||||
|
||||
for _, expectedOutput := range test.OutputContainsOne {
|
||||
if strings.Contains(output, expectedOutput) {
|
||||
t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Maintain full context by appending new token
|
||||
inputs = append(inputs, nextToken)
|
||||
}
|
||||
|
||||
t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
|
||||
})
|
||||
}
|
||||
}
|
||||
10
model/testdata/models/README.md
vendored
10
model/testdata/models/README.md
vendored
@@ -1,10 +0,0 @@
|
||||
# Test Model Directory
|
||||
|
||||
This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`.
|
||||
|
||||
## Usage
|
||||
|
||||
- Place any model files you need for testing in this directory
|
||||
- The test file will look for any model files here (e.g., `llama3.gguf`)
|
||||
- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository
|
||||
- Only `.md` files (like this README) will be tracked in git
|
||||
7
model/testdata/models/qwen2_5.json
vendored
7
model/testdata/models/qwen2_5.json
vendored
@@ -1,7 +0,0 @@
|
||||
{
|
||||
"prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n",
|
||||
"output_contains_one": [
|
||||
"Hello",
|
||||
"Hi"
|
||||
]
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
@@ -13,7 +14,8 @@ type State interface {
|
||||
|
||||
type Progress struct {
|
||||
mu sync.Mutex
|
||||
w io.Writer
|
||||
// buffer output to minimize flickering on all terminals
|
||||
w *bufio.Writer
|
||||
|
||||
pos int
|
||||
|
||||
@@ -22,7 +24,7 @@ type Progress struct {
|
||||
}
|
||||
|
||||
func NewProgress(w io.Writer) *Progress {
|
||||
p := &Progress{w: w}
|
||||
p := &Progress{w: bufio.NewWriter(w)}
|
||||
go p.start()
|
||||
return p
|
||||
}
|
||||
@@ -47,26 +49,29 @@ func (p *Progress) stop() bool {
|
||||
func (p *Progress) Stop() bool {
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
fmt.Fprintln(p.w)
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
func (p *Progress) StopAndClear() bool {
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
// clear all progress lines
|
||||
for i := range p.pos {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
fmt.Fprint(p.w, "\033[2K\033[1G")
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[2K", "\033[1G")
|
||||
}
|
||||
|
||||
// show cursor
|
||||
fmt.Fprint(p.w, "\033[?25h")
|
||||
p.w.Flush()
|
||||
return stopped
|
||||
}
|
||||
|
||||
@@ -81,30 +86,31 @@ func (p *Progress) render() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
fmt.Fprint(p.w, "\033[?2026h")
|
||||
defer fmt.Fprint(p.w, "\033[?2026l")
|
||||
|
||||
// clear already rendered progress lines
|
||||
for i := range p.pos {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
fmt.Fprint(p.w, "\033[2K\033[1G")
|
||||
for range p.pos - 1 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
fmt.Fprint(p.w, "\033[1G")
|
||||
|
||||
// render progress lines
|
||||
for i, state := range p.states {
|
||||
fmt.Fprint(p.w, state.String())
|
||||
fmt.Fprint(p.w, state.String(), "\033[K")
|
||||
if i < len(p.states)-1 {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
p.pos = len(p.states)
|
||||
p.w.Flush()
|
||||
}
|
||||
|
||||
func (p *Progress) start() {
|
||||
p.ticker = time.NewTicker(100 * time.Millisecond)
|
||||
// hide cursor
|
||||
fmt.Fprint(p.w, "\033[?25l")
|
||||
for range p.ticker.C {
|
||||
p.render()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user