mirror of
https://github.com/ollama/ollama.git
synced 2026-02-25 19:46:55 -05:00
Compare commits
3 Commits
v0.17.1-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bf41969f0 | ||
|
|
0f23b7bff5 | ||
|
|
4e57d2094e |
@@ -35,6 +35,7 @@ import (
|
||||
var (
|
||||
wv = &Webview{}
|
||||
uiServerPort int
|
||||
appStore *store.Store
|
||||
)
|
||||
|
||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||
@@ -208,6 +209,7 @@ func main() {
|
||||
uiServerPort = port
|
||||
|
||||
st := &store.Store{}
|
||||
appStore = st
|
||||
|
||||
// Enable CORS in development mode
|
||||
if devMode {
|
||||
@@ -360,8 +362,7 @@ func startHiddenTasks() {
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
settings, err := appStore.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
|
||||
@@ -289,6 +289,7 @@ func (u *Updater) TriggerImmediateCheck() {
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
@@ -333,7 +334,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
// Download successful - show tray notification
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
|
||||
@@ -351,10 +351,13 @@ func TestTriggerImmediateCheck(t *testing.T) {
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Wait for the initial check that fires after the initial delay
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("initial check did not happen")
|
||||
}
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
|
||||
@@ -9,59 +9,99 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
type kvCache struct {
|
||||
// For now we only support a single entry, so this is just one sequence
|
||||
tokens []int32
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
// Callers should append generated tokens to outputs and
|
||||
// defer close to save the cache state.
|
||||
type cacheSession struct {
|
||||
cache *kvCache
|
||||
inputs []int32
|
||||
outputs []int32
|
||||
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
if len(c.caches) == 0 {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find longest common prefix
|
||||
remaining := c.findRemaining(inputs)
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
}
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if offset := s.caches[0].Offset(); offset > 0 {
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, c := range s.caches {
|
||||
k, v := c.State()
|
||||
arrays = append(arrays, k, v)
|
||||
}
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||
}
|
||||
}
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
// sequence, trims stale cache entries, and returns the remaining tokens.
|
||||
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
if prefix < len(c.tokens) {
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
r.cache = nil
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
} else {
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
return tokens[prefix:]
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
for _, kv := range c.caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
@@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
)
|
||||
|
||||
defer func() {
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
}
|
||||
}()
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
@@ -30,22 +46,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
@@ -76,15 +89,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
@@ -94,43 +110,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}:
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
nextSample, nextLogprobs = nil, nil
|
||||
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
if r.cache != nil {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -25,8 +24,9 @@ type Request struct {
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
@@ -61,7 +61,7 @@ type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
cache kvCache
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
slog.Info("Request terminated", "error", err)
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -5,6 +5,7 @@ package mlxrunner
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -98,19 +99,36 @@ func Execute(args []string) error {
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case runner.Requests <- request:
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
case response, ok := <-request.Responses:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user