Compare commits

..

4 Commits

Author SHA1 Message Date
Roy Han
907b038ff0 reduce error footprint 2024-07-15 10:57:01 -07:00
royjhan
1f73889f34 Merge branch 'royh-batchembed' into royh-embed-parallel 2024-07-12 16:44:12 -07:00
Roy Han
5a8f8e96e0 clean up 2024-07-12 16:35:25 -07:00
Roy Han
7cddd6d741 parallelized 2024-07-12 16:08:12 -07:00
6 changed files with 82 additions and 80 deletions

View File

@@ -356,7 +356,7 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse,
return &resp, nil
}
// Embeddings generates an embedding from a model.
// Embeddings generates embeddings from a model. (Legacy)
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

View File

@@ -187,10 +187,6 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"`
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}

View File

@@ -1855,8 +1855,6 @@ struct llama_server_context
if (batch.n_tokens == 0)
{
// HANGING HERE
LOG_INFO("no tokens to process", {});
all_slots_are_idle = true;
return true;
}
@@ -3194,21 +3192,12 @@ int main(int argc, char **argv) {
prompt = prompt[0];
}
json image_data;
if (body.count("image_data") != 0)
{
image_data = body["image_data"];
}
else {
image_data = "";
}
// create and queue the task
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, { {"prompt", prompt}, {"image_data", image_data} }, true, -1);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result
task_result result = llama.queue_results.recv(id_task);

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error)
Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -860,15 +860,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbedRequest struct {
Content []string `json:"content"`
Images []ImageData `json:"image_data"`
Content []string `json:"content"`
}
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error) {
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@@ -883,7 +882,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageDat
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input, Images: images})
data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}

View File

@@ -18,6 +18,7 @@ import (
"path/filepath"
"slices"
"strings"
"sync"
"syscall"
"time"
@@ -259,50 +260,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
var input []string
reqEmbed := []string{}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: 0, Data: req.Images[i]}
}
if req.Images != nil {
// for _, _ := range images {
// input = append(input, fmt.Sprintf("[img-%d]", i.ID))
input = append(input, "[img-0]")
// }
req.Input = ""
}
switch i := req.Input.(type) {
switch embeddings := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
if embeddings == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
reqEmbed = []string{embeddings}
case []any:
for _, v := range i {
if len(embeddings) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
for _, v := range embeddings {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
reqEmbed = append(reqEmbed, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
@@ -315,35 +304,64 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
reqEmbedArray := make([]string, len(reqEmbed))
errCh := make(chan error, 1)
successCh := make(chan bool, 1)
sem := make(chan struct{}, 2)
var wg sync.WaitGroup
var mu sync.Mutex
for i, s := range reqEmbed {
wg.Add(1)
sem <- struct{}{}
go func(i int, s string) {
defer wg.Done()
defer func() { <-sem }()
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
errCh <- err
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
errCh <- err
return
}
} else {
errCh <- err
return
}
}
mu.Lock()
reqEmbedArray[i] = s
mu.Unlock()
}(i, s)
}
go func() {
wg.Wait()
successCh <- true
close(errCh)
}()
select {
case err := <-errCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
case success := <-successCh:
if !success {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to process all embeddings"})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input, images)
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil {
slog.Error("embedding generation failed", "error", err)
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
@@ -398,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
@@ -406,14 +424,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding := make([]float64, len(embeddings[0]))
embedding64 := make([]float64, len(embedding[0]))
for i, v := range embeddings[0] {
embedding[i] = float64(v)
for i, v := range embedding[0] {
embedding64[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding,
Embedding: embedding64,
}
c.JSON(http.StatusOK, resp)
}
@@ -1042,7 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)

View File

@@ -660,7 +660,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embed(ctx context.Context, input []string, images []llm.ImageData) ([][]float32, error) {
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embedResp, s.embedRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {