Compare commits

...

11 Commits

Author SHA1 Message Date
jmorganca
ebbebdf3b1 better logging / fix defaults 2025-12-20 21:05:34 -08:00
jmorganca
49393385ca better logging 2025-12-20 19:52:05 -08:00
jmorganca
12ff2d1461 tests 2025-12-20 19:23:09 -08:00
jmorganca
f90d968b8b skip 2025-12-20 19:10:35 -08:00
jmorganca
c623b256a3 better logs 2025-12-20 18:13:00 -08:00
jmorganca
8c8fb2f9f0 tweak 2025-12-20 18:10:24 -08:00
jmorganca
6e00a0c89a speed tracker 2025-12-20 18:02:52 -08:00
jmorganca
55b1ee2557 wip 2025-12-20 18:01:07 -08:00
jmorganca
51cb1155ba streaming 2025-12-20 17:39:55 -08:00
jmorganca
7c5b656bb3 wip 2025-12-20 17:24:03 -08:00
jmorganca
bddb27ab5b client2 updated 2025-12-20 16:16:34 -08:00
13 changed files with 887 additions and 411 deletions

View File

@@ -2,9 +2,11 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"log/slog"
"math"
@@ -31,9 +33,45 @@ const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errPartSlow = errors.New("part slow, racing")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
// speedTracker tracks download speeds and computes rolling median.
type speedTracker struct {
mu sync.Mutex
speeds []float64 // bytes per second
}
func (s *speedTracker) Record(bytesPerSec float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.speeds = append(s.speeds, bytesPerSec)
// Keep last 100 samples
if len(s.speeds) > 100 {
s.speeds = s.speeds[1:]
}
}
func (s *speedTracker) Median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 3 {
return 0 // not enough data
}
// Simple median: sort a copy and take middle
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
for i := range sorted {
for j := i + 1; j < len(sorted); j++ {
if sorted[j] < sorted[i] {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return sorted[len(sorted)/2]
}
var blobDownloadManager sync.Map
type blobDownload struct {
@@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil
}
const (
numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte
var (
downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48)
)
func envInt(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if v, err := strconv.Atoi(s); err == nil {
return v
}
}
return defaultVal
}
// streamHasher reads a file sequentially and hashes it as chunks complete.
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
// Works by reading from OS page cache - data just written is still in RAM.
type streamHasher struct {
file *os.File
hasher hash.Hash
parts []*blobDownloadPart
total int64 // total bytes to hash
hashed atomic.Int64
mu sync.Mutex
cond *sync.Cond
completed []bool
done bool
err error
}
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
h := &streamHasher{
file: file,
hasher: sha256.New(),
parts: parts,
total: total,
completed: make([]bool, len(parts)),
}
h.cond = sync.NewCond(&h.mu)
return h
}
// MarkComplete signals that a part has been written to disk.
func (h *streamHasher) MarkComplete(partIndex int) {
h.mu.Lock()
h.completed[partIndex] = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Run reads and hashes the file sequentially. Call in a goroutine.
func (h *streamHasher) Run() {
buf := make([]byte, 64*1024) // 64KB read buffer
var offset int64
for i, part := range h.parts {
// Wait for this part to be written
h.mu.Lock()
for !h.completed[i] && !h.done {
h.cond.Wait()
}
if h.done {
h.mu.Unlock()
return
}
h.mu.Unlock()
// Read and hash this part (from page cache)
remaining := part.Size
for remaining > 0 {
n := int64(len(buf))
if n > remaining {
n = remaining
}
nr, err := h.file.ReadAt(buf[:n], offset)
if err != nil && err != io.EOF {
h.mu.Lock()
h.err = err
h.mu.Unlock()
return
}
h.hasher.Write(buf[:nr])
offset += int64(nr)
remaining -= int64(nr)
h.hashed.Store(offset)
}
}
}
// Stop signals the hasher to exit early.
func (h *streamHasher) Stop() {
h.mu.Lock()
h.done = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Hashed returns bytes hashed so far.
func (h *streamHasher) Hashed() int64 {
return h.hashed.Load()
}
// Digest returns the computed hash.
func (h *streamHasher) Digest() string {
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
}
// Err returns any error from hashing.
func (h *streamHasher) Err() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.err
}
func (p *blobDownloadPart) Name() string {
return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
}, "-")
}
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
@@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
}
size := downloadPartSize
var offset int64
for offset < b.Total {
if offset+size > b.Total {
@@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
@@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Download chunks to disk, hash by reading from page cache.
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
tracker := &speedTracker{}
// Start hasher goroutine
hashDone := make(chan struct{})
go func() {
sh.Run()
close(hashDone)
}()
// Log progress periodically
// Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM
const pageCacheWarningBytes = 4 << 30 // 4GB
progressDone := make(chan struct{})
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
downloaded := b.Completed.Load()
hashed := sh.Hashed()
dlPct := int(downloaded * 100 / b.Total)
hPct := int(hashed * 100 / b.Total)
spread := dlPct - hPct
spreadBytes := downloaded - hashed
slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread))
if spreadBytes > pageCacheWarningBytes {
slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30)))
}
case <-progressDone:
return
}
}
}()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
g.SetLimit(downloadConcurrency)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
sh.MarkComplete(part.N)
continue
}
g.Go(func() error {
var err error
var slowRetries int
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part)
// After 3 slow retries, stop checking slowness and let it complete
skipSlowCheck := slowRetries >= 3
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case errors.Is(err, errPartSlow):
// Kill slow request, retry immediately (stays within concurrency limit)
slowRetries++
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
default:
sh.MarkComplete(part.N)
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
if err := g.Wait(); err != nil {
close(progressDone)
sh.Stop()
return err
}
// Wait for hasher to finish
<-hashDone
close(progressDone)
if err := sh.Err(); err != nil {
return err
}
// Verify hash
if computed := sh.Digest(); computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
// explicitly close the file so we can rename it
if err := file.Close(); err != nil {
return err
@@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
// downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only).
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
g, ctx := errgroup.WithContext(ctx)
startTime := time.Now()
var bytesAtLastCheck atomic.Int64
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
w := io.NewOffsetWriter(file, part.Offset)
buf := make([]byte, 32*1024)
var written int64
for written < part.Size {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
written += int64(n)
b.Completed.Add(int64(n))
bytesAtLastCheck.Store(written)
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Now()
part.lastUpdatedMu.Unlock()
}
if err == io.EOF {
break
}
if err != nil {
b.Completed.Add(-written)
return err
}
}
part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil {
return err
// Record speed for this part
elapsed := time.Since(startTime).Seconds()
if elapsed > 0 {
tracker.Record(float64(part.Size) / elapsed)
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
part.Completed.Store(part.Size)
return b.writePart(part.Name(), part)
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var lastBytes int64
checksWithoutProgress := 0
for {
select {
case <-ticker.C:
@@ -365,19 +587,47 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
return nil
}
currentBytes := bytesAtLastCheck.Load()
// Check for complete stall (30 seconds no progress)
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled
}
// Check for slow speed after 5+ seconds (only for multi-part downloads)
// Skip if we've already retried for slowness too many times
elapsed := time.Since(startTime).Seconds()
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median()
// If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow
}
}
// Also check if speed dropped significantly mid-download
if currentBytes == lastBytes {
checksWithoutProgress++
if checksWithoutProgress >= 10 {
slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N))
return errPartStalled
}
} else {
checksWithoutProgress = 0
}
lastBytes = currentBytes
case <-ctx.Done():
return ctx.Err()
}

