This commit is contained in:
binwiederhier
2025-07-21 17:44:00 +02:00
parent 51af114b2e
commit f59df0f40a
13 changed files with 333 additions and 140 deletions

View File

@@ -232,7 +232,7 @@ cli-deps-update:
go get -u
go install honnef.co/go/tools/cmd/staticcheck@latest
go install golang.org/x/lint/golint@latest
go install github.com/goreleaser/goreleaser@latest
go install github.com/goreleaser/goreleaser/v2@latest
cli-build-results:
cat dist/config.yaml

View File

@@ -105,8 +105,10 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
return err
}
u, err := manager.User(username)
if err == user.ErrUserNotFound {
if errors.Is(err, user.ErrUserNotFound) {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err
} else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
}
@@ -175,7 +177,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error {
func showUserAccess(c *cli.Context, manager *user.Manager, username string) error {
users, err := manager.User(username)
if err == user.ErrUserNotFound {
if errors.Is(err, user.ErrUserNotFound) {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err
@@ -193,19 +195,27 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error
if u.Tier != nil {
tier = u.Tier.Name
}
fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s)\n", u.Name, u.Role, tier)
provisioned := ""
if u.Provisioned {
provisioned = ", provisioned user"
}
fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s%s)\n", u.Name, u.Role, tier, provisioned)
if u.Role == user.RoleAdmin {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n")
} else if len(grants) > 0 {
for _, grant := range grants {
if grant.Allow.IsReadWrite() {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern)
} else if grant.Allow.IsRead() {
fmt.Fprintf(c.App.ErrWriter, "- read-only access to topic %s\n", grant.TopicPattern)
} else if grant.Allow.IsWrite() {
fmt.Fprintf(c.App.ErrWriter, "- write-only access to topic %s\n", grant.TopicPattern)
grantProvisioned := ""
if grant.Provisioned {
grantProvisioned = ", provisioned access entry"
}
if grant.Permission.IsReadWrite() {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s%s\n", grant.TopicPattern, grantProvisioned)
} else if grant.Permission.IsRead() {
fmt.Fprintf(c.App.ErrWriter, "- read-only access to topic %s%s\n", grant.TopicPattern, grantProvisioned)
} else if grant.Permission.IsWrite() {
fmt.Fprintf(c.App.ErrWriter, "- write-only access to topic %s%s\n", grant.TopicPattern, grantProvisioned)
} else {
fmt.Fprintf(c.App.ErrWriter, "- no access to topic %s\n", grant.TopicPattern)
fmt.Fprintf(c.App.ErrWriter, "- no access to topic %s%s\n", grant.TopicPattern, grantProvisioned)
}
}
} else {

View File

@@ -48,7 +48,8 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-startup-queries", Aliases: []string{"auth_startup_queries"}, EnvVars: []string{"NTFY_AUTH_STARTUP_QUERIES"}, Usage: "queries run when the auth database is initialized"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provisioned-users", Aliases: []string{"auth_provisioned_users"}, EnvVars: []string{"NTFY_AUTH_PROVISIONED_USERS"}, Usage: "pre-provisioned declarative users"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provision-users", Aliases: []string{"auth_provision_users"}, EnvVars: []string{"NTFY_AUTH_PROVISION_USERS"}, Usage: "pre-provisioned declarative users"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-provision-access", Aliases: []string{"auth_provision_access"}, EnvVars: []string{"NTFY_AUTH_PROVISION_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
@@ -155,8 +156,8 @@ func execServe(c *cli.Context) error {
authFile := c.String("auth-file")
authStartupQueries := c.String("auth-startup-queries")
authDefaultAccess := c.String("auth-default-access")
authProvisionedUsersRaw := c.StringSlice("auth-provisioned-users")
//authProvisionedAccessRaw := c.StringSlice("auth-provisioned-access")
authProvisionUsersRaw := c.StringSlice("auth-provision-users")
authProvisionAccessRaw := c.StringSlice("auth-provision-access")
attachmentCacheDir := c.String("attachment-cache-dir")
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
@@ -352,27 +353,13 @@ func execServe(c *cli.Context) error {
if err != nil {
return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
}
authProvisionedUsers := make([]*user.User, 0)
for _, userLine := range authProvisionedUsersRaw {
parts := strings.Split(userLine, ":")
if len(parts) != 3 {
return fmt.Errorf("invalid provisioned user %s, expected format: 'name:hash:role'", userLine)
}
username := strings.TrimSpace(parts[0])
passwordHash := strings.TrimSpace(parts[1])
role := user.Role(strings.TrimSpace(parts[2]))
if !user.AllowedUsername(username) {
return fmt.Errorf("invalid provisioned user %s, username invalid", userLine)
} else if passwordHash == "" {
return fmt.Errorf("invalid provisioned user %s, password hash cannot be empty", userLine)
} else if !user.AllowedRole(role) {
return fmt.Errorf("invalid provisioned user %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role)
}
authProvisionedUsers = append(authProvisionedUsers, &user.User{
Name: username,
Hash: passwordHash,
Role: role,
})
authProvisionUsers, err := parseProvisionUsers(authProvisionUsersRaw)
if err != nil {
return err
}
authProvisionAccess, err := parseProvisionAccess(authProvisionUsers, authProvisionAccessRaw)
if err != nil {
return err
}
// Special case: Unset default
@@ -429,8 +416,8 @@ func execServe(c *cli.Context) error {
conf.AuthFile = authFile
conf.AuthStartupQueries = authStartupQueries
conf.AuthDefault = authDefault
conf.AuthProvisionedUsers = authProvisionedUsers
conf.AuthProvisionedAccess = nil // FIXME
conf.AuthProvisionedUsers = authProvisionUsers
conf.AuthProvisionedAccess = authProvisionAccess
conf.AttachmentCacheDir = attachmentCacheDir
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
@@ -544,6 +531,76 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
return
}
func parseProvisionUsers(usersRaw []string) ([]*user.User, error) {
provisionUsers := make([]*user.User, 0)
for _, userLine := range usersRaw {
parts := strings.Split(userLine, ":")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid auth-provision-users: %s, expected format: 'name:hash:role'", userLine)
}
username := strings.TrimSpace(parts[0])
passwordHash := strings.TrimSpace(parts[1])
role := user.Role(strings.TrimSpace(parts[2]))
if !user.AllowedUsername(username) {
return nil, fmt.Errorf("invalid auth-provision-users: %s, username invalid", userLine)
} else if passwordHash == "" {
return nil, fmt.Errorf("invalid auth-provision-users: %s, password hash cannot be empty", userLine)
} else if !user.AllowedRole(role) {
return nil, fmt.Errorf("invalid auth-provision-users: %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role)
}
provisionUsers = append(provisionUsers, &user.User{
Name: username,
Hash: passwordHash,
Role: role,
Provisioned: true,
})
}
return provisionUsers, nil
}
func parseProvisionAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[string][]*user.Grant, error) {
access := make(map[string][]*user.Grant)
for _, accessLine := range provisionAccessRaw {
parts := strings.Split(accessLine, ":")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid auth-provision-access: %s, expected format: 'user:topic:permission'", accessLine)
}
username := strings.TrimSpace(parts[0])
if username == userEveryone {
username = user.Everyone
}
provisionUser, exists := util.Find(provisionUsers, func(u *user.User) bool {
return u.Name == username
})
if username != user.Everyone {
if !exists {
return nil, fmt.Errorf("invalid auth-provision-access: %s, user %s is not provisioned", accessLine, username)
} else if !user.AllowedUsername(username) {
return nil, fmt.Errorf("invalid auth-provision-access: %s, username %s invalid", accessLine, username)
} else if provisionUser.Role != user.RoleUser {
return nil, fmt.Errorf("invalid auth-provision-access: %s, user %s is not a regular user, only regular users can have ACL entries", accessLine, username)
}
}
topic := strings.TrimSpace(parts[1])
if !user.AllowedTopicPattern(topic) {
return nil, fmt.Errorf("invalid auth-provision-access: %s, topic pattern %s invalid", accessLine, topic)
}
permission, err := user.ParsePermission(strings.TrimSpace(parts[2]))
if err != nil {
return nil, fmt.Errorf("invalid auth-provision-access: %s, permission %s invalid, %s", accessLine, parts[2], err.Error())
}
if _, exists := access[username]; !exists {
access[username] = make([]*user.Grant, 0)
}
access[username] = append(access[username], &user.Grant{
TopicPattern: topic,
Permission: permission,
Provisioned: true,
})
}
return access, nil
}
func reloadLogLevel(inputSource altsrc.InputSourceContext) error {
newLevelStr, err := inputSource.String("log-level")
if err != nil {

View File

@@ -349,8 +349,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
Filename: authFile,
StartupQueries: authStartupQueries,
DefaultAccess: authDefault,
ProvisionedUsers: nil, //FIXME
ProvisionedAccess: nil, //FIXME
ProvisionEnabled: false, // Do not re-provision users on manager initialization
BcryptCost: user.DefaultUserPasswordBcryptCost,
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
}

31
go.sum
View File

@@ -1,11 +1,7 @@
cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY=
cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw=
cloud.google.com/go v0.121.3 h1:84RD+hQXNdY5Sw/MWVAx5O9Aui/rd5VQ9HEcdN19afo=
cloud.google.com/go v0.121.3/go.mod h1:6vWF3nJWRrEUv26mMB3FEIU/o1MQNVPG1iHdisa2SJc=
cloud.google.com/go v0.121.4 h1:cVvUiY0sX0xwyxPwdSU2KsF9knOVmtRyAMt8xou0iTs=
cloud.google.com/go v0.121.4/go.mod h1:XEBchUiHFJbz4lKBZwYBDHV/rSyfFktk737TLDU089s=
cloud.google.com/go/auth v0.16.2 h1:QvBAGFPLrDeoiNjyfVunhQ10HKNYuOwZ5noee0M5df4=
cloud.google.com/go/auth v0.16.2/go.mod h1:sRBas2Y1fB1vZTdurouM0AzuYQBMZinrUYL8EufhtEA=
cloud.google.com/go/auth v0.16.3 h1:kabzoQ9/bobUmnseYnBO6qQG7q4a/CffFRlJSxv2wCc=
cloud.google.com/go/auth v0.16.3/go.mod h1:NucRGjaXfzP1ltpcQ7On/VTZ0H4kWB5Jy+Y9Dnm76fA=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
@@ -26,8 +22,6 @@ cloud.google.com/go/storage v1.55.0 h1:NESjdAToN9u1tmhVqhXCaCwYBuvEhZLLv0gBr+2zn
cloud.google.com/go/storage v1.55.0/go.mod h1:ztSmTTwzsdXe5syLVS0YsbFxXuvEmEyZj7v7zChEmuY=
cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4=
cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI=
firebase.google.com/go/v4 v4.16.1 h1:Kl5cgXmM0VOWDGT1UAx6b0T2UFWa14ak0CvYqeI7Py4=
firebase.google.com/go/v4 v4.16.1/go.mod h1:aAPJq/bOyb23tBlc1K6GR+2E8sOGAeJSc8wIJVgl9SM=
firebase.google.com/go/v4 v4.17.0 h1:Bih69QV/k0YKPA1qUX04ln0aPT9IERrAo2ezibcngzE=
firebase.google.com/go/v4 v4.17.0/go.mod h1:aAPJq/bOyb23tBlc1K6GR+2E8sOGAeJSc8wIJVgl9SM=
github.com/AlekSi/pointer v1.2.0 h1:glcy/gc4h8HnG2Z3ZECSzZ1IX1x2JxRVuDzaJwQE0+w=
@@ -87,8 +81,6 @@ github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0=
github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
@@ -106,8 +98,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4=
github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
github.com/googleapis/gax-go/v2 v2.14.2 h1:eBLnkZ9635krYIPD+ag1USrOAI0Nr0QYF3+/3GqO0k0=
github.com/googleapis/gax-go/v2 v2.14.2/go.mod h1:ON64QhlJkhVtSqp4v1uaK92VyZ2gmvDQsweuyLV+8+w=
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
@@ -194,8 +184,6 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@@ -212,8 +200,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
@@ -225,8 +211,6 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -241,8 +225,6 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
@@ -254,8 +236,6 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -269,8 +249,6 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
@@ -283,23 +261,14 @@ golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.240.0 h1:PxG3AA2UIqT1ofIzWV2COM3j3JagKTKSwy7L6RHNXNU=
google.golang.org/api v0.240.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
google.golang.org/api v0.242.0 h1:7Lnb1nfnpvbkCiZek6IXKdJ0MFuAZNAJKQfA1ws62xg=
google.golang.org/api v0.242.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw=
google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79 h1:Nt6z9UHqSlIdIGJdz6KhTIs2VRx/iOsA5iE8bmQNcxs=
google.golang.org/genproto v0.0.0-20250715232539-7130f93afb79/go.mod h1:kTmlBHMPqR5uCZPBvwa2B18mvubkjyY3CRLI0c6fj0s=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY=
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc=
google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79 h1:iOye66xuaAK0WnkPuhQPUFy8eJcmwUXqGGP3om6IxX8=
google.golang.org/genproto/googleapis/api v0.0.0-20250715232539-7130f93afb79/go.mod h1:HKJDgKsFUnv5VAGeQjz8kxcgDP0HoE0iZNp0OdZNlhE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79 h1:1ZwqphdOdWYXsUHgMpU/101nCtf/kSp9hOrcvFsnl10=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250715232539-7130f93afb79/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok=

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/netip"
"path/filepath"
"strings"
"time"
@@ -286,6 +287,12 @@ type messageCache struct {
// newSqliteCache creates a SQLite file-backed cache
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) {
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
}
// Open database
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err

View File

@@ -200,8 +200,9 @@ func New(conf *Config) (*Server, error) {
Filename: conf.AuthFile,
StartupQueries: conf.AuthStartupQueries,
DefaultAccess: conf.AuthDefault,
ProvisionedUsers: conf.AuthProvisionedUsers,
ProvisionedAccess: conf.AuthProvisionedAccess,
ProvisionEnabled: true, // Enable provisioning of users and access
ProvisionUsers: conf.AuthProvisionedUsers,
ProvisionAccess: conf.AuthProvisionedAccess,
BcryptCost: conf.AuthBcryptCost,
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
}

View File

@@ -82,6 +82,10 @@
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
# - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable
# WAL mode. This is similar to cache-startup-queries. See above for details.
# - auth-provision-users is a list of users that are automatically created when the server starts.
# Each entry is in the format "<username>:<bcrypt-hash>:<role>", e.g. "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:user"
# - auth-provision-access is a list of access control entries that are automatically created when the server starts.
# Each entry is in the format "<username>:<topic-pattern>:<access>", e.g. "phil:mytopic:rw" or "phil:phil-*:rw".
#
# Debian/RPM package users:
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
@@ -94,6 +98,8 @@
# auth-file: <filename>
# auth-default-access: "read-write"
# auth-startup-queries:
# auth-provision-users:
# auth-provision-access:
# If set, the X-Forwarded-For header (or whatever is configured in proxy-forwarded-header) is used to determine
# the visitor IP address instead of the remote address of the connection.

View File

@@ -25,7 +25,7 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit
for i, g := range grants[u.ID] {
userGrants[i] = &apiUserGrantResponse{
Topic: g.TopicPattern,
Permission: g.Allow.String(),
Permission: g.Permission.String(),
}
}
usersResponse[i] = &apiUserResponse{

View File

@@ -12,6 +12,7 @@ import (
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
"net/netip"
"path/filepath"
"strings"
"sync"
"time"
@@ -75,6 +76,7 @@ const (
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
@@ -97,6 +99,7 @@ const (
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
@@ -121,8 +124,8 @@ const (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
@@ -132,26 +135,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@@ -165,8 +168,8 @@ const (
`
insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
@@ -189,18 +192,18 @@ const (
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
selectUserAllAccessQuery = `
SELECT user_id, topic, read, write
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
selectUserAccessQuery = `
SELECT topic, read, write
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
@@ -244,7 +247,8 @@ const (
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
deleteTopicAccessQuery = `
deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
deleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
@@ -427,6 +431,15 @@ const (
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
// 5 -> 6
migrate5To6UpdateQueries = `
ALTER TABLE user ADD COLUMN provisioned INT NOT NULL DEFAULT (0);
ALTER TABLE user ALTER COLUMN provisioned DROP DEFAULT;
ALTER TABLE user_access ADD COLUMN provisioned INT NOT NULL DEFAULT (0);
ALTER TABLE user_access ALTER COLUMN provisioned DROP DEFAULT;
`
)
var (
@@ -435,6 +448,7 @@ var (
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
@@ -452,8 +466,9 @@ type Config struct {
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db"
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers
DefaultAccess Permission // Default permission if no ACL matches
ProvisionedUsers []*User // Predefined users to create on startup
ProvisionedAccess map[string][]*Grant // Predefined access grants to create on startup
ProvisionEnabled bool // Enable auto-provisioning of users and access grants
ProvisionUsers []*User // Predefined users to create on startup
ProvisionAccess map[string][]*Grant // Predefined access grants to create on startup
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
BcryptCost int // Cost of generated passwords; lowering makes testing faster
}
@@ -469,6 +484,11 @@ func NewManager(config *Config) (*Manager, error) {
if config.QueueWriterInterval.Seconds() <= 0 {
config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
}
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil {
@@ -486,7 +506,7 @@ func NewManager(config *Config) (*Manager, error) {
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}
if err := manager.provisionUsers(); err != nil {
if err := manager.maybeProvisionUsersAndAccess(); err != nil {
return nil, err
}
go manager.asyncQueueWriter(config.QueueWriterInterval)
@@ -586,7 +606,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
tokens := make([]*Token, 0)
for {
token, err := a.readToken(rows)
if err == ErrTokenNotFound {
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
@@ -884,6 +904,13 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, password, role, hashed, false)
})
}
// AddUser adds a user with the given username, password and role
func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
@@ -899,8 +926,8 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
}
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil {
if errors.Is(err, sqlite3.ErrConstraintUnique) {
return ErrUserExists
}
return err
@@ -911,11 +938,17 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
// RemoveUser deletes the user with the given username. The function returns nil on success, even
// if the user did not exist in the first place.
func (a *Manager) RemoveUser(username string) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username)
})
}
func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
// Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
if _, err := tx.Exec(deleteUserQuery, username); err != nil {
return err
}
return nil
@@ -1029,24 +1062,26 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var provisioned bool
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
user := &User{
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
Provisioned: provisioned,
Stats: &Stats{
Messages: messages,
Emails: emails,
@@ -1097,8 +1132,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
grants := make(map[string][]Grant, 0)
for rows.Next() {
var userID, topic string
var read, write bool
if err := rows.Scan(&userID, &topic, &read, &write); err != nil {
var read, write, provisioned bool
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@@ -1108,7 +1143,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
}
grants[userID] = append(grants[userID], Grant{
TopicPattern: fromSQLWildcard(topic),
Allow: NewPermission(read, write),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
@@ -1124,15 +1160,16 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
grants := make([]Grant, 0)
for rows.Next() {
var topic string
var read, write bool
if err := rows.Scan(&topic, &read, &write); err != nil {
var read, write, provisioned bool
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
grants = append(grants, Grant{
TopicPattern: fromSQLWildcard(topic),
Allow: NewPermission(read, write),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
@@ -1218,9 +1255,14 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
// ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string, hashed bool) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordTx(tx, username, password, hashed)
})
}
func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error {
var hash []byte
var err error
if hashed {
hash = []byte(password)
} else {
@@ -1229,7 +1271,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
return err
}
}
if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil {
return err
}
return nil
@@ -1238,14 +1280,20 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
// ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
// all existing access control entries (Grant) are removed, since they are no longer needed.
func (a *Manager) ChangeRole(username string, role Role) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role)
})
}
func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil {
return err
}
if role == RoleAdmin {
if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil {
return err
}
}
@@ -1325,13 +1373,19 @@ func (a *Manager) AllowReservation(username string, topic string) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false)
})
}
func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
owner := ""
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil {
return err
}
return nil
@@ -1524,20 +1578,65 @@ func (a *Manager) Close() error {
return a.db.Close()
}
func (a *Manager) provisionUsers() error {
for _, user := range a.config.ProvisionedUsers {
if err := a.AddUser(user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) {
return err
}
func (a *Manager) maybeProvisionUsersAndAccess() error {
if !a.config.ProvisionEnabled {
return nil
}
for username, grants := range a.config.ProvisionedAccess {
for _, grant := range grants {
if err := a.AllowAccess(username, grant.TopicPattern, grant.Allow); err != nil {
return err
users, err := a.Users()
if err != nil {
return err
}
provisionUsernames := util.Map(a.config.ProvisionUsers, func(u *User) string {
return u.Name
})
return execTx(a.db, func(tx *sql.Tx) error {
// Remove users that are provisioned, but not in the config anymore
for _, user := range users {
if user.Name == Everyone {
continue
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
log.Tag(tag).Info("Removing previously provisioned user %s", user.Name)
if err := a.removeUserTx(tx, user.Name); err != nil {
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
}
}
}
}
return nil
// Add or update provisioned users
for _, user := range a.config.ProvisionUsers {
if user.Name == Everyone {
continue
}
existingUser, exists := util.Find(users, func(u *User) bool {
return u.Name == user.Name
})
if !exists {
log.Tag(tag).Info("Adding provisioned user %s", user.Name)
if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
}
} else if existingUser.Hash != user.Hash || existingUser.Role != user.Role {
log.Tag(tag).Info("Updating provisioned user %s", user.Name)
if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
}
if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
}
}
}
// Remove and (re-)add provisioned grants
if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
return err
}
for username, grants := range a.config.ProvisionAccess {
for _, grant := range grants {
if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
return err
}
}
}
return nil
})
}
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
@@ -1711,6 +1810,22 @@ func migrateFrom4(db *sql.DB) error {
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
@@ -1724,3 +1839,18 @@ func nullInt64(v int64) sql.NullInt64 {
}
return sql.NullInt64{Int64: v, Valid: true}
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := f(tx); err != nil {
if e := tx.Rollback(); e != nil {
return err
}
return err
}
return tx.Commit()
}

View File

@@ -489,12 +489,12 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
benGrants, err := a.Grants("ben")
require.Nil(t, err)
require.Equal(t, 1, len(benGrants))
require.Equal(t, PermissionReadWrite, benGrants[0].Allow)
require.Equal(t, PermissionReadWrite, benGrants[0].Permission)
everyoneGrants, err := a.Grants(Everyone)
require.Nil(t, err)
require.Equal(t, 1, len(everyoneGrants))
require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow)
require.Equal(t, PermissionDenyAll, everyoneGrants[0].Permission)
benReservations, err := a.Reservations("ben")
require.Nil(t, err)
@@ -1201,16 +1201,16 @@ func TestMigrationFrom1(t *testing.T) {
require.NotEqual(t, ben.SyncTopic, phil.SyncTopic)
require.Equal(t, 2, len(benGrants))
require.Equal(t, "secret", benGrants[0].TopicPattern)
require.Equal(t, PermissionRead, benGrants[0].Allow)
require.Equal(t, PermissionRead, benGrants[0].Permission)
require.Equal(t, "stats", benGrants[1].TopicPattern)
require.Equal(t, PermissionReadWrite, benGrants[1].Allow)
require.Equal(t, PermissionReadWrite, benGrants[1].Permission)
require.Equal(t, "u_everyone", everyone.ID)
require.Equal(t, Everyone, everyone.Name)
require.Equal(t, RoleAnonymous, everyone.Role)
require.Equal(t, 1, len(everyoneGrants))
require.Equal(t, "stats", everyoneGrants[0].TopicPattern)
require.Equal(t, PermissionRead, everyoneGrants[0].Allow)
require.Equal(t, PermissionRead, everyoneGrants[0].Permission)
}
func TestMigrationFrom4(t *testing.T) {

View File

@@ -12,17 +12,18 @@ import (
// User is a struct that represents a user
type User struct {
ID string
Name string
Hash string // password hash (bcrypt)
Token string // Only set if token was used to log in
Role Role
Prefs *Prefs
Tier *Tier
Stats *Stats
Billing *Billing
SyncTopic string
Deleted bool
ID string
Name string
Hash string // Password hash (bcrypt)
Token string // Only set if token was used to log in
Role Role
Prefs *Prefs
Tier *Tier
Stats *Stats
Billing *Billing
SyncTopic string
Provisioned bool // Whether the user was provisioned by the config file
Deleted bool // Whether the user was soft-deleted
}
// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
@@ -148,7 +149,8 @@ type Billing struct {
// Grant is a struct that represents an access control entry to a topic by a user
type Grant struct {
TopicPattern string // May include wildcard (*)
Allow Permission
Permission Permission
Provisioned bool // Whether the grant was provisioned by the config file
}
// Reservation is a struct that represents the ownership over a topic by a user

View File

@@ -120,6 +120,18 @@ func Filter[T any](slice []T, f func(T) bool) []T {
return result
}
// Find returns the first element in the slice that satisfies the given function, and a boolean indicating
// whether such an element was found. If no element is found, it returns the zero value of T and false.
func Find[T any](slice []T, f func(T) bool) (T, bool) {
for _, v := range slice {
if f(v) {
return v, true
}
}
var zero T
return zero, false
}
// RandomString returns a random string with a given length
func RandomString(length int) string {
return RandomStringPrefix("", length)