mirror of
https://github.com/ollama/ollama.git
synced 2026-01-20 21:40:54 -05:00
Compare commits
1 Commits
main
...
mxyng/asyn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ef2106b47 |
89
cmd/cmd.go
89
cmd/cmd.go
@@ -43,7 +43,6 @@ import (
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
@@ -205,7 +204,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = modelName
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
@@ -219,42 +217,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for blob, err := range createBlobs(req.Files, req.Adapters) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
files := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Files {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: this is incorrect since the file might be in a subdirectory
|
||||
// instead this should take the path relative to the model directory
|
||||
// but the current implementation does not allow this
|
||||
files.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
_, err := createBlob(cmd, client, blob.Abs, blob.Digest, p)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
adapters := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Adapters {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: same here
|
||||
adapters.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
if _, ok := req.Files[blob.Rel]; ok {
|
||||
req.Files[blob.Rel] = blob.Digest
|
||||
} else if _, ok := req.Adapters[blob.Rel]; ok {
|
||||
req.Adapters[blob.Rel] = blob.Digest
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Files = files.Items()
|
||||
req.Adapters = adapters.Items()
|
||||
spinner.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
@@ -292,54 +277,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
|
||||
realPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bin, err := os.Open(realPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
// Get file info to retrieve the size
|
||||
fileInfo, err := bin.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
var pw progressWriter
|
||||
status := fmt.Sprintf("copying file %s 0%%", digest)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
defer spinner.Stop()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
|
||||
case <-done:
|
||||
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
n atomic.Int64
|
||||
}
|
||||
|
||||
103
cmd/create.go
Normal file
103
cmd/create.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type blob struct {
|
||||
Rel, Abs, Digest string
|
||||
}
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
|
||||
realPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bin, err := os.Open(realPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
// Get file info to retrieve the size
|
||||
fileInfo, err := bin.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileSize := fileInfo.Size()
|
||||
|
||||
var pw progressWriter
|
||||
status := fmt.Sprintf("copying file %s 0%%", digest)
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
defer spinner.Stop()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
|
||||
case <-done:
|
||||
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func createBlobs(mappings ...map[string]string) iter.Seq2[blob, error] {
|
||||
return func(yield func(blob, error) bool) {
|
||||
for _, mapping := range mappings {
|
||||
for rel, abs := range mapping {
|
||||
if abs, ok := strings.CutPrefix(abs, "abs:"); ok {
|
||||
f, err := os.Open(abs)
|
||||
if err != nil {
|
||||
yield(blob{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
yield(blob{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
yield(blob{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(blob{
|
||||
Rel: rel,
|
||||
Abs: abs,
|
||||
Digest: fmt.Sprintf("sha256:%x", h.Sum(nil)),
|
||||
}, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
161
parser/parser.go
161
parser/parser.go
@@ -3,22 +3,20 @@ package parser
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
|
||||
@@ -54,7 +52,10 @@ var deprecatedParameters = []string{
|
||||
|
||||
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
|
||||
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
|
||||
req := &api.CreateRequest{}
|
||||
req := &api.CreateRequest{
|
||||
Files: make(map[string]string),
|
||||
Adapters: make(map[string]string),
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
var licenses []string
|
||||
@@ -63,12 +64,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
for _, c := range f.Commands {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestMap, err := fileDigestMap(path)
|
||||
files, err := filesMap(c.Args, relativeDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
req.From = c.Args
|
||||
continue
|
||||
@@ -76,25 +72,14 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.Files == nil {
|
||||
req.Files = digestMap
|
||||
} else {
|
||||
for k, v := range digestMap {
|
||||
req.Files[k] = v
|
||||
}
|
||||
}
|
||||
maps.Copy(req.Files, files)
|
||||
case "adapter":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
files, err := filesMap(c.Args, relativeDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestMap, err := fileDigestMap(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Adapters = digestMap
|
||||
maps.Copy(req.Adapters, files)
|
||||
case "template":
|
||||
req.Template = c.Args
|
||||
case "system":
|
||||
@@ -154,106 +139,66 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func fileDigestMap(path string) (map[string]string, error) {
|
||||
fl := make(map[string]string)
|
||||
func filesMap(args, base string) (map[string]string, error) {
|
||||
path, err := expandPath(args, base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var files []string
|
||||
if fi.IsDir() {
|
||||
fs, err := filesForModel(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, f := range fs {
|
||||
f, err := filepath.EvalSymlinks(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(path, f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(rel) {
|
||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||
}
|
||||
|
||||
files = append(files, f)
|
||||
}
|
||||
} else {
|
||||
files = []string{path}
|
||||
mapping := make(map[string]string)
|
||||
if !fi.IsDir() {
|
||||
return map[string]string{
|
||||
filepath.Base(path): "abs:" + path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
for _, f := range files {
|
||||
g.Go(func() error {
|
||||
digest, err := digestForFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
fl[f] = digest
|
||||
return nil
|
||||
})
|
||||
root, err := os.OpenRoot(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
files, err := filesForModel(root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fl, nil
|
||||
for _, file := range files {
|
||||
// create a temporary mapping from relative path to absolute path
|
||||
mapping[file] = "abs:" + filepath.Join(root.Name(), file)
|
||||
}
|
||||
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func digestForFile(filename string) (string, error) {
|
||||
filepath, err := filepath.EvalSymlinks(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bin, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func filesForModel(path string) ([]string, error) {
|
||||
func filesForModel(root *os.Root) ([]string, error) {
|
||||
detectContentType := func(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
f, err := root.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var b bytes.Buffer
|
||||
b.Grow(512)
|
||||
|
||||
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
||||
bts := make([]byte, 512)
|
||||
n, err := io.ReadFull(f, bts)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// short read, use what we have
|
||||
bts = bts[:n]
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
||||
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
glob := func(pattern, contentType string) ([]string, error) {
|
||||
matches, err := filepath.Glob(pattern)
|
||||
matches, err := fs.Glob(root.FS(), pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -262,7 +207,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
if ct, err := detectContentType(match); err != nil {
|
||||
return nil, err
|
||||
} else if len(contentType) > 0 && ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s, got %s", ct, match, contentType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,25 +216,25 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
var files []string
|
||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
|
||||
if st, _ := glob("model*.safetensors", ""); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
|
||||
} else if st, _ := glob("consolidated*.safetensors", ""); len(st) > 0 {
|
||||
// covers consolidated.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
files = append(files, pt...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
|
||||
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers consolidated.x.pth, consolidated.pth
|
||||
files = append(files, pt...)
|
||||
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
|
||||
} else if gg, _ := glob("*.gguf", "application/octet-stream"); len(gg) > 0 {
|
||||
// covers gguf files ending in .gguf
|
||||
files = append(files, gg...)
|
||||
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
|
||||
} else if gg, _ := glob("*.bin", "application/octet-stream"); len(gg) > 0 {
|
||||
// covers gguf files ending in .bin
|
||||
files = append(files, gg...)
|
||||
} else {
|
||||
@@ -297,7 +242,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
// add configuration files, json files are detected as text/plain
|
||||
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
||||
js, err := glob("*.json", "text/plain")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -305,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
// bert models require a nested config.json
|
||||
// TODO(mxyng): merge this with the glob above
|
||||
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
|
||||
js, err = glob("**/*.json", "text/plain")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -313,9 +258,9 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package parser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
"unicode/utf16"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/encoding"
|
||||
@@ -775,25 +775,13 @@ MESSAGE assistant Hi! How are you?
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual, c.expected); diff != "" {
|
||||
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
|
||||
t.Helper()
|
||||
|
||||
h := sha256.New()
|
||||
n, err := io.Copy(h, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
|
||||
@@ -808,19 +796,12 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Calculate sha256 of file
|
||||
if _, err := f.Seek(0, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
digest, _ := getSHA256Digest(t, f)
|
||||
|
||||
return f.Name(), digest
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func TestCreateRequestFiles(t *testing.T) {
|
||||
n1, d1 := createBinFile(t, nil, nil)
|
||||
n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
|
||||
n1 := createBinFile(t, nil, nil)
|
||||
n2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
|
||||
|
||||
cases := []struct {
|
||||
input string
|
||||
@@ -828,11 +809,20 @@ func TestCreateRequestFiles(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
fmt.Sprintf("FROM %s", n1),
|
||||
&api.CreateRequest{Files: map[string]string{n1: d1}},
|
||||
&api.CreateRequest{
|
||||
Files: map[string]string{
|
||||
filepath.Base(n1): "abs:" + n1,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
|
||||
&api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
|
||||
&api.CreateRequest{
|
||||
Files: map[string]string{
|
||||
filepath.Base(n1): "abs:" + n1,
|
||||
filepath.Base(n2): "abs:" + n2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -852,7 +842,7 @@ func TestCreateRequestFiles(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual, c.expected); diff != "" {
|
||||
if diff := cmp.Diff(actual, c.expected, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -860,15 +850,15 @@ func TestCreateRequestFiles(t *testing.T) {
|
||||
|
||||
func TestFilesForModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(string) error
|
||||
wantFiles []string
|
||||
wantErr bool
|
||||
expectErrType error
|
||||
name string
|
||||
setup func(*testing.T, *os.Root)
|
||||
want []string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "safetensors model files",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
t.Helper()
|
||||
files := []string{
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
@@ -876,13 +866,12 @@ func TestFilesForModel(t *testing.T) {
|
||||
"tokenizer.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
"config.json",
|
||||
@@ -891,7 +880,7 @@ func TestFilesForModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "safetensors with both tokenizer.json and tokenizer.model",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
// Create binary content for tokenizer.model (application/octet-stream)
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
@@ -903,17 +892,16 @@ func TestFilesForModel(t *testing.T) {
|
||||
"tokenizer.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
// Write tokenizer.model as binary
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile("tokenizer.model", binaryContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"model-00001-of-00001.safetensors",
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
@@ -922,46 +910,44 @@ func TestFilesForModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "safetensors with consolidated files - prefers model files",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
files := []string{
|
||||
"model-00001-of-00001.safetensors",
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"model-00001-of-00001.safetensors", // consolidated files should be excluded
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safetensors without model-.safetensors files - uses consolidated",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
files := []string{
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, []byte("test content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"consolidated.safetensors",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pytorch model files",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
// Create a file that will be detected as application/zip
|
||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
|
||||
files := []string{
|
||||
@@ -974,13 +960,12 @@ func TestFilesForModel(t *testing.T) {
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"pytorch_model-00001-of-00002.bin",
|
||||
"pytorch_model-00002-of-00002.bin",
|
||||
"config.json",
|
||||
@@ -988,7 +973,7 @@ func TestFilesForModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "consolidated pth files",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
|
||||
files := []string{
|
||||
"consolidated.00.pth",
|
||||
@@ -1000,13 +985,12 @@ func TestFilesForModel(t *testing.T) {
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"consolidated.00.pth",
|
||||
"consolidated.01.pth",
|
||||
"config.json",
|
||||
@@ -1014,7 +998,7 @@ func TestFilesForModel(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "gguf files",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
// Create binary content that will be detected as application/octet-stream
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
@@ -1029,20 +1013,19 @@ func TestFilesForModel(t *testing.T) {
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"model.gguf",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bin files as gguf",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
binaryContent[i] = byte(i % 256)
|
||||
@@ -1056,35 +1039,32 @@ func TestFilesForModel(t *testing.T) {
|
||||
if file == "config.json" {
|
||||
content = []byte(`{"config": true}`)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
want: []string{
|
||||
"model.bin",
|
||||
"config.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no model files found",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
// Only create non-model files
|
||||
files := []string{"README.md", "config.json"}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, []byte("content"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
expectErrType: ErrModelNotFound,
|
||||
wantErr: ErrModelNotFound,
|
||||
},
|
||||
{
|
||||
name: "invalid content type for pytorch model",
|
||||
setup: func(dir string) error {
|
||||
setup: func(t *testing.T, root *os.Root) {
|
||||
// Create pytorch model file with wrong content type (text instead of zip)
|
||||
files := []string{
|
||||
"pytorch_model.bin",
|
||||
@@ -1092,68 +1072,32 @@ func TestFilesForModel(t *testing.T) {
|
||||
}
|
||||
for _, file := range files {
|
||||
content := []byte("plain text content")
|
||||
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||
return err
|
||||
if err := root.WriteFile(file, content, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
wantErr: ErrModelNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testDir := filepath.Join(tmpDir, tt.name)
|
||||
if err := os.MkdirAll(testDir, 0o755); err != nil {
|
||||
t.Fatalf("Failed to create test directory: %v", err)
|
||||
}
|
||||
|
||||
if err := tt.setup(testDir); err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
|
||||
files, err := filesForModel(testDir)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error, but got none")
|
||||
}
|
||||
if tt.expectErrType != nil && err != tt.expectErrType {
|
||||
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
t.Fatalf("Failed to open root: %v", err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
tt.setup(t, root)
|
||||
|
||||
files, err := filesForModel(root)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("want %v error, got %v", tt.wantErr, err)
|
||||
}
|
||||
|
||||
var relativeFiles []string
|
||||
for _, file := range files {
|
||||
rel, err := filepath.Rel(testDir, file)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get relative path: %v", err)
|
||||
}
|
||||
relativeFiles = append(relativeFiles, rel)
|
||||
}
|
||||
|
||||
if len(relativeFiles) != len(tt.wantFiles) {
|
||||
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
|
||||
}
|
||||
|
||||
fileSet := make(map[string]bool)
|
||||
for _, file := range relativeFiles {
|
||||
fileSet[file] = true
|
||||
}
|
||||
|
||||
for _, wantFile := range tt.wantFiles {
|
||||
if !fileSet[wantFile] {
|
||||
t.Errorf("Missing expected file: %s", wantFile)
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, files); diff != "" {
|
||||
t.Errorf("filesForModel() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package syncmap
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SyncMap is a simple, generic thread-safe map implementation.
|
||||
type SyncMap[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
m map[K]V
|
||||
}
|
||||
|
||||
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
|
||||
return &SyncMap[K, V]{
|
||||
m: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
val, ok := s.m[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Store(key K, value V) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.m[key] = value
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Items() map[K]V {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// shallow copy map items
|
||||
return maps.Clone(s.m)
|
||||
}
|
||||
Reference in New Issue
Block a user