mirror of
https://github.com/ollama/ollama.git
synced 2026-01-20 21:40:54 -05:00
Compare commits
3 Commits
parth/decr
...
pdevine/au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53a53702e0 | ||
|
|
be04fcde16 | ||
|
|
5968989a7f |
@@ -42,6 +42,23 @@ type Client struct {
|
|||||||
|
|
||||||
func checkError(resp *http.Response, body []byte) error {
|
func checkError(resp *http.Response, body []byte) error {
|
||||||
if resp.StatusCode < http.StatusBadRequest {
|
if resp.StatusCode < http.StatusBadRequest {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// streams can contain error message even with StatusOK
|
||||||
|
var errorResponse struct {
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||||
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errorResponse.Error != "" {
|
||||||
|
return errors.New(errorResponse.Error)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,25 +230,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
scanBuf := make([]byte, 0, maxBufferSize)
|
scanBuf := make([]byte, 0, maxBufferSize)
|
||||||
scanner.Buffer(scanBuf, maxBufferSize)
|
scanner.Buffer(scanBuf, maxBufferSize)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var errorResponse struct {
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
if err := checkError(response, bts); err != nil {
|
||||||
return fmt.Errorf("unmarshal: %w", err)
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
|
||||||
return StatusError{
|
|
||||||
StatusCode: response.StatusCode,
|
|
||||||
Status: response.Status,
|
|
||||||
ErrorMessage: errorResponse.Error,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errorResponse.Error != "" {
|
|
||||||
return errors.New(errorResponse.Error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fn(bts); err != nil {
|
if err := fn(bts); err != nil {
|
||||||
|
|||||||
@@ -89,16 +89,6 @@ func TestClientStream(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: "mid-stream error",
|
wantErr: "mid-stream error",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "http status error takes precedence over general error",
|
|
||||||
responses: []any{
|
|
||||||
testError{
|
|
||||||
message: "custom error message",
|
|
||||||
statusCode: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "500",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "successful stream completion",
|
name: "successful stream completion",
|
||||||
responses: []any{
|
responses: []any{
|
||||||
|
|||||||
35
auth/auth.go
35
auth/auth.go
@@ -18,6 +18,8 @@ import (
|
|||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
|
var ErrInvalidToken = errors.New("invalid token")
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func keyPath() (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -27,6 +29,39 @@ func keyPath() (string, error) {
|
|||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseToken(token string) (key, sig []byte, _ error) {
|
||||||
|
keyData, sigData, ok := strings.Cut(token, ":")
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
|
||||||
|
}
|
||||||
|
sig, err := base64.StdEncoding.DecodeString(sigData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
|
||||||
|
}
|
||||||
|
return []byte(keyData), sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
|
||||||
|
keyShort, sigBytes, err := parseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
|
||||||
|
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pub.Verify([]byte(checkData), &ssh.Signature{
|
||||||
|
Format: pub.Type(),
|
||||||
|
Blob: sigBytes,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return pub, nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
func GetPublicKey() (string, error) {
|
||||||
keyPath, err := keyPath()
|
keyPath, err := keyPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
254
auth/authorized_keys.go
Normal file
254
auth/authorized_keys.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyEntry struct {
|
||||||
|
Name string
|
||||||
|
PublicKey string
|
||||||
|
Endpoints []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyPermission struct {
|
||||||
|
Name string
|
||||||
|
Endpoints []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIPermissions struct {
|
||||||
|
permissions map[string]*KeyPermission
|
||||||
|
lastModified time.Time
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
var ws = regexp.MustCompile(`\s+`)
|
||||||
|
|
||||||
|
func authkeyPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, ".ollama", "authorized_keys"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAPIPermissions() *APIPermissions {
|
||||||
|
return &APIPermissions{
|
||||||
|
permissions: make(map[string]*KeyPermission),
|
||||||
|
mutex: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *APIPermissions) ReloadIfNeeded() error {
|
||||||
|
ap.mutex.Lock()
|
||||||
|
defer ap.mutex.Unlock()
|
||||||
|
|
||||||
|
filename, err := authkeyPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileInfo, err := os.Stat(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to stat file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fileInfo.ModTime().After(ap.lastModified) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
ap.lastModified = fileInfo.ModTime()
|
||||||
|
return ap.parse(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *APIPermissions) parse(r io.Reader) error {
|
||||||
|
ap.permissions = make(map[string]*KeyPermission)
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
var cnt int
|
||||||
|
for scanner.Scan() {
|
||||||
|
cnt += 1
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
line = ws.ReplaceAllString(line, " ")
|
||||||
|
|
||||||
|
entry, err := ap.parseLine(line)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var pubKeyStr string
|
||||||
|
|
||||||
|
if entry.PublicKey == "*" {
|
||||||
|
pubKeyStr = "*"
|
||||||
|
} else {
|
||||||
|
pubKey, err := ap.validateAndDecodeKey(entry)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pubKeyStr = pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if perm, exists := ap.permissions[pubKeyStr]; exists {
|
||||||
|
if perm.Name == "default" {
|
||||||
|
perm.Name = entry.Name
|
||||||
|
}
|
||||||
|
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||||
|
// skip redundant entries
|
||||||
|
continue
|
||||||
|
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
|
||||||
|
// overwrite redundant entries
|
||||||
|
perm.Endpoints = entry.Endpoints
|
||||||
|
} else {
|
||||||
|
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ap.permissions[pubKeyStr] = &KeyPermission{
|
||||||
|
Name: entry.Name,
|
||||||
|
Endpoints: entry.Endpoints,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
|
||||||
|
parts := strings.SplitN(line, " ", 4)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return nil, fmt.Errorf("key type and public key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
kind, b64Key := parts[0], parts[1]
|
||||||
|
name := "default"
|
||||||
|
eps := "*"
|
||||||
|
|
||||||
|
if len(parts) >= 3 && parts[2] != "" {
|
||||||
|
if parts[2] != "*" {
|
||||||
|
name = parts[2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 4 && parts[3] != "" {
|
||||||
|
eps = parts[3]
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind != "ssh-ed25519" && kind != "*" {
|
||||||
|
return nil, fmt.Errorf("unsupported key type %s", kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == "*" && b64Key != "*" {
|
||||||
|
return nil, fmt.Errorf("unsupported key type")
|
||||||
|
}
|
||||||
|
|
||||||
|
var endpoints []string
|
||||||
|
if eps == "*" {
|
||||||
|
endpoints = []string{"*"}
|
||||||
|
} else {
|
||||||
|
for _, e := range strings.Split(eps, ",") {
|
||||||
|
e = strings.TrimSpace(e)
|
||||||
|
if e == "" {
|
||||||
|
return nil, fmt.Errorf("empty endpoint in list")
|
||||||
|
} else if e == "*" {
|
||||||
|
endpoints = []string{"*"}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
endpoints = append(endpoints, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyEntry{
|
||||||
|
PublicKey: b64Key,
|
||||||
|
Name: name,
|
||||||
|
Endpoints: endpoints,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
|
||||||
|
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("base64 decode: %w", err)
|
||||||
|
}
|
||||||
|
pub, err := ssh.ParsePublicKey(keyBlob)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse key: %w", err)
|
||||||
|
}
|
||||||
|
if pub.Type() != ssh.KeyAlgoED25519 {
|
||||||
|
return "", fmt.Errorf("key is not Ed25519")
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.PublicKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
|
||||||
|
if err := ap.ReloadIfNeeded(); err != nil {
|
||||||
|
return false, "unknown", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ap.mutex.RLock()
|
||||||
|
defer ap.mutex.RUnlock()
|
||||||
|
|
||||||
|
if wildcardPerm, exists := ap.permissions["*"]; exists {
|
||||||
|
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
|
||||||
|
return true, wildcardPerm.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowedEndpoint := range wildcardPerm.Endpoints {
|
||||||
|
if allowedEndpoint == endpoint {
|
||||||
|
return true, wildcardPerm.Name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||||
|
parts := strings.SplitN(keyString, " ", 2)
|
||||||
|
var base64Key string
|
||||||
|
if len(parts) > 1 {
|
||||||
|
base64Key = parts[1]
|
||||||
|
} else {
|
||||||
|
base64Key = parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Key = strings.TrimSpace(base64Key)
|
||||||
|
|
||||||
|
perm, exists := ap.permissions[base64Key]
|
||||||
|
if !exists {
|
||||||
|
return false, "unknown", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||||
|
return true, perm.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowedEndpoint := range perm.Endpoints {
|
||||||
|
if allowedEndpoint == endpoint {
|
||||||
|
return true, perm.Name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, "unknown", nil
|
||||||
|
}
|
||||||
133
auth/authorized_keys_test.go
Normal file
133
auth/authorized_keys_test.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
file string
|
||||||
|
want map[string]*KeyPermission
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "two fields only defaults",
|
||||||
|
file: "ssh-ed25519 " + validB64 + "\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "default",
|
||||||
|
Endpoints: []string{"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extra whitespace collapsed and default endpoints",
|
||||||
|
file: "ssh-ed25519 " + validB64 + " alice\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "alice",
|
||||||
|
Endpoints: []string{"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "four fields full",
|
||||||
|
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "bob",
|
||||||
|
Endpoints: []string{"/api/foo", "/api/bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comment lines ignored and multiple entries",
|
||||||
|
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "user1",
|
||||||
|
Endpoints: []string{"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three entries variety",
|
||||||
|
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "alice",
|
||||||
|
Endpoints: []string{"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two entries w/ wildcard",
|
||||||
|
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
validB64: {
|
||||||
|
Name: "alice",
|
||||||
|
Endpoints: []string{"/api/a"},
|
||||||
|
},
|
||||||
|
"*": {
|
||||||
|
Name: "default",
|
||||||
|
Endpoints: []string{"/api/b"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tags for everyone",
|
||||||
|
file: "* * * /api/tags",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
"*": {
|
||||||
|
Name: "default",
|
||||||
|
Endpoints: []string{"/api/tags"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default name",
|
||||||
|
file: "* * somename",
|
||||||
|
want: map[string]*KeyPermission{
|
||||||
|
"*": {
|
||||||
|
Name: "somename",
|
||||||
|
Endpoints: []string{"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported key type",
|
||||||
|
file: "ssh-rsa AAAAB3Nza...\n",
|
||||||
|
want: map[string]*KeyPermission{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad base64",
|
||||||
|
file: "ssh-ed25519 invalid@@@\n",
|
||||||
|
want: map[string]*KeyPermission{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "just an asterix",
|
||||||
|
file: "*\n",
|
||||||
|
want: map[string]*KeyPermission{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
perms := NewAPIPermissions()
|
||||||
|
err := perms.parse(bytes.NewBufferString(tc.file))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(perms.permissions) != len(tc.want) {
|
||||||
|
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(perms.permissions, tc.want) {
|
||||||
|
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,6 +28,7 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/discover"
|
"github.com/ollama/ollama/discover"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
@@ -55,6 +56,8 @@ var mode string = gin.DebugMode
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
|
|
||||||
|
perms *auth.APIPermissions
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -69,6 +72,38 @@ func init() {
|
|||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loggedFormatter(param gin.LogFormatterParams) string {
|
||||||
|
var statusColor, methodColor, resetColor string
|
||||||
|
if param.IsOutputColor() {
|
||||||
|
statusColor = param.StatusCodeColor()
|
||||||
|
methodColor = param.MethodColor()
|
||||||
|
resetColor = param.ResetColor()
|
||||||
|
}
|
||||||
|
|
||||||
|
if param.Latency > time.Minute {
|
||||||
|
param.Latency = param.Latency.Truncate(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := "default"
|
||||||
|
if userVal, exists := param.Keys["username"]; exists {
|
||||||
|
if name, ok := userVal.(string); ok {
|
||||||
|
username = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
|
||||||
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
statusColor, param.StatusCode, resetColor,
|
||||||
|
param.Latency,
|
||||||
|
param.ClientIP,
|
||||||
|
username,
|
||||||
|
methodColor, param.Method, resetColor,
|
||||||
|
param.Path,
|
||||||
|
param.ErrorMessage,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errRequired = errors.New("is required")
|
errRequired = errors.New("is required")
|
||||||
errBadTemplate = errors.New("template error")
|
errBadTemplate = errors.New("template error")
|
||||||
@@ -1111,6 +1146,43 @@ func allowedHost(host string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
|
||||||
|
if token == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("authentication error", "error", err)
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
|
||||||
|
c.Set("username", name)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("authorization error", "error", err)
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !authorized {
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
@@ -1177,10 +1249,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
}
|
}
|
||||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.New()
|
||||||
r.HandleMethodNotAllowed = true
|
r.HandleMethodNotAllowed = true
|
||||||
r.Use(
|
r.Use(
|
||||||
|
gin.LoggerWithFormatter(loggedFormatter),
|
||||||
|
gin.Recovery(),
|
||||||
cors.New(corsConfig),
|
cors.New(corsConfig),
|
||||||
|
allowedEndpointsMiddleware(s.perms),
|
||||||
allowedHostsMiddleware(s.addr),
|
allowedHostsMiddleware(s.addr),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1190,7 +1265,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||||
|
|
||||||
// Local model cache management (new implementation is at end of function)
|
// Local model cache management
|
||||||
r.POST("/api/pull", s.PullHandler)
|
r.POST("/api/pull", s.PullHandler)
|
||||||
r.POST("/api/push", s.PushHandler)
|
r.POST("/api/push", s.PushHandler)
|
||||||
r.HEAD("/api/tags", s.ListHandler)
|
r.HEAD("/api/tags", s.ListHandler)
|
||||||
@@ -1222,7 +1297,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
// wrap old with new
|
// wrap old with new
|
||||||
rs := ®istry.Local{
|
rs := ®istry.Local{
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
Logger: slog.Default(),
|
||||||
Fallback: r,
|
Fallback: r,
|
||||||
|
|
||||||
Prune: PruneLayers,
|
Prune: PruneLayers,
|
||||||
@@ -1267,6 +1342,12 @@ func Serve(ln net.Listener) error {
|
|||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
|
||||||
|
if envconfig.UseAuth() {
|
||||||
|
perms := auth.NewAPIPermissions()
|
||||||
|
perms.ReloadIfNeeded()
|
||||||
|
s.perms = perms
|
||||||
|
}
|
||||||
|
|
||||||
var rc *ollama.Registry
|
var rc *ollama.Registry
|
||||||
if useClient2 {
|
if useClient2 {
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
Reference in New Issue
Block a user