319
server/download_test.go Normal file
View File

@@ -0,0 +1,319 @@
package server
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"os"
"sync"
"testing"
)
func TestSpeedTracker_Median(t *testing.T) {
s := &speedTracker{}
// Less than 3 samples returns 0
s.Record(100)
s.Record(200)
if got := s.Median(); got != 0 {
t.Errorf("expected 0 with < 3 samples, got %f", got)
}
// With 3+ samples, returns median
s.Record(300)
// Samples: [100, 200, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
// Add more samples
s.Record(50)
s.Record(250)
// Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
}
func TestSpeedTracker_RollingWindow(t *testing.T) {
s := &speedTracker{}
// Add 105 samples (should keep only last 100)
for i := 0; i < 105; i++ {
s.Record(float64(i))
}
s.mu.Lock()
if len(s.speeds) != 100 {
t.Errorf("expected 100 samples, got %d", len(s.speeds))
}
// First sample should be 5 (0-4 were dropped)
if s.speeds[0] != 5 {
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
}
s.mu.Unlock()
}
func TestSpeedTracker_Concurrent(t *testing.T) {
s := &speedTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(v int) {
defer wg.Done()
s.Record(float64(v))
s.Median() // concurrent read
}(i)
}
wg.Wait()
// Should not panic, and should have reasonable state
s.mu.Lock()
if len(s.speeds) == 0 || len(s.speeds) > 100 {
t.Errorf("unexpected speeds length: %d", len(s.speeds))
}
s.mu.Unlock()
}
func TestStreamHasher_Sequential(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data
data := []byte("hello world, this is a test of the stream hasher")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create parts
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(len(data))},
}
sh := newStreamHasher(f, parts, int64(len(data)))
// Mark complete and run
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
if err := sh.Err(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data (3 parts of 10 bytes each)
data := []byte("0123456789ABCDEFGHIJabcdefghij")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create 3 parts
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 10},
{N: 1, Offset: 10, Size: 10},
{N: 2, Offset: 20, Size: 10},
}
sh := newStreamHasher(f, parts, int64(len(data)))
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Mark parts complete out of order: 2, 0, 1
sh.MarkComplete(2)
sh.MarkComplete(0) // This should trigger hashing of part 0
sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
}
func TestStreamHasher_Stop(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
parts := []*blobDownloadPart{
{Offset: 0, Size: 100},
}
sh := newStreamHasher(f, parts, 100)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Stop without completing any parts
sh.Stop()
<-done
// Should exit cleanly without error
if err := sh.Err(); err != nil {
t.Errorf("unexpected error after Stop: %v", err)
}
}
func TestStreamHasher_HashedProgress(t *testing.T) {
// Create temp file with known data
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
data := make([]byte, 1000)
rand.Read(data)
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 500},
{N: 1, Offset: 500, Size: 500},
}
sh := newStreamHasher(f, parts, 1000)
// Initially no progress
if got := sh.Hashed(); got != 0 {
t.Errorf("expected 0 hashed initially, got %d", got)
}
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Complete part 0
sh.MarkComplete(0)
// Give hasher time to process
for i := 0; i < 100; i++ {
if sh.Hashed() >= 500 {
break
}
}
// Complete part 1
sh.MarkComplete(1)
<-done
if got := sh.Hashed(); got != 1000 {
t.Errorf("expected 1000 hashed, got %d", got)
}
}
func BenchmarkSpeedTracker_Record(b *testing.B) {
s := &speedTracker{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Record(float64(i))
}
}
func BenchmarkSpeedTracker_Median(b *testing.B) {
s := &speedTracker{}
// Pre-populate with 100 samples
for i := 0; i < 100; i++ {
s.Record(float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Median()
}
}
func BenchmarkStreamHasher(b *testing.B) {
// Create temp file with test data
f, err := os.CreateTemp("", "streamhasher_bench")
if err != nil {
b.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
size := 64 * 1024 * 1024 // 64MB
data := make([]byte, size)
rand.Read(data)
if _, err := f.Write(data); err != nil {
b.Fatal(err)
}
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(size)},
}
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sh := newStreamHasher(f, parts, int64(size))
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
}
}
func BenchmarkHashThroughput(b *testing.B) {
// Baseline: raw SHA256 throughput on this machine
size := 256 * 1024 * 1024 // 256MB
data := make([]byte, size)
rand.Read(data)
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
h := sha256.New()
h.Write(data)
h.Sum(nil)
}
}

