mirror of
https://github.com/ollama/ollama.git
synced 2026-01-03 04:59:19 -05:00
Compare commits
10 Commits
pdevine/lo
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7c751647d | ||
|
|
4e320b8b90 | ||
|
|
eb2b22b042 | ||
|
|
4ea4d2b189 | ||
|
|
8d76fa23ef | ||
|
|
74b44fdf8f | ||
|
|
65b88c544f | ||
|
|
a422ba39c9 | ||
|
|
d2ec22371e | ||
|
|
033cec232a |
@@ -187,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
||||
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||
|
||||
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||
|
||||
```
|
||||
# Allow all Chrome, Firefox, and Safari extensions
|
||||
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||
```
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Where are models stored?
|
||||
|
||||
@@ -583,39 +583,52 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
if llm.KV().Uint("vision.block_count") == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for name, layer := range llm.Tensors().GroupLayers() {
|
||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||
for _, tensor := range layer {
|
||||
weights += tensor.Size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||
if patchSize == 0 {
|
||||
slog.Warn("unknown patch size for vision model")
|
||||
return
|
||||
}
|
||||
|
||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||
|
||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
imageSize*imageSize*numChannels*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
case "gemma3":
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
}
|
||||
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
|
||||
@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
memoryWeights += layerSize
|
||||
|
||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||
),
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
// TODO: need to set this in the conversion for mistral:
|
||||
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
@@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
@@ -75,24 +78,36 @@ type SelfAttention struct {
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
// Get head dimension - use explicit value if available, otherwise calculate
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
}
|
||||
|
||||
// Query projection and reshape
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Key projection and reshape
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Value projection and reshape
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Attention computation
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
// Reshape attention output for final projection
|
||||
outputDim := headDim * opts.numHeads
|
||||
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||
|
||||
// Apply output projection
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
|
||||
@@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// tekken loads the Tekken tokenizer for testing
|
||||
func tekken(t testing.TB) TextProcessor {
|
||||
t.Helper()
|
||||
|
||||
// Load tokenizer config from mistral-small
|
||||
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
|
||||
configFile, err := os.Open(tokenizerConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
var config struct {
|
||||
AddBosToken bool `json:"add_bos_token"`
|
||||
AddEosToken bool `json:"add_eos_token"`
|
||||
BosToken struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"bos_token"`
|
||||
EosToken struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Load tokenizer.json which contains the vocabulary and other settings
|
||||
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
|
||||
tokenizerFile, err := os.Open(tokenizerJsonPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tokenizerFile.Close()
|
||||
|
||||
var tokenizerData struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
Id int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
Behavior string `json:"behavior"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
}
|
||||
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Extract the pattern from pre_tokenizer if available
|
||||
var pattern string
|
||||
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
|
||||
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
|
||||
}
|
||||
|
||||
// Combine regular vocab and added tokens
|
||||
vocab := tokenizerData.Model.Vocab
|
||||
|
||||
// Add special tokens from added_tokens
|
||||
for _, token := range tokenizerData.AddedTokens {
|
||||
vocab[token.Content] = token.Id
|
||||
}
|
||||
|
||||
// Create vocabulary arrays
|
||||
maxId := int32(-1)
|
||||
for _, id := range vocab {
|
||||
if id > maxId {
|
||||
maxId = id
|
||||
}
|
||||
}
|
||||
|
||||
vocabSize := int(maxId + 1)
|
||||
types := make([]uint32, vocabSize)
|
||||
tokens := make([]string, vocabSize)
|
||||
scores := make([]float32, vocabSize)
|
||||
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = TOKEN_TYPE_NORMAL
|
||||
|
||||
// Assign appropriate token types for special tokens
|
||||
if token == "<s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "</s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "[INST]" || token == "[/INST]" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
}
|
||||
}
|
||||
|
||||
// In Tekken, we don't need to load merges separately as they're part of the model
|
||||
var merges []string
|
||||
|
||||
// Create vocabulary object
|
||||
vocabObj := &Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Scores: scores,
|
||||
Merges: merges,
|
||||
BOS: vocab[config.BosToken.Content],
|
||||
EOS: vocab[config.EosToken.Content],
|
||||
AddBOS: config.AddBosToken,
|
||||
AddEOS: config.AddEosToken,
|
||||
}
|
||||
|
||||
// Use pattern from tokenizer.json if available
|
||||
if pattern != "" {
|
||||
// Ensure pattern has proper escaping for Go regexp
|
||||
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
|
||||
return NewBytePairEncoding(pattern, vocabObj)
|
||||
}
|
||||
|
||||
// Fallback pattern if not found
|
||||
return NewBytePairEncoding(
|
||||
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
|
||||
vocabObj,
|
||||
)
|
||||
}
|
||||
|
||||
func TestTekken(t *testing.T) {
|
||||
// Skip if the test data isn't available
|
||||
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
|
||||
t.Skip("Mistral-small test data not available")
|
||||
}
|
||||
|
||||
tokenizer := tekken(t)
|
||||
|
||||
t.Run("whitespace_handling", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{" hello", " hello"},
|
||||
{"hello ", "hello "},
|
||||
{"hello world", "hello world"},
|
||||
{" hello world ", " hello world "},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != tc.expected {
|
||||
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat_templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the Tekken chat template format which doesn't have spaces after special tokens
|
||||
templates := []struct {
|
||||
input string
|
||||
expectSpace bool // whether we expect a space after special tokens
|
||||
}{
|
||||
{"<s>[INST]user message[/INST]", false},
|
||||
{"<s>[INST] user message[/INST]", true},
|
||||
{"<s>[INST]user message [/INST]", true},
|
||||
}
|
||||
|
||||
for _, tc := range templates {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if there's a space after special tokens
|
||||
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
|
||||
|
||||
if hasSpaceAfterINST != tc.expectSpace {
|
||||
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
|
||||
hasSpaceAfterINST, tc.expectSpace, tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test how Tekken handles special tokens
|
||||
cases := []struct {
|
||||
input string
|
||||
expected []string // We'll check if these tokens are in the decoded output
|
||||
}{
|
||||
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
|
||||
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
|
||||
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, expected := range tc.expected {
|
||||
if !strings.Contains(decoded, expected) {
|
||||
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vocabulary_coverage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Tekken has a larger vocabulary, so test coverage of various token types
|
||||
samples := []string{
|
||||
"Hello world!",
|
||||
"This is a test of the Tekken tokenizer.",
|
||||
"It has a considerably larger vocabulary size.",
|
||||
"Special characters: !@#$%^&*()",
|
||||
"Numbers: 1234567890",
|
||||
"Multiple languages: こんにちは 你好 안녕하세요",
|
||||
"Code snippets: def function(): return True",
|
||||
}
|
||||
|
||||
for _, sample := range samples {
|
||||
ids, err := tokenizer.Encode(sample, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != sample {
|
||||
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("splitting_behavior", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the splitting behavior which might differ from SentencePiece
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"user message": {"user", " message"},
|
||||
"[INST]hello": {"[INST]", "hello"},
|
||||
"hello[/INST]": {"hello", "[/INST]"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("full_chat_sequence", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test a complete chat sequence with Tekken's format
|
||||
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
|
||||
|
||||
ids, err := tokenizer.Encode(chatSequence, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode chat sequence: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
|
||||
}
|
||||
|
||||
// In Tekken, the whitespace shouldn't be added after special tokens
|
||||
if strings.Contains(decoded, "[INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
|
||||
}
|
||||
|
||||
if strings.Contains(decoded, "[/INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
|
||||
54
server/internal/cache/blob/cache.go
vendored
54
server/internal/cache/blob/cache.go
vendored
@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
|
||||
// be in either of the following forms:
|
||||
//
|
||||
// @<digest>
|
||||
// <name>
|
||||
// <name>@<digest>
|
||||
// <name>
|
||||
//
|
||||
// If a digest is provided, it is returned as is and nothing else happens.
|
||||
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
|
||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||
// these cases.
|
||||
//
|
||||
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
|
||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||
name, digest := splitNameDigest(name)
|
||||
if digest != "" {
|
||||
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
// TODO(bmizerany): Move link handling from cache to registry.
|
||||
//
|
||||
// We originally placed links in the cache due to its storage
|
||||
// knowledge. However, the registry likely offers better context for
|
||||
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||
// our on-disk format.
|
||||
//
|
||||
// Links work effectively when independent from physical location -
|
||||
// they can reference content with matching SHA regardless of storage
|
||||
// location. In an upcoming change, we plan to shift this
|
||||
// responsibility to the registry where it better aligns with the
|
||||
// system's conceptual model.
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
|
||||
return absJoin(c.dir, "blobs", filename)
|
||||
}
|
||||
|
||||
// Links returns a sequence of links in the cache in lexical order.
|
||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
||||
// Names are converted from their relative path form to their name form but are
|
||||
// not guaranteed to be valid. Callers should validate the names before using.
|
||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for path, err := range c.links() {
|
||||
@@ -414,12 +402,14 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
|
||||
}
|
||||
|
||||
type checkWriter struct {
|
||||
d Digest
|
||||
size int64
|
||||
n int64
|
||||
h hash.Hash
|
||||
d Digest
|
||||
f *os.File
|
||||
err error
|
||||
h hash.Hash
|
||||
|
||||
w io.Writer // underlying writer; set by creator
|
||||
n int64
|
||||
err error
|
||||
|
||||
testHookBeforeFinalWrite func(*os.File)
|
||||
}
|
||||
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
|
||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||
// hash.
|
||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
_, err := w.h.Write(p)
|
||||
if err != nil {
|
||||
return 0, w.seterr(err)
|
||||
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
||||
if nextSize > w.size {
|
||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||
}
|
||||
n, err := w.f.Write(p)
|
||||
n, err := w.w.Write(p)
|
||||
w.n += int64(n)
|
||||
return n, w.seterr(err)
|
||||
}
|
||||
@@ -493,10 +487,12 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
cw := &checkWriter{
|
||||
d: out,
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
d: out,
|
||||
size: size,
|
||||
h: sha256.New(),
|
||||
f: f,
|
||||
w: f,
|
||||
|
||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||
}
|
||||
n, err := io.Copy(cw, file)
|
||||
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
|
||||
var errInvalidName = errors.New("invalid name")
|
||||
|
||||
func nameToPath(name string) (_ string, err error) {
|
||||
if strings.Contains(name, "@") {
|
||||
// TODO(bmizerany): HACK: Fix names.Parse to validate.
|
||||
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
|
||||
return "", errInvalidName
|
||||
}
|
||||
n := names.Parse(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", errInvalidName
|
||||
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
|
||||
func absJoin(pp ...string) string {
|
||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||
if err != nil {
|
||||
// Likely a bug bug or a bad OS problem. Just panic.
|
||||
panic(err)
|
||||
panic(err) // this should never happen
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
73
server/internal/cache/blob/chunked.go
vendored
Normal file
73
server/internal/cache/blob/chunked.go
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
package blob
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Chunk represents a range of bytes in a blob.
|
||||
type Chunk struct {
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// Chunker writes to a blob in chunks.
|
||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
||||
type Chunker struct {
|
||||
digest Digest
|
||||
size int64
|
||||
f *os.File // nil means pre-validated
|
||||
}
|
||||
|
||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
||||
// size in chunks.
|
||||
//
|
||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
||||
name := c.GetFile(d)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
return &Chunker{}, nil
|
||||
}
|
||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Chunker{digest: d, size: size, f: f}, nil
|
||||
}
|
||||
|
||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
||||
// merging the data with the existing blob. It returns an error if any. As a
|
||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
||||
// io.ErrUnexpectedEOF.
|
||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
||||
if c.f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cw := &checkWriter{
|
||||
d: d,
|
||||
size: chunk.Size(),
|
||||
h: sha256.New(),
|
||||
f: c.f,
|
||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
||||
}
|
||||
|
||||
_, err := io.CopyN(cw, r, chunk.Size())
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying file.
|
||||
func (c *Chunker) Close() error {
|
||||
return c.f.Close()
|
||||
}
|
||||
4
server/internal/cache/blob/digest.go
vendored
4
server/internal/cache/blob/digest.go
vendored
@@ -63,6 +63,10 @@ func (d Digest) Short() string {
|
||||
return fmt.Sprintf("%x", d.sum[:4])
|
||||
}
|
||||
|
||||
func (d Digest) Sum() [32]byte {
|
||||
return d.sum
|
||||
}
|
||||
|
||||
func (d Digest) Compare(other Digest) int {
|
||||
return slices.Compare(d.sum[:], other.sum[:])
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Chunk struct {
|
||||
Start, End int64
|
||||
}
|
||||
|
||||
func New(start, end int64) Chunk {
|
||||
return Chunk{start, end}
|
||||
}
|
||||
|
||||
// ParseRange parses a string in the form "unit=range" where unit is a string
|
||||
// and range is a string in the form "start-end". It returns the unit and the
|
||||
// range as a Chunk.
|
||||
func ParseRange(s string) (unit string, _ Chunk, _ error) {
|
||||
unit, r, _ := strings.Cut(s, "=")
|
||||
if r == "" {
|
||||
return unit, Chunk{}, nil
|
||||
}
|
||||
c, err := Parse(r)
|
||||
if err != nil {
|
||||
return "", Chunk{}, err
|
||||
}
|
||||
return unit, c, err
|
||||
}
|
||||
|
||||
// Parse parses a string in the form "start-end" and returns the Chunk.
|
||||
func Parse(s string) (Chunk, error) {
|
||||
startStr, endStr, _ := strings.Cut(s, "-")
|
||||
start, err := strconv.ParseInt(startStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid start: %v", err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endStr, 10, 64)
|
||||
if err != nil {
|
||||
return Chunk{}, fmt.Errorf("invalid end: %v", err)
|
||||
}
|
||||
if start > end {
|
||||
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
|
||||
}
|
||||
return Chunk{start, end}, nil
|
||||
}
|
||||
|
||||
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
|
||||
// the range [0, size), in order.
|
||||
func Of(size, chunkSize int64) iter.Seq[Chunk] {
|
||||
return func(yield func(Chunk) bool) {
|
||||
for start := int64(0); start < size; start += chunkSize {
|
||||
end := min(start+chunkSize-1, size-1)
|
||||
if !yield(Chunk{start, end}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of Chunks of size chunkSize needed to cover the
|
||||
// range [0, size).
|
||||
func Count(size, chunkSize int64) int64 {
|
||||
return (size + chunkSize - 1) / chunkSize
|
||||
}
|
||||
|
||||
// Size returns end minus start plus one.
|
||||
func (c Chunk) Size() int64 {
|
||||
return c.End - c.Start + 1
|
||||
}
|
||||
|
||||
// String returns the string representation of the Chunk in the form
|
||||
// "{start}-{end}".
|
||||
func (c Chunk) String() string {
|
||||
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package chunks
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOf(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want []Chunk
|
||||
}{
|
||||
{0, 1, nil},
|
||||
{1, 1, []Chunk{{0, 0}}},
|
||||
{1, 2, []Chunk{{0, 0}}},
|
||||
{2, 1, []Chunk{{0, 0}, {1, 1}}},
|
||||
{10, 9, []Chunk{{0, 8}, {9, 9}}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := slices.Collect(Of(tt.total, tt.chunkSize))
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
cases := []struct {
|
||||
c Chunk
|
||||
want int64
|
||||
}{
|
||||
{Chunk{0, 0}, 1},
|
||||
{Chunk{0, 1}, 2},
|
||||
{Chunk{3, 4}, 2},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.c.Size()
|
||||
if got != tt.want {
|
||||
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
cases := []struct {
|
||||
total int64
|
||||
chunkSize int64
|
||||
want int64
|
||||
}{
|
||||
{0, 1, 0},
|
||||
{1, 1, 1},
|
||||
{1, 2, 1},
|
||||
{2, 1, 2},
|
||||
{10, 9, 2},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
got := Count(tt.total, tt.chunkSize)
|
||||
if got != tt.want {
|
||||
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -35,10 +36,8 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||
"github.com/ollama/ollama/server/internal/internal/names"
|
||||
"github.com/ollama/ollama/server/internal/internal/syncs"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
@@ -66,12 +65,7 @@ var (
|
||||
const (
|
||||
// DefaultChunkingThreshold is the threshold at which a layer should be
|
||||
// split up into chunks when downloading.
|
||||
DefaultChunkingThreshold = 128 << 20
|
||||
|
||||
// DefaultMaxChunkSize is the default maximum size of a chunk to
|
||||
// download. It is configured based on benchmarks and aims to strike a
|
||||
// balance between download speed and memory usage.
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
DefaultChunkingThreshold = 64 << 20
|
||||
)
|
||||
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
@@ -211,8 +205,7 @@ type Registry struct {
|
||||
// pushing or pulling models. If zero, the number of streams is
|
||||
// determined by [runtime.GOMAXPROCS].
|
||||
//
|
||||
// Clients that want "unlimited" streams should set this to a large
|
||||
// number.
|
||||
// A negative value means no limit.
|
||||
MaxStreams int
|
||||
|
||||
// ChunkingThreshold is the maximum size of a layer to download in a single
|
||||
@@ -282,24 +275,13 @@ func DefaultRegistry() (*Registry, error) {
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
|
||||
// Large downloads require a writter stream, so ensure we have at least
|
||||
// two streams to avoid a deadlock.
|
||||
return max(n, 2)
|
||||
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
}
|
||||
|
||||
func (r *Registry) maxChunkingThreshold() int64 {
|
||||
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
||||
}
|
||||
|
||||
// chunkSizeFor returns the chunk size for a layer of the given size. If the
|
||||
// size is less than or equal to the max chunking threshold, the size is
|
||||
// returned; otherwise, the max chunk size is returned.
|
||||
func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
@@ -426,6 +408,21 @@ func canRetry(err error) bool {
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
// It always calls update with a nil error.
|
||||
type trackingReader struct {
|
||||
r io.Reader
|
||||
n *atomic.Int64
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
return
|
||||
}
|
||||
|
||||
// Pull pulls the model with the given name from the remote registry into the
|
||||
// cache.
|
||||
//
|
||||
@@ -434,11 +431,6 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.Resolve(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -457,126 +449,95 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err == nil && info.Size == l.Size
|
||||
}
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
|
||||
layers := m.Layers
|
||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||
layers = append(layers, m.Config)
|
||||
}
|
||||
|
||||
for _, l := range layers {
|
||||
// Send initial layer trace events to allow clients to have an
|
||||
// understanding of work to be done before work starts.
|
||||
t := traceFromContext(ctx)
|
||||
skip := make([]bool, len(layers))
|
||||
for i, l := range layers {
|
||||
t.update(l, 0, nil)
|
||||
if exists(l) {
|
||||
skip[i] = true
|
||||
t.update(l, l.Size, ErrCached)
|
||||
}
|
||||
}
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(r.maxStreams())
|
||||
for i, l := range layers {
|
||||
if skip[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
|
||||
req, err := r.newRequest(ctx, "GET", blobURL, nil)
|
||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||
if err != nil {
|
||||
t.update(l, 0, err)
|
||||
continue
|
||||
}
|
||||
defer chunked.Close()
|
||||
|
||||
t.update(l, 0, nil)
|
||||
|
||||
if l.Size <= r.maxChunkingThreshold() {
|
||||
g.Go(func() error {
|
||||
// TODO(bmizerany): retry/backoff like below in
|
||||
// the chunking case
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
err = c.Put(l.Digest, res.Body, l.Size)
|
||||
if err == nil {
|
||||
t.update(l, l.Size, nil)
|
||||
}
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
q := syncs.NewRelayReader()
|
||||
var progress atomic.Int64
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
t.update(l, progress.Load(), err)
|
||||
break
|
||||
}
|
||||
|
||||
g.Go(func() (err error) {
|
||||
defer func() { q.CloseWithError(err) }()
|
||||
return c.Put(l.Digest, q, l.Size)
|
||||
})
|
||||
defer func() { t.update(l, progress.Load(), err) }()
|
||||
|
||||
var progress atomic.Int64
|
||||
|
||||
// We want to avoid extra round trips per chunk due to
|
||||
// redirects from the registry to the blob store, so
|
||||
// fire an initial request to get the final URL and
|
||||
// then use that URL for the chunk requests.
|
||||
req.Header.Set("Range", "bytes=0-0")
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Body.Close()
|
||||
req = res.Request.WithContext(req.Context())
|
||||
|
||||
wp := writerPool{size: r.maxChunkSize()}
|
||||
|
||||
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
|
||||
ticket := q.Take()
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
q.CloseWithError(err)
|
||||
}
|
||||
ticket.Close()
|
||||
t.update(l, progress.Load(), err)
|
||||
}()
|
||||
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := func() error {
|
||||
req := req.Clone(req.Context())
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
tw := wp.get()
|
||||
tw.Reset(ticket)
|
||||
defer wp.put(tw)
|
||||
|
||||
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||
if err != nil {
|
||||
return maybeUnexpectedEOF(err)
|
||||
}
|
||||
if err := tw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
total := progress.Add(chunk.Size())
|
||||
if total >= l.Size {
|
||||
q.Close()
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if !canRetry(err) {
|
||||
return err
|
||||
}
|
||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err := func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Count bytes towards
|
||||
// progress, as they arrive, so
|
||||
// that our bytes piggyback
|
||||
// other chunk updates on
|
||||
// completion.
|
||||
//
|
||||
// This tactic is enough to
|
||||
// show "smooth" progress given
|
||||
// the current CLI client. In
|
||||
// the near future, the server
|
||||
// should report download rate
|
||||
// since it knows better than
|
||||
// a client that is measuring
|
||||
// rate based on wall-clock
|
||||
// time-since-last-update.
|
||||
body := &trackingReader{r: res.Body, n: &progress}
|
||||
|
||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if !canRetry(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -615,8 +576,6 @@ type Manifest struct {
|
||||
Config *Layer `json:"config"`
|
||||
}
|
||||
|
||||
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||
|
||||
// Layer returns the layer with the given
|
||||
// digest, or nil if not found.
|
||||
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
@@ -643,10 +602,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
|
||||
// last phase of the commit which expects it, but does nothing
|
||||
// with it. This will be fixed in a future release of
|
||||
// ollama.com.
|
||||
Config *Layer `json:"config"`
|
||||
Config Layer `json:"config"`
|
||||
}{
|
||||
M: M(m),
|
||||
Config: &Layer{Digest: emptyDigest},
|
||||
M: M(m),
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
@@ -736,6 +694,123 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type chunksum struct {
|
||||
URL string
|
||||
Chunk blob.Chunk
|
||||
Digest blob.Digest
|
||||
}
|
||||
|
||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
||||
return func(yield func(chunksum, error) bool) {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if l.Size < r.maxChunkingThreshold() {
|
||||
// any layer under the threshold should be downloaded
|
||||
// in one go.
|
||||
cs := chunksum{
|
||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
),
|
||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
||||
Digest: l.Digest,
|
||||
}
|
||||
yield(cs, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
n.Host(),
|
||||
n.Namespace(),
|
||||
n.Model(),
|
||||
l.Digest,
|
||||
)
|
||||
|
||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
res, err := sendRequest(r.client(), req)
|
||||
if err != nil {
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
}
|
||||
return
|
||||
}
|
||||
d, err := blob.ParseDigest(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
if !s.Scan() {
|
||||
err := s.Err()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
||||
}
|
||||
yield(chunksum{}, err)
|
||||
return
|
||||
}
|
||||
chunk, err := parseChunk(s.Bytes())
|
||||
if err != nil {
|
||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
Digest: d,
|
||||
}
|
||||
if !yield(cs, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) client() *http.Client {
|
||||
if r.HTTPClient != nil {
|
||||
return r.HTTPClient
|
||||
@@ -898,13 +973,6 @@ func checkData(url string) string {
|
||||
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
||||
}
|
||||
|
||||
func maybeUnexpectedEOF(err error) error {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
@@ -991,27 +1059,22 @@ func splitExtended(s string) (scheme, name, digest string) {
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
type writerPool struct {
|
||||
size int64 // set by the caller
|
||||
|
||||
mu sync.Mutex
|
||||
ws []*bufio.Writer
|
||||
}
|
||||
|
||||
func (p *writerPool) get() *bufio.Writer {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.ws) == 0 {
|
||||
return bufio.NewWriterSize(nil, int(p.size))
|
||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
||||
if !found {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
||||
}
|
||||
w := p.ws[len(p.ws)-1]
|
||||
p.ws = p.ws[:len(p.ws)-1]
|
||||
return w
|
||||
}
|
||||
|
||||
func (p *writerPool) put(w *bufio.Writer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
w.Reset(nil)
|
||||
p.ws = append(p.ws, w)
|
||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
||||
}
|
||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
||||
if err != nil {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
||||
}
|
||||
if start > end {
|
||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
||||
}
|
||||
return blob.Chunk{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
@@ -428,7 +427,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
@@ -531,54 +530,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
rng := r.Header.Get("Range")
|
||||
if rng == "" {
|
||||
http.Error(w, "missing range", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
io.WriteString(w, "remote"[c.Start:c.End+1])
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
|
||||
})
|
||||
|
||||
// Force chunking by setting the threshold to less than the size of the
|
||||
// layer.
|
||||
rc.ChunkingThreshold = 3
|
||||
rc.MaxChunkSize = 3
|
||||
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
if err != nil {
|
||||
t.Errorf("update %v %d %v", d, n, err)
|
||||
}
|
||||
reads = append(reads, n)
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("reads = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func BenchmarkDownload(b *testing.B) {
|
||||
run := func(fileSize, chunkSize int64) {
|
||||
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
|
||||
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
|
||||
}
|
||||
|
||||
run(100<<20, 8<<20)
|
||||
run(100<<20, 16<<20)
|
||||
run(100<<20, 32<<20)
|
||||
run(100<<20, 64<<20)
|
||||
run(100<<20, 128<<20) // 1 chunk
|
||||
}
|
||||
|
||||
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
|
||||
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||
res, err := c.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
|
||||
return err
|
||||
}
|
||||
|
||||
var sleepTime atomic.Int64
|
||||
|
||||
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
|
||||
client := &http.Client{
|
||||
Transport: func() http.RoundTripper {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DisableKeepAlives = true
|
||||
return tr
|
||||
}(),
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
// warm up the client
|
||||
run(context.Background(), client, chunks.New(0, 1<<20))
|
||||
|
||||
b.SetBytes(fileSize)
|
||||
b.ReportAllocs()
|
||||
|
||||
// Give our CDN a min to breathe between benchmarks.
|
||||
time.Sleep(time.Duration(sleepTime.Swap(3)))
|
||||
|
||||
for b.Loop() {
|
||||
g, ctx := errgroup.WithContext(b.Context())
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for chunk := range chunks.Of(fileSize, chunkSize) {
|
||||
g.Go(func() error { return run(ctx, client, chunk) })
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
|
||||
}
|
||||
|
||||
func benchmarkWrite(b *testing.B, chunkSize int) {
|
||||
b.ReportAllocs()
|
||||
|
||||
dir := b.TempDir()
|
||||
f, err := os.Create(filepath.Join(dir, "write-single"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data := make([]byte, chunkSize)
|
||||
b.SetBytes(int64(chunkSize))
|
||||
r := bytes.NewReader(data)
|
||||
for b.Loop() {
|
||||
r.Reset(data)
|
||||
_, err := io.Copy(f, r)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
// Package registry provides an http.Handler for handling local Ollama API
|
||||
// requests for performing tasks related to the ollama.com model registry and
|
||||
// the local disk cache.
|
||||
// Package registry implements an http.Handler for handling local Ollama API
|
||||
// model management requests. See [Local] for details.
|
||||
package registry
|
||||
|
||||
import (
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,16 +18,11 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
// Local is an http.Handler for handling local Ollama API requests for
|
||||
// performing tasks related to the ollama.com model registry combined with the
|
||||
// local disk cache.
|
||||
// Local implements an http.Handler for handling local Ollama API model
|
||||
// management requests, such as pushing, pulling, and deleting models.
|
||||
//
|
||||
// It is not concern of Local, or this package, to handle model creation, which
|
||||
// proceeds any registry operations for models it produces.
|
||||
//
|
||||
// NOTE: The package built for dealing with model creation should use
|
||||
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||
// directly to the blob disk cache.
|
||||
// It can be arranged for all unknown requests to be passed through to a
|
||||
// fallback handler, if one is provided.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Logger *slog.Logger // required
|
||||
@@ -63,6 +58,7 @@ func (e serverError) Error() string {
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
@@ -175,8 +171,16 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
}
|
||||
|
||||
type params struct {
|
||||
DeprecatedName string `json:"name"` // Use [params.model]
|
||||
Model string `json:"model"` // Use [params.model]
|
||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
||||
// but is deprecated. New clients should use [Model] instead.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
DeprecatedName string `json:"name"`
|
||||
|
||||
// Model is the name of the model to push, pull, or delete.
|
||||
//
|
||||
// Use [model()] to get the model name for both old and new API requests.
|
||||
Model string `json:"model"`
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
@@ -189,9 +193,18 @@ type params struct {
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||
// progress updates.
|
||||
ProgressStream bool `json:"stream"`
|
||||
// Stream, if true, will make the server send progress updates in a
|
||||
// streaming of JSON objects. If false, the server will send a single
|
||||
// JSON object with the final status as "success", or an error object
|
||||
// if an error occurred.
|
||||
//
|
||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||
// defined to default to true if not present, so we need a way to check
|
||||
// if the client decisively it to false. So, we use a pointer to a
|
||||
// bool. Gross.
|
||||
//
|
||||
// Use [stream()] to get the correct value for this field.
|
||||
Stream *bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
@@ -199,6 +212,13 @@ func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (p params) stream() bool {
|
||||
if p.Stream == nil {
|
||||
return true
|
||||
}
|
||||
return *p.Stream
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
@@ -212,16 +232,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
return errModelNotFound
|
||||
}
|
||||
if s.Prune == nil {
|
||||
return nil
|
||||
if s.Prune != nil {
|
||||
return s.Prune()
|
||||
}
|
||||
return s.Prune()
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Status string `json:"status"`
|
||||
Status string `json:"status,omitempty,omitzero"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
@@ -237,6 +257,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
if !p.stream() {
|
||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
return errModelNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
return enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
}
|
||||
|
||||
maybeFlush := func() {
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
@@ -246,69 +277,67 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
defer maybeFlush()
|
||||
|
||||
var mu sync.Mutex
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
progress := make(map[*ollama.Layer]int64)
|
||||
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||
pushUpdate := func() {
|
||||
defer maybeFlush()
|
||||
|
||||
// TODO(bmizerany): coalesce these updates; writing per
|
||||
// update is expensive
|
||||
// TODO(bmizerany): This scales poorly with more layers due to
|
||||
// needing to flush out them all in one big update. We _could_
|
||||
// just flush on the changed ones, or just track the whole
|
||||
// download. Needs more thought. This is fine for now.
|
||||
mu.Lock()
|
||||
maps.Copy(progressCopy, progress)
|
||||
mu.Unlock()
|
||||
for l, n := range progress {
|
||||
enc.Encode(progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Status: "pulling",
|
||||
Total: l.Size,
|
||||
Completed: n,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
||||
start := sync.OnceFunc(func() {
|
||||
pushUpdate()
|
||||
t.Reset(100 * time.Millisecond)
|
||||
})
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
start() // flush initial state
|
||||
}
|
||||
mu.Lock()
|
||||
progress[l] = n
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// TODO(bmizerany): continue to support non-streaming responses
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
}()
|
||||
|
||||
func() {
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
mu.Lock()
|
||||
maybeFlush()
|
||||
mu.Unlock()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
pushUpdate()
|
||||
case err := <-done:
|
||||
pushUpdate()
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
}
|
||||
|
||||
// These final updates are not strictly necessary, because they have
|
||||
// already happened at this point. Our pull handler code used to do
|
||||
// these steps after, not during, the pull, and they were slow, so we
|
||||
// wanted to provide feedback to users what was happening. For now, we
|
||||
// keep them to not jar users who are used to seeing them. We can phase
|
||||
// them out with a new and nicer UX later. One without progress bars
|
||||
// and digests that no one cares about.
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
@@ -160,7 +159,6 @@ var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fmt.Printf("%q\n", a.Comment)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -179,7 +177,7 @@ func TestServerPull(t *testing.T) {
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving file: %s", r.URL.Path)
|
||||
t.Logf("serving blob: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
@@ -188,7 +186,7 @@ func TestServerPull(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
t.Errorf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
@@ -197,35 +195,29 @@ func TestServerPull(t *testing.T) {
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! missing %q in body", want)
|
||||
t.Errorf("! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! unexpected %q in body", want)
|
||||
t.Errorf("! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
{"status":"verifying layers"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
@@ -240,19 +232,39 @@ func TestServerPull(t *testing.T) {
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: invalid or missing name: \"\""}
|
||||
|
||||
!verifying
|
||||
!writing
|
||||
!success
|
||||
`)
|
||||
|
||||
// Non-streaming pulls
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
||||
checkResponse(got, `
|
||||
{"status":"success"}
|
||||
!digest
|
||||
!total
|
||||
!completed
|
||||
`)
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
var fellback bool
|
||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fellback = true
|
||||
})
|
||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
if !fellback {
|
||||
t.Fatal("expected Fallback to be called")
|
||||
}
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
|
||||
Reference in New Issue
Block a user