Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
087beb40ed refactor filesForModel 2025-07-30 13:49:12 -07:00
Michael Yang
19279d778d accept - to create from stdin 2025-07-30 13:23:33 -07:00
3 changed files with 262 additions and 312 deletions

View File

@@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
filename, _ := cmd.Flags().GetString("file")
if filename == "" {
filename = "Modelfile"
}
absName, err := filepath.Abs(filename)
if err != nil {
return "", err
}
_, err = os.Stat(absName)
if err != nil {
return "", err
}
return absName, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
filename, err := cmd.Flags().GetString("file")
if err != nil {
return fmt.Errorf("error retrieving file flag: %w", err)
}
modelfile, err := parser.ParseFile(reader)
var r, fallback io.Reader
switch filename {
case "-":
r = os.Stdin
case "":
filename = "Modelfile"
fallback = strings.NewReader("FROM .")
fallthrough
default:
r, err = os.Open(filename)
if errors.Is(err, os.ErrNotExist) && fallback != nil {
r = fallback
} else if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename)
} else if err != nil {
return err
} else {
defer r.(*os.File).Close()
}
}
modelfile, err := parser.ParseFile(r)
if err != nil {
return err
}
@@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop()
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
req.Quantize, _ = cmd.Flags().GetString("quantize")
client, err := api.ClientFromEnvironment()
if err != nil {

View File

@@ -3,10 +3,13 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
@@ -18,6 +21,13 @@ import (
"github.com/ollama/ollama/types/model"
)
func mockServer(t *testing.T, h http.HandlerFunc) {
t.Helper()
s := httptest.NewServer(h)
t.Cleanup(s.Close)
t.Setenv("OLLAMA_HOST", s.URL)
}
func TestShowInfo(t *testing.T) {
t.Run("bare details", func(t *testing.T) {
var b bytes.Buffer
@@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) {
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
modelfileName string
fileExists bool
expectedName string
expectedErr error
}{
{
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "no modelfile specified, modelfile exists",
modelfileName: "",
fileExists: true,
expectedName: "Modelfile",
expectedErr: nil,
},
{
name: "modelfile specified, no modelfile exists",
modelfileName: "crazyfile",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "modelfile specified, modelfile exists",
modelfileName: "anotherfile",
fileExists: true,
expectedName: "anotherfile",
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{
Use: "fakecmd",
}
cmd.Flags().String("file", "", "path to modelfile")
var expectedFilename string
if tt.fileExists {
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
} else {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(t.TempDir(), fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
defer tempFile.Close()
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
expectedFilename = tt.expectedName
if tt.modelfileName != "" {
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
}
}
actualFilename, actualErr := getModelfileName(cmd)
if actualFilename != expectedFilename {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}
if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
if !os.IsNotExist(actualErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
}
})
}
}
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
@@ -661,128 +576,165 @@ func TestListHandler(t *testing.T) {
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string
modelName string
modelFile string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
cases := []struct {
name string
filename func(*testing.T) string
wantRequest api.CreateRequest
wantErr error
}{
{
name: "successful create",
modelName: "test-model",
modelFile: "FROM foo",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/create": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
name: "not exist",
filename: func(*testing.T) string { return "not_exist" },
wantErr: os.ErrNotExist,
},
{
name: "stdin",
filename: func(t *testing.T) string {
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
req := api.CreateRequest{}
if _, err := w.WriteString("FROM test"); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
stdin := os.Stdin
t.Cleanup(func() { os.Stdin = stdin })
os.Stdin = r
return "-"
},
wantRequest: api.CreateRequest{
Model: "stdin",
From: "test",
},
},
{
name: "default",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
f, err := os.Create("Modelfile")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM test"); err != nil {
t.Fatal(err)
}
return ""
},
wantRequest: api.CreateRequest{
Model: "default",
From: "test",
},
},
{
name: "default safetensors",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
f, err := os.Create("model.safetensors")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := f.Truncate(1); err != nil {
t.Fatal(err)
}
return ""
},
wantRequest: api.CreateRequest{
Model: "default_safetensors",
Files: map[string]string{
"model.safetensors": "sha256:6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d",
},
},
},
{
name: "file flag",
filename: func(t *testing.T) string {
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM test"); err != nil {
t.Fatal(err)
}
return f.Name()
},
wantRequest: api.CreateRequest{
Model: "file_flag",
From: "test",
},
},
{
name: "insecure path",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil {
t.Fatal(err)
}
return ""
},
wantErr: fmt.Errorf("openat %s: path escapes from parent", "model.safetensors"),
},
}
var cmd cobra.Command
cmd.SetContext(t.Context())
cmd.Flags().String("file", "", "")
cmd.Flags().String("quantize", "", "")
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
mockServer(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/api/create" {
var req api.CreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Model != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
if diff := cmp.Diff(tt.wantRequest, req); diff != "" {
t.Errorf("Create request mismatch (-want +got):\n%s", diff)
}
if req.From != "foo" {
t.Errorf("expected from 'foo', got %s", req.From)
}
responses := []api.ProgressResponse{
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
{Status: "writing manifest"},
{Status: "success"},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler, ok := tt.serverResponse[r.URL.Path]
if !ok {
t.Errorf("unexpected request to %s", r.URL.Path)
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
w.WriteHeader(http.StatusOK)
} else {
http.Error(w, "not found", http.StatusNotFound)
return
}
handler(w, r)
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
})
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
t.Fatal(err)
var filename string
if tt.filename != nil {
filename = tt.filename(t)
}
if err := tempFile.Close(); err != nil {
if err := cmd.Flags().Set("file", filename); err != nil {
t.Fatal(err)
}
cmd := &cobra.Command{}
cmd.Flags().String("file", "", "")
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); err != tt.wantErr &&
err.Error() != tt.wantErr.Error() &&
!errors.Is(err, tt.wantErr) {
t.Fatal(err)
}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err = CreateHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
}
})
}
}

