Compare commits

..

22 Commits

Author SHA1 Message Date
Jeffrey Morgan
177b69a211 add missing entries for 34B 2023-08-25 18:35:35 -07:00
Michael Yang
dad63f0821 Merge pull request #411 from jmorganca/mxyng/34b
patch llama.cpp for 34B
2023-08-25 11:59:05 -07:00
Michael Yang
7a378f8b66 patch llama.cpp for 34B 2023-08-25 10:06:55 -07:00
Michael Yang
de0bdd7f29 Merge pull request #405 from jmorganca/mxyng/34b
add 34b model type
2023-08-24 10:37:22 -07:00
Michael Yang
b1cececb8e add 34b model type 2023-08-24 10:35:44 -07:00
Michael Yang
e0d39fa3bf Merge pull request #398 from jmorganca/mxyng/cleanup
Mxyng/cleanup
2023-08-22 15:51:41 -07:00
Michael Yang
968ced2e71 Merge pull request #393 from jmorganca/mxyng/net-url
use url.URL
2023-08-22 15:51:33 -07:00
Michael Yang
32d1a00017 remove unused requestContextKey 2023-08-22 10:49:54 -07:00
Michael Yang
04e2128273 move upload funcs to upload.go 2023-08-22 10:49:53 -07:00
Michael Yang
2cc634689b use url.URL 2023-08-22 10:49:07 -07:00
Michael Yang
8f827641b0 Merge pull request #397 from jmorganca/mxyng/release-mode
build release mode
2023-08-22 10:48:44 -07:00
Michael Yang
95187d7e1e build release mode 2023-08-22 09:52:43 -07:00
Michael Yang
9ec7e37534 Merge pull request #392 from jmorganca/mxyng/version
add version
2023-08-22 09:50:25 -07:00
Michael Yang
2c7f956b38 add version 2023-08-22 09:40:58 -07:00
Jeffrey Morgan
a9f6c56652 fix FROM instruction erroring when referring to a file 2023-08-22 09:39:42 -07:00
Ryan Baker
0a892419ad Strip protocol from model path (#377) 2023-08-21 21:56:56 -07:00
Jeffrey Morgan
e3054fc74e add .env to .dockerignore 2023-08-21 09:32:02 -07:00
Michael Yang
23c2485044 Merge pull request #381 from jmorganca/mxyng/fix-push-chunks
retry on unauthorized chunk push
2023-08-18 13:49:25 -07:00
Michael Yang
386c66f285 Merge pull request #378 from jmorganca/mxyng/copy-metadata-from-source
copy metadata from source
2023-08-18 13:49:09 -07:00
Michael Yang
3b49315f97 retry on unauthorized chunk push
The token printed for authorized requests has a lifetime of 1h. If an
upload exceeds 1h, a chunk push will fail since the token is created on
a "start upload" request.

This replaces the Pipe with SectionReader which is simpler and
implements Seek, a requirement for makeRequestWithRetry. This is
slightly worse than using a Pipe since the progress update is directly
tied to the chunk size instead of controlled separately.
2023-08-18 11:23:47 -07:00
Michael Yang
5ca05c2e88 fix ModelType() 2023-08-18 11:23:38 -07:00
Michael Yang
7eda70f23b copy metadata from source 2023-08-17 21:55:25 -07:00
18 changed files with 447 additions and 281 deletions

View File

@@ -4,4 +4,5 @@ llama/build
.vscode
ollama
app
web
web
.env

1
.gitignore vendored
View File

@@ -6,4 +6,3 @@
dist
ollama
/ggml-metal.metal
build

View File

@@ -1,40 +0,0 @@
cmake_minimum_required(VERSION 3.14) # 3.11 or later for FetchContent, but some features might require newer versions
project(llama_cpp)
include(FetchContent)
FetchContent_Declare(
llama_cpp_gguf
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG 6381d4e
)
FetchContent_Declare(
llama_cpp_ggml
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG dadbed9
)
FetchContent_MakeAvailable(llama_cpp_ggml)
add_subdirectory(${llama_cpp_ggml_SOURCE_DIR}/examples EXCLUDE_FROM_ALL)
add_executable(llama_cpp ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
include_directories(${llama_cpp_ggml_SOURCE_DIR})
include_directories(${llama_cpp_ggml_SOURCE_DIR}/examples)
target_compile_features(llama_cpp PRIVATE cxx_std_11)
target_link_libraries(llama_cpp PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
if (APPLE)
add_executable(llama_cpp_metal ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
target_compile_options(llama_cpp_metal PRIVATE -DLLAMA_STATIC=ON -DLLAMA_METAL=ON -DGGML_USE_METAL=1)
target_compile_features(llama_cpp_metal PRIVATE cxx_std_11)
target_link_libraries(llama_cpp_metal PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
configure_file(${llama_cpp_SOURCE_DIR}/ggml-metal.metal ${CMAKE_BINARY_DIR}/ggml-metal.metal COPYONLY)
else()
add_executable(llama_cpp_cublas ${llama_cpp_ggml_SOURCE_DIR}/examples/server/server.cpp)
target_compile_definitions(llama_cpp_cublas PRIVATE -DLLAMA_STATIC=ON -DLLAMA_CUBLAS=ON)
target_compile_options(llama_cpp_cublas PRIVATE -DLLAMA_CUBLAS=ON -DLLAMA_STATIC=ON)
target_compile_features(llama_cpp_cublas PRIVATE cxx_std_11)
target_link_libraries(llama_cpp_cublas PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
endif()

View File

@@ -10,7 +10,10 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strings"
"github.com/jmorganca/ollama/version"
)
const DefaultHost = "localhost:11434"
@@ -83,21 +86,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
reqBody = bytes.NewReader(data)
}
url := c.Base.JoinPath(path).String()
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
requestURL := c.Base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
for k, v := range c.Headers {
req.Header[k] = v
request.Header[k] = v
}
respObj, err := c.HTTP.Do(req)
respObj, err := c.HTTP.Do(request)
if err != nil {
return err
}
@@ -131,13 +134,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
buf = bytes.NewBuffer(bts)
}
request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
requestURL := c.Base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil {
return err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
response, err := http.DefaultClient.Do(request)
if err != nil {

View File

@@ -30,6 +30,7 @@ import (
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/progressbar"
"github.com/jmorganca/ollama/server"
"github.com/jmorganca/ollama/version"
)
func CreateHandler(cmd *cobra.Command, args []string) error {
@@ -97,7 +98,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
func RunHandler(cmd *cobra.Command, args []string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
}
mp := server.ParseModelPath(args[0])
if err != nil {
return err
}
if mp.ProtocolScheme == "http" && !insecure {
return fmt.Errorf("insecure protocol http")
}
fp, err := mp.GetManifestPath(false)
if err != nil {
return err
@@ -106,7 +120,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
_, err = os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
if err := pull(args[0], false); err != nil {
if err := pull(args[0], insecure); err != nil {
var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) {
return err
@@ -507,6 +521,10 @@ func generateInteractive(cmd *cobra.Command, model string) error {
args := strings.Fields(line)
if len(args) > 1 {
mp := server.ParseModelPath(model)
if err != nil {
return err
}
manifest, err := server.GetManifest(mp)
if err != nil {
fmt.Println("error: couldn't get a manifest for this model")
@@ -569,7 +587,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
var host, port = "127.0.0.1", "11434"
host, port := "127.0.0.1", "11434"
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
if ip := net.ParseIP(parts[0]); ip != nil {
@@ -630,7 +648,7 @@ func initializeKeypair() error {
return fmt.Errorf("could not create directory %w", err)
}
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
if err != nil {
return err
}
@@ -642,7 +660,7 @@ func initializeKeypair() error {
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
if err != nil {
return err
}
@@ -714,6 +732,7 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Version: version.Version,
}
cobra.EnableCommandSorting = false
@@ -737,6 +756,7 @@ func NewCLI() *cobra.Command {
}
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
serveCmd := &cobra.Command{
Use: "serve",

View File

@@ -1,4 +0,0 @@
#!/bin/bash
cmake -B build
make -C build

View File

@@ -15,6 +15,7 @@ const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType34B ModelType = 48
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
@@ -27,6 +28,8 @@ func (mt ModelType) String() string {
return "7B"
case ModelType13B:
return "13B"
case ModelType34B:
return "34B"
case ModelType30B:
return "30B"
case ModelType65B:

View File

@@ -105,6 +105,7 @@ enum e_model {
MODEL_7B,
MODEL_13B,
MODEL_30B,
MODEL_34B,
MODEL_65B,
MODEL_70B,
};
@@ -148,6 +149,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
{ MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB },
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_34B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
{ MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
};
@@ -161,6 +163,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
{ MODEL_7B, 160ull * MB },
{ MODEL_13B, 192ull * MB },
{ MODEL_30B, 256ull * MB },
{ MODEL_34B, 256ull * MB },
{ MODEL_65B, 384ull * MB }, // guess
{ MODEL_70B, 304ull * MB },
};
@@ -175,6 +178,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
{ MODEL_7B, 10ull * MB },
{ MODEL_13B, 12ull * MB },
{ MODEL_30B, 16ull * MB },
{ MODEL_34B, 16ull * MB },
{ MODEL_65B, 24ull * MB }, // guess
{ MODEL_70B, 24ull * MB },
};
@@ -190,6 +194,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
{ MODEL_7B, 512ull * kB },
{ MODEL_13B, 640ull * kB },
{ MODEL_30B, 768ull * kB },
{ MODEL_34B, 768ull * kB },
{ MODEL_65B, 1280ull * kB },
{ MODEL_70B, 1280ull * kB },
};
@@ -205,6 +210,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
{ MODEL_7B, 128ull },
{ MODEL_13B, 160ull },
{ MODEL_30B, 208ull },
{ MODEL_34B, 208ull },
{ MODEL_65B, 256ull },
{ MODEL_70B, 256ull },
};
@@ -1053,6 +1059,7 @@ static const char *llama_model_type_name(e_model type) {
case MODEL_7B: return "7B";
case MODEL_13B: return "13B";
case MODEL_30B: return "30B";
case MODEL_34B: return "34B";
case MODEL_65B: return "65B";
case MODEL_70B: return "70B";
default: LLAMA_ASSERT(false);
@@ -1100,6 +1107,7 @@ static void llama_model_load_internal(
case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
case 48: model.type = e_model::MODEL_34B; break;
case 60: model.type = e_model::MODEL_30B; break;
case 80: model.type = e_model::MODEL_65B; break;
default:
@@ -1120,6 +1128,8 @@ static void llama_model_load_internal(
LLAMA_LOG_WARN("%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
model.type = e_model::MODEL_70B;
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
} else if (model.type == e_model::MODEL_34B && n_gqa == 8) {
hparams.f_ffn_mult = 1.0f; // from the params.json of the 34B model
}
hparams.rope_freq_base = rope_freq_base;

View File

@@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily {
}
func (llm *llamaModel) ModelType() ModelType {
return ModelType30B
switch llm.hyperparameters.NumLayer {
case 26:
return ModelType3B
case 32:
return ModelType7B
case 40:
return ModelType13B
case 60:
return ModelType30B
case 80:
return ModelType65B
}
// TODO: find a better default
return ModelType7B
}
func (llm *llamaModel) FileType() FileType {

View File

@@ -2,9 +2,12 @@
mkdir -p dist
GO_LDFLAGS="-X github.com/jmorganca/ollama/version.Version=$VERSION"
GO_LDFLAGS="$GO_LDFLAGS -X github.com/jmorganca/ollama/server.mode=release"
# build universal binary
CGO_ENABLED=1 GOARCH=arm64 go build -o dist/ollama-darwin-arm64
CGO_ENABLED=1 GOARCH=amd64 go build -o dist/ollama-darwin-amd64
CGO_ENABLED=1 GOARCH=arm64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-arm64
CGO_ENABLED=1 GOARCH=amd64 go build -ldflags "$GO_LDFLAGS" -o dist/ollama-darwin-amd64
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
rm dist/ollama-darwin-amd64 dist/ollama-darwin-arm64
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama

View File

@@ -12,8 +12,10 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"path"
"strconv"
"strings"
"time"
@@ -43,21 +45,34 @@ func generateNonce(length int) (string, error) {
return base64.RawURLEncoding.EncodeToString(nonce), nil
}
func (r AuthRedirect) URL() (string, error) {
func (r AuthRedirect) URL() (*url.URL, error) {
redirectURL, err := url.Parse(r.Realm)
if err != nil {
return nil, err
}
values := redirectURL.Query()
values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") {
values.Add("scope", s)
}
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
nonce, err := generateNonce(16)
if err != nil {
return "", err
return nil, err
}
scopes := []string{}
for _, s := range strings.Split(r.Scope, " ") {
scopes = append(scopes, fmt.Sprintf("scope=%s", s))
}
scopeStr := strings.Join(scopes, "&")
return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil
values.Add("nonce", nonce)
redirectURL.RawQuery = values.Encode()
return redirectURL, nil
}
func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
url, err := redirData.URL()
redirectURL, err := redirData.URL()
if err != nil {
return "", err
}
@@ -77,28 +92,18 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry
s := SignatureData{
Method: "GET",
Path: url,
Path: redirectURL.String(),
Data: nil,
}
if !strings.HasPrefix(s.Path, "http") {
if regOpts.Insecure {
s.Path = "http://" + url
} else {
s.Path = "https://" + url
}
}
sig, err := s.Sign(rawKey)
if err != nil {
return "", err
}
headers := map[string]string{
"Authorization": sig,
}
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
headers := make(http.Header)
headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't get token: %q", err)
}

View File

@@ -155,12 +155,13 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
}
}
url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest)
headers := map[string]string{
"Range": fmt.Sprintf("bytes=%d-", size),
}
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest)
resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts)
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
if err != nil {
log.Printf("couldn't download blob: %v", err)
return fmt.Errorf("%w: %w", errDownload, err)

View File

@@ -12,10 +12,12 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@@ -23,6 +25,7 @@ import (
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
"github.com/jmorganca/ollama/version"
)
const MaxRetries = 3
@@ -154,7 +157,6 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
if err != nil {
return nil, err
@@ -272,6 +274,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
case "model":
fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args
mp := ParseModelPath(c.Args)
mf, err := GetManifest(mp)
if err != nil {
@@ -286,7 +289,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, err = GetManifest(ParseModelPath(c.Args))
mf, err = GetManifest(mp)
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
@@ -325,7 +328,27 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
}
if mf != nil {
log.Printf("manifest = %#v", mf)
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
if err != nil {
return err
}
sourceBlob, err := os.Open(sourceBlobPath)
if err != nil {
return err
}
defer sourceBlob.Close()
var source ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
return err
}
// copie the model metadata
config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType
config.FileType = source.FileType
for _, l := range mf.Layers {
newLayer, err := GetLayerWithBufferFromLayer(l)
if err != nil {
@@ -655,7 +678,6 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@@ -786,11 +808,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
}
func CopyModel(src, dest string) error {
srcPath, err := ParseModelPath(src).GetManifestPath(false)
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath(false)
if err != nil {
return err
}
destPath, err := ParseModelPath(dest).GetManifestPath(true)
destModelPath := ParseModelPath(dest)
destPath, err := destModelPath.GetManifestPath(true)
if err != nil {
return err
}
@@ -813,7 +838,6 @@ func CopyModel(src, dest string) error {
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
if err != nil {
return err
@@ -893,9 +917,12 @@ func DeleteModel(name string) error {
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
manifest, err := GetManifest(mp)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
@@ -935,8 +962,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return err
}
if strings.HasPrefix(path.Base(location), "sha256:") {
layer.Digest = path.Base(location)
if strings.HasPrefix(path.Base(location.Path), "sha256:") {
layer.Digest = path.Base(location.Path)
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
@@ -953,17 +980,17 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
}
fn(api.ProgressResponse{Status: "pushing manifest"})
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
}
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
@@ -977,6 +1004,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err := pullModelManifest(ctx, mp, regOpts)
@@ -1043,12 +1074,11 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
}
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't get manifest: %v", err)
return nil, err
@@ -1107,35 +1137,12 @@ func GetSHA256Digest(r io.Reader) (string, int) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
}
type requestContextKey string
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
if layer.From != "" {
url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
}
resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return "", err
}
defer resp.Body.Close()
// Extract UUID location from header
location := resp.Header.Get("Location")
if location == "" {
return "", fmt.Errorf("location header is missing in response")
}
return location, nil
}
// Function to check if a blob already exists in the Docker registry
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't check for blob: %v", err)
return false, err
@@ -1146,110 +1153,10 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt
return resp.StatusCode == http.StatusOK, nil
}
func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability
// TODO allow canceling uploads via DELETE
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
completed := 0
chunkSize := 10 * 1024 * 1024
for {
r, w := io.Pipe()
defer r.Close()
limit := completed + chunkSize
if chunkSize >= layer.Size-completed {
limit = layer.Size
chunkSize = layer.Size - completed
}
go func() {
defer w.Close()
for {
n, err := io.CopyN(w, f, 1024*1024)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error copying pipe: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: completed,
})
return
}
completed += int(n)
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: completed,
})
if completed >= limit {
return
}
}
}()
headers := make(map[string]string)
headers["Content-Type"] = "application/octet-stream"
headers["Content-Length"] = strconv.Itoa(chunkSize)
headers["Content-Range"] = fmt.Sprintf("%d-%d", completed, limit-1)
resp, err := makeRequest(ctx, "PATCH", url, headers, r, regOpts)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
url = resp.Header.Get("Location")
if completed >= layer.Size {
break
}
}
url = fmt.Sprintf("%s&digest=%s", url, layer.Digest)
headers := make(map[string]string)
headers["Content-Type"] = "application/octet-stream"
headers["Content-Length"] = "0"
// finish the upload
resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
return nil
}
func makeRequestWithRetry(ctx context.Context, method, url string, headers map[string]string, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
var status string
for try := 0; try < MaxRetries; try++ {
resp, err := makeRequest(ctx, method, url, headers, body, regOpts)
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return nil, err
@@ -1285,29 +1192,27 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers map[s
return nil, fmt.Errorf("max retry exceeded: %v", status)
}
func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if !strings.HasPrefix(url, "http") {
if regOpts.Insecure {
url = "http://" + url
} else {
url = "https://" + url
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, url, body)
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
for k, v := range headers {
req.Header.Set(k, v)
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {

View File

@@ -1,7 +1,9 @@
package server
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -23,42 +25,46 @@ const (
DefaultProtocolScheme = "https"
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
)
func ParseModelPath(name string) ModelPath {
slashParts := strings.Split(name, "/")
var registry, namespace, repository, tag string
switch len(slashParts) {
case 3:
registry = slashParts[0]
namespace = slashParts[1]
repository = strings.Split(slashParts[2], ":")[0]
case 2:
registry = DefaultRegistry
namespace = slashParts[0]
repository = strings.Split(slashParts[1], ":")[0]
case 1:
registry = DefaultRegistry
namespace = DefaultNamespace
repository = strings.Split(slashParts[0], ":")[0]
default:
fmt.Println("Invalid image format.")
return ModelPath{}
}
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
if len(colonParts) == 2 {
tag = colonParts[1]
} else {
tag = DefaultTag
}
return ModelPath{
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: registry,
Namespace: namespace,
Repository: repository,
Tag: tag,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}
if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}
return mp
}
func (mp ModelPath) GetNamespaceRepository() string {
@@ -95,6 +101,13 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
return path, nil
}
func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}
func GetManifestPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {

88
server/modelpath_test.go Normal file
View File

@@ -0,0 +1,88 @@
package server
import "testing"
func TestParseModelPath(t *testing.T) {
tests := []struct {
name string
arg string
want ModelPath
}{
{
"full path https",
"https://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"full path http",
"http://example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no protocol",
"example.com/ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no registry",
"ns/repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
},
{
"no namespace",
"repo:tag",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
},
{
"no tag",
"repo",
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := ParseModelPath(tc.arg)
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View File

@@ -25,6 +25,20 @@ import (
"github.com/jmorganca/ollama/vector"
)
var mode string = gin.DebugMode
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
}
var loaded struct {
mu sync.Mutex
@@ -357,6 +371,7 @@ func ListModelsHandler(c *gin.Context) {
return nil
}
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
mp := ParseModelPath(tag)
manifest, err := GetManifest(mp)
if err != nil {

125
server/upload.go Normal file
View File

@@ -0,0 +1,125 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"strconv"
"github.com/jmorganca/ollama/api"
)
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if layer.From != "" {
values := requestURL.Query()
values.Add("mount", layer.Digest)
values.Add("from", layer.From)
requestURL.RawQuery = values.Encode()
}
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return nil, err
}
defer resp.Body.Close()
// Extract UUID location from header
location := resp.Header.Get("Location")
if location == "" {
return nil, fmt.Errorf("location header is missing in response")
}
return url.Parse(location)
}
func uploadBlobChunked(ctx context.Context, mp ModelPath, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability
// TODO allow canceling uploads via DELETE
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
var completed int64
chunkSize := 10 * 1024 * 1024
for {
chunk := int64(layer.Size) - completed
if chunk > int64(chunkSize) {
chunk = int64(chunkSize)
}
sectionReader := io.NewSectionReader(f, int64(completed), chunk)
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(chunk)))
headers.Set("Content-Range", fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1))
resp, err := makeRequestWithRetry(ctx, "PATCH", requestURL, headers, sectionReader, regOpts)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(completed),
})
return err
}
defer resp.Body.Close()
completed += sectionReader.Size()
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(completed),
})
requestURL, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
return err
}
if completed >= int64(layer.Size) {
break
}
}
values := requestURL.Query()
values.Add("digest", layer.Digest)
requestURL.RawQuery = values.Encode()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0")
// finish the upload
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body))
}
return nil
}

3
version/version.go Normal file
View File

@@ -0,0 +1,3 @@
package version
var Version string = "0.0.0"