Compare commits

...

4 Commits

Author SHA1 Message Date
jmorganca
bf63d18b11 linter 2025-12-20 22:46:46 -08:00
jmorganca
bdb9ea4772 cleanup 2025-12-20 22:37:14 -08:00
jmorganca
9a8c2a4635 revert unwanted changes 2025-12-20 22:00:35 -08:00
jmorganca
2aee6c172b server: stream hash verification during download
Hash blob data while downloading (by trying to using page cache as much as possible)
instead of after, improving download speeds. Add configurable download concurrency
(default 48) and part size (default 64MB) for faster downloads on high-bandwidth
connections.
2025-12-20 21:43:18 -08:00
4 changed files with 576 additions and 110 deletions

View File

@@ -2,9 +2,11 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"log/slog"
"math"
@@ -13,6 +15,7 @@ import (
"net/url"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
@@ -31,9 +34,38 @@ const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errPartSlow = errors.New("part too slow")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
// speedTracker tracks download speeds and computes rolling median.
type speedTracker struct {
mu sync.Mutex
speeds []float64 // bytes per second
}
func (s *speedTracker) Record(bytesPerSec float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.speeds = append(s.speeds, bytesPerSec)
// Keep last 30 samples (flushes stale speeds faster when conditions change)
if len(s.speeds) > 30 {
s.speeds = s.speeds[1:]
}
}
func (s *speedTracker) Median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 10 {
return 0 // not enough data for reliable median
}
sorted := slices.Clone(s.speeds)
slices.Sort(sorted)
return sorted[len(sorted)/2]
}
var blobDownloadManager sync.Map
type blobDownload struct {
@@ -58,9 +90,6 @@ type blobDownloadPart struct {
Size int64
Completed atomic.Int64
lastUpdatedMu sync.Mutex
lastUpdated time.Time
*blobDownload `json:"-"`
}
@@ -94,32 +123,130 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil
}
const (
numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte
var (
downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 32)
)
func envInt(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if v, err := strconv.Atoi(s); err == nil {
return v
}
}
return defaultVal
}
// streamHasher reads a file sequentially and hashes it as chunks complete.
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
// Works by trying to read from OS page cache - data just written should still be in RAM.
type streamHasher struct {
file *os.File
hasher hash.Hash
parts []*blobDownloadPart
total int64 // total bytes to hash
hashed atomic.Int64
mu sync.Mutex
cond *sync.Cond
completed []bool
done bool
err error
}
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
h := &streamHasher{
file: file,
hasher: sha256.New(),
parts: parts,
total: total,
completed: make([]bool, len(parts)),
}
h.cond = sync.NewCond(&h.mu)
return h
}
// Done signals that a part has been written to disk.
func (h *streamHasher) Done(partIndex int) {
h.mu.Lock()
h.completed[partIndex] = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Run reads and hashes the file sequentially
func (h *streamHasher) Run() {
buf := make([]byte, 64*1024) // 64KB read buffer
var offset int64
for i, part := range h.parts {
// Wait for this part to be written
h.mu.Lock()
for !h.completed[i] && !h.done {
h.cond.Wait()
}
if h.done {
h.mu.Unlock()
return
}
h.mu.Unlock()
// Read and hash part
remaining := part.Size
for remaining > 0 {
n := int64(len(buf))
if n > remaining {
n = remaining
}
nr, err := h.file.ReadAt(buf[:n], offset)
if err != nil && err != io.EOF {
h.mu.Lock()
h.err = err
h.mu.Unlock()
return
}
h.hasher.Write(buf[:nr])
offset += int64(nr)
remaining -= int64(nr)
h.hashed.Store(offset)
}
}
}
// Stop signals the hasher to exit early.
func (h *streamHasher) Stop() {
h.mu.Lock()
h.done = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Hashed returns bytes hashed so far.
func (h *streamHasher) Hashed() int64 {
return h.hashed.Load()
}
// Digest returns the computed hash.
func (h *streamHasher) Digest() string {
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
}
// Err returns any error from hashing.
func (h *streamHasher) Err() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.err
}
func (p *blobDownloadPart) Name() string {
return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
}, "-")
}
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdatedMu.Lock()
p.lastUpdated = time.Now()
p.lastUpdatedMu.Unlock()
return n, nil
}
@@ -151,14 +278,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
}
size := downloadPartSize
var offset int64
for offset < b.Total {
if offset+size > b.Total {
@@ -270,44 +390,75 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Download chunks to disk, hash sequentially
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
tracker := &speedTracker{}
hashDone := make(chan struct{})
go func() {
sh.Run()
close(hashDone)
}()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
g.SetLimit(downloadConcurrency)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
sh.Done(part.N)
continue
}
g.Go(func() error {
var err error
var slowRetries int
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part)
// After 3 slow retries, stop checking slowness and let it complete
skipSlowCheck := slowRetries >= 3
err = b.downloadChunk(inner, directURL, file, part, tracker, skipSlowCheck)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case errors.Is(err, errPartSlow):
// Kill slow request, retry immediately (stays within concurrency limit)
slowRetries++
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
default:
sh.Done(part.N)
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
if err := g.Wait(); err != nil {
sh.Stop()
return err
}
// Wait for hasher to finish
<-hashDone
if err := sh.Err(); err != nil {
return err
}
// Verify hash
if computed := sh.Digest(); computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
// explicitly close the file so we can rename it
if err := file.Close(); err != nil {
return err
@@ -326,38 +477,64 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
// downloadChunk streams a part directly to disk at its offset.
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
g, ctx := errgroup.WithContext(ctx)
startTime := time.Now()
var bytesAtLastCheck atomic.Int64
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
w := io.NewOffsetWriter(file, part.Offset)
buf := make([]byte, 32*1024)
var written int64
for written < part.Size {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
written += int64(n)
b.Completed.Add(int64(n))
bytesAtLastCheck.Store(written)
}
if err == io.EOF {
break
}
if err != nil {
b.Completed.Add(-written)
return err
}
}
part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil {
return err
// Record speed for this part
elapsed := time.Since(startTime).Seconds()
if elapsed > 0 {
tracker.Record(float64(part.Size) / elapsed)
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
part.Completed.Store(part.Size)
return b.writePart(part.Name(), part)
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var lastBytes int64
checksWithoutProgress := 0
for {
select {
case <-ticker.C:
@@ -365,19 +542,35 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
return nil
}
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
currentBytes := bytesAtLastCheck.Load()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled
// Check for stall (no progress for 10 seconds)
if currentBytes == lastBytes {
checksWithoutProgress++
if checksWithoutProgress >= 10 {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
return errPartStalled
}
} else {
checksWithoutProgress = 0
}
lastBytes = currentBytes
// Check for slow speed after 5+ seconds (only for multi-part downloads)
// Skip if we've already retried for slowness too many times
elapsed := time.Since(startTime).Seconds()
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median()
// If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow
}
}
case <-ctx.Done():
return ctx.Err()
}
@@ -463,21 +656,21 @@ type downloadOpts struct {
}
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
func downloadBlob(ctx context.Context, opts downloadOpts) error {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
return fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
}
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return false, err
return err
}
fi, err := os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return false, err
return err
default:
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
@@ -486,7 +679,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
Completed: fi.Size(),
})
return true, nil
return nil
}
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
@@ -496,12 +689,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err
return err
}
//nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}
return false, download.Wait(ctx, opts.fn)
return download.Wait(ctx, opts.fn)
}