View File

@@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
_, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
@@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err
}
}
// Note: Digest verification now happens inline during download in blobDownload.run()
// via the orderedWriter, so no separate verification pass is needed.
fn(api.ProgressResponse{Status: "writing manifest"})

View File

@@ -10,7 +10,6 @@ import (
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
@@ -327,21 +326,19 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of link names. The sequence is in lexical order.
// Links returns a slice of link names in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
func (c *DiskCache) Links() ([]string, error) {
paths, err := c.links()
if err != nil {
return nil, err
}
names := make([]string, len(paths))
for i, path := range paths {
names[i] = pathToName(path)
}
return names, nil
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
@@ -372,10 +369,11 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
paths, err := c.links()
if err != nil {
return "", err
}
for _, l := range paths {
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
@@ -383,22 +381,10 @@ func (c *DiskCache) manifestPath(name string) (string, error) {
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
// links returns a slice of link paths in the cache in lexical order.
func (c *DiskCache) links() ([]string, error) {
fsys := os.DirFS(c.dir)
return fs.Glob(fsys, "manifests/*/*/*/*")
}
type checkWriter struct {

View File

@@ -466,12 +466,9 @@ func testManifestNameReuse(t *testing.T) {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err := c.links()
if err != nil {
t.Fatal(err)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
@@ -487,12 +484,9 @@ func testManifestNameReuse(t *testing.T) {
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err = c.links()
if err != nil {
t.Fatal(err)
}
// we should have only one link that is same case as the last link
@@ -554,12 +548,9 @@ func TestNames(t *testing.T) {
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
got, err := c.Links()
if err != nil {
t.Fatal(err)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {

View File

@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
@@ -546,18 +545,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
})
}()
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Note the chunksum stream
// interruption, but do not cancel
// in-flight downloads. We can still
// make progress on them. Once they are
// done, ErrIncomplete will be returned
// below.
update(0, err)
break
}
err = r.chunksums(ctx, name, l, func(cs chunksum) bool {
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
@@ -569,7 +557,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
_, err := c.Get(cacheKeyDigest)
if err == nil {
update(cs.Chunk.Size(), ErrCached)
continue
return true // continue
}
wg.Add(1)
@@ -620,6 +608,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Record the downloading of this chunk.
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
})
return true // continue processing chunks
})
if err != nil {
// Note the chunksum stream interruption, but do not cancel
// in-flight downloads. We can still make progress on them.
// Once they are done, ErrIncomplete will be returned below.
update(0, err)
}
return nil
@@ -674,19 +669,6 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
@@ -811,125 +793,114 @@ type chunksum struct {
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
// chunksums calls fn for each chunksum in the layer. If the layer is under the
// chunking threshold, a single chunksum covering the entire layer is passed to fn.
// If the layer is over the chunking threshold, chunksums are read from the chunksums endpoint.
// Returns an error if the chunksum stream fails, or nil if all chunksums were processed.
// If fn returns false, iteration stops early and chunksums returns nil.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer, fn func(chunksum) bool) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
fn(cs)
return nil
}
// The response is a sequence of chunksums.
//
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
return err
}
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
return s.Err()
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, err)
return
return fmt.Errorf("invalid digest: %q", s.Bytes())
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(cs, nil)
return
return err
}
// The response is a sequence of chunksums.
//
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, err)
return
return fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
if !fn(cs) {
return nil
}
}
}
@@ -1176,8 +1147,8 @@ func splitExtended(s string) (scheme, name, digest string) {
return scheme, s, digest
}
// parseChunk parses a string in the form "start-end" and returns the Chunk.
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
// parseChunk parses a byte slice in the form "start-end" and returns the Chunk.
func parseChunk(s []byte) (blob.Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)

