mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
275 lines
7.5 KiB
Go
275 lines
7.5 KiB
Go
package server
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Masterminds/semver/v3"
|
|
"github.com/ProtonMail/go-proton-api"
|
|
"github.com/ProtonMail/go-proton-api/server/backend"
|
|
"github.com/bradenaw/juniper/xslices"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type AuthCacher interface {
|
|
GetAuthInfo(username string) (proton.AuthInfo, bool)
|
|
SetAuthInfo(username string, info proton.AuthInfo)
|
|
GetAuth(username string) (proton.Auth, bool)
|
|
SetAuth(username string, auth proton.Auth)
|
|
}
|
|
|
|
// StatusHook is a function that can be used to modify the response code of a call.
|
|
type StatusHook func(*http.Request) (int, bool)
|
|
|
|
type Server struct {
|
|
// r is the gin router.
|
|
r *gin.Engine
|
|
|
|
// s is the underlying server.
|
|
s *httptest.Server
|
|
|
|
// b is the server backend, which manages accounts, messages, attachments, etc.
|
|
b *backend.Backend
|
|
|
|
// callWatchers records callWatchers received by the server.
|
|
callWatchers []callWatcher
|
|
callWatchersLock sync.RWMutex
|
|
|
|
// statusHooks are hooks that can be used to modify the response code of a call.
|
|
statusHooks []StatusHook
|
|
statusHooksLock sync.RWMutex
|
|
|
|
// domain is the test server domain.
|
|
domain string
|
|
|
|
// minAppVersion is the minimum app version that the server will accept.
|
|
minAppVersion *semver.Version
|
|
|
|
// proxyOrigin is the URL of the origin server when the server is a proxy.
|
|
proxyOrigin string
|
|
|
|
// proxyTransport is the transport to use when the server is a proxy.
|
|
proxyTransport *http.Transport
|
|
|
|
// authCacher can optionally be set to cache proxied auth calls.
|
|
authCacher AuthCacher
|
|
|
|
// offline is whether to pretend the server is offline and return 5xx errors.
|
|
offline bool
|
|
|
|
// rateLimit is the rate limiter for the server.
|
|
rateLimit *rateLimiter
|
|
}
|
|
|
|
func New(opts ...Option) *Server {
|
|
builder := newServerBuilder()
|
|
|
|
for _, opt := range opts {
|
|
opt.config(builder)
|
|
}
|
|
|
|
return builder.build()
|
|
}
|
|
|
|
// GetHostURL returns the API root to make calls to.
|
|
func (s *Server) GetHostURL() string {
|
|
return s.s.URL
|
|
}
|
|
|
|
// GetProxyURL returns the API root to make calls to which should be proxied.
|
|
func (s *Server) GetProxyURL() string {
|
|
return s.s.URL + "/proxy"
|
|
}
|
|
|
|
// GetDomain returns the domain of the server (e.g. "proton.local").
|
|
func (s *Server) GetDomain() string {
|
|
return s.domain
|
|
}
|
|
|
|
// AddCallWatcher adds a call watcher to the server.
|
|
func (s *Server) AddCallWatcher(fn func(Call), paths ...string) {
|
|
s.callWatchersLock.Lock()
|
|
defer s.callWatchersLock.Unlock()
|
|
|
|
s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...))
|
|
}
|
|
|
|
// AddStatusHook adds a status hook to the server.
|
|
func (s *Server) AddStatusHook(fn StatusHook) {
|
|
s.statusHooksLock.Lock()
|
|
defer s.statusHooksLock.Unlock()
|
|
|
|
s.statusHooks = append(s.statusHooks, fn)
|
|
}
|
|
|
|
// CreateUser creates a new server user with the given username and password.
|
|
// A single address will be created for the user, derived from the username and the server's domain.
|
|
func (s *Server) CreateUser(username string, password []byte) (string, string, error) {
|
|
userID, err := s.b.CreateUser(username, password)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
addrID, err := s.b.CreateAddress(userID, username+"@"+s.domain, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal, true)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return userID, addrID, nil
|
|
}
|
|
|
|
func (s *Server) RemoveUser(userID string) error {
|
|
return s.b.RemoveUser(userID)
|
|
}
|
|
|
|
func (s *Server) RefreshUser(userID string, refresh proton.RefreshFlag) error {
|
|
return s.b.RefreshUser(userID, refresh)
|
|
}
|
|
|
|
func (s *Server) GetUserKeyIDs(userID string) ([]string, error) {
|
|
user, err := s.b.GetUser(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return xslices.Map(user.Keys, func(key proton.Key) string {
|
|
return key.ID
|
|
}), nil
|
|
}
|
|
|
|
func (s *Server) CreateUserKey(userID string, password []byte) error {
|
|
return s.b.CreateUserKey(userID, password)
|
|
}
|
|
|
|
func (s *Server) RemoveUserKey(userID, keyID string) error {
|
|
return s.b.RemoveUserKey(userID, keyID)
|
|
}
|
|
|
|
func (s *Server) CreateAddress(userID, email string, password []byte, withSend bool) (string, error) {
|
|
return s.b.CreateAddress(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal, withSend)
|
|
}
|
|
|
|
func (s *Server) CreateExternalAddress(userID, email string, password []byte, withSend bool) (string, error) {
|
|
return s.b.CreateAddress(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeExternal, withSend)
|
|
}
|
|
|
|
func (s *Server) CreateAddressAsUpdate(userID, email string, password []byte) (string, error) {
|
|
return s.b.CreateAddressAsUpdate(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal)
|
|
}
|
|
|
|
func (s *Server) ChangeAddressAllowSend(userID, addrID string, allowSend bool) error {
|
|
return s.b.ChangeAddressAllowSend(userID, addrID, allowSend)
|
|
}
|
|
|
|
func (s *Server) SetAddressOrder(userID string, addrIDs []string) error {
|
|
return s.b.SetAddressOrder(userID, addrIDs)
|
|
}
|
|
|
|
func (s *Server) ChangeAddressType(userID, addrId string, addrType proton.AddressType) error {
|
|
return s.b.ChangeAddressType(userID, addrId, addrType)
|
|
}
|
|
|
|
func (s *Server) ChangeAddressDisplayName(userID, addrID, displayName string) error {
|
|
return s.b.ChangeAddressDisplayName(userID, addrID, displayName)
|
|
}
|
|
|
|
func (s *Server) RemoveAddress(userID, addrID string) error {
|
|
return s.b.RemoveAddress(userID, addrID)
|
|
}
|
|
|
|
func (s *Server) CreateAddressKey(userID, addrID string, password []byte) error {
|
|
return s.b.CreateAddressKey(userID, addrID, password)
|
|
}
|
|
|
|
func (s *Server) RemoveAddressKey(userID, addrID, keyID string) error {
|
|
return s.b.RemoveAddressKey(userID, addrID, keyID)
|
|
}
|
|
|
|
func (s *Server) CreateLabel(userID, name, parentID string, labelType proton.LabelType) (string, error) {
|
|
label, err := s.b.CreateLabel(userID, name, parentID, labelType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return label.ID, nil
|
|
}
|
|
|
|
func (s *Server) GetLabels(userID string) ([]proton.Label, error) {
|
|
return s.b.GetLabels(userID)
|
|
}
|
|
|
|
func (s *Server) LabelMessage(userID, msgID, labelID string) error {
|
|
return s.b.LabelMessages(userID, labelID, msgID)
|
|
}
|
|
|
|
func (s *Server) UnlabelMessage(userID, msgID, labelID string) error {
|
|
return s.b.UnlabelMessages(userID, labelID, msgID)
|
|
}
|
|
|
|
func (s *Server) AddAddressCreatedEvent(userID, addrID string) error {
|
|
return s.b.AddAddressCreatedUpdate(userID, addrID)
|
|
}
|
|
|
|
func (s *Server) AddAddressUpdatedEvent(userID, addrID string) error {
|
|
return s.b.AddAddressUpdatedEvent(userID, addrID)
|
|
}
|
|
|
|
func (s *Server) AddLabelCreatedEvent(userID, labelID string) error {
|
|
return s.b.AddLabelCreatedUpdate(userID, labelID)
|
|
}
|
|
|
|
func (s *Server) AddMessageCreatedEvent(userID, messageID string) error {
|
|
return s.b.AddMessageCreatedUpdate(userID, messageID)
|
|
}
|
|
|
|
// SetMaxUpdatesPerEvent
|
|
func (s *Server) SetMaxUpdatesPerEvent(max int) {
|
|
s.b.SetMaxUpdatesPerEvent(max)
|
|
}
|
|
|
|
func (s *Server) SetAuthLife(authLife time.Duration) {
|
|
s.b.SetAuthLife(authLife)
|
|
}
|
|
|
|
func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) {
|
|
s.minAppVersion = minAppVersion
|
|
}
|
|
|
|
func (s *Server) SetOffline(offline bool) {
|
|
s.offline = offline
|
|
}
|
|
|
|
func (s *Server) RevokeUser(userID string) error {
|
|
sessions, err := s.b.GetSessions(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, session := range sessions {
|
|
if err := s.b.DeleteSession(userID, session.UID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) Close() {
|
|
s.proxyTransport.CloseIdleConnections()
|
|
s.s.Close()
|
|
}
|
|
|
|
func (s *Server) PushFeatureFlag(flagName string) {
|
|
s.b.PushFeatureFlag(flagName)
|
|
}
|
|
|
|
func (s *Server) DeleteFeatureFlags() {
|
|
s.b.DeleteFeatureFlags()
|
|
}
|
|
|
|
func (s *Server) GetObservabilityStatistics() backend.ObservabilityStatistics {
|
|
return s.b.GetObservabilityStatistics()
|
|
}
|