From 4924a926d6546a1bc102905fa147c8e5f414e772 Mon Sep 17 00:00:00 2001 From: Chun-Hung Tseng Date: Thu, 13 Jul 2023 19:55:59 +0200 Subject: [PATCH] Add insert kr to cache --- cache.go | 165 ++++++++++++++++++++++++++++++++++++++++++++------ cache_test.go | 1 + error.go | 2 + file.go | 8 +-- folder.go | 11 ++-- keyring.go | 51 ---------------- search.go | 8 +-- 7 files changed, 161 insertions(+), 85 deletions(-) create mode 100644 cache_test.go delete mode 100644 keyring.go diff --git a/cache.go b/cache.go index 89f4905..690a1c8 100644 --- a/cache.go +++ b/cache.go @@ -4,11 +4,13 @@ import ( "context" "sync" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/henrybear327/go-proton-api" ) type cacheEntry struct { link *proton.Link + kr *crypto.KeyRing } type cache struct { @@ -25,41 +27,92 @@ func newCache(disableCaching bool) *cache { } } -func (linkCache *cache) _getLink(linkID string) *proton.Link { - if linkCache.disableCaching { +func (cache *cache) _get(linkID string) *cacheEntry { + if cache.disableCaching { return nil } - linkCache.RLock() - defer linkCache.RUnlock() + cache.RLock() + defer cache.RUnlock() - if data, ok := linkCache.data[linkID]; ok && data.link != nil { - return data.link + if data, ok := cache.data[linkID]; ok { + return data } return nil } -func (linkCache *cache) _insertLink(linkID string, link *proton.Link) { - if linkCache.disableCaching { +func (cache *cache) _insert(linkID string, link *proton.Link, kr *crypto.KeyRing) { + if cache.disableCaching { return } - linkCache.Lock() - defer linkCache.Unlock() + cache.Lock() + defer cache.Unlock() - linkCache.data[linkID] = &cacheEntry{ + cache.data[linkID] = &cacheEntry{ link: link, + kr: kr, } } -func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*proton.Link, error) { - // attempt to get from cache first - if link := protonDrive.cache._getLink(linkID); link != nil { - // log.Println("From cache") - return link, nil +/* The original non-caching version, which resolves the keyring recursively */ +func (protonDrive *ProtonDrive) _getLinkKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) { + if linkID == "" { + // most likely someone requested root link's parent link, which happen to be "" + // return protonDrive.MainShareKR.Copy() // we need to return a deep copy since the keyring will be freed by the caller when it finishes using the keyring -> now we go through caching, so we won't clear kr + + return protonDrive.MainShareKR, nil + } + + link, err := protonDrive.getLink(ctx, linkID) + if err != nil { + return nil, err + } + + return protonDrive._getLinkKR(ctx, link) +} + +/* The original non-caching version, which resolves the keyring recursively */ +func (protonDrive *ProtonDrive) _getLinkKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) { + if link.ParentLinkID == "" { // link is rootLink + nodeKR, err := link.GetKeyRing(protonDrive.MainShareKR, protonDrive.AddrKR) + if err != nil { + return nil, err + } + + return nodeKR, nil + } + + parentLink, err := protonDrive.getLink(ctx, link.ParentLinkID) + if err != nil { + return nil, err + } + + // parentNodeKR is used to decrypt the current node's KR, as each node has its keyring, which can be decrypted by its parent + parentNodeKR, err := protonDrive._getLinkKR(ctx, parentLink) + if err != nil { + return nil, err + } + + nodeKR, err := link.GetKeyRing(parentNodeKR, protonDrive.AddrKR) + if err != nil { + return nil, err + } + + return nodeKR, nil +} + +func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*proton.Link, error) { + if linkID == "" { + // this is only possible when doing rootLink's parent, which should be handled beforehand + return nil, ErrWrongUsageOfGetLink + } + + // attempt to get from cache first + if data := protonDrive.cache._get(linkID); data != nil && data.link != nil { + return data.link, nil } - // log.Println("Not from cache") // no cached data, fetch link, err := protonDrive.c.GetLink(ctx, protonDrive.MainShare.ShareID, linkID) if err != nil { @@ -67,7 +120,83 @@ func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*pr } // populate cache - protonDrive.cache._insertLink(linkID, &link) + protonDrive.cache._insert(linkID, &link, nil) return &link, nil } + +func (protonDrive *ProtonDrive) getLinkKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) { + if protonDrive.cache.disableCaching { + return protonDrive._getLinkKR(ctx, link) + } + + if link == nil { + return nil, ErrWrongUsageOfGetLinkKR + } + + // attempt to get from cache first + if data := protonDrive.cache._get(link.LinkID); data != nil && data.link != nil { + if data.kr != nil { + return data.kr, nil + } + + // decrypt keyring and cache it + parentNodeKR, err := protonDrive.getLinkKRByID(ctx, data.link.ParentLinkID) + if err != nil { + return nil, err + } + + kr, err := data.link.GetKeyRing(parentNodeKR, protonDrive.AddrKR) + if err != nil { + return nil, err + } + data.kr = kr + return data.kr, nil + } + + // no cached data, fetch + protonDrive.cache._insert(link.LinkID, link, nil) + + return protonDrive.getLinkKR(ctx, link) +} + +func (protonDrive *ProtonDrive) getLinkKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) { + if protonDrive.cache.disableCaching { + return protonDrive._getLinkKRByID(ctx, linkID) + } + + if linkID == "" { + return protonDrive.MainShareKR, nil + } + + // attempt to get from cache first + if data := protonDrive.cache._get(linkID); data != nil && data.link != nil { + return protonDrive.getLinkKR(ctx, data.link) + } + + // log.Println("Not from cache") + // no cached data, fetch + link, err := protonDrive.getLink(ctx, linkID) + if err != nil { + return nil, err + } + + return protonDrive.getLinkKR(ctx, link) +} + +// TODO: handle removal upon rmdir, mv, etc. cases + +// func (protonDrive *ProtonDrive) clearCache() { +// if protonDrive.cache.disableCaching { +// return +// } + +// protonDrive.cache.Lock() +// defer protonDrive.cache.Unlock() + +// for _, entry := range protonDrive.cache.data { +// entry.kr.ClearPrivateParams() +// } + +// protonDrive.cache.data = make(map[string]*cacheEntry) +// } diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..d88881f --- /dev/null +++ b/cache_test.go @@ -0,0 +1 @@ +package proton_api_bridge diff --git a/error.go b/error.go index 44ee430..9e2bf6d 100644 --- a/error.go +++ b/error.go @@ -17,4 +17,6 @@ var ( ErrDraftExists = errors.New("a draft exist - usually this means a file is being uploaded at another client, or, there was a failed upload attempt") ErrCantFindActiveRevision = errors.New("can't find an active revision") ErrCantFindDraftRevision = errors.New("can't find a draft revision") + ErrWrongUsageOfGetLinkKR = errors.New("internal error for GetLinkKR - nil passed in for link") + ErrWrongUsageOfGetLink = errors.New("internal error for getLink - empty linkID passed in") ) diff --git a/file.go b/file.go index 5e818b8..6903a2f 100644 --- a/file.go +++ b/file.go @@ -68,7 +68,7 @@ func (protonDrive *ProtonDrive) GetActiveRevisionWithAttrs(ctx context.Context, return nil, nil, err } - nodeKR, err := protonDrive.getNodeKR(ctx, link) + nodeKR, err := protonDrive.getLinkKR(ctx, link) if err != nil { return nil, nil, err } @@ -94,7 +94,7 @@ func (protonDrive *ProtonDrive) DownloadFile(ctx context.Context, link *proton.L return nil, nil, ErrLinkTypeMustToBeFileType } - parentNodeKR, err := protonDrive.getNodeKRByID(ctx, link.ParentLinkID) + parentNodeKR, err := protonDrive.getLinkKRByID(ctx, link.ParentLinkID) if err != nil { return nil, nil, err } @@ -218,7 +218,7 @@ func (protonDrive *ProtonDrive) handleRevisionConflict(ctx context.Context, link } func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, parentLink *proton.Link, filename string, modTime time.Time, mimeType string) (string, string, *crypto.SessionKey, *crypto.KeyRing, error) { - parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink) + parentNodeKR, err := protonDrive.getLinkKR(ctx, parentLink) if err != nil { return "", "", nil, nil, err } @@ -336,7 +336,7 @@ func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, paren linkID = link.LinkID // get original newSessionKey and newNodeKR - parentNodeKR, err = protonDrive.getNodeKRByID(ctx, link.ParentLinkID) + parentNodeKR, err = protonDrive.getLinkKRByID(ctx, link.ParentLinkID) if err != nil { return "", "", nil, nil, err } diff --git a/folder.go b/folder.go index 5f3203f..ebccae6 100644 --- a/folder.go +++ b/folder.go @@ -32,16 +32,14 @@ func (protonDrive *ProtonDrive) ListDirectory( } if childrenLinks != nil { - folderParentKR, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID) + folderParentKR, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID) if err != nil { return nil, err } - defer folderParentKR.ClearPrivateParams() folderLinkKR, err := folderLink.GetKeyRing(folderParentKR, protonDrive.AddrKR) if err != nil { return nil, err } - defer folderLinkKR.ClearPrivateParams() for i := range childrenLinks { if childrenLinks[i].State != proton.LinkStateActive { @@ -142,7 +140,6 @@ func (protonDrive *ProtonDrive) ListDirectoriesRecursively( if err != nil { return err } - defer linkKR.ClearPrivateParams() for _, childLink := range childrenLinks { err = protonDrive.ListDirectoriesRecursively(ctx, linkKR, &childLink, download, maxDepth, curDepth+1, excludeRoot, currentPath, paths) @@ -167,7 +164,7 @@ func (protonDrive *ProtonDrive) CreateNewFolderByID(ctx context.Context, parentL } func (protonDrive *ProtonDrive) CreateNewFolder(ctx context.Context, parentLink *proton.Link, folderName string) (string, error) { - parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink) + parentNodeKR, err := protonDrive.getLinkKR(ctx, parentLink) if err != nil { return "", err } @@ -282,7 +279,7 @@ func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Li SignatureAddress: protonDrive.signatureAddress, } - dstParentKR, err := protonDrive.getNodeKR(ctx, dstParentLink) + dstParentKR, err := protonDrive.getLinkKR(ctx, dstParentLink) if err != nil { return err } @@ -301,7 +298,7 @@ func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Li return err } - srcParentKR, err := protonDrive.getNodeKRByID(ctx, srcLink.ParentLinkID) + srcParentKR, err := protonDrive.getLinkKRByID(ctx, srcLink.ParentLinkID) if err != nil { return err } diff --git a/keyring.go b/keyring.go deleted file mode 100644 index 9783ef0..0000000 --- a/keyring.go +++ /dev/null @@ -1,51 +0,0 @@ -package proton_api_bridge - -import ( - "context" - - "github.com/ProtonMail/gopenpgp/v2/crypto" - "github.com/henrybear327/go-proton-api" -) - -func (protonDrive *ProtonDrive) getNodeKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) { - if linkID == "" { - // most likely someone requested parent link, which happen to be "" - return protonDrive.MainShareKR.Copy() // we need to return a deep copy since the keyring will be freed by the caller when it finishes using the keyring - } - - link, err := protonDrive.getLink(ctx, linkID) - if err != nil { - return nil, err - } - - return protonDrive.getNodeKR(ctx, link) -} - -func (protonDrive *ProtonDrive) getNodeKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) { - if link.ParentLinkID == "" { - nodeKR, err := link.GetKeyRing(protonDrive.MainShareKR, protonDrive.AddrKR) - if err != nil { - return nil, err - } - - return nodeKR, nil - } - - parentLink, err := protonDrive.getLink(ctx, link.ParentLinkID) - if err != nil { - return nil, err - } - - // parentNodeKR is used to decrypt the current node's KR, as each node has its keyring, which can be decrypted by its parent - parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink) - if err != nil { - return nil, err - } - - nodeKR, err := link.GetKeyRing(parentNodeKR, protonDrive.AddrKR) - if err != nil { - return nil, err - } - - return nodeKR, nil -} diff --git a/search.go b/search.go index faf54d7..7eb241d 100644 --- a/search.go +++ b/search.go @@ -37,7 +37,7 @@ func (protonDrive *ProtonDrive) SearchByNameRecursivelyByID(ctx context.Context, if folderLink.Type != proton.LinkTypeFolder { return nil, ErrLinkTypeMustToBeFolderType } - folderKeyRing, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID) + folderKeyRing, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func (protonDrive *ProtonDrive) SearchByNameRecursively(ctx context.Context, fol if folderLink.Type != proton.LinkTypeFolder { return nil, ErrLinkTypeMustToBeFolderType } - folderKeyRing, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID) + folderKeyRing, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID) if err != nil { return nil, err } @@ -98,7 +98,6 @@ func (protonDrive *ProtonDrive) searchByNameRecursively( if err != nil { return nil, err } - defer linkKR.ClearPrivateParams() for _, childLink := range childrenLinks { ret, err := protonDrive.searchByNameRecursively(ctx, linkKR, &childLink, targetName, linkType, listAllActiveOrDraftFiles) @@ -150,7 +149,7 @@ func (protonDrive *ProtonDrive) SearchByNameInActiveFolder( return nil, nil } - parentNodeKR, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID) + parentNodeKR, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID) if err != nil { return nil, err } @@ -160,7 +159,6 @@ func (protonDrive *ProtonDrive) SearchByNameInActiveFolder( if err != nil { return nil, err } - defer folderLinkKR.ClearPrivateParams() childrenLinks, err := protonDrive.c.ListChildren(ctx, protonDrive.MainShare.ShareID, folderLink.LinkID, true) if err != nil {