refactor(groupware): session cache and DNS autodiscovery

* move the logging of the username and session state away from pkg/jmap
   and into services/groupware

 * introduce more decoupling for the session cache, as well as moving
   the implementation into groupware_session.go
This commit is contained in:
Pascal Bleser
2025-09-04 22:16:44 +02:00
parent e63a7c4bc5
commit 3813d14cae
9 changed files with 330 additions and 204 deletions

View File

@@ -23,15 +23,13 @@ type BlobClient interface {
}
const (
logOperation = "operation"
logUsername = "username"
logMailboxId = "mailbox-id"
logFetchBodies = "fetch-bodies"
logOffset = "offset"
logLimit = "limit"
logDownloadUrl = "download-url"
logBlobId = "blob-id"
logUploadUrl = "download-url"
logSessionState = "session-state"
logSince = "since"
logOperation = "operation"
logMailboxId = "mailbox-id"
logFetchBodies = "fetch-bodies"
logOffset = "offset"
logLimit = "limit"
logDownloadUrl = "download-url"
logBlobId = "blob-id"
logUploadUrl = "download-url"
logSince = "since"
)

View File

@@ -16,6 +16,8 @@ type Client struct {
io.Closer
}
var _ io.Closer = &Client{}
func (j *Client) Close() error {
return j.api.Close()
}

View File

