Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
61349a8ec6 tests: move csv output to benstat format 2025-10-26 18:24:35 -07:00
12 changed files with 133 additions and 395 deletions

View File

@@ -4,9 +4,7 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@@ -301,197 +299,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
t.Run("single input token count", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("batch parallel token counting", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"cat", "dog and mouse", "bird"},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("truncation single input", func(t *testing.T) {
truncTrue := true
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 50},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount > 50 {
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
}
if res.PromptEvalCount == 0 {
t.Fatal("expected non-zero token count after truncation")
}
})
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull the model if needed
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
}

View File

@@ -161,11 +161,12 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
}
testCases := []struct {
name string
prompt string
anyResp []string
}{
{blueSkyPrompt, blueSkyExpected},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
{"blue_sky", blueSkyPrompt, blueSkyExpected},
{"max", maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
}
var gpuPercent int
for _, tc := range testCases {
@@ -259,25 +260,20 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
}
}
}
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL",
"CONTEXT",
"GPU PERCENT",
"APPROX PROMPT COUNT",
"LOAD TIME",
"PROMPT EVAL TPS",
"EVAL TPS",
)
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
model,
numCtx,
gpuPercent,
(resp.PromptEvalCount/10)*10,
float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
)
prefillTimePerToken := float64(resp.PromptEvalDuration.Nanoseconds()) / float64(resp.PromptEvalCount)
prefillTokensPerSec := float64(resp.PromptEvalCount) / (float64(resp.PromptEvalDuration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(os.Stderr, "BenchmarkModel/name=%s-%s/%d/step=%s %d %.2f ns/token %.2f token/sec\n",
model, tc.name, numCtx, "prefill", resp.PromptEvalCount, prefillTimePerToken, prefillTokensPerSec)
evalTimePerToken := float64(resp.EvalDuration.Nanoseconds()) / float64(resp.EvalCount)
evalTokensPerSec := float64(resp.EvalCount) / (float64(resp.EvalDuration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(os.Stderr, "BenchmarkModel/name=%s-%s/%d/step=%s %d %.2f ns/token %.2f token/sec\n",
model, tc.name, numCtx, "generate", resp.EvalCount, evalTimePerToken, evalTokensPerSec)
fmt.Fprintf(os.Stderr, "BenchmarkMode/name=%s-%s/%d 1 %d ns/request\n",
model, tc.name, numCtx, resp.TotalDuration.Nanoseconds())
fmt.Fprintf(os.Stderr, "BenchmarkMode/name=%s-%s/%d/step=%s 1 %d ns/request\n",
model, tc.name, numCtx, "load", resp.LoadDuration.Nanoseconds())
}
}
})

View File

@@ -69,7 +69,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -1545,16 +1545,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbeddingRequest struct {
Content string `json:"content"`
Truncate bool `json:"truncate"`
Content string `json:"content"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_eval_count"`
Embedding []float32 `json:"embedding"`
}
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
logutil.Trace("embedding request", "input", input)
if err := s.sem.Acquire(ctx, 1); err != nil {
@@ -1563,54 +1561,51 @@ func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool)
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return nil, 0, err
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return nil, 0, err
return nil, err
} else if status != ServerStatusReady {
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
return nil, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil {
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
return nil, fmt.Errorf("error creating embed request: %w", err)
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, 0, fmt.Errorf("do embedding request: %w", err)
return nil, fmt.Errorf("do embedding request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
return nil, fmt.Errorf("error reading embed response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm embedding error: %s", body)
return nil, 0, api.StatusError{
StatusCode: resp.StatusCode,
ErrorMessage: string(body),
}
return nil, fmt.Errorf("%s", body)
}
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return e.Embedding, e.PromptEvalCount, nil
return e.Embedding, nil
}
type TokenizeRequest struct {

View File

@@ -2,6 +2,7 @@ package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
@@ -52,5 +53,10 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@@ -182,18 +182,16 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
if cache != nil {
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor

View File

@@ -10,20 +10,19 @@ import (
func TestToEmbeddingList(t *testing.T) {
testCases := []struct {
name string
embeddings [][]float32
format string
expectType string // "float" or "base64"
expectBase64 []string
expectCount int
promptEval int
name string
embeddings [][]float32
format string
expectType string // "float" or "base64"
expectCount int
promptEval int
}{
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
{"empty embeddings", nil, "float", "", nil, 0, 0},
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", 1, 10},
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", 1, 5},
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", 1, 0},
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", 1, 0},
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", 3, 0},
{"empty embeddings", nil, "float", "", 0, 0},
}
for _, tc := range testCases {
@@ -57,24 +56,11 @@ func TestToEmbeddingList(t *testing.T) {
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
}
case "base64":
for i, data := range result.Data {
embStr, ok := data.Embedding.(string)
if !ok {
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
continue
}
// Verify it's valid base64
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
t.Errorf("embedding %d: invalid base64: %v", i, err)
}
// Compare against expected base64 string if provided
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
if embStr != tc.expectBase64[i] {
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
}
}
embStr, ok := result.Data[0].Embedding.(string)
if !ok {
t.Errorf("expected string, got %T", result.Data[0].Embedding)
} else if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
t.Errorf("invalid base64: %v", err)
}
}

View File

@@ -709,13 +709,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@@ -758,8 +758,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
PromptEvalCount: seq.numPromptInputs,
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@@ -948,13 +948,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
embedding: true,
truncate: req.Truncate,
// TODO (jmorganca): this should be provided by the server via the
// request options and truncated here in the runner, instead of relying on
// the server's truncate logic
truncate: true,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@@ -995,8 +995,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
PromptEvalCount: seq.numPromptInputs,
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}

View File

@@ -119,27 +119,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
if err != nil {
ch <- gin.H{"error": err.Error()}
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "") {
manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig ConfigV2
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
if config.Renderer == "" {
config.Renderer = baseConfig.Renderer
}
if config.Parser == "" {
config.Parser = baseConfig.Parser
}
}
cfgFile.Close()
}
}
}
}
}
} else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)

View File

@@ -21,7 +21,6 @@ import (
"os/signal"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"
@@ -660,7 +659,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@@ -673,12 +672,61 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input {
g.Go(func() error {
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
@@ -688,18 +736,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil
})
}
if err := g.Wait(); err != nil {
var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
} else {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
}
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return
}
@@ -708,7 +750,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
PromptEvalCount: count,
}
c.JSON(http.StatusOK, resp)
}
@@ -754,7 +796,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return

View File

@@ -188,72 +188,6 @@ func TestCreateFromModel(t *testing.T) {
})
}
func TestCreateFromModelInheritsRendererParser(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
const (
renderer = "custom-renderer"
parser = "custom-parser"
)
_, digest := createBinFile(t, nil, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "base",
Files: map[string]string{"base.gguf": digest},
Renderer: renderer,
Parser: parser,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "child",
From: "base",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
manifest, err := ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if manifest.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
configPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != renderer {
t.Fatalf("expected renderer %q, got %q", renderer, cfg.Renderer)
}
if cfg.Parser != parser {
t.Fatalf("expected parser %q, got %q", parser, cfg.Parser)
}
}
func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -780,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
return s.embeddingResp, 0, s.embeddingRespErr
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {