test(GODT-2181): Fix test server proxy goroutine leaks

This commit is contained in:
James Houlahan
2022-12-13 03:40:57 +01:00
committed by James
parent fd06b106da
commit 0ed178e88e
6 changed files with 81 additions and 29 deletions

View File

@@ -30,5 +30,5 @@ func (m *Manager) QuarkRes(ctx context.Context, command string, args ...string)
return nil, err
}
return []byte(goquery.NewDocumentFromNode(doc).Find(".content").Text()), nil
return []byte(strings.TrimSpace(goquery.NewDocumentFromNode(doc).Find(".content").Text())), nil
}

View File

@@ -338,6 +338,11 @@ func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error {
})
}
// TODO: Implement this when we support subscriptions in the test server.
func (b *Backend) CreateSubscription(userID, planID string) error {
return nil
}
func (b *Backend) CreateMessage(
userID, addrID string,
subject string,

View File

@@ -18,6 +18,9 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) {
case "user:create:address":
return s.quarkUserCreateAddress(args...)
case "user:create:subscription":
return s.quarkUserCreateSubscription(args...)
default:
return nil, fmt.Errorf("unknown command: %s", command)
}
@@ -26,7 +29,7 @@ func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) {
func (s *Backend) quarkEncryptionID(args ...string) (string, error) {
fs := flag.NewFlagSet("encryption:id", flag.ContinueOnError)
// Required arguments.
// Positional arguments.
// arg0: value
decrypt := fs.Bool("decrypt", false, "decrypt the given encrypted ID")
@@ -46,11 +49,9 @@ func (s *Backend) quarkEncryptionID(args ...string) (string, error) {
func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) {
fs := flag.NewFlagSet("user:create", flag.ContinueOnError)
// Required arguments.
// Flag arguments.
name := fs.String("name", "", "new user's name")
pass := fs.String("password", "", "new user's password")
// Optional arguments.
newAddr := fs.Bool("create-address", false, "create the user's default address, will not automatically setup the address key")
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
@@ -76,12 +77,12 @@ func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) {
func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) {
fs := flag.NewFlagSet("user:create:address", flag.ContinueOnError)
// Required arguments.
// Positional arguments.
// arg0: userID
// arg1: password
// arg2: email
// Optional arguments.
// Flag arguments.
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
if err := fs.Parse(args); err != nil {
@@ -96,3 +97,23 @@ func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error)
return s.GetAddress(fs.Arg(0), addrID)
}
func (s *Backend) quarkUserCreateSubscription(args ...string) (any, error) {
fs := flag.NewFlagSet("user:create:subscription", flag.ContinueOnError)
// Positional arguments.
// arg0: userID
// Flag arguments.
planID := fs.String("planID", "", "plan ID for the user")
if err := fs.Parse(args); err != nil {
return nil, err
}
if err := s.CreateSubscription(fs.Arg(0), *planID); err != nil {
return proton.Address{}, fmt.Errorf("failed to create subscription: %w", err)
}
return nil, nil
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/gin-gonic/gin"
)
func newProxy(proxyOrigin, base, path string) http.HandlerFunc {
func newProxy(proxyOrigin, base, path string, transport http.RoundTripper) http.HandlerFunc {
origin, err := url.Parse(proxyOrigin)
if err != nil {
panic(err)
@@ -28,13 +28,13 @@ func newProxy(proxyOrigin, base, path string) http.HandlerFunc {
req.Host = origin.Host
},
Transport: proton.InsecureTransport(),
Transport: transport,
}).ServeHTTP
}
func (s *Server) handleProxy(base string) gin.HandlerFunc {
return func(c *gin.Context) {
proxy := newProxyServer(s.proxyOrigin, base)
proxy := newProxyServer(s.proxyOrigin, base, s.proxyTransport)
proxy.handle("/", s.handleProxyAll)
@@ -138,13 +138,16 @@ type proxyServer struct {
mux *http.ServeMux
origin, base string
transport http.RoundTripper
}
func newProxyServer(origin, base string) *proxyServer {
func newProxyServer(origin, base string, transport http.RoundTripper) *proxyServer {
return &proxyServer{
mux: http.NewServeMux(),
origin: origin,
base: base,
mux: http.NewServeMux(),
origin: origin,
base: base,
transport: transport,
}
}
@@ -158,7 +161,7 @@ func (s *proxyServer) handle(path string, h func(func(string) HandlerFunc) http.
buf := new(bytes.Buffer)
// Call the proxy, capturing whatever data it writes.
newProxy(s.origin, s.base, path)(&writerWrapper{w, buf}, r)
newProxy(s.origin, s.base, path, s.transport)(&writerWrapper{w, buf}, r)
// If there is a gzip header entry, decode it.
if strings.Contains(w.Header().Get("Content-Encoding"), "gzip") {

View File

@@ -1,6 +1,7 @@
package server
import (
"net/http"
"net/http/httptest"
"sync"
"time"
@@ -42,6 +43,9 @@ type Server struct {
// proxyOrigin is the URL of the origin server when the server is a proxy.
proxyOrigin string
// proxyTransport is the transport to use when the server is a proxy.
proxyTransport *http.Transport
// authCacher can optionally be set to cache proxied auth calls.
authCacher AuthCacher
@@ -192,5 +196,6 @@ func (s *Server) RevokeUser(userID string) error {
}
func (s *Server) Close() {
s.proxyTransport.CloseIdleConnections()
s.s.Close()
}

View File

@@ -2,6 +2,7 @@ package server
import (
"io"
"net/http"
"net/http/httptest"
"os"
"time"
@@ -12,12 +13,13 @@ import (
)
type serverBuilder struct {
withTLS bool
domain string
logger io.Writer
origin string
cacher AuthCacher
rateLimiter *rateLimiter
withTLS bool
domain string
logger io.Writer
origin string
proxyTransport *http.Transport
cacher AuthCacher
rateLimiter *rateLimiter
}
func newServerBuilder() *serverBuilder {
@@ -30,10 +32,11 @@ func newServerBuilder() *serverBuilder {
}
return &serverBuilder{
withTLS: true,
domain: "proton.local",
logger: logger,
origin: proton.DefaultHostURL,
withTLS: true,
domain: "proton.local",
logger: logger,
origin: proton.DefaultHostURL,
proxyTransport: &http.Transport{},
}
}
@@ -44,10 +47,11 @@ func (builder *serverBuilder) build() *Server {
r: gin.New(),
b: backend.New(time.Hour, builder.domain),
domain: builder.domain,
proxyOrigin: builder.origin,
authCacher: builder.cacher,
rateLimit: builder.rateLimiter,
domain: builder.domain,
proxyOrigin: builder.origin,
authCacher: builder.cacher,
rateLimit: builder.rateLimiter,
proxyTransport: builder.proxyTransport,
}
if builder.withTLS {
@@ -161,3 +165,17 @@ type withRateLimit struct {
func (opt withRateLimit) config(builder *serverBuilder) {
builder.rateLimiter = newRateLimiter(opt.limit, opt.window)
}
func WithProxyTransport(transport *http.Transport) Option {
return &withProxyTransport{
transport: transport,
}
}
type withProxyTransport struct {
transport *http.Transport
}
func (opt withProxyTransport) config(builder *serverBuilder) {
builder.proxyTransport = opt.transport
}