Compare commits

...

3 Commits

Author SHA1 Message Date
Patrick Devine
53a53702e0 gofumpt the file 2025-07-29 21:37:05 -07:00
Patrick Devine
be04fcde16 feed the linter 2025-07-29 21:20:22 -07:00
Patrick Devine
5968989a7f server: add authorized_keys file
This change adds an "authorized_keys" file similar to sshd which can control
access to an Ollama server. The file itself is very simple and consists of
various entries for Ollama public keys.

The format is:

<key format> <public key> <name> [<endpoint>,...]

Examples:

ssh-ed25519 AAAAC3NzaC1lZDI1NT... bob /api/tags,/api/ps,/api/show,/api/generate,/api/chat

Use the "*" wildcard symbol to substitute any value, e.g.:

To grant full access to "bob":
ssh-ed25519 AAAAC3NzaC1lZDI1NT... bob *

To allow all callers to view tags (i.e. "ollama ls"):
* * * /api/tags

- The key format must be set to "ssh-ed25519" or set to the wildcard character.
- The public key must be an ssh based ed25519 (Ollama) public key or set to the wildcard
  character.
- Name can be any string you wish to associate with the public key. Note that if a public
  key is used in more than one entry in the file, the first instance of the name will be
  used and subsequent name values will be ignored.
- Endpoints is a comma separated list of Ollama Server API endpoints or the wildcard
  character. The HTTP method is not currently needed, but could be added in the future.
2025-07-29 20:50:01 -07:00
6 changed files with 525 additions and 31 deletions

View File

