Add insert kr to cache

This commit is contained in:
Chun-Hung Tseng
2023-07-13 19:55:59 +02:00
parent d293198b51
commit 4924a926d6
7 changed files with 161 additions and 85 deletions

165
cache.go
View File

@@ -4,11 +4,13 @@ import (
"context"
"sync"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/henrybear327/go-proton-api"
)
type cacheEntry struct {
link *proton.Link
kr *crypto.KeyRing
}
type cache struct {
@@ -25,41 +27,92 @@ func newCache(disableCaching bool) *cache {
}
}
func (linkCache *cache) _getLink(linkID string) *proton.Link {
if linkCache.disableCaching {
func (cache *cache) _get(linkID string) *cacheEntry {
if cache.disableCaching {
return nil
}
linkCache.RLock()
defer linkCache.RUnlock()
cache.RLock()
defer cache.RUnlock()
if data, ok := linkCache.data[linkID]; ok && data.link != nil {
return data.link
if data, ok := cache.data[linkID]; ok {
return data
}
return nil
}
func (linkCache *cache) _insertLink(linkID string, link *proton.Link) {
if linkCache.disableCaching {
func (cache *cache) _insert(linkID string, link *proton.Link, kr *crypto.KeyRing) {
if cache.disableCaching {
return
}
linkCache.Lock()
defer linkCache.Unlock()
cache.Lock()
defer cache.Unlock()
linkCache.data[linkID] = &cacheEntry{
cache.data[linkID] = &cacheEntry{
link: link,
kr: kr,
}
}
func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*proton.Link, error) {
// attempt to get from cache first
if link := protonDrive.cache._getLink(linkID); link != nil {
// log.Println("From cache")
return link, nil
/* The original non-caching version, which resolves the keyring recursively */
func (protonDrive *ProtonDrive) _getLinkKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) {
if linkID == "" {
// most likely someone requested root link's parent link, which happen to be ""
// return protonDrive.MainShareKR.Copy() // we need to return a deep copy since the keyring will be freed by the caller when it finishes using the keyring -> now we go through caching, so we won't clear kr
return protonDrive.MainShareKR, nil
}
link, err := protonDrive.getLink(ctx, linkID)
if err != nil {
return nil, err
}
return protonDrive._getLinkKR(ctx, link)
}
/* The original non-caching version, which resolves the keyring recursively */
func (protonDrive *ProtonDrive) _getLinkKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) {
if link.ParentLinkID == "" { // link is rootLink
nodeKR, err := link.GetKeyRing(protonDrive.MainShareKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
return nodeKR, nil
}
parentLink, err := protonDrive.getLink(ctx, link.ParentLinkID)
if err != nil {
return nil, err
}
// parentNodeKR is used to decrypt the current node's KR, as each node has its keyring, which can be decrypted by its parent
parentNodeKR, err := protonDrive._getLinkKR(ctx, parentLink)
if err != nil {
return nil, err
}
nodeKR, err := link.GetKeyRing(parentNodeKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
return nodeKR, nil
}
func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*proton.Link, error) {
if linkID == "" {
// this is only possible when doing rootLink's parent, which should be handled beforehand
return nil, ErrWrongUsageOfGetLink
}
// attempt to get from cache first
if data := protonDrive.cache._get(linkID); data != nil && data.link != nil {
return data.link, nil
}
// log.Println("Not from cache")
// no cached data, fetch
link, err := protonDrive.c.GetLink(ctx, protonDrive.MainShare.ShareID, linkID)
if err != nil {
@@ -67,7 +120,83 @@ func (protonDrive *ProtonDrive) getLink(ctx context.Context, linkID string) (*pr
}
// populate cache
protonDrive.cache._insertLink(linkID, &link)
protonDrive.cache._insert(linkID, &link, nil)
return &link, nil
}
func (protonDrive *ProtonDrive) getLinkKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) {
if protonDrive.cache.disableCaching {
return protonDrive._getLinkKR(ctx, link)
}
if link == nil {
return nil, ErrWrongUsageOfGetLinkKR
}
// attempt to get from cache first
if data := protonDrive.cache._get(link.LinkID); data != nil && data.link != nil {
if data.kr != nil {
return data.kr, nil
}
// decrypt keyring and cache it
parentNodeKR, err := protonDrive.getLinkKRByID(ctx, data.link.ParentLinkID)
if err != nil {
return nil, err
}
kr, err := data.link.GetKeyRing(parentNodeKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
data.kr = kr
return data.kr, nil
}
// no cached data, fetch
protonDrive.cache._insert(link.LinkID, link, nil)
return protonDrive.getLinkKR(ctx, link)
}
func (protonDrive *ProtonDrive) getLinkKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) {
if protonDrive.cache.disableCaching {
return protonDrive._getLinkKRByID(ctx, linkID)
}
if linkID == "" {
return protonDrive.MainShareKR, nil
}
// attempt to get from cache first
if data := protonDrive.cache._get(linkID); data != nil && data.link != nil {
return protonDrive.getLinkKR(ctx, data.link)
}
// log.Println("Not from cache")
// no cached data, fetch
link, err := protonDrive.getLink(ctx, linkID)
if err != nil {
return nil, err
}
return protonDrive.getLinkKR(ctx, link)
}
// TODO: handle removal upon rmdir, mv, etc. cases
// func (protonDrive *ProtonDrive) clearCache() {
// if protonDrive.cache.disableCaching {
// return
// }
// protonDrive.cache.Lock()
// defer protonDrive.cache.Unlock()
// for _, entry := range protonDrive.cache.data {
// entry.kr.ClearPrivateParams()
// }
// protonDrive.cache.data = make(map[string]*cacheEntry)
// }

1
cache_test.go Normal file
View File

@@ -0,0 +1 @@
package proton_api_bridge

View File

@@ -17,4 +17,6 @@ var (
ErrDraftExists = errors.New("a draft exist - usually this means a file is being uploaded at another client, or, there was a failed upload attempt")
ErrCantFindActiveRevision = errors.New("can't find an active revision")
ErrCantFindDraftRevision = errors.New("can't find a draft revision")
ErrWrongUsageOfGetLinkKR = errors.New("internal error for GetLinkKR - nil passed in for link")
ErrWrongUsageOfGetLink = errors.New("internal error for getLink - empty linkID passed in")
)

View File

@@ -68,7 +68,7 @@ func (protonDrive *ProtonDrive) GetActiveRevisionWithAttrs(ctx context.Context,
return nil, nil, err
}
nodeKR, err := protonDrive.getNodeKR(ctx, link)
nodeKR, err := protonDrive.getLinkKR(ctx, link)
if err != nil {
return nil, nil, err
}
@@ -94,7 +94,7 @@ func (protonDrive *ProtonDrive) DownloadFile(ctx context.Context, link *proton.L
return nil, nil, ErrLinkTypeMustToBeFileType
}
parentNodeKR, err := protonDrive.getNodeKRByID(ctx, link.ParentLinkID)
parentNodeKR, err := protonDrive.getLinkKRByID(ctx, link.ParentLinkID)
if err != nil {
return nil, nil, err
}
@@ -218,7 +218,7 @@ func (protonDrive *ProtonDrive) handleRevisionConflict(ctx context.Context, link
}
func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, parentLink *proton.Link, filename string, modTime time.Time, mimeType string) (string, string, *crypto.SessionKey, *crypto.KeyRing, error) {
parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink)
parentNodeKR, err := protonDrive.getLinkKR(ctx, parentLink)
if err != nil {
return "", "", nil, nil, err
}
@@ -336,7 +336,7 @@ func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, paren
linkID = link.LinkID
// get original newSessionKey and newNodeKR
parentNodeKR, err = protonDrive.getNodeKRByID(ctx, link.ParentLinkID)
parentNodeKR, err = protonDrive.getLinkKRByID(ctx, link.ParentLinkID)
if err != nil {
return "", "", nil, nil, err
}

View File

@@ -32,16 +32,14 @@ func (protonDrive *ProtonDrive) ListDirectory(
}
if childrenLinks != nil {
folderParentKR, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID)
folderParentKR, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID)
if err != nil {
return nil, err
}
defer folderParentKR.ClearPrivateParams()
folderLinkKR, err := folderLink.GetKeyRing(folderParentKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
defer folderLinkKR.ClearPrivateParams()
for i := range childrenLinks {
if childrenLinks[i].State != proton.LinkStateActive {
@@ -142,7 +140,6 @@ func (protonDrive *ProtonDrive) ListDirectoriesRecursively(
if err != nil {
return err
}
defer linkKR.ClearPrivateParams()
for _, childLink := range childrenLinks {
err = protonDrive.ListDirectoriesRecursively(ctx, linkKR, &childLink, download, maxDepth, curDepth+1, excludeRoot, currentPath, paths)
@@ -167,7 +164,7 @@ func (protonDrive *ProtonDrive) CreateNewFolderByID(ctx context.Context, parentL
}
func (protonDrive *ProtonDrive) CreateNewFolder(ctx context.Context, parentLink *proton.Link, folderName string) (string, error) {
parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink)
parentNodeKR, err := protonDrive.getLinkKR(ctx, parentLink)
if err != nil {
return "", err
}
@@ -282,7 +279,7 @@ func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Li
SignatureAddress: protonDrive.signatureAddress,
}
dstParentKR, err := protonDrive.getNodeKR(ctx, dstParentLink)
dstParentKR, err := protonDrive.getLinkKR(ctx, dstParentLink)
if err != nil {
return err
}
@@ -301,7 +298,7 @@ func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Li
return err
}
srcParentKR, err := protonDrive.getNodeKRByID(ctx, srcLink.ParentLinkID)
srcParentKR, err := protonDrive.getLinkKRByID(ctx, srcLink.ParentLinkID)
if err != nil {
return err
}

View File

@@ -1,51 +0,0 @@
package proton_api_bridge
import (
"context"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/henrybear327/go-proton-api"
)
func (protonDrive *ProtonDrive) getNodeKRByID(ctx context.Context, linkID string) (*crypto.KeyRing, error) {
if linkID == "" {
// most likely someone requested parent link, which happen to be ""
return protonDrive.MainShareKR.Copy() // we need to return a deep copy since the keyring will be freed by the caller when it finishes using the keyring
}
link, err := protonDrive.getLink(ctx, linkID)
if err != nil {
return nil, err
}
return protonDrive.getNodeKR(ctx, link)
}
func (protonDrive *ProtonDrive) getNodeKR(ctx context.Context, link *proton.Link) (*crypto.KeyRing, error) {
if link.ParentLinkID == "" {
nodeKR, err := link.GetKeyRing(protonDrive.MainShareKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
return nodeKR, nil
}
parentLink, err := protonDrive.getLink(ctx, link.ParentLinkID)
if err != nil {
return nil, err
}
// parentNodeKR is used to decrypt the current node's KR, as each node has its keyring, which can be decrypted by its parent
parentNodeKR, err := protonDrive.getNodeKR(ctx, parentLink)
if err != nil {
return nil, err
}
nodeKR, err := link.GetKeyRing(parentNodeKR, protonDrive.AddrKR)
if err != nil {
return nil, err
}
return nodeKR, nil
}

View File

@@ -37,7 +37,7 @@ func (protonDrive *ProtonDrive) SearchByNameRecursivelyByID(ctx context.Context,
if folderLink.Type != proton.LinkTypeFolder {
return nil, ErrLinkTypeMustToBeFolderType
}
folderKeyRing, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID)
folderKeyRing, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID)
if err != nil {
return nil, err
}
@@ -55,7 +55,7 @@ func (protonDrive *ProtonDrive) SearchByNameRecursively(ctx context.Context, fol
if folderLink.Type != proton.LinkTypeFolder {
return nil, ErrLinkTypeMustToBeFolderType
}
folderKeyRing, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID)
folderKeyRing, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID)
if err != nil {
return nil, err
}
@@ -98,7 +98,6 @@ func (protonDrive *ProtonDrive) searchByNameRecursively(
if err != nil {
return nil, err
}
defer linkKR.ClearPrivateParams()
for _, childLink := range childrenLinks {
ret, err := protonDrive.searchByNameRecursively(ctx, linkKR, &childLink, targetName, linkType, listAllActiveOrDraftFiles)
@@ -150,7 +149,7 @@ func (protonDrive *ProtonDrive) SearchByNameInActiveFolder(
return nil, nil
}
parentNodeKR, err := protonDrive.getNodeKRByID(ctx, folderLink.ParentLinkID)
parentNodeKR, err := protonDrive.getLinkKRByID(ctx, folderLink.ParentLinkID)
if err != nil {
return nil, err
}
@@ -160,7 +159,6 @@ func (protonDrive *ProtonDrive) SearchByNameInActiveFolder(
if err != nil {
return nil, err
}
defer folderLinkKR.ClearPrivateParams()
childrenLinks, err := protonDrive.c.ListChildren(ctx, protonDrive.MainShare.ShareID, folderLink.LinkID, true)
if err != nil {