mirror of
https://github.com/ollama/ollama.git
synced 2026-01-03 13:10:17 -05:00
Compare commits
4 Commits
royh-imgem
...
royh-embed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
907b038ff0 | ||
|
|
1f73889f34 | ||
|
|
5a8f8e96e0 | ||
|
|
7cddd6d741 |
@@ -356,7 +356,7 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse,
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Embeddings generates an embedding from a model.
|
||||
// Embeddings generates embeddings from a model. (Legacy)
|
||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||
|
||||
@@ -187,10 +187,6 @@ type EmbedRequest struct {
|
||||
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
13
llm/ext_server/server.cpp
vendored
13
llm/ext_server/server.cpp
vendored
@@ -1855,8 +1855,6 @@ struct llama_server_context
|
||||
|
||||
if (batch.n_tokens == 0)
|
||||
{
|
||||
// HANGING HERE
|
||||
LOG_INFO("no tokens to process", {});
|
||||
all_slots_are_idle = true;
|
||||
return true;
|
||||
}
|
||||
@@ -3194,21 +3192,12 @@ int main(int argc, char **argv) {
|
||||
prompt = prompt[0];
|
||||
}
|
||||
|
||||
json image_data;
|
||||
if (body.count("image_data") != 0)
|
||||
{
|
||||
image_data = body["image_data"];
|
||||
}
|
||||
else {
|
||||
image_data = "";
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(id_task);
|
||||
llama.request_completion(id_task, { {"prompt", prompt}, {"image_data", image_data} }, true, -1);
|
||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
||||
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(id_task);
|
||||
|
||||
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error)
|
||||
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@@ -860,15 +860,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Content []string `json:"content"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Content []string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
Embedding [][]float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error) {
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return nil, err
|
||||
@@ -883,7 +882,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageDat
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbedRequest{Content: input, Images: images})
|
||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
132
server/routes.go
132
server/routes.go
@@ -18,6 +18,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -259,50 +260,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
if req.Truncate == nil {
|
||||
truncate := true
|
||||
req.Truncate = &truncate
|
||||
}
|
||||
|
||||
var input []string
|
||||
reqEmbed := []string{}
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
images[i] = llm.ImageData{ID: 0, Data: req.Images[i]}
|
||||
}
|
||||
|
||||
if req.Images != nil {
|
||||
// for _, _ := range images {
|
||||
// input = append(input, fmt.Sprintf("[img-%d]", i.ID))
|
||||
input = append(input, "[img-0]")
|
||||
// }
|
||||
req.Input = ""
|
||||
}
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
switch embeddings := req.Input.(type) {
|
||||
case string:
|
||||
if len(i) > 0 {
|
||||
input = append(input, i)
|
||||
if embeddings == "" {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
reqEmbed = []string{embeddings}
|
||||
case []any:
|
||||
for _, v := range i {
|
||||
if len(embeddings) == 0 {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range embeddings {
|
||||
if _, ok := v.(string); !ok {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
input = append(input, v.(string))
|
||||
reqEmbed = append(reqEmbed, v.(string))
|
||||
}
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(input) == 0 {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@@ -315,35 +304,64 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
errCh := make(chan error, 1)
|
||||
successCh := make(chan bool, 1)
|
||||
sem := make(chan struct{}, 2)
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
for i, s := range reqEmbed {
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(i int, s string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if *req.Truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
reqEmbedArray[i] = s
|
||||
mu.Unlock()
|
||||
}(i, s)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
successCh <- true
|
||||
close(errCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
case success := <-successCh:
|
||||
if !success {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to process all embeddings"})
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
|
||||
embeddings, err := r.Embed(c.Request.Context(), input, images)
|
||||
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
||||
|
||||
if err != nil {
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
@@ -398,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
|
||||
embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
@@ -406,14 +424,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddings[0]))
|
||||
embedding64 := make([]float64, len(embedding[0]))
|
||||
|
||||
for i, v := range embeddings[0] {
|
||||
embedding[i] = float64(v)
|
||||
for i, v := range embedding[0] {
|
||||
embedding64[i] = float64(v)
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
Embedding: embedding64,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -1042,7 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
|
||||
r.POST("/api/create", s.CreateModelHandler)
|
||||
r.POST("/api/push", s.PushModelHandler)
|
||||
r.POST("/api/copy", s.CopyModelHandler)
|
||||
|
||||
@@ -660,7 +660,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
return s.completionResp
|
||||
}
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string, images []llm.ImageData) ([][]float32, error) {
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
return s.embedResp, s.embedRespErr
|
||||
}
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
Reference in New Issue
Block a user