@@ -42,6 +42,23 @@ type Client struct {
func checkError(resp *http.Response, body []byte) error {
if resp.StatusCode < http.StatusBadRequest {
if len(body) == 0 {
return nil
}
// streams can contain error message even with StatusOK
var errorResponse struct {
Error string `json:"error,omitempty"`
}
if err := json.Unmarshal(body, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
return nil
}
@@ -213,25 +230,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
}
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
if err := checkError(response, bts); err != nil {
return err
}
if err := fn(bts); err != nil {

View File

@@ -89,16 +89,6 @@ func TestClientStream(t *testing.T) {
},
wantErr: "mid-stream error",
},
{
name: "http status error takes precedence over general error",
responses: []any{
testError{
message: "custom error message",
statusCode: http.StatusInternalServerError,
},
},
wantErr: "500",
},
{
name: "successful stream completion",
responses: []any{

View File

@@ -18,6 +18,8 @@ import (
const defaultPrivateKey = "id_ed25519"
var ErrInvalidToken = errors.New("invalid token")
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
@@ -27,6 +29,39 @@ func keyPath() (string, error) {
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func parseToken(token string) (key, sig []byte, _ error) {
keyData, sigData, ok := strings.Cut(token, ":")
if !ok {
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
}
sig, err := base64.StdEncoding.DecodeString(sigData)
if err != nil {
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
}
return []byte(keyData), sig, nil
}
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
keyShort, sigBytes, err := parseToken(token)
if err != nil {
return nil, err
}
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
if err != nil {
return nil, err
}
if err := pub.Verify([]byte(checkData), &ssh.Signature{
Format: pub.Type(),
Blob: sigBytes,
}); err != nil {
return nil, err
}
return pub, nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {

254
auth/authorized_keys.go Normal file
View File

@@ -0,0 +1,254 @@
package auth
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
type KeyEntry struct {
Name string
PublicKey string
Endpoints []string
}
type KeyPermission struct {
Name string
Endpoints []string
}
type APIPermissions struct {
permissions map[string]*KeyPermission
lastModified time.Time
mutex sync.RWMutex
}
var ws = regexp.MustCompile(`\s+`)
func authkeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "authorized_keys"), nil
}
func NewAPIPermissions() *APIPermissions {
return &APIPermissions{
permissions: make(map[string]*KeyPermission),
mutex: sync.RWMutex{},
}
}
func (ap *APIPermissions) ReloadIfNeeded() error {
ap.mutex.Lock()
defer ap.mutex.Unlock()
filename, err := authkeyPath()
if err != nil {
return err
}
fileInfo, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("failed to stat file: %v", err)
}
if !fileInfo.ModTime().After(ap.lastModified) {
return nil
}
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ap.lastModified = fileInfo.ModTime()
return ap.parse(file)
}
func (ap *APIPermissions) parse(r io.Reader) error {
ap.permissions = make(map[string]*KeyPermission)
scanner := bufio.NewScanner(r)
var cnt int
for scanner.Scan() {
cnt += 1
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
line = ws.ReplaceAllString(line, " ")
entry, err := ap.parseLine(line)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
continue
}
var pubKeyStr string
if entry.PublicKey == "*" {
pubKeyStr = "*"
} else {
pubKey, err := ap.validateAndDecodeKey(entry)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
continue
}
pubKeyStr = pubKey
}
if perm, exists := ap.permissions[pubKeyStr]; exists {
if perm.Name == "default" {
perm.Name = entry.Name
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
// skip redundant entries
continue
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
// overwrite redundant entries
perm.Endpoints = entry.Endpoints
} else {
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
}
} else {
ap.permissions[pubKeyStr] = &KeyPermission{
Name: entry.Name,
Endpoints: entry.Endpoints,
}
}
}
return scanner.Err()
}
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
parts := strings.SplitN(line, " ", 4)
if len(parts) < 2 {
return nil, fmt.Errorf("key type and public key not found")
}
kind, b64Key := parts[0], parts[1]
name := "default"
eps := "*"
if len(parts) >= 3 && parts[2] != "" {
if parts[2] != "*" {
name = parts[2]
}
}
if len(parts) == 4 && parts[3] != "" {
eps = parts[3]
}
if kind != "ssh-ed25519" && kind != "*" {
return nil, fmt.Errorf("unsupported key type %s", kind)
}
if kind == "*" && b64Key != "*" {
return nil, fmt.Errorf("unsupported key type")
}
var endpoints []string
if eps == "*" {
endpoints = []string{"*"}
} else {
for _, e := range strings.Split(eps, ",") {
e = strings.TrimSpace(e)
if e == "" {
return nil, fmt.Errorf("empty endpoint in list")
} else if e == "*" {
endpoints = []string{"*"}
break
}
endpoints = append(endpoints, e)
}
}
return &KeyEntry{
PublicKey: b64Key,
Name: name,
Endpoints: endpoints,
}, nil
}
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
pub, err := ssh.ParsePublicKey(keyBlob)
if err != nil {
return "", fmt.Errorf("parse key: %w", err)
}
if pub.Type() != ssh.KeyAlgoED25519 {
return "", fmt.Errorf("key is not Ed25519")
}
return entry.PublicKey, nil
}
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
if err := ap.ReloadIfNeeded(); err != nil {
return false, "unknown", err
}
ap.mutex.RLock()
defer ap.mutex.RUnlock()
if wildcardPerm, exists := ap.permissions["*"]; exists {
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
return true, wildcardPerm.Name, nil
}
for _, allowedEndpoint := range wildcardPerm.Endpoints {
if allowedEndpoint == endpoint {
return true, wildcardPerm.Name, nil
}
}
}
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
parts := strings.SplitN(keyString, " ", 2)
var base64Key string
if len(parts) > 1 {
base64Key = parts[1]
} else {
base64Key = parts[0]
}
base64Key = strings.TrimSpace(base64Key)
perm, exists := ap.permissions[base64Key]
if !exists {
return false, "unknown", nil
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
return true, perm.Name, nil
}
for _, allowedEndpoint := range perm.Endpoints {
if allowedEndpoint == endpoint {
return true, perm.Name, nil
}
}
return false, "unknown", nil
}

View File

