mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
test(GODT-2181): Fix test server proxy goroutine leaks
This commit is contained in:
@@ -30,5 +30,5 @@ func (m *Manager) QuarkRes(ctx context.Context, command string, args ...string)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(goquery.NewDocumentFromNode(doc).Find(".content").Text()), nil
|
||||
return []byte(strings.TrimSpace(goquery.NewDocumentFromNode(doc).Find(".content").Text())), nil
|
||||
}
|
||||
|
||||
@@ -338,6 +338,11 @@ func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error {
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: Implement this when we support subscriptions in the test server.
|
||||
func (b *Backend) CreateSubscription(userID, planID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) CreateMessage(
|
||||
userID, addrID string,
|
||||
subject string,
|
||||
|
||||
@@ -18,6 +18,9 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) {
|
||||
case "user:create:address":
|
||||
return s.quarkUserCreateAddress(args...)
|
||||
|
||||
case "user:create:subscription":
|
||||
return s.quarkUserCreateSubscription(args...)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown command: %s", command)
|
||||
}
|
||||
@@ -26,7 +29,7 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) {
|
||||
func (s *Backend) quarkEncryptionID(args ...string) (string, error) {
|
||||
fs := flag.NewFlagSet("encryption:id", flag.ContinueOnError)
|
||||
|
||||
// Required arguments.
|
||||
// Positional arguments.
|
||||
// arg0: value
|
||||
|
||||
decrypt := fs.Bool("decrypt", false, "decrypt the given encrypted ID")
|
||||
@@ -46,11 +49,9 @@ func (s *Backend) quarkEncryptionID(args ...string) (string, error) {
|
||||
func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) {
|
||||
fs := flag.NewFlagSet("user:create", flag.ContinueOnError)
|
||||
|
||||
// Required arguments.
|
||||
// Flag arguments.
|
||||
name := fs.String("name", "", "new user's name")
|
||||
pass := fs.String("password", "", "new user's password")
|
||||
|
||||
// Optional arguments.
|
||||
newAddr := fs.Bool("create-address", false, "create the user's default address, will not automatically setup the address key")
|
||||
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
|
||||
|
||||
@@ -76,12 +77,12 @@ func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) {
|
||||
func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) {
|
||||
fs := flag.NewFlagSet("user:create:address", flag.ContinueOnError)
|
||||
|
||||
// Required arguments.
|
||||
// Positional arguments.
|
||||
// arg0: userID
|
||||
// arg1: password
|
||||
// arg2: email
|
||||
|
||||
// Optional arguments.
|
||||
// Flag arguments.
|
||||
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
@@ -96,3 +97,23 @@ func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error)
|
||||
|
||||
return s.GetAddress(fs.Arg(0), addrID)
|
||||
}
|
||||
|
||||
func (s *Backend) quarkUserCreateSubscription(args ...string) (any, error) {
|
||||
fs := flag.NewFlagSet("user:create:subscription", flag.ContinueOnError)
|
||||
|
||||
// Positional arguments.
|
||||
// arg0: userID
|
||||
|
||||
// Flag arguments.
|
||||
planID := fs.String("planID", "", "plan ID for the user")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.CreateSubscription(fs.Arg(0), *planID); err != nil {
|
||||
return proton.Address{}, fmt.Errorf("failed to create subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func newProxy(proxyOrigin, base, path string) http.HandlerFunc {
|
||||
func newProxy(proxyOrigin, base, path string, transport http.RoundTripper) http.HandlerFunc {
|
||||
origin, err := url.Parse(proxyOrigin)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -28,13 +28,13 @@ func newProxy(proxyOrigin, base, path string) http.HandlerFunc {
|
||||
req.Host = origin.Host
|
||||
},
|
||||
|
||||
Transport: proton.InsecureTransport(),
|
||||
Transport: transport,
|
||||
}).ServeHTTP
|
||||
}
|
||||
|
||||
func (s *Server) handleProxy(base string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
proxy := newProxyServer(s.proxyOrigin, base)
|
||||
proxy := newProxyServer(s.proxyOrigin, base, s.proxyTransport)
|
||||
|
||||
proxy.handle("/", s.handleProxyAll)
|
||||
|
||||
@@ -138,13 +138,16 @@ type proxyServer struct {
|
||||
mux *http.ServeMux
|
||||
|
||||
origin, base string
|
||||
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
func newProxyServer(origin, base string) *proxyServer {
|
||||
func newProxyServer(origin, base string, transport http.RoundTripper) *proxyServer {
|
||||
return &proxyServer{
|
||||
mux: http.NewServeMux(),
|
||||
origin: origin,
|
||||
base: base,
|
||||
mux: http.NewServeMux(),
|
||||
origin: origin,
|
||||
base: base,
|
||||
transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +161,7 @@ func (s *proxyServer) handle(path string, h func(func(string) HandlerFunc) http.
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// Call the proxy, capturing whatever data it writes.
|
||||
newProxy(s.origin, s.base, path)(&writerWrapper{w, buf}, r)
|
||||
newProxy(s.origin, s.base, path, s.transport)(&writerWrapper{w, buf}, r)
|
||||
|
||||
// If there is a gzip header entry, decode it.
|
||||
if strings.Contains(w.Header().Get("Content-Encoding"), "gzip") {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -42,6 +43,9 @@ type Server struct {
|
||||
// proxyOrigin is the URL of the origin server when the server is a proxy.
|
||||
proxyOrigin string
|
||||
|
||||
// proxyTransport is the transport to use when the server is a proxy.
|
||||
proxyTransport *http.Transport
|
||||
|
||||
// authCacher can optionally be set to cache proxied auth calls.
|
||||
authCacher AuthCacher
|
||||
|
||||
@@ -192,5 +196,6 @@ func (s *Server) RevokeUser(userID string) error {
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.proxyTransport.CloseIdleConnections()
|
||||
s.s.Close()
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"time"
|
||||
@@ -12,12 +13,13 @@ import (
|
||||
)
|
||||
|
||||
type serverBuilder struct {
|
||||
withTLS bool
|
||||
domain string
|
||||
logger io.Writer
|
||||
origin string
|
||||
cacher AuthCacher
|
||||
rateLimiter *rateLimiter
|
||||
withTLS bool
|
||||
domain string
|
||||
logger io.Writer
|
||||
origin string
|
||||
proxyTransport *http.Transport
|
||||
cacher AuthCacher
|
||||
rateLimiter *rateLimiter
|
||||
}
|
||||
|
||||
func newServerBuilder() *serverBuilder {
|
||||
@@ -30,10 +32,11 @@ func newServerBuilder() *serverBuilder {
|
||||
}
|
||||
|
||||
return &serverBuilder{
|
||||
withTLS: true,
|
||||
domain: "proton.local",
|
||||
logger: logger,
|
||||
origin: proton.DefaultHostURL,
|
||||
withTLS: true,
|
||||
domain: "proton.local",
|
||||
logger: logger,
|
||||
origin: proton.DefaultHostURL,
|
||||
proxyTransport: &http.Transport{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,10 +47,11 @@ func (builder *serverBuilder) build() *Server {
|
||||
r: gin.New(),
|
||||
b: backend.New(time.Hour, builder.domain),
|
||||
|
||||
domain: builder.domain,
|
||||
proxyOrigin: builder.origin,
|
||||
authCacher: builder.cacher,
|
||||
rateLimit: builder.rateLimiter,
|
||||
domain: builder.domain,
|
||||
proxyOrigin: builder.origin,
|
||||
authCacher: builder.cacher,
|
||||
rateLimit: builder.rateLimiter,
|
||||
proxyTransport: builder.proxyTransport,
|
||||
}
|
||||
|
||||
if builder.withTLS {
|
||||
@@ -161,3 +165,17 @@ type withRateLimit struct {
|
||||
func (opt withRateLimit) config(builder *serverBuilder) {
|
||||
builder.rateLimiter = newRateLimiter(opt.limit, opt.window)
|
||||
}
|
||||
|
||||
func WithProxyTransport(transport *http.Transport) Option {
|
||||
return &withProxyTransport{
|
||||
transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
type withProxyTransport struct {
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (opt withProxyTransport) config(builder *serverBuilder) {
|
||||
builder.proxyTransport = opt.transport
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user