mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 11:29:26 -05:00
Compare commits
4 Commits
royh-imgem
...
royh-embed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
907b038ff0 | ||
|
|
1f73889f34 | ||
|
|
5a8f8e96e0 | ||
|
|
7cddd6d741 |
@@ -18,6 +18,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -304,30 +305,59 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
errCh := make(chan error, 1)
|
||||
successCh := make(chan bool, 1)
|
||||
sem := make(chan struct{}, 2)
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
for i, s := range reqEmbed {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if *req.Truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(i int, s string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
reqEmbedArray[i] = s
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if *req.Truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
reqEmbedArray[i] = s
|
||||
mu.Unlock()
|
||||
}(i, s)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
successCh <- true
|
||||
close(errCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
case success := <-successCh:
|
||||
if !success {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to process all embeddings"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
||||
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user