View File

@@ -27,46 +27,20 @@ type Trace struct {
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
if t != nil && t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace adds a trace to the context for transfer progress reporting.
// WithTrace attaches a Trace to the context for transfer progress reporting.
func WithTrace(ctx context.Context, t *Trace) context.Context {
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
return context.WithValue(ctx, traceKey{}, t)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
// traceFromContext returns the Trace associated with ctx, or nil if none.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}

View File

@@ -2,44 +2,46 @@ package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
// Retry calls fn repeatedly with exponential backoff until it returns nil,
// a non-retryable error (shouldRetry returns false), or the context is cancelled.
// The shouldRetry function determines if an error is retryable.
// Returns the last error encountered, or nil if fn succeeded.
func Retry(ctx context.Context, maxBackoff time.Duration, shouldRetry func(error) bool, fn func() error) error {
var t *time.Timer
for n := 0; ; n++ {
if err := ctx.Err(); err != nil {
return err
}
if !yield(n, nil) {
return
}
err := fn()
if err == nil {
return nil
}
if !shouldRetry(err) {
return err
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
return ctx.Err()
case <-t.C:
}
}
}

View File

@@ -10,31 +10,70 @@ import (
"time"
)
func TestLoop(t *testing.T) {
func TestRetry(t *testing.T) {
synctest.Run(func() {
last := -1
n := 0
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
err := Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n > 5 {
cancel()
}
return errors.New("keep going")
})
if !errors.Is(err, context.Canceled) {
t.Errorf("err = %v, want context.Canceled", err)
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
if n != 6 {
t.Errorf("n = %d, want 6", n)
}
})
}
func TestRetrySuccess(t *testing.T) {
synctest.Run(func() {
n := 0
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n >= 3 {
return nil // success
}
return errors.New("retry")
})
if err != nil {
t.Errorf("err = %v, want nil", err)
}
if n != 3 {
t.Errorf("n = %d, want 3", n)
}
})
}
func TestRetryNonRetryable(t *testing.T) {
synctest.Run(func() {
permanent := errors.New("permanent error")
n := 0
err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool {
return !errors.Is(err, permanent)
}, func() error {
n++
if n >= 2 {
return permanent
}
return errors.New("retry")
})
if !errors.Is(err, permanent) {
t.Errorf("err = %v, want permanent", err)
}
if n != 2 {
t.Errorf("n = %d, want 2", n)
}
})
}

