diff --git a/pkg/jmap/jmap_api.go b/pkg/jmap/jmap_api.go index d02276b236..e6f77997a5 100644 --- a/pkg/jmap/jmap_api.go +++ b/pkg/jmap/jmap_api.go @@ -23,15 +23,13 @@ type BlobClient interface { } const ( - logOperation = "operation" - logUsername = "username" - logMailboxId = "mailbox-id" - logFetchBodies = "fetch-bodies" - logOffset = "offset" - logLimit = "limit" - logDownloadUrl = "download-url" - logBlobId = "blob-id" - logUploadUrl = "download-url" - logSessionState = "session-state" - logSince = "since" + logOperation = "operation" + logMailboxId = "mailbox-id" + logFetchBodies = "fetch-bodies" + logOffset = "offset" + logLimit = "limit" + logDownloadUrl = "download-url" + logBlobId = "blob-id" + logUploadUrl = "download-url" + logSince = "since" ) diff --git a/pkg/jmap/jmap_client.go b/pkg/jmap/jmap_client.go index 1fad905a1e..f55db0661e 100644 --- a/pkg/jmap/jmap_client.go +++ b/pkg/jmap/jmap_client.go @@ -16,6 +16,8 @@ type Client struct { io.Closer } +var _ io.Closer = &Client{} + func (j *Client) Close() error { return j.api.Close() } diff --git a/pkg/jmap/jmap_http.go b/pkg/jmap/jmap_http.go index 56d1a1ade2..bdc633bd37 100644 --- a/pkg/jmap/jmap_http.go +++ b/pkg/jmap/jmap_http.go @@ -123,10 +123,11 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger //sessionUrl := baseurl.JoinPath(".well-known", "jmap") sessionUrlStr := sessionUrl.String() endpoint := endpointOf(sessionUrl) + logger = log.From(logger.With().Str(logEndpoint, endpoint)) req, err := http.NewRequest(http.MethodGet, sessionUrlStr, nil) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", sessionUrl) + logger.Error().Err(err).Msgf("failed to create GET request for %v", sessionUrl) return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err} } h.auth(username, logger, req) @@ -135,12 +136,12 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger res, err := h.client.Do(req) if err != nil { h.listener.OnFailedRequest(endpoint, err) - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", sessionUrl) + logger.Error().Err(err).Msgf("failed to perform GET %v", sessionUrl) return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err} } if res.StatusCode < 200 || res.StatusCode > 299 { h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode) - logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200") + logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200") return SessionResponse{}, SimpleError{code: JmapErrorServerResponse, err: fmt.Errorf("JMAP API response status is %v", res.Status)} } h.listener.OnSuccessfulRequest(endpoint, res.StatusCode) @@ -156,7 +157,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger body, err := io.ReadAll(res.Body) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body") + logger.Error().Err(err).Msg("failed to read response body") h.listener.OnResponseBodyReadingError(endpoint, err) return SessionResponse{}, SimpleError{code: JmapErrorReadingResponseBody, err: err} } @@ -164,7 +165,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger var data SessionResponse err = json.Unmarshal(body, &data) if err != nil { - logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response") + logger.Error().Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response") h.listener.OnResponseBodyUnmarshallingError(endpoint, err) return SessionResponse{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err} } @@ -175,16 +176,17 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, session *Session, request Request) ([]byte, Error) { jmapUrl := session.JmapUrl.String() endpoint := session.JmapEndpoint + logger = log.From(logger.With().Str(logEndpoint, endpoint)) bodyBytes, err := json.Marshal(request) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to marshall JSON payload") + logger.Error().Err(err).Msg("failed to marshall JSON payload") return nil, SimpleError{code: JmapErrorEncodingRequestBody, err: err} } req, err := http.NewRequestWithContext(ctx, http.MethodPost, jmapUrl, bytes.NewBuffer(bodyBytes)) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", jmapUrl) + logger.Error().Err(err).Msgf("failed to create POST request for %v", jmapUrl) return nil, SimpleError{code: JmapErrorCreatingRequest, err: err} } req.Header.Add("Content-Type", "application/json") @@ -194,7 +196,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio res, err := h.client.Do(req) if err != nil { h.listener.OnFailedRequest(endpoint, err) - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", jmapUrl) + logger.Error().Err(err).Msgf("failed to perform POST %v", jmapUrl) return nil, SimpleError{code: JmapErrorSendingRequest, err: err} } if res.StatusCode < 200 || res.StatusCode > 299 { @@ -214,7 +216,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio body, err := io.ReadAll(res.Body) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body") + logger.Error().Err(err).Msg("failed to read response body") h.listener.OnResponseBodyReadingError(endpoint, err) return nil, SimpleError{code: JmapErrorServerResponse, err: err} } @@ -223,9 +225,11 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio } func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, session *Session, uploadUrl string, endpoint string, contentType string, body io.Reader) (UploadedBlob, Error) { + logger = log.From(logger.With().Str(logEndpoint, endpoint)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl, body) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", uploadUrl) + logger.Error().Err(err).Msgf("failed to create POST request for %v", uploadUrl) return UploadedBlob{}, SimpleError{code: JmapErrorCreatingRequest, err: err} } req.Header.Add("Content-Type", contentType) @@ -235,12 +239,12 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s res, err := h.client.Do(req) if err != nil { h.listener.OnFailedRequest(endpoint, err) - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", uploadUrl) + logger.Error().Err(err).Msgf("failed to perform POST %v", uploadUrl) return UploadedBlob{}, SimpleError{code: JmapErrorSendingRequest, err: err} } if res.StatusCode < 200 || res.StatusCode > 299 { h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode) - logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx") + logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx") return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err} } if res.Body != nil { @@ -255,7 +259,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s responseBody, err := io.ReadAll(res.Body) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body") + logger.Error().Err(err).Msg("failed to read response body") h.listener.OnResponseBodyReadingError(endpoint, err) return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err} } @@ -263,7 +267,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s var result UploadedBlob err = json.Unmarshal(responseBody, &result) if err != nil { - logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response") + logger.Error().Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response") h.listener.OnResponseBodyUnmarshallingError(endpoint, err) return UploadedBlob{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err} } @@ -272,9 +276,11 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s } func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, session *Session, downloadUrl string, endpoint string) (*BlobDownload, Error) { + logger = log.From(logger.With().Str(logEndpoint, endpoint)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil) if err != nil { - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", downloadUrl) + logger.Error().Err(err).Msgf("failed to create GET request for %v", downloadUrl) return nil, SimpleError{code: JmapErrorCreatingRequest, err: err} } req.Header.Add("User-Agent", h.userAgent) @@ -283,7 +289,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, res, err := h.client.Do(req) if err != nil { h.listener.OnFailedRequest(endpoint, err) - logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", downloadUrl) + logger.Error().Err(err).Msgf("failed to perform GET %v", downloadUrl) return nil, SimpleError{code: JmapErrorSendingRequest, err: err} } if res.StatusCode == http.StatusNotFound { @@ -291,7 +297,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, } if res.StatusCode < 200 || res.StatusCode > 299 { h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode) - logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx") + logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx") return nil, SimpleError{code: JmapErrorServerResponse, err: err} } h.listener.OnSuccessfulRequest(endpoint, res.StatusCode) @@ -301,7 +307,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, if sizeStr != "" { size, err = strconv.Atoi(sizeStr) if err != nil { - logger.Warn().Err(err).Str(logEndpoint, endpoint).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr) + logger.Warn().Err(err).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr) size = -1 } } diff --git a/pkg/jmap/jmap_session.go b/pkg/jmap/jmap_session.go index 3b2bd703ed..67a5ec4c14 100644 --- a/pkg/jmap/jmap_session.go +++ b/pkg/jmap/jmap_session.go @@ -4,8 +4,6 @@ import ( "errors" "fmt" "net/url" - - "github.com/opencloud-eu/opencloud/pkg/log" ) type SessionEventListener interface { @@ -101,14 +99,6 @@ func newSession(sessionResponse SessionResponse) (Session, Error) { }, nil } -// Create a new log.Logger that is decorated with fields containing information about the Session. -func (s Session) DecorateLogger(l log.Logger) *log.Logger { - return log.From(l.With(). - Str(logUsername, log.SafeString(s.Username)). - Str(logEndpoint, log.SafeString(s.JmapEndpoint)). - Str(logSessionState, log.SafeString(string(s.State)))) -} - func endpointOf(u *url.URL) string { if u != nil { return fmt.Sprintf("%s://%s", u.Scheme, u.Host) diff --git a/services/groupware/pkg/groupware/groupware_dns.go b/services/groupware/pkg/groupware/groupware_dns.go index 6ef89b677f..15ef97d1e2 100644 --- a/services/groupware/pkg/groupware/groupware_dns.go +++ b/services/groupware/pkg/groupware/groupware_dns.go @@ -17,17 +17,22 @@ var ( ) type DnsSessionUrlResolver struct { - defaultSessionUrl *url.URL - defaultDomain string - domainGreenList []string - domainRedList []string - config *dns.ClientConfig - client *dns.Client + defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError) + defaultDomain string + domainGreenList []string + domainRedList []string + config *dns.ClientConfig + client *dns.Client } -func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string, - config *dns.ClientConfig, domainGreenList []string, domainRedList []string, - dialTimeout time.Duration, readTimeout time.Duration, +func NewDnsSessionUrlResolver( + defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError), + defaultDomain string, + config *dns.ClientConfig, + domainGreenList []string, + domainRedList []string, + dialTimeout time.Duration, + readTimeout time.Duration, ) (DnsSessionUrlResolver, error) { // TODO the whole udp or tcp dialier configuration, see https://github.com/miekg/exdns/blob/master/q/q.go @@ -37,10 +42,10 @@ func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string, } return DnsSessionUrlResolver{ - defaultSessionUrl: defaultSessionUrl, - defaultDomain: defaultDomain, - config: config, - client: c, + defaultSessionUrlSupplier: defaultSessionUrlSupplier, + defaultDomain: defaultDomain, + config: config, + client: c, }, nil } @@ -75,7 +80,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr // nevertheless then? if d.defaultDomain == "" { // we don't, then let's fall back to the static session URL instead - return d.defaultSessionUrl, nil + return d.defaultSessionUrlSupplier(username) } } else { domain = parts[len(parts)-1] @@ -133,7 +138,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr } } - return d.defaultSessionUrl, nil + return d.defaultSessionUrlSupplier(username) } func (d DnsSessionUrlResolver) dnsQuery(c *dns.Client, msg *dns.Msg) (*dns.Msg, error) { diff --git a/services/groupware/pkg/groupware/groupware_framework.go b/services/groupware/pkg/groupware/groupware_framework.go index d4d00fd954..1baf84ca92 100644 --- a/services/groupware/pkg/groupware/groupware_framework.go +++ b/services/groupware/pkg/groupware/groupware_framework.go @@ -18,8 +18,6 @@ import ( "github.com/prometheus/client_golang/prometheus" - "github.com/jellydator/ttlcache/v3" - cmap "github.com/orcaman/concurrent-map" "github.com/opencloud-eu/opencloud/pkg/jmap" @@ -30,8 +28,9 @@ import ( ) const ( - logUsername = "username" // this should match jmap.logUsername to avoid having the field twice in the logs under different keys + logUsername = "username" logUserId = "user-id" + logSessionState = "session-state" logAccountId = "account-id" logErrorId = "error-id" logErrorCode = "code" @@ -51,17 +50,17 @@ const ( logMethod = "method" ) -// Minimalistic representation of a User, containing only the attributes that are +// Minimalistic representation of a user, containing only the attributes that are // necessary for the Groupware implementation. -type User interface { +type user interface { GetUsername() string GetId() string } // Provides a User that is associated with a request. -type UserProvider interface { +type userProvider interface { // Provide the user for JMAP operations. - GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error) + GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error) } // Background job that needs to be executed asynchronously by the Groupware. @@ -89,9 +88,9 @@ type Groupware struct { defaultEmailLimit uint maxBodyValueBytes uint // Caches successful and failed Sessions by the username. - sessionCache *ttlcache.Cache[sessionKey, cachedSession] + sessionCache sessionCache jmap *jmap.Client - userProvider UserProvider + userProvider userProvider // SSE events that need to be pushed to clients. eventChannel chan Event // Background jobs that need to be executed. @@ -127,10 +126,14 @@ type Event struct { Body any } +// A jmap.HttpJmapApiClientEventListener implementation that records those JMAP +// events as metric increments. type groupwareHttpJmapApiClientMetricsRecorder struct { m *metrics.Metrics } +var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{} + func (r groupwareHttpJmapApiClientMetricsRecorder) OnSuccessfulRequest(endpoint string, status int) { r.m.SuccessfulRequestPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc() } @@ -147,8 +150,6 @@ func (r groupwareHttpJmapApiClientMetricsRecorder) OnResponseBodyUnmarshallingEr r.m.ResponseBodyUnmarshallingErrorPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc() } -var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{} - func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prometheusRegistry prometheus.Registerer) (*Groupware, error) { baseUrl, err := url.Parse(config.Mail.BaseUrl) if err != nil { @@ -190,111 +191,68 @@ func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prome m := metrics.New(prometheusRegistry, logger) // TODO add timeouts and other meaningful configuration settings for the HTTP client - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.ResponseHeaderTimeout = responseHeaderTimeout + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.ResponseHeaderTimeout = responseHeaderTimeout if insecureTls { tlsConfig := &tls.Config{InsecureSkipVerify: true} // TODO make configurable - tr.TLSClientConfig = tlsConfig + httpTransport.TLSClientConfig = tlsConfig } - c := *http.DefaultClient - c.Transport = tr + httpClient := *http.DefaultClient + httpClient.Transport = httpTransport - userProvider := NewRevaContextUsernameProvider() + userProvider := newRevaContextUsernameProvider() jmapMetricsAdapter := groupwareHttpJmapApiClientMetricsRecorder{m: m} api := jmap.NewHttpJmapClient( - &c, + &httpClient, masterUsername, masterPassword, jmapMetricsAdapter, ) + // api implements all three interfaces: jmapClient := jmap.NewClient(api, api, api) - var sessionCache *ttlcache.Cache[sessionKey, cachedSession] - { - sessionUrlResolver := func(_ string) (*url.URL, *GroupwareError) { - return sessionUrl, nil - } - if useDnsForSessionResolution { - defaultSessionDomain := "example.com" // TODO default domain from configuration - // TODO resolv.conf or other configuration - conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") - if err != nil { - return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err} - } - - var domainGreenList []string = nil // TODO domain greenlist from configuration - var domainRedList []string = nil // TODO domain redlist from configuration - - dialTimeout := time.Duration(2) * time.Second // TODO configuration - readTimeout := time.Duration(2) * time.Second // TODO configuration - - dnsSessionUrlResolver, err := NewDnsSessionUrlResolver( - sessionUrl, - defaultSessionDomain, - conf, - domainGreenList, - domainRedList, - dialTimeout, - readTimeout, - ) - if err != nil { - return nil, GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err} - } - sessionUrlResolver = dnsSessionUrlResolver.Resolve + sessionCacheBuilder := newSessionCacheBuilder( + sessionUrl, + logger, + jmapClient.FetchSession, + prometheusRegistry, + m, + sessionCacheMaxCapacity, + sessionCacheTtl, + sessionFailureCacheTtl, + ) + if useDnsForSessionResolution { + conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err} } - sessionLoader := &sessionCacheLoader{ - logger: logger, - jmapClient: &jmapClient, - errorTtl: sessionFailureCacheTtl, - sessionUrlProvider: sessionUrlResolver, - } + var dnsDomainGreenList []string = nil // TODO domain greenlist from configuration + var dnsDomainRedList []string = nil // TODO domain redlist from configuration + dnsDialTimeout := time.Duration(2) * time.Second // TODO DNS server connection timeout configuration + dnsReadTimeout := time.Duration(2) * time.Second // TODO DNS server response reading timeout configuration + defaultDomain := "example.com" // TODO default domain when the username is not an email address configuration - sessionCache = ttlcache.New( - ttlcache.WithCapacity[sessionKey, cachedSession](sessionCacheMaxCapacity), - ttlcache.WithTTL[sessionKey, cachedSession](sessionCacheTtl), - ttlcache.WithDisableTouchOnHit[sessionKey, cachedSession](), - ttlcache.WithLoader(sessionLoader), + sessionCacheBuilder = sessionCacheBuilder.withDnsAutoDiscovery( + defaultDomain, + conf, + dnsDialTimeout, + dnsReadTimeout, + dnsDomainGreenList, + dnsDomainRedList, ) - go sessionCache.Start() - - prometheusRegistry.Register(sessionCacheMetricsCollector{desc: m.SessionCacheDesc, supply: sessionCache.Metrics}) } - sessionCache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionKey, cachedSession]) { - if logger.Trace().Enabled() { - reason := "" - switch r { - case ttlcache.EvictionReasonDeleted: - reason = "deleted" - case ttlcache.EvictionReasonCapacityReached: - reason = "capacity reached" - case ttlcache.EvictionReasonExpired: - reason = fmt.Sprintf("expired after %v", item.TTL()) - case ttlcache.EvictionReasonMaxCostExceeded: - reason = "max cost exceeded" - } - if reason == "" { - reason = fmt.Sprintf("unknown (%v)", r) - } - spentInCache := time.Since(item.Value().Since()) - tipe := "successful" - if !item.Value().Success() { - tipe = "failed" - } - logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason) - } - }) + sessionCache, err := sessionCacheBuilder.build() - sessionEventListener := sessionEventListener{ - sessionCache: sessionCache, - logger: logger, - counter: m.OutdatedSessionsCounter, + if err != nil { + // assuming that the error was logged in great detail upstream + return nil, GroupwareInitializationError{Message: "failed to initialize the session cache", Err: err} } - jmapClient.AddSessionEventListener(&sessionEventListener) + jmapClient.AddSessionEventListener(sessionCache) // A channel to process SSE Events with a single worker. eventChannel := make(chan Event, eventChannelSize) @@ -426,7 +384,7 @@ func (g *Groupware) listenForEvents() { } } -func (g *Groupware) push(user User, typ string, body any) { +func (g *Groupware) push(user user, typ string, body any) { g.metrics.SSEEventsCounter.WithLabelValues(typ).Inc() g.eventChannel <- Event{Type: typ, Stream: user.GetUsername(), Body: body} } @@ -467,16 +425,13 @@ func (g *Groupware) ServeSSE(w http.ResponseWriter, r *http.Request) { } // Provide a JMAP Session for the -func (g *Groupware) session(user User, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) { - item := g.sessionCache.Get(toSessionKey(user.GetUsername())) - if item != nil { - value := item.Value() - if value != nil { - if value.Success() { - return value.Get(), true, nil - } else { - return jmap.Session{}, false, value.Error() - } +func (g *Groupware) session(user user, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) { + s := g.sessionCache.Get(user.GetUsername()) + if s != nil { + if s.Success() { + return s.Get(), true, nil + } else { + return jmap.Session{}, false, s.Error() } } return jmap.Session{}, false, nil @@ -562,7 +517,7 @@ func (g *Groupware) withSession(w http.ResponseWriter, r *http.Request, handler g.serveError(w, r, apiError(errorId, *gwerr)) return Response{}, false } - decoratedLogger := session.DecorateLogger(*logger) + decoratedLogger := decorateLogger(logger, session) req := Request{ g: g, @@ -666,7 +621,7 @@ func (g *Groupware) stream(w http.ResponseWriter, r *http.Request, handler func( return } - decoratedLogger := session.DecorateLogger(*logger) + decoratedLogger := decorateLogger(logger, session) req := Request{ r: r, diff --git a/services/groupware/pkg/groupware/groupware_request.go b/services/groupware/pkg/groupware/groupware_request.go index 06bdc0e402..cbe01b6e2c 100644 --- a/services/groupware/pkg/groupware/groupware_request.go +++ b/services/groupware/pkg/groupware/groupware_request.go @@ -27,7 +27,7 @@ import ( // the parameter list of every single handler function type Request struct { g *Groupware - user User + user user r *http.Request ctx context.Context logger *log.Logger @@ -38,7 +38,7 @@ func (r Request) push(typ string, event any) { r.g.push(r.user, typ, event) } -func (r Request) GetUser() User { +func (r Request) GetUser() user { return r.user } diff --git a/services/groupware/pkg/groupware/groupware_reva.go b/services/groupware/pkg/groupware/groupware_reva.go index d9f0a50a47..4e2704f4ce 100644 --- a/services/groupware/pkg/groupware/groupware_reva.go +++ b/services/groupware/pkg/groupware/groupware_reva.go @@ -15,9 +15,9 @@ import ( type revaContextUserProvider struct { } -var _ UserProvider = revaContextUserProvider{} +var _ userProvider = revaContextUserProvider{} -func NewRevaContextUsernameProvider() UserProvider { +func newRevaContextUsernameProvider() userProvider { return revaContextUserProvider{} } @@ -27,26 +27,26 @@ var ( errUserNotInRevaContext = errors.New("failed to find user in reva context") ) -func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error) { +func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error) { u, ok := revactx.ContextGetUser(ctx) if !ok { err := errUserNotInRevaContext logger.Error().Err(err).Ctx(ctx).Msgf("could not get user: user not in reva context: %v", ctx) return nil, err } - return RevaUser{user: u}, nil + return revaUser{user: u}, nil } -type RevaUser struct { +type revaUser struct { user *userv1beta1.User } -func (r RevaUser) GetUsername() string { +func (r revaUser) GetUsername() string { return r.user.GetUsername() } -func (r RevaUser) GetId() string { +func (r revaUser) GetId() string { return r.user.GetId().GetOpaqueId() } -var _ User = RevaUser{} +var _ user = revaUser{} diff --git a/services/groupware/pkg/groupware/groupware_session.go b/services/groupware/pkg/groupware/groupware_session.go index e5b63035ff..a85748f55b 100644 --- a/services/groupware/pkg/groupware/groupware_session.go +++ b/services/groupware/pkg/groupware/groupware_session.go @@ -1,10 +1,13 @@ package groupware import ( + "context" + "fmt" "net/url" "time" "github.com/jellydator/ttlcache/v3" + "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" "github.com/opencloud-eu/opencloud/pkg/jmap" @@ -12,28 +15,41 @@ import ( "github.com/opencloud-eu/opencloud/services/groupware/pkg/metrics" ) -type sessionKey string +// An alias for the internal session cache key, which might become something composed in the future. +type sessionCacheKey string -func toSessionKey(username string) sessionKey { - return sessionKey(username) +func toSessionCacheKey(username string) sessionCacheKey { + return sessionCacheKey(username) } -func usernameFromSessionKey(key sessionKey) string { - return string(key) +func (k sessionCacheKey) username() string { + return string(k) } +// Interface for cached sessions in the session cache. +// The purpose here is mainly to be able to also persist failed +// attempts to retrieve a session. type cachedSession interface { + // Whether the Session retrieval was successful or not. Success() bool + // When Success() returns true, one may use this method to retrieve the actual JMAP Session. Get() jmap.Session + // When Success() returns false, one may use this method to retrieve the error that caused the failure. Error() *GroupwareError + // The timestamp of when this cached session information was obtained, regardless of success or failure. Since() time.Time } +// An implementation of a cachedSession that succeeded. type succeededSession struct { - since time.Time + // Timestamp of when this succeededSession was created. + since time.Time + // The JMAP Session itself. session jmap.Session } +var _ cachedSession = succeededSession{} + func (s succeededSession) Success() bool { return true } @@ -47,18 +63,21 @@ func (s succeededSession) Since() time.Time { return s.since } -var _ cachedSession = succeededSession{} - +// An implementation of a cachedSession that failed. type failedSession struct { + // Timestamp of when this failedSession was created. since time.Time - err *GroupwareError + // The error that caused the Session acquisition to fail. + err *GroupwareError } +var _ cachedSession = failedSession{} + func (s failedSession) Success() bool { return false } func (s failedSession) Get() jmap.Session { - panic("this should never be called") + panic(fmt.Sprintf("never call %T.Get()", failedSession{})) } func (s failedSession) Error() *GroupwareError { return s.err @@ -67,23 +86,27 @@ func (s failedSession) Since() time.Time { return s.since } -var _ cachedSession = failedSession{} - +// Implements the ttlcache.Loader interface, by loading JMAP Sessions for users +// using the jmap.Client. type sessionCacheLoader struct { - logger *log.Logger + logger *log.Logger + // A minimalistic contract for supplying the JMAP Session URL for a given username. sessionUrlProvider func(username string) (*url.URL, *GroupwareError) - jmapClient *jmap.Client - errorTtl time.Duration + // A minimalistic contract for supplying JMAP Sessions using various input parameters. + sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error) + errorTtl time.Duration } -func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession], key sessionKey) *ttlcache.Item[sessionKey, cachedSession] { - username := usernameFromSessionKey(key) +var _ ttlcache.Loader[sessionCacheKey, cachedSession] = &sessionCacheLoader{} + +func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionCacheKey, cachedSession], key sessionCacheKey) *ttlcache.Item[sessionCacheKey, cachedSession] { + username := key.username() sessionUrl, gwerr := l.sessionUrlProvider(username) if gwerr != nil { l.logger.Warn().Str("username", username).Str("code", gwerr.Code).Msgf("failed to determine session URL for '%v'", key) return c.Set(key, failedSession{since: time.Now(), err: gwerr}, l.errorTtl) } - session, jerr := l.jmapClient.FetchSession(sessionUrl, username, l.logger) + session, jerr := l.sessionSupplier(sessionUrl, username, l.logger) if jerr != nil { l.logger.Warn().Str("username", username).Err(jerr).Msgf("failed to create session for '%v'", key) return c.Set(key, failedSession{since: time.Now(), err: groupwareErrorFromJmap(jerr)}, l.errorTtl) @@ -93,28 +116,168 @@ func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession], } } -var _ ttlcache.Loader[sessionKey, cachedSession] = &sessionCacheLoader{} - -// Listens to JMAP Session outdated events, in order to remove outdated Sessions -// from the Groupware Session cache. -type sessionEventListener struct { - logger *log.Logger - sessionCache *ttlcache.Cache[sessionKey, cachedSession] - counter prometheus.Counter +type sessionCache interface { + Get(username string) cachedSession + jmap.SessionEventListener } -func (l sessionEventListener) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) { - // it's enough to remove the session from the cache, as it will be fetched on-demand - // the next time an operation is performed on behalf of the user - l.sessionCache.Delete(toSessionKey(session.Username)) - if l.counter != nil { - l.counter.Inc() +type ttlcacheSessionCache struct { + sessionCache *ttlcache.Cache[sessionCacheKey, cachedSession] + outdatedSessionCounter prometheus.Counter + logger *log.Logger +} + +var _ sessionCache = &ttlcacheSessionCache{} +var _ jmap.SessionEventListener = &ttlcacheSessionCache{} + +func (c *ttlcacheSessionCache) Get(username string) cachedSession { + item := c.sessionCache.Get(toSessionCacheKey(username)) + if item != nil { + return item.Value() + } else { + return nil + } +} + +type sessionCacheBuilder struct { + logger *log.Logger + sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error) + defaultUrlResolver func(string) (*url.URL, *GroupwareError) + sessionUrlResolverFactory func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) + prometheusRegistry prometheus.Registerer + m *metrics.Metrics + sessionCacheMaxCapacity uint64 + sessionCacheTtl time.Duration + sessionFailureCacheTtl time.Duration +} + +func newSessionCacheBuilder( + sessionUrl *url.URL, + logger *log.Logger, + sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error), + prometheusRegistry prometheus.Registerer, + m *metrics.Metrics, + sessionCacheMaxCapacity uint64, + sessionCacheTtl time.Duration, + sessionFailureCacheTtl time.Duration, +) *sessionCacheBuilder { + defaultUrlResolver := func(_ string) (*url.URL, *GroupwareError) { + return sessionUrl, nil } - l.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState) + return &sessionCacheBuilder{ + logger: logger, + sessionSupplier: sessionSupplier, + defaultUrlResolver: defaultUrlResolver, + sessionUrlResolverFactory: func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) { + return defaultUrlResolver, nil + }, + prometheusRegistry: prometheusRegistry, + m: m, + sessionCacheMaxCapacity: sessionCacheMaxCapacity, + sessionCacheTtl: sessionCacheTtl, + sessionFailureCacheTtl: sessionFailureCacheTtl, + } } -var _ jmap.SessionEventListener = sessionEventListener{} +func (b *sessionCacheBuilder) withDnsAutoDiscovery( + defaultSessionDomain string, + config *dns.ClientConfig, + dnsDialTimeout time.Duration, + dnsReadTimeout time.Duration, + domainGreenList []string, + domainRedList []string, +) *sessionCacheBuilder { + dnsSessionUrlResolverFactory := func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) { + d, err := NewDnsSessionUrlResolver( + b.defaultUrlResolver, + defaultSessionDomain, + config, + domainGreenList, + domainRedList, + dnsDialTimeout, + dnsReadTimeout, + ) + if err != nil { + return nil, &GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err} + } else { + return d.Resolve, nil + } + } + b.sessionUrlResolverFactory = dnsSessionUrlResolverFactory + return b +} + +func (b sessionCacheBuilder) build() (sessionCache, error) { + var cache *ttlcache.Cache[sessionCacheKey, cachedSession] + + sessionUrlResolver, err := b.sessionUrlResolverFactory() + if err != nil { + return nil, err + } + + sessionLoader := &sessionCacheLoader{ + logger: b.logger, + sessionSupplier: b.sessionSupplier, + errorTtl: b.sessionFailureCacheTtl, + sessionUrlProvider: sessionUrlResolver, + } + + cache = ttlcache.New( + ttlcache.WithCapacity[sessionCacheKey, cachedSession](b.sessionCacheMaxCapacity), + ttlcache.WithTTL[sessionCacheKey, cachedSession](b.sessionCacheTtl), + ttlcache.WithDisableTouchOnHit[sessionCacheKey, cachedSession](), + ttlcache.WithLoader(sessionLoader), + ) + + b.prometheusRegistry.Register(sessionCacheMetricsCollector{desc: b.m.SessionCacheDesc, supply: cache.Metrics}) + + cache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionCacheKey, cachedSession]) { + if b.logger.Trace().Enabled() { + reason := "" + switch r { + case ttlcache.EvictionReasonDeleted: + reason = "deleted" + case ttlcache.EvictionReasonCapacityReached: + reason = "capacity reached" + case ttlcache.EvictionReasonExpired: + reason = fmt.Sprintf("expired after %v", item.TTL()) + case ttlcache.EvictionReasonMaxCostExceeded: + reason = "max cost exceeded" + } + if reason == "" { + reason = fmt.Sprintf("unknown (%v)", r) + } + spentInCache := time.Since(item.Value().Since()) + tipe := "successful" + if !item.Value().Success() { + tipe = "failed" + } + b.logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason) + } + }) + + s := &ttlcacheSessionCache{ + sessionCache: cache, + logger: b.logger, + outdatedSessionCounter: b.m.OutdatedSessionsCounter, + } + + go cache.Start() + + return s, nil +} + +func (c ttlcacheSessionCache) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) { + // it's enough to remove the session from the cache, as it will be fetched on-demand + // the next time an operation is performed on behalf of the user + c.sessionCache.Delete(toSessionCacheKey(session.Username)) + if c.outdatedSessionCounter != nil { + c.outdatedSessionCounter.Inc() + } + + c.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState) +} // A Prometheus Collector for the Session cache metrics. type sessionCacheMetricsCollector struct { @@ -134,3 +297,10 @@ func (s sessionCacheMetricsCollector) Collect(ch chan<- prometheus.Metric) { } var _ prometheus.Collector = sessionCacheMetricsCollector{} + +// Create a new log.Logger that is decorated with fields containing information about the Session. +func decorateLogger(l *log.Logger, session jmap.Session) *log.Logger { + return log.From(l.With(). + Str(logUsername, log.SafeString(session.Username)). + Str(logSessionState, log.SafeString(string(session.State)))) +}