diff --git a/internal.go b/internal.go index 533de4a..179b28a 100644 --- a/internal.go +++ b/internal.go @@ -30,5 +30,5 @@ func (m *Manager) QuarkRes(ctx context.Context, command string, args ...string) return nil, err } - return []byte(goquery.NewDocumentFromNode(doc).Find(".content").Text()), nil + return []byte(strings.TrimSpace(goquery.NewDocumentFromNode(doc).Find(".content").Text())), nil } diff --git a/server/backend/backend.go b/server/backend/backend.go index 7d8da26..753bcf7 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -338,6 +338,11 @@ func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error { }) } +// TODO: Implement this when we support subscriptions in the test server. +func (b *Backend) CreateSubscription(userID, planID string) error { + return nil +} + func (b *Backend) CreateMessage( userID, addrID string, subject string, diff --git a/server/backend/quark.go b/server/backend/quark.go index d256838..46505c4 100644 --- a/server/backend/quark.go +++ b/server/backend/quark.go @@ -18,6 +18,9 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) { case "user:create:address": return s.quarkUserCreateAddress(args...) + case "user:create:subscription": + return s.quarkUserCreateSubscription(args...) + default: return nil, fmt.Errorf("unknown command: %s", command) } @@ -26,7 +29,7 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) { func (s *Backend) quarkEncryptionID(args ...string) (string, error) { fs := flag.NewFlagSet("encryption:id", flag.ContinueOnError) - // Required arguments. + // Positional arguments. // arg0: value decrypt := fs.Bool("decrypt", false, "decrypt the given encrypted ID") @@ -46,11 +49,9 @@ func (s *Backend) quarkEncryptionID(args ...string) (string, error) { func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) { fs := flag.NewFlagSet("user:create", flag.ContinueOnError) - // Required arguments. + // Flag arguments. name := fs.String("name", "", "new user's name") pass := fs.String("password", "", "new user's password") - - // Optional arguments. newAddr := fs.Bool("create-address", false, "create the user's default address, will not automatically setup the address key") genKeys := fs.String("gen-keys", "", "generate new address keys for the user") @@ -76,12 +77,12 @@ func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) { func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) { fs := flag.NewFlagSet("user:create:address", flag.ContinueOnError) - // Required arguments. + // Positional arguments. // arg0: userID // arg1: password // arg2: email - // Optional arguments. + // Flag arguments. genKeys := fs.String("gen-keys", "", "generate new address keys for the user") if err := fs.Parse(args); err != nil { @@ -96,3 +97,23 @@ func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) return s.GetAddress(fs.Arg(0), addrID) } + +func (s *Backend) quarkUserCreateSubscription(args ...string) (any, error) { + fs := flag.NewFlagSet("user:create:subscription", flag.ContinueOnError) + + // Positional arguments. + // arg0: userID + + // Flag arguments. + planID := fs.String("planID", "", "plan ID for the user") + + if err := fs.Parse(args); err != nil { + return nil, err + } + + if err := s.CreateSubscription(fs.Arg(0), *planID); err != nil { + return proton.Address{}, fmt.Errorf("failed to create subscription: %w", err) + } + + return nil, nil +} diff --git a/server/proxy.go b/server/proxy.go index 89e8ed1..70b3868 100644 --- a/server/proxy.go +++ b/server/proxy.go @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" ) -func newProxy(proxyOrigin, base, path string) http.HandlerFunc { +func newProxy(proxyOrigin, base, path string, transport http.RoundTripper) http.HandlerFunc { origin, err := url.Parse(proxyOrigin) if err != nil { panic(err) @@ -28,13 +28,13 @@ func newProxy(proxyOrigin, base, path string) http.HandlerFunc { req.Host = origin.Host }, - Transport: proton.InsecureTransport(), + Transport: transport, }).ServeHTTP } func (s *Server) handleProxy(base string) gin.HandlerFunc { return func(c *gin.Context) { - proxy := newProxyServer(s.proxyOrigin, base) + proxy := newProxyServer(s.proxyOrigin, base, s.proxyTransport) proxy.handle("/", s.handleProxyAll) @@ -138,13 +138,16 @@ type proxyServer struct { mux *http.ServeMux origin, base string + + transport http.RoundTripper } -func newProxyServer(origin, base string) *proxyServer { +func newProxyServer(origin, base string, transport http.RoundTripper) *proxyServer { return &proxyServer{ - mux: http.NewServeMux(), - origin: origin, - base: base, + mux: http.NewServeMux(), + origin: origin, + base: base, + transport: transport, } } @@ -158,7 +161,7 @@ func (s *proxyServer) handle(path string, h func(func(string) HandlerFunc) http. buf := new(bytes.Buffer) // Call the proxy, capturing whatever data it writes. - newProxy(s.origin, s.base, path)(&writerWrapper{w, buf}, r) + newProxy(s.origin, s.base, path, s.transport)(&writerWrapper{w, buf}, r) // If there is a gzip header entry, decode it. if strings.Contains(w.Header().Get("Content-Encoding"), "gzip") { diff --git a/server/server.go b/server/server.go index 3e797c8..f1403bb 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "net/http" "net/http/httptest" "sync" "time" @@ -42,6 +43,9 @@ type Server struct { // proxyOrigin is the URL of the origin server when the server is a proxy. proxyOrigin string + // proxyTransport is the transport to use when the server is a proxy. + proxyTransport *http.Transport + // authCacher can optionally be set to cache proxied auth calls. authCacher AuthCacher @@ -192,5 +196,6 @@ func (s *Server) RevokeUser(userID string) error { } func (s *Server) Close() { + s.proxyTransport.CloseIdleConnections() s.s.Close() } diff --git a/server/server_builder.go b/server/server_builder.go index fda65cd..6545090 100644 --- a/server/server_builder.go +++ b/server/server_builder.go @@ -2,6 +2,7 @@ package server import ( "io" + "net/http" "net/http/httptest" "os" "time" @@ -12,12 +13,13 @@ import ( ) type serverBuilder struct { - withTLS bool - domain string - logger io.Writer - origin string - cacher AuthCacher - rateLimiter *rateLimiter + withTLS bool + domain string + logger io.Writer + origin string + proxyTransport *http.Transport + cacher AuthCacher + rateLimiter *rateLimiter } func newServerBuilder() *serverBuilder { @@ -30,10 +32,11 @@ func newServerBuilder() *serverBuilder { } return &serverBuilder{ - withTLS: true, - domain: "proton.local", - logger: logger, - origin: proton.DefaultHostURL, + withTLS: true, + domain: "proton.local", + logger: logger, + origin: proton.DefaultHostURL, + proxyTransport: &http.Transport{}, } } @@ -44,10 +47,11 @@ func (builder *serverBuilder) build() *Server { r: gin.New(), b: backend.New(time.Hour, builder.domain), - domain: builder.domain, - proxyOrigin: builder.origin, - authCacher: builder.cacher, - rateLimit: builder.rateLimiter, + domain: builder.domain, + proxyOrigin: builder.origin, + authCacher: builder.cacher, + rateLimit: builder.rateLimiter, + proxyTransport: builder.proxyTransport, } if builder.withTLS { @@ -161,3 +165,17 @@ type withRateLimit struct { func (opt withRateLimit) config(builder *serverBuilder) { builder.rateLimiter = newRateLimiter(opt.limit, opt.window) } + +func WithProxyTransport(transport *http.Transport) Option { + return &withProxyTransport{ + transport: transport, + } +} + +type withProxyTransport struct { + transport *http.Transport +} + +func (opt withProxyTransport) config(builder *serverBuilder) { + builder.proxyTransport = opt.transport +}