@@ -123,10 +123,11 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
//sessionUrl := baseurl.JoinPath(".well-known", "jmap")
sessionUrlStr := sessionUrl.String()
endpoint := endpointOf(sessionUrl)
logger = log.From(logger.With().Str(logEndpoint, endpoint))
req, err := http.NewRequest(http.MethodGet, sessionUrlStr, nil)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", sessionUrl)
logger.Error().Err(err).Msgf("failed to create GET request for %v", sessionUrl)
return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err}
}
h.auth(username, logger, req)
@@ -135,12 +136,12 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", sessionUrl)
logger.Error().Err(err).Msgf("failed to perform GET %v", sessionUrl)
return SessionResponse{}, SimpleError{code: JmapErrorInvalidHttpRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200")
logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 200")
return SessionResponse{}, SimpleError{code: JmapErrorServerResponse, err: fmt.Errorf("JMAP API response status is %v", res.Status)}
}
h.listener.OnSuccessfulRequest(endpoint, res.StatusCode)
@@ -156,7 +157,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
body, err := io.ReadAll(res.Body)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return SessionResponse{}, SimpleError{code: JmapErrorReadingResponseBody, err: err}
}
@@ -164,7 +165,7 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
var data SessionResponse
err = json.Unmarshal(body, &data)
if err != nil {
logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response")
logger.Error().Str(logHttpUrl, sessionUrlStr).Err(err).Msg("failed to decode JSON payload from .well-known/jmap response")
h.listener.OnResponseBodyUnmarshallingError(endpoint, err)
return SessionResponse{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err}
}
@@ -175,16 +176,17 @@ func (h *HttpJmapClient) GetSession(sessionUrl *url.URL, username string, logger
func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, session *Session, request Request) ([]byte, Error) {
jmapUrl := session.JmapUrl.String()
endpoint := session.JmapEndpoint
logger = log.From(logger.With().Str(logEndpoint, endpoint))
bodyBytes, err := json.Marshal(request)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to marshall JSON payload")
logger.Error().Err(err).Msg("failed to marshall JSON payload")
return nil, SimpleError{code: JmapErrorEncodingRequestBody, err: err}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, jmapUrl, bytes.NewBuffer(bodyBytes))
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", jmapUrl)
logger.Error().Err(err).Msgf("failed to create POST request for %v", jmapUrl)
return nil, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("Content-Type", "application/json")
@@ -194,7 +196,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", jmapUrl)
logger.Error().Err(err).Msgf("failed to perform POST %v", jmapUrl)
return nil, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
@@ -214,7 +216,7 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
body, err := io.ReadAll(res.Body)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return nil, SimpleError{code: JmapErrorServerResponse, err: err}
}
@@ -223,9 +225,11 @@ func (h *HttpJmapClient) Command(ctx context.Context, logger *log.Logger, sessio
}
func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, session *Session, uploadUrl string, endpoint string, contentType string, body io.Reader) (UploadedBlob, Error) {
logger = log.From(logger.With().Str(logEndpoint, endpoint))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl, body)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create POST request for %v", uploadUrl)
logger.Error().Err(err).Msgf("failed to create POST request for %v", uploadUrl)
return UploadedBlob{}, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("Content-Type", contentType)
@@ -235,12 +239,12 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform POST %v", uploadUrl)
logger.Error().Err(err).Msgf("failed to perform POST %v", uploadUrl)
return UploadedBlob{}, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err}
}
if res.Body != nil {
@@ -255,7 +259,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
responseBody, err := io.ReadAll(res.Body)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msg("failed to read response body")
logger.Error().Err(err).Msg("failed to read response body")
h.listener.OnResponseBodyReadingError(endpoint, err)
return UploadedBlob{}, SimpleError{code: JmapErrorServerResponse, err: err}
}
@@ -263,7 +267,7 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
var result UploadedBlob
err = json.Unmarshal(responseBody, &result)
if err != nil {
logger.Error().Str(logEndpoint, endpoint).Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response")
logger.Error().Str(logHttpUrl, uploadUrl).Err(err).Msg("failed to decode JSON payload from the upload response")
h.listener.OnResponseBodyUnmarshallingError(endpoint, err)
return UploadedBlob{}, SimpleError{code: JmapErrorDecodingResponseBody, err: err}
}
@@ -272,9 +276,11 @@ func (h *HttpJmapClient) UploadBinary(ctx context.Context, logger *log.Logger, s
}
func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger, session *Session, downloadUrl string, endpoint string) (*BlobDownload, Error) {
logger = log.From(logger.With().Str(logEndpoint, endpoint))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to create GET request for %v", downloadUrl)
logger.Error().Err(err).Msgf("failed to create GET request for %v", downloadUrl)
return nil, SimpleError{code: JmapErrorCreatingRequest, err: err}
}
req.Header.Add("User-Agent", h.userAgent)
@@ -283,7 +289,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
res, err := h.client.Do(req)
if err != nil {
h.listener.OnFailedRequest(endpoint, err)
logger.Error().Err(err).Str(logEndpoint, endpoint).Msgf("failed to perform GET %v", downloadUrl)
logger.Error().Err(err).Msgf("failed to perform GET %v", downloadUrl)
return nil, SimpleError{code: JmapErrorSendingRequest, err: err}
}
if res.StatusCode == http.StatusNotFound {
@@ -291,7 +297,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
}
if res.StatusCode < 200 || res.StatusCode > 299 {
h.listener.OnFailedRequestWithStatus(endpoint, res.StatusCode)
logger.Error().Str(logEndpoint, endpoint).Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
logger.Error().Str(logHttpStatus, res.Status).Int(logHttpStatusCode, res.StatusCode).Msg("HTTP response status code is not 2xx")
return nil, SimpleError{code: JmapErrorServerResponse, err: err}
}
h.listener.OnSuccessfulRequest(endpoint, res.StatusCode)
@@ -301,7 +307,7 @@ func (h *HttpJmapClient) DownloadBinary(ctx context.Context, logger *log.Logger,
if sizeStr != "" {
size, err = strconv.Atoi(sizeStr)
if err != nil {
logger.Warn().Err(err).Str(logEndpoint, endpoint).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr)
logger.Warn().Err(err).Msgf("failed to parse Content-Length blob download response header value '%v'", sizeStr)
size = -1
}
}

View File

