mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
2 Commits
parth/decr
...
mxyng/crea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
087beb40ed | ||
|
|
19279d778d |
72
cmd/cmd.go
72
cmd/cmd.go
@@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
|||||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
|
||||||
|
|
||||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
|
||||||
filename, _ := cmd.Flags().GetString("file")
|
|
||||||
|
|
||||||
if filename == "" {
|
|
||||||
filename = "Modelfile"
|
|
||||||
}
|
|
||||||
|
|
||||||
absName, err := filepath.Abs(filename)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = os.Stat(absName)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return absName, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
|
|
||||||
var reader io.Reader
|
filename, err := cmd.Flags().GetString("file")
|
||||||
|
if err != nil {
|
||||||
filename, err := getModelfileName(cmd)
|
return fmt.Errorf("error retrieving file flag: %w", err)
|
||||||
if os.IsNotExist(err) {
|
|
||||||
if filename == "" {
|
|
||||||
reader = strings.NewReader("FROM .\n")
|
|
||||||
} else {
|
|
||||||
return errModelfileNotFound
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
f, err := os.Open(filename)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
reader = f
|
|
||||||
defer f.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelfile, err := parser.ParseFile(reader)
|
var r, fallback io.Reader
|
||||||
|
switch filename {
|
||||||
|
case "-":
|
||||||
|
r = os.Stdin
|
||||||
|
case "":
|
||||||
|
filename = "Modelfile"
|
||||||
|
fallback = strings.NewReader("FROM .")
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
r, err = os.Open(filename)
|
||||||
|
if errors.Is(err, os.ErrNotExist) && fallback != nil {
|
||||||
|
r = fallback
|
||||||
|
} else if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename)
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
defer r.(*os.File).Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelfile, err := parser.ParseFile(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
|
||||||
req.Model = args[0]
|
req.Model = args[0]
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
req.Quantize, _ = cmd.Flags().GetString("quantize")
|
||||||
if quantize != "" {
|
|
||||||
req.Quantize = quantize
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
350
cmd/cmd_test.go
350
cmd/cmd_test.go
@@ -3,10 +3,13 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,6 +21,13 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func mockServer(t *testing.T, h http.HandlerFunc) {
|
||||||
|
t.Helper()
|
||||||
|
s := httptest.NewServer(h)
|
||||||
|
t.Cleanup(s.Close)
|
||||||
|
t.Setenv("OLLAMA_HOST", s.URL)
|
||||||
|
}
|
||||||
|
|
||||||
func TestShowInfo(t *testing.T) {
|
func TestShowInfo(t *testing.T) {
|
||||||
t.Run("bare details", func(t *testing.T) {
|
t.Run("bare details", func(t *testing.T) {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
@@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetModelfileName(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
modelfileName string
|
|
||||||
fileExists bool
|
|
||||||
expectedName string
|
|
||||||
expectedErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no modelfile specified, no modelfile exists",
|
|
||||||
modelfileName: "",
|
|
||||||
fileExists: false,
|
|
||||||
expectedName: "",
|
|
||||||
expectedErr: os.ErrNotExist,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no modelfile specified, modelfile exists",
|
|
||||||
modelfileName: "",
|
|
||||||
fileExists: true,
|
|
||||||
expectedName: "Modelfile",
|
|
||||||
expectedErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "modelfile specified, no modelfile exists",
|
|
||||||
modelfileName: "crazyfile",
|
|
||||||
fileExists: false,
|
|
||||||
expectedName: "",
|
|
||||||
expectedErr: os.ErrNotExist,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "modelfile specified, modelfile exists",
|
|
||||||
modelfileName: "anotherfile",
|
|
||||||
fileExists: true,
|
|
||||||
expectedName: "anotherfile",
|
|
||||||
expectedErr: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "fakecmd",
|
|
||||||
}
|
|
||||||
cmd.Flags().String("file", "", "path to modelfile")
|
|
||||||
|
|
||||||
var expectedFilename string
|
|
||||||
|
|
||||||
if tt.fileExists {
|
|
||||||
var fn string
|
|
||||||
if tt.modelfileName != "" {
|
|
||||||
fn = tt.modelfileName
|
|
||||||
} else {
|
|
||||||
fn = "Modelfile"
|
|
||||||
}
|
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
|
||||||
}
|
|
||||||
defer tempFile.Close()
|
|
||||||
|
|
||||||
expectedFilename = tempFile.Name()
|
|
||||||
err = cmd.Flags().Set("file", expectedFilename)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't set file flag: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
expectedFilename = tt.expectedName
|
|
||||||
if tt.modelfileName != "" {
|
|
||||||
err := cmd.Flags().Set("file", tt.modelfileName)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't set file flag: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
actualFilename, actualErr := getModelfileName(cmd)
|
|
||||||
|
|
||||||
if actualFilename != expectedFilename {
|
|
||||||
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.expectedErr != os.ErrNotExist {
|
|
||||||
if actualErr != tt.expectedErr {
|
|
||||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !os.IsNotExist(actualErr) {
|
|
||||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPushHandler(t *testing.T) {
|
func TestPushHandler(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -661,128 +576,165 @@ func TestListHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateHandler(t *testing.T) {
|
func TestCreateHandler(t *testing.T) {
|
||||||
tests := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
modelName string
|
filename func(*testing.T) string
|
||||||
modelFile string
|
|
||||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
wantRequest api.CreateRequest
|
||||||
expectedError string
|
wantErr error
|
||||||
expectedOutput string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "successful create",
|
name: "not exist",
|
||||||
modelName: "test-model",
|
filename: func(*testing.T) string { return "not_exist" },
|
||||||
modelFile: "FROM foo",
|
wantErr: os.ErrNotExist,
|
||||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
},
|
||||||
"/api/create": func(w http.ResponseWriter, r *http.Request) {
|
{
|
||||||
if r.Method != http.MethodPost {
|
name: "stdin",
|
||||||
t.Errorf("expected POST request, got %s", r.Method)
|
filename: func(t *testing.T) string {
|
||||||
}
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
req := api.CreateRequest{}
|
if _, err := w.WriteString("FROM test"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin := os.Stdin
|
||||||
|
t.Cleanup(func() { os.Stdin = stdin })
|
||||||
|
os.Stdin = r
|
||||||
|
return "-"
|
||||||
|
},
|
||||||
|
wantRequest: api.CreateRequest{
|
||||||
|
Model: "stdin",
|
||||||
|
From: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
filename: func(t *testing.T) string {
|
||||||
|
t.Chdir(t.TempDir())
|
||||||
|
f, err := os.Create("Modelfile")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := f.WriteString("FROM test"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
},
|
||||||
|
wantRequest: api.CreateRequest{
|
||||||
|
Model: "default",
|
||||||
|
From: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default safetensors",
|
||||||
|
filename: func(t *testing.T) string {
|
||||||
|
t.Chdir(t.TempDir())
|
||||||
|
f, err := os.Create("model.safetensors")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := f.Truncate(1); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
},
|
||||||
|
wantRequest: api.CreateRequest{
|
||||||
|
Model: "default_safetensors",
|
||||||
|
Files: map[string]string{
|
||||||
|
"model.safetensors": "sha256:6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file flag",
|
||||||
|
filename: func(t *testing.T) string {
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := f.WriteString("FROM test"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
},
|
||||||
|
wantRequest: api.CreateRequest{
|
||||||
|
Model: "file_flag",
|
||||||
|
From: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insecure path",
|
||||||
|
filename: func(t *testing.T) string {
|
||||||
|
t.Chdir(t.TempDir())
|
||||||
|
if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
},
|
||||||
|
wantErr: fmt.Errorf("openat %s: path escapes from parent", "model.safetensors"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd cobra.Command
|
||||||
|
cmd.SetContext(t.Context())
|
||||||
|
cmd.Flags().String("file", "", "")
|
||||||
|
cmd.Flags().String("quantize", "", "")
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/api/create" {
|
||||||
|
var req api.CreateRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
if diff := cmp.Diff(tt.wantRequest, req); diff != "" {
|
||||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
t.Errorf("Create request mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
|
||||||
if req.From != "foo" {
|
w.WriteHeader(http.StatusOK)
|
||||||
t.Errorf("expected from 'foo', got %s", req.From)
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
responses := []api.ProgressResponse{
|
|
||||||
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
|
|
||||||
{Status: "writing manifest"},
|
|
||||||
{Status: "success"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, resp := range responses {
|
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.(http.Flusher).Flush()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedOutput: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handler, ok := tt.serverResponse[r.URL.Path]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("unexpected request to %s", r.URL.Path)
|
|
||||||
http.Error(w, "not found", http.StatusNotFound)
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
handler(w, r)
|
})
|
||||||
}))
|
|
||||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
|
||||||
t.Cleanup(mockServer.Close)
|
|
||||||
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer os.Remove(tempFile.Name())
|
|
||||||
|
|
||||||
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
|
var filename string
|
||||||
t.Fatal(err)
|
if tt.filename != nil {
|
||||||
|
filename = tt.filename(t)
|
||||||
}
|
}
|
||||||
if err := tempFile.Close(); err != nil {
|
|
||||||
|
if err := cmd.Flags().Set("file", filename); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); err != tt.wantErr &&
|
||||||
cmd.Flags().String("file", "", "")
|
err.Error() != tt.wantErr.Error() &&
|
||||||
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
|
!errors.Is(err, tt.wantErr) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Flags().Bool("insecure", false, "")
|
|
||||||
cmd.SetContext(t.Context())
|
|
||||||
|
|
||||||
// Redirect stderr to capture progress output
|
|
||||||
oldStderr := os.Stderr
|
|
||||||
r, w, _ := os.Pipe()
|
|
||||||
os.Stderr = w
|
|
||||||
|
|
||||||
// Capture stdout for the "Model pushed" message
|
|
||||||
oldStdout := os.Stdout
|
|
||||||
outR, outW, _ := os.Pipe()
|
|
||||||
os.Stdout = outW
|
|
||||||
|
|
||||||
err = CreateHandler(cmd, []string{tt.modelName})
|
|
||||||
|
|
||||||
// Restore stderr
|
|
||||||
w.Close()
|
|
||||||
os.Stderr = oldStderr
|
|
||||||
// drain the pipe
|
|
||||||
if _, err := io.ReadAll(r); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore stdout and get output
|
|
||||||
outW.Close()
|
|
||||||
os.Stdout = oldStdout
|
|
||||||
stdout, _ := io.ReadAll(outR)
|
|
||||||
|
|
||||||
if tt.expectedError == "" {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.expectedOutput != "" {
|
|
||||||
if got := string(stdout); got != tt.expectedOutput {
|
|
||||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
152
parser/parser.go
152
parser/parser.go
@@ -7,6 +7,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"iter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
@@ -148,31 +150,23 @@ func fileDigestMap(path string) (map[string]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
if fi.IsDir() {
|
if !fi.IsDir() {
|
||||||
fs, err := filesForModel(path)
|
files = []string{path}
|
||||||
|
} else {
|
||||||
|
root, err := os.OpenRoot(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
|
fs, err := filesForModel(root.FS())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range fs {
|
for _, f := range fs {
|
||||||
f, err := filepath.EvalSymlinks(f)
|
files = append(files, filepath.Join(path, f))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rel, err := filepath.Rel(path, f)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !filepath.IsLocal(rel) {
|
|
||||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
|
||||||
}
|
|
||||||
|
|
||||||
files = append(files, f)
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
files = []string{path}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
@@ -218,67 +212,90 @@ func digestForFile(filename string) (string, error) {
|
|||||||
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
|
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func filesForModel(path string) ([]string, error) {
|
func detectContentType(fsys fs.FS, path string) (string, error) {
|
||||||
detectContentType := func(path string) (string, error) {
|
f, err := fsys.Open(path)
|
||||||
f, err := os.Open(path)
|
if err != nil {
|
||||||
if err != nil {
|
return "", err
|
||||||
return "", err
|
}
|
||||||
}
|
defer f.Close()
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
bts := make([]byte, 512)
|
||||||
b.Grow(512)
|
n, err := io.ReadFull(f, bts)
|
||||||
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
bts = bts[:n]
|
||||||
return "", err
|
} else if err != nil {
|
||||||
}
|
return "", err
|
||||||
|
|
||||||
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
|
||||||
return contentType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
glob := func(pattern, contentType string) ([]string, error) {
|
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
|
||||||
matches, err := filepath.Glob(pattern)
|
return contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchFirst(fsys fs.FS, patternsContentTypes ...string) iter.Seq2[string, error] {
|
||||||
|
return func(yield func(string, error) bool) {
|
||||||
|
for i := 0; i < len(patternsContentTypes); i += 2 {
|
||||||
|
pattern := patternsContentTypes[i]
|
||||||
|
contentType := patternsContentTypes[i+1]
|
||||||
|
matches, err := fs.Glob(fsys, pattern)
|
||||||
|
if err != nil {
|
||||||
|
if !yield("", err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) > 0 {
|
||||||
|
for _, match := range matches {
|
||||||
|
if ct, err := detectContentType(fsys, match); err != nil {
|
||||||
|
if !yield("", err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if ct == contentType {
|
||||||
|
if !yield(match, nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collect[E any](it iter.Seq2[E, error]) (s []E, _ error) {
|
||||||
|
for v, err := range it {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
s = append(s, v)
|
||||||
for _, match := range matches {
|
|
||||||
if ct, err := detectContentType(match); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if ct != contentType {
|
|
||||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return matches, nil
|
|
||||||
}
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
var files []string
|
func filesForModel(fsys fs.FS) ([]string, error) {
|
||||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
files, err := collect(matchFirst(
|
||||||
|
fsys,
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
"*.safetensors", "application/octet-stream",
|
||||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
files = append(files, pt...)
|
"pytorch_model*.bin", "application/zip",
|
||||||
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
|
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers consolidated.x.pth, consolidated.pth
|
// covers consolidated.x.pth, consolidated.pth
|
||||||
files = append(files, pt...)
|
"consolidated*.pth", "application/zip",
|
||||||
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
|
|
||||||
// covers gguf files ending in .gguf
|
// covers gguf files ending in .gguf
|
||||||
files = append(files, gg...)
|
"*.gguf", "application/octet-stream",
|
||||||
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
|
|
||||||
// covers gguf files ending in .bin
|
// covers gguf files ending in .bin
|
||||||
files = append(files, gg...)
|
"*.bin", "application/octet-stream",
|
||||||
} else {
|
))
|
||||||
return nil, ErrModelNotFound
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// add configuration files, json files are detected as text/plain
|
// add configuration files, json files are detected as text/plain
|
||||||
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
js, err := collect(matchFirst(fsys, "*.json", "text/plain"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -286,7 +303,7 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
// bert models require a nested config.json
|
// bert models require a nested config.json
|
||||||
// TODO(mxyng): merge this with the glob above
|
// TODO(mxyng): merge this with the glob above
|
||||||
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
|
js, err = collect(matchFirst(fsys, "**/*.json", "text/plain"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -296,14 +313,15 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
if !slices.ContainsFunc(files, func(s string) bool {
|
if !slices.ContainsFunc(files, func(s string) bool {
|
||||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
||||||
}) {
|
}) {
|
||||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
tokenizers, err := collect(matchFirst(fsys,
|
||||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||||
files = append(files, tks...)
|
"tokenizer.model", "application/octet-stream",
|
||||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
))
|
||||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
if err != nil {
|
||||||
files = append(files, tks...)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
files = append(files, tokenizers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return files, nil
|
return files, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user