320
server/download_test.go Normal file
View File

@@ -0,0 +1,320 @@
package server
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"os"
"sync"
"testing"
)
func TestSpeedTracker_Median(t *testing.T) {
s := &speedTracker{}
// Less than 10 samples returns 0
for i := range 9 {
s.Record(float64(100 + i*10))
}
if got := s.Median(); got != 0 {
t.Errorf("expected 0 with < 10 samples, got %f", got)
}
// With 10+ samples, returns median
s.Record(190)
// Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 150
if got := s.Median(); got != 150 {
t.Errorf("expected median 150, got %f", got)
}
// Add more samples
s.Record(50)
// Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 50]
// sorted = [50, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 140
if got := s.Median(); got != 140 {
t.Errorf("expected median 140, got %f", got)
}
}
func TestSpeedTracker_RollingWindow(t *testing.T) {
s := &speedTracker{}
// Add 35 samples (should keep only last 30)
for i := range 35 {
s.Record(float64(i))
}
s.mu.Lock()
if len(s.speeds) != 30 {
t.Errorf("expected 30 samples, got %d", len(s.speeds))
}
// First sample should be 5 (0-4 were dropped)
if s.speeds[0] != 5 {
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
}
s.mu.Unlock()
}
func TestSpeedTracker_Concurrent(t *testing.T) {
s := &speedTracker{}
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(v int) {
defer wg.Done()
s.Record(float64(v))
s.Median() // concurrent read
}(i)
}
wg.Wait()
// Should not panic, and should have reasonable state
s.mu.Lock()
if len(s.speeds) == 0 || len(s.speeds) > 100 {
t.Errorf("unexpected speeds length: %d", len(s.speeds))
}
s.mu.Unlock()
}
func TestStreamHasher_Sequential(t *testing.T) {
// Create temp file
f, err := os.CreateTemp(t.TempDir(), "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data
data := []byte("hello world, this is a test of the stream hasher")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create parts
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(len(data))},
}
sh := newStreamHasher(f, parts, int64(len(data)))
// Mark complete and run
sh.Done(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
if err := sh.Err(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
// Create temp file
f, err := os.CreateTemp(t.TempDir(), "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data (3 parts of 10 bytes each)
data := []byte("0123456789ABCDEFGHIJabcdefghij")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create 3 parts
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 10},
{N: 1, Offset: 10, Size: 10},
{N: 2, Offset: 20, Size: 10},
}
sh := newStreamHasher(f, parts, int64(len(data)))
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Mark parts complete out of order: 2, 0, 1
sh.Done(2)
sh.Done(0) // This should trigger hashing of part 0
sh.Done(1) // This should trigger hashing of parts 1 and 2
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
}
func TestStreamHasher_Stop(t *testing.T) {
// Create temp file
f, err := os.CreateTemp(t.TempDir(), "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
parts := []*blobDownloadPart{
{Offset: 0, Size: 100},
}
sh := newStreamHasher(f, parts, 100)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Stop without completing any parts
sh.Stop()
<-done
// Should exit cleanly without error
if err := sh.Err(); err != nil {
t.Errorf("unexpected error after Stop: %v", err)
}
}
func TestStreamHasher_HashedProgress(t *testing.T) {
// Create temp file with known data
f, err := os.CreateTemp(t.TempDir(), "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
data := make([]byte, 1000)
rand.Read(data)
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 500},
{N: 1, Offset: 500, Size: 500},
}
sh := newStreamHasher(f, parts, 1000)
// Initially no progress
if got := sh.Hashed(); got != 0 {
t.Errorf("expected 0 hashed initially, got %d", got)
}
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Complete part 0
sh.Done(0)
// Give hasher time to process
for range 100 {
if sh.Hashed() >= 500 {
break
}
}
// Complete part 1
sh.Done(1)
<-done
if got := sh.Hashed(); got != 1000 {
t.Errorf("expected 1000 hashed, got %d", got)
}
}
func BenchmarkSpeedTracker_Record(b *testing.B) {
s := &speedTracker{}
b.ResetTimer()
for i := range b.N {
s.Record(float64(i))
}
}
func BenchmarkSpeedTracker_Median(b *testing.B) {
s := &speedTracker{}
// Pre-populate with 100 samples
for i := range 100 {
s.Record(float64(i))
}
b.ResetTimer()
for range b.N {
s.Median()
}
}
func BenchmarkStreamHasher(b *testing.B) {
// Create temp file with test data
f, err := os.CreateTemp(b.TempDir(), "streamhasher_bench")
if err != nil {
b.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
size := 64 * 1024 * 1024 // 64MB
data := make([]byte, size)
rand.Read(data)
if _, err := f.Write(data); err != nil {
b.Fatal(err)
}
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(size)},
}
b.SetBytes(int64(size))
b.ResetTimer()
for range b.N {
sh := newStreamHasher(f, parts, int64(size))
sh.Done(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
}
}
func BenchmarkHashThroughput(b *testing.B) {
// Baseline: raw SHA256 throughput on this machine
size := 256 * 1024 * 1024 // 256MB
data := make([]byte, size)
rand.Read(data)
b.SetBytes(int64(size))
b.ResetTimer()
for range b.N {
h := sha256.New()
h.Write(data)
h.Sum(nil)
}
}

View File

@@ -620,43 +620,19 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
if err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
})
if err != nil {
}); err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
@@ -859,25 +835,3 @@ func parseRegistryChallenge(authStr string) registryChallenge {
Scope: getValue(authStr, "scope"),
}
}
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
}
return nil
}

View File

@@ -2395,4 +2395,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}