mirror of
https://github.com/ollama/ollama.git
synced 2025-12-27 01:30:39 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1178fd2cbb | ||
|
|
97c15b601a |
@@ -4,5 +4,4 @@ llama/build
|
||||
.vscode
|
||||
ollama
|
||||
app
|
||||
web
|
||||
.env
|
||||
web
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@
|
||||
dist
|
||||
ollama
|
||||
/ggml-metal.metal
|
||||
build
|
||||
|
||||
40
CMakeLists.txt
Normal file
40
CMakeLists.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
cmake_minimum_required(VERSION 3.14) # 3.11 or later for FetchContent, but some features might require newer versions
|
||||
|
||||
project(llama_cpp)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
llama_cpp_gguf
|
||||
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
|
||||
GIT_TAG 6381d4e
|
||||
)
|
||||
|
||||
FetchContent_Declare(
|
||||
llama_cpp_ggml
|
||||
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
|
||||
GIT_TAG dadbed9
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(llama_cpp_ggml)
|
||||
|
||||
add_subdirectory(${llama_cpp_ggml_SOURCE_DIR}/examples EXCLUDE_FROM_ALL)
|
||||
add_executable(llama_cpp ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
|
||||
include_directories(${llama_cpp_ggml_SOURCE_DIR})
|
||||
include_directories(${llama_cpp_ggml_SOURCE_DIR}/examples)
|
||||
target_compile_features(llama_cpp PRIVATE cxx_std_11)
|
||||
target_link_libraries(llama_cpp PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if (APPLE)
|
||||
add_executable(llama_cpp_metal ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
|
||||
target_compile_options(llama_cpp_metal PRIVATE -DLLAMA_STATIC=ON -DLLAMA_METAL=ON -DGGML_USE_METAL=1)
|
||||
target_compile_features(llama_cpp_metal PRIVATE cxx_std_11)
|
||||
target_link_libraries(llama_cpp_metal PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
configure_file(${llama_cpp_SOURCE_DIR}/ggml-metal.metal ${CMAKE_BINARY_DIR}/ggml-metal.metal COPYONLY)
|
||||
else()
|
||||
add_executable(llama_cpp_cublas ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
|
||||
target_compile_definitions(llama_cpp_cublas PRIVATE -DLLAMA_STATIC=ON -DLLAMA_CUBLAS=ON)
|
||||
target_compile_options(llama_cpp_cublas PRIVATE -DLLAMA_CUBLAS=ON -DLLAMA_STATIC=ON)
|
||||
target_compile_features(llama_cpp_cublas PRIVATE cxx_std_11)
|
||||
target_link_libraries(llama_cpp_cublas PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
endif()
|
||||
@@ -10,10 +10,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
const DefaultHost = "localhost:11434"
|
||||
@@ -86,21 +83,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
reqBody = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
requestURL := c.Base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||
url := c.Base.JoinPath(path).String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
for k, v := range c.Headers {
|
||||
request.Header[k] = v
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
respObj, err := c.HTTP.Do(request)
|
||||
respObj, err := c.HTTP.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -134,15 +131,13 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
buf = bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
requestURL := c.Base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||
request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
|
||||
28
cmd/cmd.go
28
cmd/cmd.go
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"github.com/jmorganca/ollama/progressbar"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -98,20 +97,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp := server.ParseModelPath(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -120,7 +106,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
_, err = os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := pull(args[0], insecure); err != nil {
|
||||
if err := pull(args[0], false); err != nil {
|
||||
var apiStatusError api.StatusError
|
||||
if !errors.As(err, &apiStatusError) {
|
||||
return err
|
||||
@@ -521,10 +507,6 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
mp := server.ParseModelPath(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest, err := server.GetManifest(mp)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get a manifest for this model")
|
||||
@@ -587,7 +569,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port := "127.0.0.1", "11434"
|
||||
var host, port = "127.0.0.1", "11434"
|
||||
|
||||
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
|
||||
if ip := net.ParseIP(parts[0]); ip != nil {
|
||||
@@ -648,7 +630,7 @@ func initializeKeypair() error {
|
||||
return fmt.Errorf("could not create directory %w", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -660,7 +642,7 @@ func initializeKeypair() error {
|
||||
|
||||
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
|
||||
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -732,7 +714,6 @@ func NewCLI() *cobra.Command {
|
||||
CompletionOptions: cobra.CompletionOptions{
|
||||
DisableDefaultCmd: true,
|
||||
},
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
cobra.EnableCommandSorting = false
|
||||
@@ -756,7 +737,6 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
|
||||
@@ -15,7 +15,6 @@ const (
|
||||
ModelType3B ModelType = 26
|
||||
ModelType7B ModelType = 32
|
||||
ModelType13B ModelType = 40
|
||||
ModelType34B ModelType = 48
|
||||
ModelType30B ModelType = 60
|
||||
ModelType65B ModelType = 80
|
||||
)
|
||||
@@ -28,8 +27,6 @@ func (mt ModelType) String() string {
|
||||
return "7B"
|
||||
case ModelType13B:
|
||||
return "13B"
|
||||
case ModelType34B:
|
||||
return "34B"
|
||||
case ModelType30B:
|
||||
return "30B"
|
||||
case ModelType65B:
|
||||
|
||||
@@ -105,7 +105,6 @@ enum e_model {
|
||||
MODEL_7B,
|
||||
MODEL_13B,
|
||||
MODEL_30B,
|
||||
MODEL_34B,
|
||||
MODEL_65B,
|
||||
MODEL_70B,
|
||||
};
|
||||
@@ -149,7 +148,6 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
|
||||
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
|
||||
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
|
||||
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
||||
{ MODEL_34B, ((size_t) n_ctx / 9ull + 160ull) * MB },
|
||||
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
|
||||
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
|
||||
};
|
||||
@@ -163,7 +161,6 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
|
||||
{ MODEL_7B, 160ull * MB },
|
||||
{ MODEL_13B, 192ull * MB },
|
||||
{ MODEL_30B, 256ull * MB },
|
||||
{ MODEL_34B, 256ull * MB },
|
||||
{ MODEL_65B, 384ull * MB }, // guess
|
||||
{ MODEL_70B, 304ull * MB },
|
||||
};
|
||||
@@ -178,7 +175,6 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
|
||||
{ MODEL_7B, 10ull * MB },
|
||||
{ MODEL_13B, 12ull * MB },
|
||||
{ MODEL_30B, 16ull * MB },
|
||||
{ MODEL_34B, 16ull * MB },
|
||||
{ MODEL_65B, 24ull * MB }, // guess
|
||||
{ MODEL_70B, 24ull * MB },
|
||||
};
|
||||
@@ -194,7 +190,6 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
|
||||
{ MODEL_7B, 512ull * kB },
|
||||
{ MODEL_13B, 640ull * kB },
|
||||
{ MODEL_30B, 768ull * kB },
|
||||
{ MODEL_34B, 768ull * kB },
|
||||
{ MODEL_65B, 1280ull * kB },
|
||||
{ MODEL_70B, 1280ull * kB },
|
||||
};
|
||||
@@ -210,7 +205,6 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
|
||||
{ MODEL_7B, 128ull },
|
||||
{ MODEL_13B, 160ull },
|
||||
{ MODEL_30B, 208ull },
|
||||
{ MODEL_34B, 208ull },
|
||||
{ MODEL_65B, 256ull },
|
||||
{ MODEL_70B, 256ull },
|
||||
};
|
||||
@@ -1059,7 +1053,6 @@ static const char *llama_model_type_name(e_model type) {
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_30B: return "30B";
|
||||
case MODEL_34B: return "34B";
|
||||
case MODEL_65B: return "65B";
|
||||
case MODEL_70B: return "70B";
|
||||
default: LLAMA_ASSERT(false);
|
||||
@@ -1107,7 +1100,6 @@ static void llama_model_load_internal(
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
case 48: model.type = e_model::MODEL_34B; break;
|
||||
case 60: model.type = e_model::MODEL_30B; break;
|
||||
case 80: model.type = e_model::MODEL_65B; break;
|
||||
default:
|
||||
@@ -1128,8 +1120,6 @@ static void llama_model_load_internal(
|
||||
LLAMA_LOG_WARN("%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
|
||||
model.type = e_model::MODEL_70B;
|
||||
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
|
||||
} else if (model.type == e_model::MODEL_34B && n_gqa == 8) {
|
||||
hparams.f_ffn_mult = 1.0f; // from the params.json of the 34B model
|
||||
}
|
||||
|
||||
hparams.rope_freq_base = rope_freq_base;
|
||||
|
||||
16
llm/llama.go
16
llm/llama.go
@@ -117,21 +117,7 @@ func (llm *llamaModel) ModelFamily() ModelFamily {
|
||||
}
|
||||
|
||||
func (llm *llamaModel) ModelType() ModelType {
|
||||
switch llm.hyperparameters.NumLayer {
|
||||
case 26:
|
||||
return ModelType3B
|
||||
case 32:
|
||||
return ModelType7B
|
||||
case 40:
|
||||
return ModelType13B
|
||||
case 60:
|
||||
return ModelType30B
|
||||
case 80:
|
||||
return ModelType65B
|
||||
}
|
||||
|
||||
// TODO: find a better default
|
||||
return ModelType7B
|
||||
return ModelType30B
|
||||
}
|
||||
|
||||
func (llm *llamaModel) FileType() FileType {
|
||||
|
||||
@@ -2,12 +2,9 @@
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
GO_LDFLAGS="-X github.com/jmorganca/ollama/version.Version=$VERSION"
|
||||
GO_LDFLAGS="$GO_LDFLAGS -X github.com/jmorganca/ollama/server.mode=release"
|
||||
|
||||
# build universal binary
|
||||
CGO_ENABLED=1 GOARCH=arm64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-arm64
|
||||
CGO_ENABLED=1 GOARCH=amd64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-amd64
|
||||
CGO_ENABLED=1 GOARCH=arm64 go build -o dist/ollama-darwin-arm64
|
||||
CGO_ENABLED=1 GOARCH=amd64 go build -o dist/ollama-darwin-amd64
|
||||
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
|
||||
rm dist/ollama-darwin-amd64 dist/ollama-darwin-arm64
|
||||
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
|
||||
|
||||
@@ -12,10 +12,8 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -45,34 +43,21 @@ func generateNonce(length int) (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
||||
}
|
||||
|
||||
func (r AuthRedirect) URL() (*url.URL, error) {
|
||||
redirectURL, err := url.Parse(r.Realm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := redirectURL.Query()
|
||||
|
||||
values.Add("service", r.Service)
|
||||
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
values.Add("scope", s)
|
||||
}
|
||||
|
||||
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
func (r AuthRedirect) URL() (string, error) {
|
||||
nonce, err := generateNonce(16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
values.Add("nonce", nonce)
|
||||
|
||||
redirectURL.RawQuery = values.Encode()
|
||||
return redirectURL, nil
|
||||
scopes := []string{}
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
scopes = append(scopes, fmt.Sprintf("scope=%s", s))
|
||||
}
|
||||
scopeStr := strings.Join(scopes, "&")
|
||||
return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil
|
||||
}
|
||||
|
||||
func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
|
||||
redirectURL, err := redirData.URL()
|
||||
url, err := redirData.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -92,18 +77,28 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
|
||||
|
||||
s := SignatureData{
|
||||
Method: "GET",
|
||||
Path: redirectURL.String(),
|
||||
Path: url,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(s.Path, "http") {
|
||||
if regOpts.Insecure {
|
||||
s.Path = "http://" + url
|
||||
} else {
|
||||
s.Path = "https://" + url
|
||||
}
|
||||
}
|
||||
|
||||
sig, err := s.Sign(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Authorization", sig)
|
||||
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
|
||||
headers := map[string]string{
|
||||
"Authorization": sig,
|
||||
}
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get token: %q", err)
|
||||
}
|
||||
|
||||
@@ -155,13 +155,12 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
|
||||
}
|
||||
}
|
||||
|
||||
requestURL := opts.mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
|
||||
headers := map[string]string{
|
||||
"Range": fmt.Sprintf("bytes=%d-", size),
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
|
||||
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't download blob: %v", err)
|
||||
return fmt.Errorf("%w: %w", errDownload, err)
|
||||
|
||||
223
server/images.go
223
server/images.go
@@ -12,12 +12,10 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -25,7 +23,6 @@ import (
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
const MaxRetries = 3
|
||||
@@ -157,6 +154,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -274,7 +272,6 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
case "model":
|
||||
fn(api.ProgressResponse{Status: "looking for model"})
|
||||
embed.model = c.Args
|
||||
|
||||
mp := ParseModelPath(c.Args)
|
||||
mf, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
@@ -289,7 +286,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
mf, err = GetManifest(mp)
|
||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||
}
|
||||
@@ -328,27 +325,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||
}
|
||||
|
||||
if mf != nil {
|
||||
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sourceBlob, err := os.Open(sourceBlobPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceBlob.Close()
|
||||
|
||||
var source ConfigV2
|
||||
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copie the model metadata
|
||||
config.ModelFamily = source.ModelFamily
|
||||
config.ModelType = source.ModelType
|
||||
config.FileType = source.FileType
|
||||
|
||||
log.Printf("manifest = %#v", mf)
|
||||
for _, l := range mf.Layers {
|
||||
newLayer, err := GetLayerWithBufferFromLayer(l)
|
||||
if err != nil {
|
||||
@@ -678,6 +655,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
|
||||
|
||||
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
@@ -808,14 +786,11 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
srcModelPath := ParseModelPath(src)
|
||||
srcPath, err := srcModelPath.GetManifestPath(false)
|
||||
srcPath, err := ParseModelPath(src).GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destModelPath := ParseModelPath(dest)
|
||||
destPath, err := destModelPath.GetManifestPath(true)
|
||||
destPath, err := ParseModelPath(dest).GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -838,6 +813,7 @@ func CopyModel(src, dest string) error {
|
||||
|
||||
func DeleteModel(name string) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -917,11 +893,8 @@ func DeleteModel(name string) error {
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
@@ -962,8 +935,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path.Base(location.Path), "sha256:") {
|
||||
layer.Digest = path.Base(location.Path)
|
||||
if strings.HasPrefix(path.Base(location), "sha256:") {
|
||||
layer.Digest = path.Base(location)
|
||||
fn(api.ProgressResponse{
|
||||
Status: "using existing layer",
|
||||
Digest: layer.Digest,
|
||||
@@ -980,17 +953,17 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pushing manifest"})
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
|
||||
manifestJSON, err := json.Marshal(manifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1004,10 +977,6 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return fmt.Errorf("insecure protocol http")
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
manifest, err := pullModelManifest(ctx, mp, regOpts)
|
||||
@@ -1074,11 +1043,12 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
||||
headers := map[string]string{
|
||||
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
|
||||
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't get manifest: %v", err)
|
||||
return nil, err
|
||||
@@ -1137,12 +1107,35 @@ func GetSHA256Digest(r io.Reader) (string, int) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
||||
}
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
|
||||
if layer.From != "" {
|
||||
url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Extract UUID location from header
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return "", fmt.Errorf("location header is missing in response")
|
||||
}
|
||||
|
||||
return location, nil
|
||||
}
|
||||
|
||||
// Function to check if a blob already exists in the Docker registry
|
||||
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
|
||||
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
||||
|
||||
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
|
||||
resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't check for blob: %v", err)
|
||||
return false, err
|
||||
@@ -1153,10 +1146,110 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
completed := 0
|
||||
chunkSize := 10 * 1024 * 1024
|
||||
|
||||
for {
|
||||
r, w := io.Pipe()
|
||||
defer r.Close()
|
||||
|
||||
limit := completed + chunkSize
|
||||
if chunkSize >= layer.Size-completed {
|
||||
limit = layer.Size
|
||||
chunkSize = layer.Size - completed
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer w.Close()
|
||||
for {
|
||||
n, err := io.CopyN(w, f, 1024*1024)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error copying pipe: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: completed,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
completed += int(n)
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: completed,
|
||||
})
|
||||
|
||||
if completed >= limit {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
headers["Content-Length"] = strconv.Itoa(chunkSize)
|
||||
headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, limit-1)
|
||||
|
||||
resp, err := makeRequest(ctx, "PATCH", url, headers, r, regOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
url = resp.Header.Get("Location")
|
||||
if completed >= layer.Size {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
|
||||
|
||||
headers := make(map[string]string)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
headers["Content-Length"] = "0"
|
||||
|
||||
// finish the upload
|
||||
resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
var status string
|
||||
for try := 0; try < MaxRetries; try++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return nil, err
|
||||
@@ -1192,27 +1285,29 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
return nil, fmt.Errorf("max retry exceeded: %v", status)
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if requestURL.Scheme != "http" && regOpts.Insecure {
|
||||
requestURL.Scheme = "http"
|
||||
func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if !strings.HasPrefix(url, "http") {
|
||||
if regOpts.Insecure {
|
||||
url = "http://" + url
|
||||
} else {
|
||||
url = "https://" + url
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if headers != nil {
|
||||
req.Header = headers
|
||||
}
|
||||
|
||||
if regOpts.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
||||
} else if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -25,46 +23,42 @@ const (
|
||||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
slashParts := strings.Split(name, "/")
|
||||
var registry, namespace, repository, tag string
|
||||
|
||||
before, after, found := strings.Cut(name, "://")
|
||||
if found {
|
||||
mp.ProtocolScheme = before
|
||||
name = after
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
switch len(slashParts) {
|
||||
case 3:
|
||||
mp.Registry = parts[0]
|
||||
mp.Namespace = parts[1]
|
||||
mp.Repository = parts[2]
|
||||
registry = slashParts[0]
|
||||
namespace = slashParts[1]
|
||||
repository = strings.Split(slashParts[2], ":")[0]
|
||||
case 2:
|
||||
mp.Namespace = parts[0]
|
||||
mp.Repository = parts[1]
|
||||
registry = DefaultRegistry
|
||||
namespace = slashParts[0]
|
||||
repository = strings.Split(slashParts[1], ":")[0]
|
||||
case 1:
|
||||
mp.Repository = parts[0]
|
||||
registry = DefaultRegistry
|
||||
namespace = DefaultNamespace
|
||||
repository = strings.Split(slashParts[0], ":")[0]
|
||||
default:
|
||||
fmt.Println("Invalid image format.")
|
||||
return ModelPath{}
|
||||
}
|
||||
|
||||
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
|
||||
if len(colonParts) == 2 {
|
||||
tag = colonParts[1]
|
||||
} else {
|
||||
tag = DefaultTag
|
||||
}
|
||||
|
||||
return mp
|
||||
return ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: registry,
|
||||
Namespace: namespace,
|
||||
Repository: repository,
|
||||
Tag: tag,
|
||||
}
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
@@ -101,13 +95,6 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: mp.ProtocolScheme,
|
||||
Host: mp.Registry,
|
||||
}
|
||||
}
|
||||
|
||||
func GetManifestPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg string
|
||||
want ModelPath
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
"https://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"full path http",
|
||||
"http://example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
"example.com/ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
"ns/repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
"repo:tag",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
"repo",
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ParseModelPath(tc.arg)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,20 +25,6 @@ import (
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
||||
var mode string = gin.DebugMode
|
||||
|
||||
func init() {
|
||||
switch mode {
|
||||
case gin.DebugMode:
|
||||
case gin.ReleaseMode:
|
||||
case gin.TestMode:
|
||||
default:
|
||||
mode = gin.DebugMode
|
||||
}
|
||||
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
@@ -371,7 +357,6 @@ func ListModelsHandler(c *gin.Context) {
|
||||
return nil
|
||||
}
|
||||
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||
|
||||
mp := ParseModelPath(tag)
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
|
||||
125
server/upload.go
125
server/upload.go
@@ -1,125 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
|
||||
if layer.From != "" {
|
||||
values := requestURL.Query()
|
||||
values.Add("mount", layer.Digest)
|
||||
values.Add("from", layer.From)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't start upload: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Extract UUID location from header
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return nil, fmt.Errorf("location header is missing in response")
|
||||
}
|
||||
|
||||
return url.Parse(location)
|
||||
}
|
||||
|
||||
func uploadBlobChunked(ctx context.Context, mp ModelPath, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
// TODO allow resumability
|
||||
// TODO allow canceling uploads via DELETE
|
||||
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var completed int64
|
||||
chunkSize := 10 * 1024 * 1024
|
||||
|
||||
for {
|
||||
chunk := int64(layer.Size) - completed
|
||||
if chunk > int64(chunkSize) {
|
||||
chunk = int64(chunkSize)
|
||||
}
|
||||
|
||||
sectionReader := io.NewSectionReader(f, int64(completed), chunk)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", strconv.Itoa(int(chunk)))
|
||||
headers.Set("Content-Range", fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, "PATCH", requestURL, headers, sectionReader, regOpts)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: int(completed),
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
completed += sectionReader.Size()
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("uploading %s", layer.Digest),
|
||||
Digest: layer.Digest,
|
||||
Total: layer.Size,
|
||||
Completed: int(completed),
|
||||
})
|
||||
|
||||
requestURL, err = url.Parse(resp.Header.Get("Location"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if completed >= int64(layer.Size) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
values := requestURL.Query()
|
||||
values.Add("digest", layer.Digest)
|
||||
requestURL.RawQuery = values.Encode()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", "0")
|
||||
|
||||
// finish the upload
|
||||
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
|
||||
if err != nil {
|
||||
log.Printf("couldn't finish upload: %v", err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
package version
|
||||
|
||||
var Version string = "0.0.0"
|
||||
Reference in New Issue
Block a user