Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Yang
9ef2106b47 cmd: create blob in parallel with checksum
a simple optimisation where once a blob has been checksumed, immediately
upload it; don't wait for all files to be checksumed before starting
upload.
2026-01-20 12:09:02 -08:00
5 changed files with 246 additions and 355 deletions

View File

@@ -43,7 +43,6 @@ import (
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
@@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
spinner.Stop()
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
@@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
g.SetLimit(runtime.GOMAXPROCS(0))
for blob, err := range createBlobs(req.Files, req.Adapters) {
if err != nil {
return err
}
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
_, err := createBlob(cmd, client, blob.Abs, blob.Digest, p)
return err
})
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
if _, ok := req.Files[blob.Rel]; ok {
req.Files[blob.Rel] = blob.Digest
} else if _, ok := req.Adapters[blob.Rel]; ok {
req.Adapters[blob.Rel] = blob.Digest
}
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
spinner.Stop()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
@@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct {
n atomic.Int64
}

103
cmd/create.go Normal file
View File

@@ -0,0 +1,103 @@
package cmd
import (
"crypto/sha256"
"fmt"
"io"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
type blob struct {
Rel, Abs, Digest string
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] {
return func(yield func(blob, error) bool) {
for _, mapping := range mappings {
for rel, abs := range mapping {
if abs, ok := strings.CutPrefix(abs, "abs:"); ok {
f, err := os.Open(abs)
if err != nil {
yield(blob{}, err)
return
}
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
yield(blob{}, err)
return
}
if err := f.Close(); err != nil {
yield(blob{}, err)
return
}
if !yield(blob{
Rel: rel,
Abs: abs,
Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)),
}, nil) {
return
}
}
}
}
}
}

View File

