From 3813d14cae28d52b111a413b9a02efb49c20bcbb Mon Sep 17 00:00:00 2001
From: Pascal Bleser
Date: Thu, 4 Sep 2025 22:16:44 +0200
Subject: [PATCH] refactor(groupware): session cache and DNS autodiscovery
* move the logging of the username and session state away from pkg/jmap
and into services/groupware
* introduce more decoupling for the session cache, as well as moving
the implementation into groupware_session.go
---
pkg/jmap/jmap_api.go | 20 +-
pkg/jmap/jmap_client.go | 2 +
pkg/jmap/jmap_http.go | 42 ++--
pkg/jmap/jmap_session.go | 10 -
.../groupware/pkg/groupware/groupware_dns.go | 35 +--
.../pkg/groupware/groupware_framework.go | 167 +++++-------
.../pkg/groupware/groupware_request.go | 4 +-
.../groupware/pkg/groupware/groupware_reva.go | 16 +-
.../pkg/groupware/groupware_session.go | 238 +++++++++++++++---
9 files changed, 330 insertions(+), 204 deletions(-)
diff --git a/pkg/jmap/jmap_api.go b/pkg/jmap/jmap_api.go
index d02276b236..e6f77997a5 100644
--- a/pkg/jmap/jmap_api.go
+++ b/pkg/jmap/jmap_api.go
@@ -23,15 +23,13 @@ type BlobClient interface {
}
const (
- logOperation = "operation"
- logUsername = "username"
- logMailboxId = "mailbox-id"
- logFetchBodies = "fetch-bodies"
- logOffset = "offset"
- logLimit = "limit"
- logDownloadUrl = "download-url"
- logBlobId = "blob-id"
- logUploadUrl = "download-url"
- logSessionState = "session-state"
- logSince = "since"
+ logOperation = "operation"
+ logMailboxId = "mailbox-id"
+ logFetchBodies = "fetch-bodies"
+ logOffset = "offset"
+ logLimit = "limit"
+ logDownloadUrl = "download-url"
+ logBlobId = "blob-id"
+ logUploadUrl = "download-url"
+ logSince = "since"
)
diff --git a/pkg/jmap/jmap_client.go b/pkg/jmap/jmap_client.go
index 1fad905a1e..f55db0661e 100644
--- a/pkg/jmap/jmap_client.go
+++ b/pkg/jmap/jmap_client.go
@@ -16,6 +16,8 @@ type Client struct {
io.Closer
}
+var _ io.Closer = &Client{}
+
func (j *Client) Close() error {
return j.api.Close()
}
diff --git a/pkg/jmap/jmap_http.go b/pkg/jmap/jmap_http.go
index 56d1a1ade2..bdc633bd37 100644
--- a/pkg/jmap/jmap_http.go
+++ b/pkg/jmap/jmap_http.go
@@ -123,10 +123,11 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
//sessionUrl := baseurl.JoinPath(".well-known", "jmap")
sessionUrlStr := sessionUrl.String()
endpoint := endpointOf(sessionUrl)
+ logger = log.From(logger.With().Str(logEndpoint, endpoint))
req, err := http.NewRequest(http.MethodGet, sessionUrlStr, nil)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", sessionUrl)
+ logger.Error().Err(err).Msgf("failed to create GET request for %v", sessionUrl)
return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err}
}
h.auth(username, logger, req)
@@ -135,12 +136,12 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", sessionUrl)
+ logger.Error().Err(err).Msgf("failed to perform GET %v", sessionUrl)
return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
- logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200")
+ logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200")
return SessionResponse{}, SimpleError{code: JmapErrorServerResponse, err: fmt.Errorf("JMAP API response status is %v", res.Status)}
}
h.listener.OnSuccessfulRequest(endpoint, res.StatusCode)
@@ -156,7 +157,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
body, err := io.ReadAll(res.Body)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
+ logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return SessionResponse{}, SimpleError{code: JmapErrorReadingResponseBody, err: err}
}
@@ -164,7 +165,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
var data SessionResponse
err = json.Unmarshal(body, &data)
if err != nil {
- logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response")
+ logger.Error().Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response")
h.listener.OnResponseBodyUnmarshallingError(endpoint, err)
return SessionResponse{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err}
}
@@ -175,16 +176,17 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, session *Session, request Request) ([]byte, Error) {
jmapUrl := session.JmapUrl.String()
endpoint := session.JmapEndpoint
+ logger = log.From(logger.With().Str(logEndpoint, endpoint))
bodyBytes, err := json.Marshal(request)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to marshall JSON payload")
+ logger.Error().Err(err).Msg("failed to marshall JSON payload")
return nil, SimpleError{code: JmapErrorEncodingRequestBody, err: err}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, jmapUrl, bytes.NewBuffer(bodyBytes))
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", jmapUrl)
+ logger.Error().Err(err).Msgf("failed to create POST request for %v", jmapUrl)
return nil, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("Content-Type", "application/json")
@@ -194,7 +196,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", jmapUrl)
+ logger.Error().Err(err).Msgf("failed to perform POST %v", jmapUrl)
return nil, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
@@ -214,7 +216,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
body, err := io.ReadAll(res.Body)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
+ logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return nil, SimpleError{code: JmapErrorServerResponse, err: err}
}
@@ -223,9 +225,11 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
}
func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, session *Session, uploadUrl string, endpoint string, contentType string, body io.Reader) (UploadedBlob, Error) {
+ logger = log.From(logger.With().Str(logEndpoint, endpoint))
+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl, body)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", uploadUrl)
+ logger.Error().Err(err).Msgf("failed to create POST request for %v", uploadUrl)
return UploadedBlob{}, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("Content-Type", contentType)
@@ -235,12 +239,12 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", uploadUrl)
+ logger.Error().Err(err).Msgf("failed to perform POST %v", uploadUrl)
return UploadedBlob{}, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
- logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
+ logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err}
}
if res.Body != nil {
@@ -255,7 +259,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
responseBody, err := io.ReadAll(res.Body)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
+ logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err}
}
@@ -263,7 +267,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
var result UploadedBlob
err = json.Unmarshal(responseBody, &result)
if err != nil {
- logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response")
+ logger.Error().Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response")
h.listener.OnResponseBodyUnmarshallingError(endpoint, err)
return UploadedBlob{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err}
}
@@ -272,9 +276,11 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
}
func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, session *Session, downloadUrl string, endpoint string) (*BlobDownload, Error) {
+ logger = log.From(logger.With().Str(logEndpoint, endpoint))
+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", downloadUrl)
+ logger.Error().Err(err).Msgf("failed to create GET request for %v", downloadUrl)
return nil, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("User-Agent", h.userAgent)
@@ -283,7 +289,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
- logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", downloadUrl)
+ logger.Error().Err(err).Msgf("failed to perform GET %v", downloadUrl)
return nil, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode == http.StatusNotFound {
@@ -291,7 +297,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
- logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
+ logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
return nil, SimpleError{code: JmapErrorServerResponse, err: err}
}
h.listener.OnSuccessfulRequest(endpoint, res.StatusCode)
@@ -301,7 +307,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
if sizeStr != "" {
size, err = strconv.Atoi(sizeStr)
if err != nil {
- logger.Warn().Err(err).Str(logEndpoint, endpoint).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr)
+ logger.Warn().Err(err).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr)
size = -1
}
}
diff --git a/pkg/jmap/jmap_session.go b/pkg/jmap/jmap_session.go
index 3b2bd703ed..67a5ec4c14 100644
--- a/pkg/jmap/jmap_session.go
+++ b/pkg/jmap/jmap_session.go
@@ -4,8 +4,6 @@ import (
"errors"
"fmt"
"net/url"
-
- "github.com/opencloud-eu/opencloud/pkg/log"
)
type SessionEventListener interface {
@@ -101,14 +99,6 @@ func newSession(sessionResponse SessionResponse) (Session, Error) {
}, nil
}
-// Create a new log.Logger that is decorated with fields containing information about the Session.
-func (s Session) DecorateLogger(l log.Logger) *log.Logger {
- return log.From(l.With().
- Str(logUsername, log.SafeString(s.Username)).
- Str(logEndpoint, log.SafeString(s.JmapEndpoint)).
- Str(logSessionState, log.SafeString(string(s.State))))
-}
-
func endpointOf(u *url.URL) string {
if u != nil {
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)
diff --git a/services/groupware/pkg/groupware/groupware_dns.go b/services/groupware/pkg/groupware/groupware_dns.go
index 6ef89b677f..15ef97d1e2 100644
--- a/services/groupware/pkg/groupware/groupware_dns.go
+++ b/services/groupware/pkg/groupware/groupware_dns.go
@@ -17,17 +17,22 @@ var (
)
type DnsSessionUrlResolver struct {
- defaultSessionUrl *url.URL
- defaultDomain string
- domainGreenList []string
- domainRedList []string
- config *dns.ClientConfig
- client *dns.Client
+ defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError)
+ defaultDomain string
+ domainGreenList []string
+ domainRedList []string
+ config *dns.ClientConfig
+ client *dns.Client
}
-func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string,
- config *dns.ClientConfig, domainGreenList []string, domainRedList []string,
- dialTimeout time.Duration, readTimeout time.Duration,
+func NewDnsSessionUrlResolver(
+ defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError),
+ defaultDomain string,
+ config *dns.ClientConfig,
+ domainGreenList []string,
+ domainRedList []string,
+ dialTimeout time.Duration,
+ readTimeout time.Duration,
) (DnsSessionUrlResolver, error) {
// TODO the whole udp or tcp dialier configuration, see https://github.com/miekg/exdns/blob/master/q/q.go
@@ -37,10 +42,10 @@ func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string,
}
return DnsSessionUrlResolver{
- defaultSessionUrl: defaultSessionUrl,
- defaultDomain: defaultDomain,
- config: config,
- client: c,
+ defaultSessionUrlSupplier: defaultSessionUrlSupplier,
+ defaultDomain: defaultDomain,
+ config: config,
+ client: c,
}, nil
}
@@ -75,7 +80,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr
// nevertheless then?
if d.defaultDomain == "" {
// we don't, then let's fall back to the static session URL instead
- return d.defaultSessionUrl, nil
+ return d.defaultSessionUrlSupplier(username)
}
} else {
domain = parts[len(parts)-1]
@@ -133,7 +138,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr
}
}
- return d.defaultSessionUrl, nil
+ return d.defaultSessionUrlSupplier(username)
}
func (d DnsSessionUrlResolver) dnsQuery(c *dns.Client, msg *dns.Msg) (*dns.Msg, error) {
diff --git a/services/groupware/pkg/groupware/groupware_framework.go b/services/groupware/pkg/groupware/groupware_framework.go
index d4d00fd954..1baf84ca92 100644
--- a/services/groupware/pkg/groupware/groupware_framework.go
+++ b/services/groupware/pkg/groupware/groupware_framework.go
@@ -18,8 +18,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
- "github.com/jellydator/ttlcache/v3"
-
cmap "github.com/orcaman/concurrent-map"
"github.com/opencloud-eu/opencloud/pkg/jmap"
@@ -30,8 +28,9 @@ import (
)
const (
- logUsername = "username" // this should match jmap.logUsername to avoid having the field twice in the logs under different keys
+ logUsername = "username"
logUserId = "user-id"
+ logSessionState = "session-state"
logAccountId = "account-id"
logErrorId = "error-id"
logErrorCode = "code"
@@ -51,17 +50,17 @@ const (
logMethod = "method"
)
-// Minimalistic representation of a User, containing only the attributes that are
+// Minimalistic representation of a user, containing only the attributes that are
// necessary for the Groupware implementation.
-type User interface {
+type user interface {
GetUsername() string
GetId() string
}
// Provides a User that is associated with a request.
-type UserProvider interface {
+type userProvider interface {
// Provide the user for JMAP operations.
- GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error)
+ GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error)
}
// Background job that needs to be executed asynchronously by the Groupware.
@@ -89,9 +88,9 @@ type Groupware struct {
defaultEmailLimit uint
maxBodyValueBytes uint
// Caches successful and failed Sessions by the username.
- sessionCache *ttlcache.Cache[sessionKey, cachedSession]
+ sessionCache sessionCache
jmap *jmap.Client
- userProvider UserProvider
+ userProvider userProvider
// SSE events that need to be pushed to clients.
eventChannel chan Event
// Background jobs that need to be executed.
@@ -127,10 +126,14 @@ type Event struct {
Body any
}
+// A jmap.HttpJmapApiClientEventListener implementation that records those JMAP
+// events as metric increments.
type groupwareHttpJmapApiClientMetricsRecorder struct {
m *metrics.Metrics
}
+var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{}
+
func (r groupwareHttpJmapApiClientMetricsRecorder) OnSuccessfulRequest(endpoint string, status int) {
r.m.SuccessfulRequestPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc()
}
@@ -147,8 +150,6 @@ func (r groupwareHttpJmapApiClientMetricsRecorder) OnResponseBodyUnmarshallingEr
r.m.ResponseBodyUnmarshallingErrorPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc()
}
-var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{}
-
func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prometheusRegistry prometheus.Registerer) (*Groupware, error) {
baseUrl, err := url.Parse(config.Mail.BaseUrl)
if err != nil {
@@ -190,111 +191,68 @@ func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prome
m := metrics.New(prometheusRegistry, logger)
// TODO add timeouts and other meaningful configuration settings for the HTTP client
- tr := http.DefaultTransport.(*http.Transport).Clone()
- tr.ResponseHeaderTimeout = responseHeaderTimeout
+ httpTransport := http.DefaultTransport.(*http.Transport).Clone()
+ httpTransport.ResponseHeaderTimeout = responseHeaderTimeout
if insecureTls {
tlsConfig := &tls.Config{InsecureSkipVerify: true} // TODO make configurable
- tr.TLSClientConfig = tlsConfig
+ httpTransport.TLSClientConfig = tlsConfig
}
- c := *http.DefaultClient
- c.Transport = tr
+ httpClient := *http.DefaultClient
+ httpClient.Transport = httpTransport
- userProvider := NewRevaContextUsernameProvider()
+ userProvider := newRevaContextUsernameProvider()
jmapMetricsAdapter := groupwareHttpJmapApiClientMetricsRecorder{m: m}
api := jmap.NewHttpJmapClient(
- &c,
+ &httpClient,
masterUsername,
masterPassword,
jmapMetricsAdapter,
)
+ // api implements all three interfaces:
jmapClient := jmap.NewClient(api, api, api)
- var sessionCache *ttlcache.Cache[sessionKey, cachedSession]
- {
- sessionUrlResolver := func(_ string) (*url.URL, *GroupwareError) {
- return sessionUrl, nil
- }
- if useDnsForSessionResolution {
- defaultSessionDomain := "example.com" // TODO default domain from configuration
- // TODO resolv.conf or other configuration
- conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
- if err != nil {
- return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err}
- }
-
- var domainGreenList []string = nil // TODO domain greenlist from configuration
- var domainRedList []string = nil // TODO domain redlist from configuration
-
- dialTimeout := time.Duration(2) * time.Second // TODO configuration
- readTimeout := time.Duration(2) * time.Second // TODO configuration
-
- dnsSessionUrlResolver, err := NewDnsSessionUrlResolver(
- sessionUrl,
- defaultSessionDomain,
- conf,
- domainGreenList,
- domainRedList,
- dialTimeout,
- readTimeout,
- )
- if err != nil {
- return nil, GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err}
- }
- sessionUrlResolver = dnsSessionUrlResolver.Resolve
+ sessionCacheBuilder := newSessionCacheBuilder(
+ sessionUrl,
+ logger,
+ jmapClient.FetchSession,
+ prometheusRegistry,
+ m,
+ sessionCacheMaxCapacity,
+ sessionCacheTtl,
+ sessionFailureCacheTtl,
+ )
+ if useDnsForSessionResolution {
+ conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
+ if err != nil {
+ return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err}
}
- sessionLoader := &sessionCacheLoader{
- logger: logger,
- jmapClient: &jmapClient,
- errorTtl: sessionFailureCacheTtl,
- sessionUrlProvider: sessionUrlResolver,
- }
+ var dnsDomainGreenList []string = nil // TODO domain greenlist from configuration
+ var dnsDomainRedList []string = nil // TODO domain redlist from configuration
+ dnsDialTimeout := time.Duration(2) * time.Second // TODO DNS server connection timeout configuration
+ dnsReadTimeout := time.Duration(2) * time.Second // TODO DNS server response reading timeout configuration
+ defaultDomain := "example.com" // TODO default domain when the username is not an email address configuration
- sessionCache = ttlcache.New(
- ttlcache.WithCapacity[sessionKey, cachedSession](sessionCacheMaxCapacity),
- ttlcache.WithTTL[sessionKey, cachedSession](sessionCacheTtl),
- ttlcache.WithDisableTouchOnHit[sessionKey, cachedSession](),
- ttlcache.WithLoader(sessionLoader),
+ sessionCacheBuilder = sessionCacheBuilder.withDnsAutoDiscovery(
+ defaultDomain,
+ conf,
+ dnsDialTimeout,
+ dnsReadTimeout,
+ dnsDomainGreenList,
+ dnsDomainRedList,
)
- go sessionCache.Start()
-
- prometheusRegistry.Register(sessionCacheMetricsCollector{desc: m.SessionCacheDesc, supply: sessionCache.Metrics})
}
- sessionCache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionKey, cachedSession]) {
- if logger.Trace().Enabled() {
- reason := ""
- switch r {
- case ttlcache.EvictionReasonDeleted:
- reason = "deleted"
- case ttlcache.EvictionReasonCapacityReached:
- reason = "capacity reached"
- case ttlcache.EvictionReasonExpired:
- reason = fmt.Sprintf("expired after %v", item.TTL())
- case ttlcache.EvictionReasonMaxCostExceeded:
- reason = "max cost exceeded"
- }
- if reason == "" {
- reason = fmt.Sprintf("unknown (%v)", r)
- }
- spentInCache := time.Since(item.Value().Since())
- tipe := "successful"
- if !item.Value().Success() {
- tipe = "failed"
- }
- logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason)
- }
- })
+ sessionCache, err := sessionCacheBuilder.build()
- sessionEventListener := sessionEventListener{
- sessionCache: sessionCache,
- logger: logger,
- counter: m.OutdatedSessionsCounter,
+ if err != nil {
+ // assuming that the error was logged in great detail upstream
+ return nil, GroupwareInitializationError{Message: "failed to initialize the session cache", Err: err}
}
- jmapClient.AddSessionEventListener(&sessionEventListener)
+ jmapClient.AddSessionEventListener(sessionCache)
// A channel to process SSE Events with a single worker.
eventChannel := make(chan Event, eventChannelSize)
@@ -426,7 +384,7 @@ func (g *Groupware) listenForEvents() {
}
}
-func (g *Groupware) push(user User, typ string, body any) {
+func (g *Groupware) push(user user, typ string, body any) {
g.metrics.SSEEventsCounter.WithLabelValues(typ).Inc()
g.eventChannel <- Event{Type: typ, Stream: user.GetUsername(), Body: body}
}
@@ -467,16 +425,13 @@ func (g *Groupware) ServeSSE(w http.ResponseWriter, r *http.Request) {
}
// Provide a JMAP Session for the
-func (g *Groupware) session(user User, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) {
- item := g.sessionCache.Get(toSessionKey(user.GetUsername()))
- if item != nil {
- value := item.Value()
- if value != nil {
- if value.Success() {
- return value.Get(), true, nil
- } else {
- return jmap.Session{}, false, value.Error()
- }
+func (g *Groupware) session(user user, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) {
+ s := g.sessionCache.Get(user.GetUsername())
+ if s != nil {
+ if s.Success() {
+ return s.Get(), true, nil
+ } else {
+ return jmap.Session{}, false, s.Error()
}
}
return jmap.Session{}, false, nil
@@ -562,7 +517,7 @@ func (g *Groupware) withSession(w http.ResponseWriter, r *http.Request, handler
g.serveError(w, r, apiError(errorId, *gwerr))
return Response{}, false
}
- decoratedLogger := session.DecorateLogger(*logger)
+ decoratedLogger := decorateLogger(logger, session)
req := Request{
g: g,
@@ -666,7 +621,7 @@ func (g *Groupware) stream(w http.ResponseWriter, r *http.Request, handler func(
return
}
- decoratedLogger := session.DecorateLogger(*logger)
+ decoratedLogger := decorateLogger(logger, session)
req := Request{
r: r,
diff --git a/services/groupware/pkg/groupware/groupware_request.go b/services/groupware/pkg/groupware/groupware_request.go
index 06bdc0e402..cbe01b6e2c 100644
--- a/services/groupware/pkg/groupware/groupware_request.go
+++ b/services/groupware/pkg/groupware/groupware_request.go
@@ -27,7 +27,7 @@ import (
// the parameter list of every single handler function
type Request struct {
g *Groupware
- user User
+ user user
r *http.Request
ctx context.Context
logger *log.Logger
@@ -38,7 +38,7 @@ func (r Request) push(typ string, event any) {
r.g.push(r.user, typ, event)
}
-func (r Request) GetUser() User {
+func (r Request) GetUser() user {
return r.user
}
diff --git a/services/groupware/pkg/groupware/groupware_reva.go b/services/groupware/pkg/groupware/groupware_reva.go
index d9f0a50a47..4e2704f4ce 100644
--- a/services/groupware/pkg/groupware/groupware_reva.go
+++ b/services/groupware/pkg/groupware/groupware_reva.go
@@ -15,9 +15,9 @@ import (
type revaContextUserProvider struct {
}
-var _ UserProvider = revaContextUserProvider{}
+var _ userProvider = revaContextUserProvider{}
-func NewRevaContextUsernameProvider() UserProvider {
+func newRevaContextUsernameProvider() userProvider {
return revaContextUserProvider{}
}
@@ -27,26 +27,26 @@ var (
errUserNotInRevaContext = errors.New("failed to find user in reva context")
)
-func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error) {
+func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error) {
u, ok := revactx.ContextGetUser(ctx)
if !ok {
err := errUserNotInRevaContext
logger.Error().Err(err).Ctx(ctx).Msgf("could not get user: user not in reva context: %v", ctx)
return nil, err
}
- return RevaUser{user: u}, nil
+ return revaUser{user: u}, nil
}
-type RevaUser struct {
+type revaUser struct {
user *userv1beta1.User
}
-func (r RevaUser) GetUsername() string {
+func (r revaUser) GetUsername() string {
return r.user.GetUsername()
}
-func (r RevaUser) GetId() string {
+func (r revaUser) GetId() string {
return r.user.GetId().GetOpaqueId()
}
-var _ User = RevaUser{}
+var _ user = revaUser{}
diff --git a/services/groupware/pkg/groupware/groupware_session.go b/services/groupware/pkg/groupware/groupware_session.go
index e5b63035ff..a85748f55b 100644
--- a/services/groupware/pkg/groupware/groupware_session.go
+++ b/services/groupware/pkg/groupware/groupware_session.go
@@ -1,10 +1,13 @@
package groupware
import (
+ "context"
+ "fmt"
"net/url"
"time"
"github.com/jellydator/ttlcache/v3"
+ "github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/opencloud-eu/opencloud/pkg/jmap"
@@ -12,28 +15,41 @@ import (
"github.com/opencloud-eu/opencloud/services/groupware/pkg/metrics"
)
-type sessionKey string
+// An alias for the internal session cache key, which might become something composed in the future.
+type sessionCacheKey string
-func toSessionKey(username string) sessionKey {
- return sessionKey(username)
+func toSessionCacheKey(username string) sessionCacheKey {
+ return sessionCacheKey(username)
}
-func usernameFromSessionKey(key sessionKey) string {
- return string(key)
+func (k sessionCacheKey) username() string {
+ return string(k)
}
+// Interface for cached sessions in the session cache.
+// The purpose here is mainly to be able to also persist failed
+// attempts to retrieve a session.
type cachedSession interface {
+ // Whether the Session retrieval was successful or not.
Success() bool
+ // When Success() returns true, one may use this method to retrieve the actual JMAP Session.
Get() jmap.Session
+ // When Success() returns false, one may use this method to retrieve the error that caused the failure.
Error() *GroupwareError
+ // The timestamp of when this cached session information was obtained, regardless of success or failure.
Since() time.Time
}
+// An implementation of a cachedSession that succeeded.
type succeededSession struct {
- since time.Time
+ // Timestamp of when this succeededSession was created.
+ since time.Time
+ // The JMAP Session itself.
session jmap.Session
}
+var _ cachedSession = succeededSession{}
+
func (s succeededSession) Success() bool {
return true
}
@@ -47,18 +63,21 @@ func (s succeededSession) Since() time.Time {
return s.since
}
-var _ cachedSession = succeededSession{}
-
+// An implementation of a cachedSession that failed.
type failedSession struct {
+ // Timestamp of when this failedSession was created.
since time.Time
- err *GroupwareError
+ // The error that caused the Session acquisition to fail.
+ err *GroupwareError
}
+var _ cachedSession = failedSession{}
+
func (s failedSession) Success() bool {
return false
}
func (s failedSession) Get() jmap.Session {
- panic("this should never be called")
+ panic(fmt.Sprintf("never call %T.Get()", failedSession{}))
}
func (s failedSession) Error() *GroupwareError {
return s.err
@@ -67,23 +86,27 @@ func (s failedSession) Since() time.Time {
return s.since
}
-var _ cachedSession = failedSession{}
-
+// Implements the ttlcache.Loader interface, by loading JMAP Sessions for users
+// using the jmap.Client.
type sessionCacheLoader struct {
- logger *log.Logger
+ logger *log.Logger
+ // A minimalistic contract for supplying the JMAP Session URL for a given username.
sessionUrlProvider func(username string) (*url.URL, *GroupwareError)
- jmapClient *jmap.Client
- errorTtl time.Duration
+ // A minimalistic contract for supplying JMAP Sessions using various input parameters.
+ sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error)
+ errorTtl time.Duration
}
-func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession], key sessionKey) *ttlcache.Item[sessionKey, cachedSession] {
- username := usernameFromSessionKey(key)
+var _ ttlcache.Loader[sessionCacheKey, cachedSession] = &sessionCacheLoader{}
+
+func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionCacheKey, cachedSession], key sessionCacheKey) *ttlcache.Item[sessionCacheKey, cachedSession] {
+ username := key.username()
sessionUrl, gwerr := l.sessionUrlProvider(username)
if gwerr != nil {
l.logger.Warn().Str("username", username).Str("code", gwerr.Code).Msgf("failed to determine session URL for '%v'", key)
return c.Set(key, failedSession{since: time.Now(), err: gwerr}, l.errorTtl)
}
- session, jerr := l.jmapClient.FetchSession(sessionUrl, username, l.logger)
+ session, jerr := l.sessionSupplier(sessionUrl, username, l.logger)
if jerr != nil {
l.logger.Warn().Str("username", username).Err(jerr).Msgf("failed to create session for '%v'", key)
return c.Set(key, failedSession{since: time.Now(), err: groupwareErrorFromJmap(jerr)}, l.errorTtl)
@@ -93,28 +116,168 @@ func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession],
}
}
-var _ ttlcache.Loader[sessionKey, cachedSession] = &sessionCacheLoader{}
-
-// Listens to JMAP Session outdated events, in order to remove outdated Sessions
-// from the Groupware Session cache.
-type sessionEventListener struct {
- logger *log.Logger
- sessionCache *ttlcache.Cache[sessionKey, cachedSession]
- counter prometheus.Counter
+type sessionCache interface {
+ Get(username string) cachedSession
+ jmap.SessionEventListener
}
-func (l sessionEventListener) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) {
- // it's enough to remove the session from the cache, as it will be fetched on-demand
- // the next time an operation is performed on behalf of the user
- l.sessionCache.Delete(toSessionKey(session.Username))
- if l.counter != nil {
- l.counter.Inc()
+type ttlcacheSessionCache struct {
+ sessionCache *ttlcache.Cache[sessionCacheKey, cachedSession]
+ outdatedSessionCounter prometheus.Counter
+ logger *log.Logger
+}
+
+var _ sessionCache = &ttlcacheSessionCache{}
+var _ jmap.SessionEventListener = &ttlcacheSessionCache{}
+
+func (c *ttlcacheSessionCache) Get(username string) cachedSession {
+ item := c.sessionCache.Get(toSessionCacheKey(username))
+ if item != nil {
+ return item.Value()
+ } else {
+ return nil
+ }
+}
+
+type sessionCacheBuilder struct {
+ logger *log.Logger
+ sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error)
+ defaultUrlResolver func(string) (*url.URL, *GroupwareError)
+ sessionUrlResolverFactory func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError)
+ prometheusRegistry prometheus.Registerer
+ m *metrics.Metrics
+ sessionCacheMaxCapacity uint64
+ sessionCacheTtl time.Duration
+ sessionFailureCacheTtl time.Duration
+}
+
+func newSessionCacheBuilder(
+ sessionUrl *url.URL,
+ logger *log.Logger,
+ sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error),
+ prometheusRegistry prometheus.Registerer,
+ m *metrics.Metrics,
+ sessionCacheMaxCapacity uint64,
+ sessionCacheTtl time.Duration,
+ sessionFailureCacheTtl time.Duration,
+) *sessionCacheBuilder {
+ defaultUrlResolver := func(_ string) (*url.URL, *GroupwareError) {
+ return sessionUrl, nil
}
- l.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState)
+ return &sessionCacheBuilder{
+ logger: logger,
+ sessionSupplier: sessionSupplier,
+ defaultUrlResolver: defaultUrlResolver,
+ sessionUrlResolverFactory: func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) {
+ return defaultUrlResolver, nil
+ },
+ prometheusRegistry: prometheusRegistry,
+ m: m,
+ sessionCacheMaxCapacity: sessionCacheMaxCapacity,
+ sessionCacheTtl: sessionCacheTtl,
+ sessionFailureCacheTtl: sessionFailureCacheTtl,
+ }
}
-var _ jmap.SessionEventListener = sessionEventListener{}
+func (b *sessionCacheBuilder) withDnsAutoDiscovery(
+ defaultSessionDomain string,
+ config *dns.ClientConfig,
+ dnsDialTimeout time.Duration,
+ dnsReadTimeout time.Duration,
+ domainGreenList []string,
+ domainRedList []string,
+) *sessionCacheBuilder {
+ dnsSessionUrlResolverFactory := func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) {
+ d, err := NewDnsSessionUrlResolver(
+ b.defaultUrlResolver,
+ defaultSessionDomain,
+ config,
+ domainGreenList,
+ domainRedList,
+ dnsDialTimeout,
+ dnsReadTimeout,
+ )
+ if err != nil {
+ return nil, &GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err}
+ } else {
+ return d.Resolve, nil
+ }
+ }
+ b.sessionUrlResolverFactory = dnsSessionUrlResolverFactory
+ return b
+}
+
+func (b sessionCacheBuilder) build() (sessionCache, error) {
+ var cache *ttlcache.Cache[sessionCacheKey, cachedSession]
+
+ sessionUrlResolver, err := b.sessionUrlResolverFactory()
+ if err != nil {
+ return nil, err
+ }
+
+ sessionLoader := &sessionCacheLoader{
+ logger: b.logger,
+ sessionSupplier: b.sessionSupplier,
+ errorTtl: b.sessionFailureCacheTtl,
+ sessionUrlProvider: sessionUrlResolver,
+ }
+
+ cache = ttlcache.New(
+ ttlcache.WithCapacity[sessionCacheKey, cachedSession](b.sessionCacheMaxCapacity),
+ ttlcache.WithTTL[sessionCacheKey, cachedSession](b.sessionCacheTtl),
+ ttlcache.WithDisableTouchOnHit[sessionCacheKey, cachedSession](),
+ ttlcache.WithLoader(sessionLoader),
+ )
+
+ b.prometheusRegistry.Register(sessionCacheMetricsCollector{desc: b.m.SessionCacheDesc, supply: cache.Metrics})
+
+ cache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionCacheKey, cachedSession]) {
+ if b.logger.Trace().Enabled() {
+ reason := ""
+ switch r {
+ case ttlcache.EvictionReasonDeleted:
+ reason = "deleted"
+ case ttlcache.EvictionReasonCapacityReached:
+ reason = "capacity reached"
+ case ttlcache.EvictionReasonExpired:
+ reason = fmt.Sprintf("expired after %v", item.TTL())
+ case ttlcache.EvictionReasonMaxCostExceeded:
+ reason = "max cost exceeded"
+ }
+ if reason == "" {
+ reason = fmt.Sprintf("unknown (%v)", r)
+ }
+ spentInCache := time.Since(item.Value().Since())
+ tipe := "successful"
+ if !item.Value().Success() {
+ tipe = "failed"
+ }
+ b.logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason)
+ }
+ })
+
+ s := &ttlcacheSessionCache{
+ sessionCache: cache,
+ logger: b.logger,
+ outdatedSessionCounter: b.m.OutdatedSessionsCounter,
+ }
+
+ go cache.Start()
+
+ return s, nil
+}
+
+func (c ttlcacheSessionCache) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) {
+ // it's enough to remove the session from the cache, as it will be fetched on-demand
+ // the next time an operation is performed on behalf of the user
+ c.sessionCache.Delete(toSessionCacheKey(session.Username))
+ if c.outdatedSessionCounter != nil {
+ c.outdatedSessionCounter.Inc()
+ }
+
+ c.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState)
+}
// A Prometheus Collector for the Session cache metrics.
type sessionCacheMetricsCollector struct {
@@ -134,3 +297,10 @@ func (s sessionCacheMetricsCollector) Collect(ch chan<- prometheus.Metric) {
}
var _ prometheus.Collector = sessionCacheMetricsCollector{}
+
+// Create a new log.Logger that is decorated with fields containing information about the Session.
+func decorateLogger(l *log.Logger, session jmap.Session) *log.Logger {
+ return log.From(l.With().
+ Str(logUsername, log.SafeString(session.Username)).
+ Str(logSessionState, log.SafeString(string(session.State))))
+}