View File

@@ -3,37 +3,46 @@
package backoff
import (
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoopAllocs(t *testing.T) {
var errRetry = errors.New("retry")
func TestRetryAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
tick := 0
Retry(t.Context(), 1, func(err error) bool { return true }, func() error {
tick++
if tick >= i {
break
return nil
}
}
return errRetry
})
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
t.Errorf("[%d ticks]: allocs = %v, want <= %v", i, got, want)
}
}
}
func BenchmarkLoop(b *testing.B) {
func BenchmarkRetry(b *testing.B) {
ctx := b.Context()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
n := 0
Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error {
n++
if n == b.N {
break
return nil
}
}
return errRetry
})
})
}

View File

@@ -231,7 +231,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
p, err := decodeParams(r.Body)
if err != nil {
return err
}
@@ -261,7 +261,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
p, err := decodeParams(r.Body)
if err != nil {
return err
}
@@ -293,10 +293,14 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
}
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
// ticker controls periodic progress flushing. It starts paused (very long
// interval) and is activated by start() once all layers are registered,
// so clients see a complete total before progress begins.
ticker := time.NewTicker(1 << 62) // effectively paused until started
defer ticker.Stop()
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
t.Reset(100 * time.Millisecond)
flushProgress()
ticker.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
@@ -320,36 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
}()
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
},
})
done := make(chan error, 1)
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
if canRetry(err) {
continue
}
return err
}
return nil
go func() {
done <- backoff.Retry(ctx, 3*time.Second, canRetry, func() error {
return s.Client.Pull(ctx, p.model())
})
}()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for {
select {
case <-t.C:
case <-ticker.C:
flushProgress()
case err := <-done:
flushProgress()
@@ -374,20 +363,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
}
func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T
err := json.NewDecoder(r).Decode(&v)
func decodeParams(r io.Reader) (*params, error) {
var p params
err := json.NewDecoder(r).Decode(&p)
if err == nil {
return v, nil
return &p, nil
}
var zero T
// Not sure why, but I can't seem to be able to use:
//
// errors.As(err, &json.UnmarshalTypeError{})
//
// This is working fine in stdlib, so I'm not sure what rules changed
// and why this no longer works here. So, we do it the verbose way.
var a *json.UnmarshalTypeError
var b *json.SyntaxError
if errors.As(err, &a) || errors.As(err, &b) {
@@ -396,7 +378,7 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
if errors.Is(err, io.EOF) {
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
}
return zero, err
return nil, err
}
func canRetry(err error) bool {
@@ -408,10 +390,8 @@ func canRetry(err error) bool {
return oe.Temporary()
}
s := err.Error()
return cmp.Or(
errors.Is(err, context.DeadlineExceeded),
strings.Contains(s, "unreachable"),
strings.Contains(s, "no route to host"),
strings.Contains(s, "connection reset by peer"),
)
return errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(s, "unreachable") ||
strings.Contains(s, "no route to host") ||
strings.Contains(s, "connection reset by peer")
}

View File

@@ -1,8 +0,0 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

View File

@@ -1,17 +0,0 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}