Compare commits

...

1 Commits

Author SHA1 Message Date
Bruce MacDonald
46faf61a14 ... 2025-06-06 16:34:53 -07:00
2 changed files with 65 additions and 2 deletions

View File

@@ -24,7 +24,7 @@ import (
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
func createBinFile(t testing.TB, kv map[string]any, ti []*ggml.Tensor) (string, string) {
t.Helper()
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
@@ -71,7 +71,7 @@ func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
func createRequest(t testing.TB, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
// if OLLAMA_MODELS is not set, set it to the temp directory
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))

View File

@@ -2,6 +2,7 @@ package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
@@ -64,3 +65,65 @@ func TestList(t *testing.T) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}
func BenchmarkListHandler(b *testing.B) {
gin.SetMode(gin.TestMode)
// Test with higher model counts to simulate real-world scenarios
modelCounts := []int{50, 100, 250, 500, 1000, 2000}
for _, count := range modelCounts {
b.Run(fmt.Sprintf("models_%d", count), func(b *testing.B) {
benchmarkListWithModelCount(b, count)
})
}
}
func benchmarkListWithModelCount(b *testing.B, modelCount int) {
// Setup
tempDir := b.TempDir()
b.Setenv("OLLAMA_MODELS", tempDir)
var s Server
// Create the specified number of models
b.Logf("Creating %d models for benchmark...", modelCount)
for i := range modelCount {
modelName := fmt.Sprintf("testmodel%d:latest", i)
_, digest := createBinFile(b, nil, nil)
createRequest(b, s.CreateHandler, api.CreateRequest{
Name: modelName,
Files: map[string]string{"test.gguf": digest},
})
// Log progress for large numbers
if modelCount >= 500 && i%100 == 0 {
b.Logf("Created %d/%d models", i, modelCount)
}
}
b.Logf("Setup complete, starting benchmark with %d models", modelCount)
// Reset timer to exclude setup time
b.ResetTimer()
// Run the actual benchmark
for i := 0; i < b.N; i++ {
w := createRequest(b, s.ListHandler, nil)
if w.Code != http.StatusOK {
b.Fatalf("expected status code 200, actual %d", w.Code)
}
// Optional: Verify we got the expected number of models
if i == 0 { // Only check on first iteration to avoid overhead
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
b.Fatal(err)
}
if len(resp.Models) != modelCount {
b.Fatalf("expected %d models, got %d", modelCount, len(resp.Models))
}
}
}
}