Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
31d04eb795 model: add a test for model forward pass during implementation
Adds a new test file to verify model forward pass behavior through
JSON-specified test cases. The framework loads model files (.gguf) and their
corresponding test specifications to validate expected outputs using greedy
sampling.
2025-02-18 14:21:10 -08:00
9 changed files with 216 additions and 47 deletions

3
.gitignore vendored
View File

@@ -14,3 +14,6 @@ test_data
__debug_bin*
llama/build
llama/vendor
model/testdata/models/*
!model/testdata/models/*.md
!model/testdata/models/*.json

View File

@@ -24,7 +24,7 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
set(GGML_CPU_ALL_VARIANTS ON)
endif()

View File

@@ -126,8 +126,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return err
}
}
return ctx.Err()
return nil
}
const maxBufferSize = 512 * format.KiloByte
@@ -190,7 +189,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
}
return ctx.Err()
return nil
}
// GenerateResponseFunc is a function that [Client.Generate] invokes every time

View File

@@ -15,11 +15,13 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/containerd/console"
@@ -328,7 +330,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
@@ -857,6 +858,17 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
@@ -891,7 +903,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
return nil, err
}
@@ -931,6 +946,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
fn := func(response api.GenerateResponse) error {
@@ -966,7 +992,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
@@ -988,7 +1017,8 @@ func generate(cmd *cobra.Command, opts runOptions) error {
latest.Summary()
}
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}

14
main.go
View File

@@ -2,8 +2,6 @@ package main
import (
"context"
"os"
"os/signal"
"github.com/spf13/cobra"
@@ -11,15 +9,5 @@ import (
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
go func() {
<-sigChan
cancel()
}()
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
}

View File

@@ -0,0 +1,138 @@
// Package model_test provides external tests for the model package.
// This test file specifically tests the forward pass functionality on models.
// It is in a separate package (model_test) to avoid import cycles while still
// being able to test the public API of the model package.
package model_test
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/sample"
_ "github.com/ollama/ollama/model/models"
)
type modelTest struct {
Prompt string `json:"prompt"`
OutputContainsOne []string `json:"output_contains_one"`
}
func TestForwardSimple(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
// Read all JSON files from testdata/models
files, err := os.ReadDir("testdata/models")
if err != nil {
t.Fatal(err)
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".json") {
continue
}
jsonPath := filepath.Join("testdata/models", file.Name())
ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
// Skip if no corresponding .gguf file exists
if _, err := os.Stat(ggufPath); err != nil {
t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
continue
}
data, err := os.ReadFile(jsonPath)
if err != nil {
t.Fatal(err)
}
var test modelTest
if err := json.Unmarshal(data, &test); err != nil {
t.Fatal(err)
}
t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
m, err := model.New(ggufPath)
if err != nil {
t.Fatal(err)
}
m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
if err != nil {
t.Fatal(err)
}
var result []string
for len(result) < 100 { // Limit to 100 tokens max
options := model.Options{
Inputs: inputs,
Positions: make([]int32, len(inputs)),
Sequences: make([]int, len(inputs)),
Outputs: []int32{int32(len(inputs) - 1)},
}
for i := range options.Positions {
options.Positions[i] = int32(i)
options.Sequences[i] = 0
}
ctx := m.Backend().NewContext()
modelOutput, err := model.Forward(ctx, m, options)
if err != nil {
ctx.Close()
t.Fatal(fmt.Errorf("forward pass failed: %v", err))
}
f32s := modelOutput.Floats()
logits := make([]float64, len(f32s))
for i, f32 := range f32s {
logits[i] = float64(f32)
}
token, err := sample.Sample(logits, sample.Greedy())
if err != nil {
ctx.Close()
t.Fatal(fmt.Errorf("sampling failed: %v", err))
}
ctx.Close()
// Greedy sampling: take the token with the highest logit
nextToken := int32(token[0])
if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
break
}
piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
if err != nil {
t.Fatal(err)
}
result = append(result, piece)
output := strings.Join(result, "")
for _, expectedOutput := range test.OutputContainsOne {
if strings.Contains(output, expectedOutput) {
t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
return
}
}
// Maintain full context by appending new token
inputs = append(inputs, nextToken)
}
t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
})
}
}

10
model/testdata/models/README.md vendored Normal file
View File

@@ -0,0 +1,10 @@
# Test Model Directory
This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`.
## Usage
- Place any model files you need for testing in this directory
- The test file will look for any model files here (e.g., `llama3.gguf`)
- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository
- Only `.md` files (like this README) will be tracked in git

7
model/testdata/models/qwen2_5.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n",
"output_contains_one": [
"Hello",
"Hi"
]
}

View File

@@ -1,7 +1,6 @@
package progress
import (
"bufio"
"fmt"
"io"
"sync"
@@ -14,8 +13,7 @@ type State interface {
type Progress struct {
mu sync.Mutex
// buffer output to minimize flickering on all terminals
w *bufio.Writer
w io.Writer
pos int
@@ -24,7 +22,7 @@ type Progress struct {
}
func NewProgress(w io.Writer) *Progress {
p := &Progress{w: bufio.NewWriter(w)}
p := &Progress{w: w}
go p.start()
return p
}
@@ -49,29 +47,26 @@ func (p *Progress) stop() bool {
func (p *Progress) Stop() bool {
stopped := p.stop()
if stopped {
fmt.Fprintln(p.w)
fmt.Fprint(p.w, "\n")
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
func (p *Progress) StopAndClear() bool {
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
stopped := p.stop()
if stopped {
// clear all progress lines
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
}
fmt.Fprint(p.w, "\033[2K", "\033[1G")
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
@@ -86,31 +81,30 @@ func (p *Progress) render() {
p.mu.Lock()
defer p.mu.Unlock()
fmt.Fprint(p.w, "\033[?2026h")
defer fmt.Fprint(p.w, "\033[?2026l")
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
// clear already rendered progress lines
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
}
fmt.Fprint(p.w, "\033[1G")
// render progress lines
for i, state := range p.states {
fmt.Fprint(p.w, state.String(), "\033[K")
fmt.Fprint(p.w, state.String())
if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n")
}
}
p.pos = len(p.states)
p.w.Flush()
}
func (p *Progress) start() {
p.ticker = time.NewTicker(100 * time.Millisecond)
// hide cursor
fmt.Fprint(p.w, "\033[?25l")
for range p.ticker.C {
p.render()
}