@@ -4,8 +4,6 @@ import (
"errors"
"fmt"
"net/url"
"github.com/opencloud-eu/opencloud/pkg/log"
)
type SessionEventListener interface {
@@ -101,14 +99,6 @@ func newSession(sessionResponse SessionResponse) (Session, Error) {
}, nil
}
// Create a new log.Logger that is decorated with fields containing information about the Session.
func (s Session) DecorateLogger(l log.Logger) *log.Logger {
return log.From(l.With().
Str(logUsername, log.SafeString(s.Username)).
Str(logEndpoint, log.SafeString(s.JmapEndpoint)).
Str(logSessionState, log.SafeString(string(s.State))))
}
func endpointOf(u *url.URL) string {
if u != nil {
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)

View File

@@ -17,17 +17,22 @@ var (
)
type DnsSessionUrlResolver struct {
defaultSessionUrl *url.URL
defaultDomain string
domainGreenList []string
domainRedList []string
config *dns.ClientConfig
client *dns.Client
defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError)
defaultDomain string
domainGreenList []string
domainRedList []string
config *dns.ClientConfig
client *dns.Client
}
func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string,
config *dns.ClientConfig, domainGreenList []string, domainRedList []string,
dialTimeout time.Duration, readTimeout time.Duration,
func NewDnsSessionUrlResolver(
defaultSessionUrlSupplier func(string) (*url.URL, *GroupwareError),
defaultDomain string,
config *dns.ClientConfig,
domainGreenList []string,
domainRedList []string,
dialTimeout time.Duration,
readTimeout time.Duration,
) (DnsSessionUrlResolver, error) {
// TODO the whole udp or tcp dialier configuration, see https://github.com/miekg/exdns/blob/master/q/q.go
@@ -37,10 +42,10 @@ func NewDnsSessionUrlResolver(defaultSessionUrl *url.URL, defaultDomain string,
}
return DnsSessionUrlResolver{
defaultSessionUrl: defaultSessionUrl,
defaultDomain: defaultDomain,
config: config,
client: c,
defaultSessionUrlSupplier: defaultSessionUrlSupplier,
defaultDomain: defaultDomain,
config: config,
client: c,
}, nil
}
@@ -75,7 +80,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr
// nevertheless then?
if d.defaultDomain == "" {
// we don't, then let's fall back to the static session URL instead
return d.defaultSessionUrl, nil
return d.defaultSessionUrlSupplier(username)
}
} else {
domain = parts[len(parts)-1]
@@ -133,7 +138,7 @@ func (d DnsSessionUrlResolver) Resolve(username string) (*url.URL, *GroupwareErr
}
}
return d.defaultSessionUrl, nil
return d.defaultSessionUrlSupplier(username)
}
func (d DnsSessionUrlResolver) dnsQuery(c *dns.Client, msg *dns.Msg) (*dns.Msg, error) {

View File

@@ -18,8 +18,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/jellydator/ttlcache/v3"
cmap "github.com/orcaman/concurrent-map"
"github.com/opencloud-eu/opencloud/pkg/jmap"
@@ -30,8 +28,9 @@ import (
)
const (
logUsername = "username" // this should match jmap.logUsername to avoid having the field twice in the logs under different keys
logUsername = "username"
logUserId = "user-id"
logSessionState = "session-state"
logAccountId = "account-id"
logErrorId = "error-id"
logErrorCode = "code"
@@ -51,17 +50,17 @@ const (
logMethod = "method"
)
// Minimalistic representation of a User, containing only the attributes that are
// Minimalistic representation of a user, containing only the attributes that are
// necessary for the Groupware implementation.
type User interface {
type user interface {
GetUsername() string
GetId() string
}
// Provides a User that is associated with a request.
type UserProvider interface {
type userProvider interface {
// Provide the user for JMAP operations.
GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error)
GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error)
}
// Background job that needs to be executed asynchronously by the Groupware.
@@ -89,9 +88,9 @@ type Groupware struct {
defaultEmailLimit uint
maxBodyValueBytes uint
// Caches successful and failed Sessions by the username.
sessionCache *ttlcache.Cache[sessionKey, cachedSession]
sessionCache sessionCache
jmap *jmap.Client
userProvider UserProvider
userProvider userProvider
// SSE events that need to be pushed to clients.
eventChannel chan Event
// Background jobs that need to be executed.
@@ -127,10 +126,14 @@ type Event struct {
Body any
}
// A jmap.HttpJmapApiClientEventListener implementation that records those JMAP
// events as metric increments.
type groupwareHttpJmapApiClientMetricsRecorder struct {
m *metrics.Metrics
}
var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{}
func (r groupwareHttpJmapApiClientMetricsRecorder) OnSuccessfulRequest(endpoint string, status int) {
r.m.SuccessfulRequestPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc()
}
@@ -147,8 +150,6 @@ func (r groupwareHttpJmapApiClientMetricsRecorder) OnResponseBodyUnmarshallingEr
r.m.ResponseBodyUnmarshallingErrorPerEndpointCounter.With(metrics.Endpoint(endpoint)).Inc()
}
var _ jmap.HttpJmapApiClientEventListener = groupwareHttpJmapApiClientMetricsRecorder{}
func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prometheusRegistry prometheus.Registerer) (*Groupware, error) {
baseUrl, err := url.Parse(config.Mail.BaseUrl)
if err != nil {
@@ -190,111 +191,68 @@ func NewGroupware(config *config.Config, logger *log.Logger, mux *chi.Mux, prome
m := metrics.New(prometheusRegistry, logger)
// TODO add timeouts and other meaningful configuration settings for the HTTP client
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.ResponseHeaderTimeout = responseHeaderTimeout
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.ResponseHeaderTimeout = responseHeaderTimeout
if insecureTls {
tlsConfig := &tls.Config{InsecureSkipVerify: true} // TODO make configurable
tr.TLSClientConfig = tlsConfig
httpTransport.TLSClientConfig = tlsConfig
}
c := *http.DefaultClient
c.Transport = tr
httpClient := *http.DefaultClient
httpClient.Transport = httpTransport
userProvider := NewRevaContextUsernameProvider()
userProvider := newRevaContextUsernameProvider()
jmapMetricsAdapter := groupwareHttpJmapApiClientMetricsRecorder{m: m}
api := jmap.NewHttpJmapClient(
&c,
&httpClient,
masterUsername,
masterPassword,
jmapMetricsAdapter,
)
// api implements all three interfaces:
jmapClient := jmap.NewClient(api, api, api)
var sessionCache *ttlcache.Cache[sessionKey, cachedSession]
{
sessionUrlResolver := func(_ string) (*url.URL, *GroupwareError) {
return sessionUrl, nil
}
if useDnsForSessionResolution {
defaultSessionDomain := "example.com" // TODO default domain from configuration
// TODO resolv.conf or other configuration
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err}
}
var domainGreenList []string = nil // TODO domain greenlist from configuration
var domainRedList []string = nil // TODO domain redlist from configuration
dialTimeout := time.Duration(2) * time.Second // TODO configuration
readTimeout := time.Duration(2) * time.Second // TODO configuration
dnsSessionUrlResolver, err := NewDnsSessionUrlResolver(
sessionUrl,
defaultSessionDomain,
conf,
domainGreenList,
domainRedList,
dialTimeout,
readTimeout,
)
if err != nil {
return nil, GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err}
}
sessionUrlResolver = dnsSessionUrlResolver.Resolve
sessionCacheBuilder := newSessionCacheBuilder(
sessionUrl,
logger,
jmapClient.FetchSession,
prometheusRegistry,
m,
sessionCacheMaxCapacity,
sessionCacheTtl,
sessionFailureCacheTtl,
)
if useDnsForSessionResolution {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, GroupwareInitializationError{Message: "failed to parse DNS client configuration from /etc/resolv.conf", Err: err}
}
sessionLoader := &sessionCacheLoader{
logger: logger,
jmapClient: &jmapClient,
errorTtl: sessionFailureCacheTtl,
sessionUrlProvider: sessionUrlResolver,
}
var dnsDomainGreenList []string = nil // TODO domain greenlist from configuration
var dnsDomainRedList []string = nil // TODO domain redlist from configuration
dnsDialTimeout := time.Duration(2) * time.Second // TODO DNS server connection timeout configuration
dnsReadTimeout := time.Duration(2) * time.Second // TODO DNS server response reading timeout configuration
defaultDomain := "example.com" // TODO default domain when the username is not an email address configuration
sessionCache = ttlcache.New(
ttlcache.WithCapacity[sessionKey, cachedSession](sessionCacheMaxCapacity),
ttlcache.WithTTL[sessionKey, cachedSession](sessionCacheTtl),
ttlcache.WithDisableTouchOnHit[sessionKey, cachedSession](),
ttlcache.WithLoader(sessionLoader),
sessionCacheBuilder = sessionCacheBuilder.withDnsAutoDiscovery(
defaultDomain,
conf,
dnsDialTimeout,
dnsReadTimeout,
dnsDomainGreenList,
dnsDomainRedList,
)
go sessionCache.Start()
prometheusRegistry.Register(sessionCacheMetricsCollector{desc: m.SessionCacheDesc, supply: sessionCache.Metrics})
}
sessionCache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionKey, cachedSession]) {
if logger.Trace().Enabled() {
reason := ""
switch r {
case ttlcache.EvictionReasonDeleted:
reason = "deleted"
case ttlcache.EvictionReasonCapacityReached:
reason = "capacity reached"
case ttlcache.EvictionReasonExpired:
reason = fmt.Sprintf("expired after %v", item.TTL())
case ttlcache.EvictionReasonMaxCostExceeded:
reason = "max cost exceeded"
}
if reason == "" {
reason = fmt.Sprintf("unknown (%v)", r)
}
spentInCache := time.Since(item.Value().Since())
tipe := "successful"
if !item.Value().Success() {
tipe = "failed"
}
logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason)
}
})
sessionCache, err := sessionCacheBuilder.build()
sessionEventListener := sessionEventListener{
sessionCache: sessionCache,
logger: logger,
counter: m.OutdatedSessionsCounter,
if err != nil {
// assuming that the error was logged in great detail upstream
return nil, GroupwareInitializationError{Message: "failed to initialize the session cache", Err: err}
}
jmapClient.AddSessionEventListener(&sessionEventListener)
jmapClient.AddSessionEventListener(sessionCache)
// A channel to process SSE Events with a single worker.
eventChannel := make(chan Event, eventChannelSize)
@@ -426,7 +384,7 @@ func (g *Groupware) listenForEvents() {
}
}
func (g *Groupware) push(user User, typ string, body any) {
func (g *Groupware) push(user user, typ string, body any) {
g.metrics.SSEEventsCounter.WithLabelValues(typ).Inc()
g.eventChannel <- Event{Type: typ, Stream: user.GetUsername(), Body: body}
}
@@ -467,16 +425,13 @@ func (g *Groupware) ServeSSE(w http.ResponseWriter, r *http.Request) {
}
// Provide a JMAP Session for the
func (g *Groupware) session(user User, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) {
item := g.sessionCache.Get(toSessionKey(user.GetUsername()))
if item != nil {
value := item.Value()
if value != nil {
if value.Success() {
return value.Get(), true, nil
} else {
return jmap.Session{}, false, value.Error()
}
func (g *Groupware) session(user user, _ *http.Request, _ context.Context, _ *log.Logger) (jmap.Session, bool, *GroupwareError) {
s := g.sessionCache.Get(user.GetUsername())
if s != nil {
if s.Success() {
return s.Get(), true, nil
} else {
return jmap.Session{}, false, s.Error()
}
}
return jmap.Session{}, false, nil
@@ -562,7 +517,7 @@ func (g *Groupware) withSession(w http.ResponseWriter, r *http.Request, handler
g.serveError(w, r, apiError(errorId, *gwerr))
return Response{}, false
}
decoratedLogger := session.DecorateLogger(*logger)
decoratedLogger := decorateLogger(logger, session)
req := Request{
g: g,
@@ -666,7 +621,7 @@ func (g *Groupware) stream(w http.ResponseWriter, r *http.Request, handler func(
return
}
decoratedLogger := session.DecorateLogger(*logger)
decoratedLogger := decorateLogger(logger, session)
req := Request{
r: r,

View File

@@ -27,7 +27,7 @@ import (
// the parameter list of every single handler function
type Request struct {
g *Groupware
user User
user user
r *http.Request
ctx context.Context
logger *log.Logger
@@ -38,7 +38,7 @@ func (r Request) push(typ string, event any) {
r.g.push(r.user, typ, event)
}
func (r Request) GetUser() User {
func (r Request) GetUser() user {
return r.user
}

View File

@@ -15,9 +15,9 @@ import (
type revaContextUserProvider struct {
}
var _ UserProvider = revaContextUserProvider{}
var _ userProvider = revaContextUserProvider{}
func NewRevaContextUsernameProvider() UserProvider {
func newRevaContextUsernameProvider() userProvider {
return revaContextUserProvider{}
}
@@ -27,26 +27,26 @@ var (
errUserNotInRevaContext = errors.New("failed to find user in reva context")
)
func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (User, error) {
func (r revaContextUserProvider) GetUser(req *http.Request, ctx context.Context, logger *log.Logger) (user, error) {
u, ok := revactx.ContextGetUser(ctx)
if !ok {
err := errUserNotInRevaContext
logger.Error().Err(err).Ctx(ctx).Msgf("could not get user: user not in reva context: %v", ctx)
return nil, err
}
return RevaUser{user: u}, nil
return revaUser{user: u}, nil
}
type RevaUser struct {
type revaUser struct {
user *userv1beta1.User
}
func (r RevaUser) GetUsername() string {
func (r revaUser) GetUsername() string {
return r.user.GetUsername()
}
func (r RevaUser) GetId() string {
func (r revaUser) GetId() string {
return r.user.GetId().GetOpaqueId()
}
var _ User = RevaUser{}
var _ user = revaUser{}

View File

@@ -1,10 +1,13 @@
package groupware
import (
"context"
"fmt"
"net/url"
"time"
"github.com/jellydator/ttlcache/v3"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/opencloud-eu/opencloud/pkg/jmap"
@@ -12,28 +15,41 @@ import (
"github.com/opencloud-eu/opencloud/services/groupware/pkg/metrics"
)
type sessionKey string
// An alias for the internal session cache key, which might become something composed in the future.
type sessionCacheKey string
func toSessionKey(username string) sessionKey {
return sessionKey(username)
func toSessionCacheKey(username string) sessionCacheKey {
return sessionCacheKey(username)
}
func usernameFromSessionKey(key sessionKey) string {
return string(key)
func (k sessionCacheKey) username() string {
return string(k)
}
// Interface for cached sessions in the session cache.
// The purpose here is mainly to be able to also persist failed
// attempts to retrieve a session.
type cachedSession interface {
// Whether the Session retrieval was successful or not.
Success() bool
// When Success() returns true, one may use this method to retrieve the actual JMAP Session.
Get() jmap.Session
// When Success() returns false, one may use this method to retrieve the error that caused the failure.
Error() *GroupwareError
// The timestamp of when this cached session information was obtained, regardless of success or failure.
Since() time.Time
}
// An implementation of a cachedSession that succeeded.
type succeededSession struct {
since time.Time
// Timestamp of when this succeededSession was created.
since time.Time
// The JMAP Session itself.
session jmap.Session
}
var _ cachedSession = succeededSession{}
func (s succeededSession) Success() bool {
return true
}
@@ -47,18 +63,21 @@ func (s succeededSession) Since() time.Time {
return s.since
}
var _ cachedSession = succeededSession{}
// An implementation of a cachedSession that failed.
type failedSession struct {
// Timestamp of when this failedSession was created.
since time.Time
err *GroupwareError
// The error that caused the Session acquisition to fail.
err *GroupwareError
}
var _ cachedSession = failedSession{}
func (s failedSession) Success() bool {
return false
}
func (s failedSession) Get() jmap.Session {
panic("this should never be called")
panic(fmt.Sprintf("never call %T.Get()", failedSession{}))
}
func (s failedSession) Error() *GroupwareError {
return s.err
@@ -67,23 +86,27 @@ func (s failedSession) Since() time.Time {
return s.since
}
var _ cachedSession = failedSession{}
// Implements the ttlcache.Loader interface, by loading JMAP Sessions for users
// using the jmap.Client.
type sessionCacheLoader struct {
logger *log.Logger
logger *log.Logger
// A minimalistic contract for supplying the JMAP Session URL for a given username.
sessionUrlProvider func(username string) (*url.URL, *GroupwareError)
jmapClient *jmap.Client
errorTtl time.Duration
// A minimalistic contract for supplying JMAP Sessions using various input parameters.
sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error)
errorTtl time.Duration
}
func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession], key sessionKey) *ttlcache.Item[sessionKey, cachedSession] {
username := usernameFromSessionKey(key)
var _ ttlcache.Loader[sessionCacheKey, cachedSession] = &sessionCacheLoader{}
func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionCacheKey, cachedSession], key sessionCacheKey) *ttlcache.Item[sessionCacheKey, cachedSession] {
username := key.username()
sessionUrl, gwerr := l.sessionUrlProvider(username)
if gwerr != nil {
l.logger.Warn().Str("username", username).Str("code", gwerr.Code).Msgf("failed to determine session URL for '%v'", key)
return c.Set(key, failedSession{since: time.Now(), err: gwerr}, l.errorTtl)
}
session, jerr := l.jmapClient.FetchSession(sessionUrl, username, l.logger)
session, jerr := l.sessionSupplier(sessionUrl, username, l.logger)
if jerr != nil {
l.logger.Warn().Str("username", username).Err(jerr).Msgf("failed to create session for '%v'", key)
return c.Set(key, failedSession{since: time.Now(), err: groupwareErrorFromJmap(jerr)}, l.errorTtl)
@@ -93,28 +116,168 @@ func (l *sessionCacheLoader) Load(c *ttlcache.Cache[sessionKey, cachedSession],
}
}
var _ ttlcache.Loader[sessionKey, cachedSession] = &sessionCacheLoader{}
// Listens to JMAP Session outdated events, in order to remove outdated Sessions
// from the Groupware Session cache.
type sessionEventListener struct {
logger *log.Logger
sessionCache *ttlcache.Cache[sessionKey, cachedSession]
counter prometheus.Counter
type sessionCache interface {
Get(username string) cachedSession
jmap.SessionEventListener
}
func (l sessionEventListener) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) {
// it's enough to remove the session from the cache, as it will be fetched on-demand
// the next time an operation is performed on behalf of the user
l.sessionCache.Delete(toSessionKey(session.Username))
if l.counter != nil {
l.counter.Inc()
type ttlcacheSessionCache struct {
sessionCache *ttlcache.Cache[sessionCacheKey, cachedSession]
outdatedSessionCounter prometheus.Counter
logger *log.Logger
}
var _ sessionCache = &ttlcacheSessionCache{}
var _ jmap.SessionEventListener = &ttlcacheSessionCache{}
func (c *ttlcacheSessionCache) Get(username string) cachedSession {
item := c.sessionCache.Get(toSessionCacheKey(username))
if item != nil {
return item.Value()
} else {
return nil
}
}
type sessionCacheBuilder struct {
logger *log.Logger
sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error)
defaultUrlResolver func(string) (*url.URL, *GroupwareError)
sessionUrlResolverFactory func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError)
prometheusRegistry prometheus.Registerer
m *metrics.Metrics
sessionCacheMaxCapacity uint64
sessionCacheTtl time.Duration
sessionFailureCacheTtl time.Duration
}
func newSessionCacheBuilder(
sessionUrl *url.URL,
logger *log.Logger,
sessionSupplier func(sessionUrl *url.URL, username string, logger *log.Logger) (jmap.Session, jmap.Error),
prometheusRegistry prometheus.Registerer,
m *metrics.Metrics,
sessionCacheMaxCapacity uint64,
sessionCacheTtl time.Duration,
sessionFailureCacheTtl time.Duration,
) *sessionCacheBuilder {
defaultUrlResolver := func(_ string) (*url.URL, *GroupwareError) {
return sessionUrl, nil
}
l.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState)
return &sessionCacheBuilder{
logger: logger,
sessionSupplier: sessionSupplier,
defaultUrlResolver: defaultUrlResolver,
sessionUrlResolverFactory: func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) {
return defaultUrlResolver, nil
},
prometheusRegistry: prometheusRegistry,
m: m,
sessionCacheMaxCapacity: sessionCacheMaxCapacity,
sessionCacheTtl: sessionCacheTtl,
sessionFailureCacheTtl: sessionFailureCacheTtl,
}
}
var _ jmap.SessionEventListener = sessionEventListener{}
func (b *sessionCacheBuilder) withDnsAutoDiscovery(
defaultSessionDomain string,
config *dns.ClientConfig,
dnsDialTimeout time.Duration,
dnsReadTimeout time.Duration,
domainGreenList []string,
domainRedList []string,
) *sessionCacheBuilder {
dnsSessionUrlResolverFactory := func() (func(string) (*url.URL, *GroupwareError), *GroupwareInitializationError) {
d, err := NewDnsSessionUrlResolver(
b.defaultUrlResolver,
defaultSessionDomain,
config,
domainGreenList,
domainRedList,
dnsDialTimeout,
dnsReadTimeout,
)
if err != nil {
return nil, &GroupwareInitializationError{Message: "failed to instantiate the DNS session URL resolver", Err: err}
} else {
return d.Resolve, nil
}
}
b.sessionUrlResolverFactory = dnsSessionUrlResolverFactory
return b
}
func (b sessionCacheBuilder) build() (sessionCache, error) {
var cache *ttlcache.Cache[sessionCacheKey, cachedSession]
sessionUrlResolver, err := b.sessionUrlResolverFactory()
if err != nil {
return nil, err
}
sessionLoader := &sessionCacheLoader{
logger: b.logger,
sessionSupplier: b.sessionSupplier,
errorTtl: b.sessionFailureCacheTtl,
sessionUrlProvider: sessionUrlResolver,
}
cache = ttlcache.New(
ttlcache.WithCapacity[sessionCacheKey, cachedSession](b.sessionCacheMaxCapacity),
ttlcache.WithTTL[sessionCacheKey, cachedSession](b.sessionCacheTtl),
ttlcache.WithDisableTouchOnHit[sessionCacheKey, cachedSession](),
ttlcache.WithLoader(sessionLoader),
)
b.prometheusRegistry.Register(sessionCacheMetricsCollector{desc: b.m.SessionCacheDesc, supply: cache.Metrics})
cache.OnEviction(func(c context.Context, r ttlcache.EvictionReason, item *ttlcache.Item[sessionCacheKey, cachedSession]) {
if b.logger.Trace().Enabled() {
reason := ""
switch r {
case ttlcache.EvictionReasonDeleted:
reason = "deleted"
case ttlcache.EvictionReasonCapacityReached:
reason = "capacity reached"
case ttlcache.EvictionReasonExpired:
reason = fmt.Sprintf("expired after %v", item.TTL())
case ttlcache.EvictionReasonMaxCostExceeded:
reason = "max cost exceeded"
}
if reason == "" {
reason = fmt.Sprintf("unknown (%v)", r)
}
spentInCache := time.Since(item.Value().Since())
tipe := "successful"
if !item.Value().Success() {
tipe = "failed"
}
b.logger.Trace().Msgf("%s session cache eviction of user '%v' after %v: %v", tipe, item.Key(), spentInCache, reason)
}
})
s := &ttlcacheSessionCache{
sessionCache: cache,
logger: b.logger,
outdatedSessionCounter: b.m.OutdatedSessionsCounter,
}
go cache.Start()
return s, nil
}
func (c ttlcacheSessionCache) OnSessionOutdated(session *jmap.Session, newSessionState jmap.SessionState) {
// it's enough to remove the session from the cache, as it will be fetched on-demand
// the next time an operation is performed on behalf of the user
c.sessionCache.Delete(toSessionCacheKey(session.Username))
if c.outdatedSessionCounter != nil {
c.outdatedSessionCounter.Inc()
}
c.logger.Trace().Msgf("removed outdated session for user '%v': state %v -> %v", session.Username, session.State, newSessionState)
}
// A Prometheus Collector for the Session cache metrics.
type sessionCacheMetricsCollector struct {
@@ -134,3 +297,10 @@ func (s sessionCacheMetricsCollector) Collect(ch chan<- prometheus.Metric) {
}
var _ prometheus.Collector = sessionCacheMetricsCollector{}
// Create a new log.Logger that is decorated with fields containing information about the Session.
func decorateLogger(l *log.Logger, session jmap.Session) *log.Logger {
return log.From(l.With().
Str(logUsername, log.SafeString(session.Username)).
Str(logSessionState, log.SafeString(string(session.State))))
}