@@ -3,22 +3,20 @@ package parser
import (
"bufio"
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"maps"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/mod/semver"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
@@ -54,7 +52,10 @@ var deprecatedParameters = []string{
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
req := &api.CreateRequest{}
req := &api.CreateRequest{
Files: make(map[string]string),
Adapters: make(map[string]string),
}
var messages []api.Message
var licenses []string
@@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
files, err := filesMap(c.Args, relativeDir)
if errors.Is(err, os.ErrNotExist) {
req.From = c.Args
continue
@@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return nil, err
}
if req.Files == nil {
req.Files = digestMap
} else {
for k, v := range digestMap {
req.Files[k] = v
}
}
maps.Copy(req.Files, files)
case "adapter":
path, err := expandPath(c.Args, relativeDir)
files, err := filesMap(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if err != nil {
return nil, err
}
req.Adapters = digestMap
maps.Copy(req.Adapters, files)
case "template":
req.Template = c.Args
case "system":
@@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
return req, nil
}
func fileDigestMap(path string) (map[string]string, error) {
fl := make(map[string]string)
func filesMap(args, base string) (map[string]string, error) {
path, err := expandPath(args, base)
if err != nil {
return nil, err
}
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
var files []string
if fi.IsDir() {
fs, err := filesForModel(path)
if err != nil {
return nil, err
}
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else {
files = []string{path}
mapping := make(map[string]string)
if !fi.IsDir() {
return map[string]string{
filepath.Base(path): "abs:" + path,
}, nil
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
root, err := os.OpenRoot(path)
if err != nil {
return nil, err
}
defer root.Close()
if err := g.Wait(); err != nil {
files, err := filesForModel(root)
if err != nil {
return nil, err
}
return fl, nil
for _, file := range files {
// create a temporary mapping from relative path to absolute path
mapping[file] = "abs:" + filepath.Join(root.Name(), file)
}
return mapping, nil
}
func digestForFile(filename string) (string, error) {
filepath, err := filepath.EvalSymlinks(filename)
if err != nil {
return "", err
}
bin, err := os.Open(filepath)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
func filesForModel(root *os.Root) ([]string, error) {
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
f, err := root.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
bts := make([]byte, 512)
n, err := io.ReadFull(f, bts)
if errors.Is(err, io.ErrUnexpectedEOF) {
// short read, use what we have
bts = bts[:n]
} else if err != nil {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
matches, err := fs.Glob(root.FS(), pattern)
if err != nil {
return nil, err
}
@@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType)
}
}
@@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) {
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
if st, _ := glob("model*.safetensors", ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
} else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 {
// covers consolidated.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
} else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf
files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
} else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin
files = append(files, gg...)
} else {
@@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) {
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
js, err := glob("*.json", "text/plain")
if err != nil {
return nil, err
}
@@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
js, err = glob("**/*.json", "text/plain")
if err != nil {
return nil, err
}
@@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) {
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
// tokenizer.model might be a unresolved git lfs reference; error if it is
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}

View File

@@ -2,7 +2,6 @@ package parser
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
@@ -15,6 +14,7 @@ import (
"unicode/utf16"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
@@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you?
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
t.Helper()
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
@@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
if err := ggml.WriteGGUF(f, base, ti); err != nil {
t.Fatal(err)
}
// Calculate sha256 of file
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := getSHA256Digest(t, f)
return f.Name(), digest
return f.Name()
}
func TestCreateRequestFiles(t *testing.T) {
n1, d1 := createBinFile(t, nil, nil)
n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
n1 := createBinFile(t, nil, nil)
n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
cases := []struct {
input string
@@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) {
}{
{
fmt.Sprintf("FROM %s", n1),
&api.CreateRequest{Files: map[string]string{n1: d1}},
&api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
},
},
},
{
fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
&api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
&api.CreateRequest{
Files: map[string]string{
filepath.Base(n1): "abs:" + n1,
filepath.Base(n2): "abs:" + n2,
},
},
},
}
@@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
@@ -860,15 +850,15 @@ func TestCreateRequestFiles(t *testing.T) {
func TestFilesForModel(t *testing.T) {
tests := []struct {
name string
setup func(string) error
wantFiles []string
wantErr bool
expectErrType error
name string
setup func(*testing.T, *os.Root)
want []string
wantErr error
}{
{
name: "safetensors model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
t.Helper()
files := []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
@@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"config.json",
@@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "safetensors with both tokenizer.json and tokenizer.model",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create binary content for tokenizer.model (application/octet-stream)
binaryContent := make([]byte, 512)
for i := range binaryContent {
@@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
// Write tokenizer.model as binary
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
return err
if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil {
t.Fatal(err)
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
@@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
files := []string{
"model-00001-of-00001.safetensors",
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model-00001-of-00001.safetensors", // consolidated files should be excluded
"config.json",
},
},
{
name: "safetensors without model-.safetensors files - uses consolidated",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
files := []string{
"consolidated.safetensors",
"config.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"consolidated.safetensors",
"config.json",
},
},
{
name: "pytorch model files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create a file that will be detected as application/zip
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
files := []string{
@@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
"config.json",
@@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "consolidated pth files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
files := []string{
"consolidated.00.pth",
@@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"consolidated.00.pth",
"consolidated.01.pth",
"config.json",
@@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) {
},
{
name: "gguf files",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create binary content that will be detected as application/octet-stream
binaryContent := make([]byte, 512)
for i := range binaryContent {
@@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model.gguf",
"config.json",
},
},
{
name: "bin files as gguf",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
@@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) {
if file == "config.json" {
content = []byte(`{"config": true}`)
}
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantFiles: []string{
want: []string{
"model.bin",
"config.json",
},
},
{
name: "no model files found",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Only create non-model files
files := []string{"README.md", "config.json"}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
return err
if err := root.WriteFile(file, []byte("content"), 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantErr: true,
expectErrType: ErrModelNotFound,
wantErr: ErrModelNotFound,
},
{
name: "invalid content type for pytorch model",
setup: func(dir string) error {
setup: func(t *testing.T, root *os.Root) {
// Create pytorch model file with wrong content type (text instead of zip)
files := []string{
"pytorch_model.bin",
@@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) {
}
for _, file := range files {
content := []byte("plain text content")
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
return err
if err := root.WriteFile(file, content, 0o644); err != nil {
t.Fatal(err)
}
}
return nil
},
wantErr: true,
wantErr: ErrModelNotFound,
},
}
tmpDir := t.TempDir()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testDir := filepath.Join(tmpDir, tt.name)
if err := os.MkdirAll(testDir, 0o755); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
if err := tt.setup(testDir); err != nil {
t.Fatalf("Setup failed: %v", err)
}
files, err := filesForModel(testDir)
if tt.wantErr {
if err == nil {
t.Error("Expected error, but got none")
}
if tt.expectErrType != nil && err != tt.expectErrType {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
}
return
}
root, err := os.OpenRoot(t.TempDir())
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
t.Fatalf("Failed to open root: %v", err)
}
defer root.Close()
tt.setup(t, root)
files, err := filesForModel(root)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("want %v error, got %v", tt.wantErr, err)
}
var relativeFiles []string
for _, file := range files {
rel, err := filepath.Rel(testDir, file)
if err != nil {
t.Fatalf("Failed to get relative path: %v", err)
}
relativeFiles = append(relativeFiles, rel)
}
if len(relativeFiles) != len(tt.wantFiles) {
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
}
fileSet := make(map[string]bool)
for _, file := range relativeFiles {
fileSet[file] = true
}
for _, wantFile := range tt.wantFiles {
if !fileSet[wantFile] {
t.Errorf("Missing expected file: %s", wantFile)
}
if diff := cmp.Diff(tt.want, files); diff != "" {
t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff)
}
})
}

View File

@@ -1,38 +0,0 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}