@@ -0,0 +1,133 @@
package auth
import (
"bytes"
"reflect"
"testing"
)
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
func TestParse(t *testing.T) {
tests := []struct {
name string
file string
want map[string]*KeyPermission
}{
{
name: "two fields only defaults",
file: "ssh-ed25519 " + validB64 + "\n",
want: map[string]*KeyPermission{
validB64: {
Name: "default",
Endpoints: []string{"*"},
},
},
},
{
name: "extra whitespace collapsed and default endpoints",
file: "ssh-ed25519 " + validB64 + " alice\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "four fields full",
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
want: map[string]*KeyPermission{
validB64: {
Name: "bob",
Endpoints: []string{"/api/foo", "/api/bar"},
},
},
},
{
name: "comment lines ignored and multiple entries",
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
want: map[string]*KeyPermission{
validB64: {
Name: "user1",
Endpoints: []string{"*"},
},
},
},
{
name: "three entries variety",
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "two entries w/ wildcard",
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
want: map[string]*KeyPermission{
validB64: {
Name: "alice",
Endpoints: []string{"/api/a"},
},
"*": {
Name: "default",
Endpoints: []string{"/api/b"},
},
},
},
{
name: "tags for everyone",
file: "* * * /api/tags",
want: map[string]*KeyPermission{
"*": {
Name: "default",
Endpoints: []string{"/api/tags"},
},
},
},
{
name: "default name",
file: "* * somename",
want: map[string]*KeyPermission{
"*": {
Name: "somename",
Endpoints: []string{"*"},
},
},
},
{
name: "unsupported key type",
file: "ssh-rsa AAAAB3Nza...\n",
want: map[string]*KeyPermission{},
},
{
name: "bad base64",
file: "ssh-ed25519 invalid@@@\n",
want: map[string]*KeyPermission{},
},
{
name: "just an asterix",
file: "*\n",
want: map[string]*KeyPermission{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
perms := NewAPIPermissions()
err := perms.parse(bytes.NewBufferString(tc.file))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(perms.permissions) != len(tc.want) {
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
}
if !reflect.DeepEqual(perms.permissions, tc.want) {
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
}
})
}
}

View File

@@ -28,6 +28,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
@@ -55,6 +56,8 @@ var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
perms *auth.APIPermissions
}
func init() {
@@ -69,6 +72,38 @@ func init() {
gin.SetMode(mode)
}
func loggedFormatter(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
username := "default"
if userVal, exists := param.Keys["username"]; exists {
if name, ok := userVal.(string); ok {
username = name
}
}
return fmt.Sprintf(
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
username,
methodColor, param.Method, resetColor,
param.Path,
param.ErrorMessage,
)
}
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
@@ -1111,6 +1146,43 @@ func allowedHost(host string) bool {
return false
}
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
return func(c *gin.Context) {
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
c.Next()
return
}
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
if err != nil {
slog.Error("authentication error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
c.Set("username", name)
if err != nil {
slog.Error("authorization error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
if !authorized {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.Next()
}
}
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
@@ -1177,10 +1249,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
}
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r := gin.New()
r.HandleMethodNotAllowed = true
r.Use(
gin.LoggerWithFormatter(loggedFormatter),
gin.Recovery(),
cors.New(corsConfig),
allowedEndpointsMiddleware(s.perms),
allowedHostsMiddleware(s.addr),
)
@@ -1190,7 +1265,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management (new implementation is at end of function)
// Local model cache management
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.HEAD("/api/tags", s.ListHandler)
@@ -1222,7 +1297,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// wrap old with new
rs := &registry.Local{
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Logger: slog.Default(),
Fallback: r,
Prune: PruneLayers,
@@ -1267,6 +1342,12 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
if envconfig.UseAuth() {
perms := auth.NewAPIPermissions()
perms.ReloadIfNeeded()
s.perms = perms
}
var rc *ollama.Registry
if useClient2 {
var err error