View File

@@ -7,6 +7,8 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"iter"
"net/http"
"os"
"os/user"
@@ -148,31 +150,23 @@ func fileDigestMap(path string) (map[string]string, error) {
}
var files []string
if fi.IsDir() {
fs, err := filesForModel(path)
if !fi.IsDir() {
files = []string{path}
} else {
root, err := os.OpenRoot(path)
if err != nil {
return nil, err
}
defer root.Close()
fs, err := filesForModel(root.FS())
if err != nil {
return nil, err
}
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
files = append(files, filepath.Join(path, f))
}
} else {
files = []string{path}
}
var mu sync.Mutex
@@ -218,67 +212,90 @@ func digestForFile(filename string) (string, error) {
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
func detectContentType(fsys fs.FS, path string) (string, error) {
f, err := fsys.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
bts := make([]byte, 512)
n, err := io.ReadFull(f, bts)
if errors.Is(err, io.ErrUnexpectedEOF) {
bts = bts[:n]
} else if err != nil {
return "", err
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil
}
func matchFirst(fsys fs.FS, patternsContentTypes ...string) iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for i := 0; i < len(patternsContentTypes); i += 2 {
pattern := patternsContentTypes[i]
contentType := patternsContentTypes[i+1]
matches, err := fs.Glob(fsys, pattern)
if err != nil {
if !yield("", err) {
return
}
continue
}
if len(matches) > 0 {
for _, match := range matches {
if ct, err := detectContentType(fsys, match); err != nil {
if !yield("", err) {
return
}
} else if ct == contentType {
if !yield(match, nil) {
return
}
}
}
return
}
}
}
}
func collect[E any](it iter.Seq2[E, error]) (s []E, _ error) {
for v, err := range it {
if err != nil {
return nil, err
}
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
return matches, nil
s = append(s, v)
}
return s, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
func filesForModel(fsys fs.FS) ([]string, error) {
files, err := collect(matchFirst(
fsys,
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
"*.safetensors", "application/octet-stream",
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
"pytorch_model*.bin", "application/zip",
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
"consolidated*.pth", "application/zip",
// covers gguf files ending in .gguf
files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
"*.gguf", "application/octet-stream",
// covers gguf files ending in .bin
files = append(files, gg...)
} else {
return nil, ErrModelNotFound
"*.bin", "application/octet-stream",
))
if err != nil {
return nil, err
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
js, err := collect(matchFirst(fsys, "*.json", "text/plain"))
if err != nil {
return nil, err
}
@@ -286,7 +303,7 @@ func filesForModel(path string) ([]string, error) {
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
js, err = collect(matchFirst(fsys, "**/*.json", "text/plain"))
if err != nil {
return nil, err
}
@@ -296,14 +313,15 @@ func filesForModel(path string) ([]string, error) {
if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
tokenizers, err := collect(matchFirst(fsys,
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
"tokenizer.model", "application/octet-stream",
))
if err != nil {
return nil, err
}
files = append(files, tokenizers...)
}
return files, nil