From ad20ba8bd02a0dfd5045c47963532cf4349edddc Mon Sep 17 00:00:00 2001 From: Chun-Hung Tseng Date: Sat, 15 Jul 2023 17:03:53 +0200 Subject: [PATCH] Refactor file download to return file ReadCloser, and also support on-demand block decryption --- README.md | 3 +- constants.go | 5 +- drive_test_helper.go | 10 +++- file.go | 109 ++++++++++++++++++++++++++++++++++--------- folder.go | 7 ++- 5 files changed, 107 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 9426a4c..54086b5 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ We are using a fork of the [proton-go-api](https://github.com/henrybear327/go-pr # Drive APIs > In collaboration with Azimjon Pulatov, in memory of our good old days at Meta, London, in the summer of 2022. +> Thanks to Anson Chen for the motivation and some initial help on various matters! Currently, the development are split into 2 versions. V1 supports the features [required by rclone](https://github.com/henrybear327/rclone/blob/master/fs/types.go), such as `file listing`. As the unit and integration tests from rclone have all been passed, we would stabilize this and then move onto developing V2. @@ -49,6 +50,7 @@ V2 will bring in optimizations and enhancements, such as optimizing uploading an - [x] File actions - [x] Download - [x] Download empty file + - [x] Improve large file download handling - [x] Properly handle large files and empty files (check iOS codebase) - esp. large files, where buffering in-memory will screw up the runtime - [x] Check signature and hash @@ -117,7 +119,6 @@ V2 will bring in optimizations and enhancements, such as optimizing uploading an - [ ] Figure out the bottleneck by doing some profiling - [ ] File - [ ] Parallel download / upload -> enc/dec is expensive - - [ ] Improve large file download handling - [ ] [Filename encoding](https://github.com/ProtonMail/WebClients/blob/b4eba99d241af4fdae06ff7138bd651a40ef5d3c/applications/drive/src/app/store/_links/validation.ts#L51) - [ ] Commit back to proton-go-api and switch to using upstream (make sure the tag is at the tip though) - [ ] Support legacy 2-password mode diff --git a/constants.go b/constants.go index c298f5a..c304e5d 100644 --- a/constants.go +++ b/constants.go @@ -1,6 +1,7 @@ package proton_api_bridge var ( - UPLOAD_BLOCK_SIZE = 4 * 1024 * 1024 // 4 MB - UPLOAD_BATCH_BLOCK_SIZE = 8 + UPLOAD_BLOCK_SIZE = 4 * 1024 * 1024 // 4 MB + UPLOAD_BATCH_BLOCK_SIZE = 8 + DOWNLOAD_BATCH_BLOCK_SIZE = 8 ) diff --git a/drive_test_helper.go b/drive_test_helper.go index a5251c7..c91ab6f 100644 --- a/drive_test_helper.go +++ b/drive_test_helper.go @@ -211,7 +211,12 @@ func downloadFile(t *testing.T, ctx context.Context, protonDrive *ProtonDrive, p if targetFileLink == nil { t.Fatalf("File %v not found", name) } else { - downloadedData, fileSystemAttr, err := protonDrive.DownloadFileByID(ctx, targetFileLink.LinkID) + reader, sizeOnServer, fileSystemAttr, err := protonDrive.DownloadFileByID(ctx, targetFileLink.LinkID) + if err != nil { + t.Fatal(err) + } + + downloadedData, err := io.ReadAll(reader) if err != nil { t.Fatal(err) } @@ -220,6 +225,9 @@ func downloadFile(t *testing.T, ctx context.Context, protonDrive *ProtonDrive, p if fileSystemAttr == nil { t.Fatalf("FileSystemAttr should not be nil") } else { + if sizeOnServer == fileSystemAttr.Size { + t.Fatalf("Not possible due to encryption file overhead") + } if len(downloadedData) != int(fileSystemAttr.Size) { t.Fatalf("Downloaded file size != uploaded file size: %#v vs %#v", len(downloadedData), int(fileSystemAttr.Size)) } diff --git a/file.go b/file.go index decbc89..fe33619 100644 --- a/file.go +++ b/file.go @@ -22,13 +22,49 @@ type FileSystemAttrs struct { Size int64 } -func (protonDrive *ProtonDrive) DownloadFileByID(ctx context.Context, linkID string) ([]byte, *FileSystemAttrs, error) { +type FileDownloadReader struct { + protonDrive *ProtonDrive + ctx context.Context + + data *bytes.Buffer + nodeKR *crypto.KeyRing + sessionKey *crypto.SessionKey + revision *proton.Revision + nextRevision int + + isEOF bool +} + +func (r *FileDownloadReader) Read(p []byte) (int, error) { + if r.data.Len() == 0 { + // we download and decrypt more content + err := r.downloadFileOnRead() + if err != nil { + return 0, err + } + + if r.isEOF { + // if the file has been downloaded entirely, we return EOF + return 0, io.EOF + } + } + + return r.data.Read(p) +} + +func (r *FileDownloadReader) Close() error { + r.protonDrive = nil + + return nil +} + +func (protonDrive *ProtonDrive) DownloadFileByID(ctx context.Context, linkID string) (io.ReadCloser, int64, *FileSystemAttrs, error) { /* It's like event system, we need to get the latest information before creating the move request! */ protonDrive.removeLinkIDFromCache(linkID, false) link, err := protonDrive.getLink(ctx, linkID) if err != nil { - return nil, nil, err + return nil, 0, nil, err } return protonDrive.DownloadFile(ctx, link) @@ -127,50 +163,79 @@ func (protonDrive *ProtonDrive) GetActiveRevisionWithAttrs(ctx context.Context, }, nil } -func (protonDrive *ProtonDrive) DownloadFile(ctx context.Context, link *proton.Link) ([]byte, *FileSystemAttrs, error) { +func (protonDrive *ProtonDrive) DownloadFile(ctx context.Context, link *proton.Link) (io.ReadCloser, int64, *FileSystemAttrs, error) { if link.Type != proton.LinkTypeFile { - return nil, nil, ErrLinkTypeMustToBeFileType + return nil, 0, nil, ErrLinkTypeMustToBeFileType } parentNodeKR, err := protonDrive.getLinkKRByID(ctx, link.ParentLinkID) if err != nil { - return nil, nil, err + return nil, 0, nil, err } nodeKR, err := link.GetKeyRing(parentNodeKR, protonDrive.AddrKR) if err != nil { - return nil, nil, err + return nil, 0, nil, err } sessionKey, err := link.GetSessionKey(protonDrive.AddrKR, nodeKR) if err != nil { - return nil, nil, err + return nil, 0, nil, err } revision, fileSystemAttrs, err := protonDrive.GetActiveRevisionWithAttrs(ctx, link) if err != nil { - return nil, nil, err + return nil, 0, nil, err } - buffer := bytes.NewBuffer(nil) - for i := range revision.Blocks { - // TODO: parallel download - blockReader, err := protonDrive.c.GetBlock(ctx, revision.Blocks[i].BareURL, revision.Blocks[i].Token) - if err != nil { - return nil, nil, err - } - defer blockReader.Close() + reader := &FileDownloadReader{ + protonDrive: protonDrive, + ctx: ctx, - err = decryptBlockIntoBuffer(sessionKey, protonDrive.AddrKR, nodeKR, revision.Blocks[i].Hash, revision.Blocks[i].EncSignature, buffer, blockReader) - if err != nil { - return nil, nil, err - } + data: bytes.NewBuffer(nil), + nodeKR: nodeKR, + sessionKey: sessionKey, + revision: revision, + nextRevision: 0, + + isEOF: false, + } + + err = reader.downloadFileOnRead() + if err != nil { + return nil, 0, nil, err } if fileSystemAttrs != nil { - return buffer.Bytes(), fileSystemAttrs, nil + return reader, link.Size, fileSystemAttrs, nil } - return buffer.Bytes(), nil, nil + return reader, link.Size, nil, nil +} + +func (reader *FileDownloadReader) downloadFileOnRead() error { + if len(reader.revision.Blocks) == 0 || len(reader.revision.Blocks) == reader.nextRevision { + reader.isEOF = true + return nil + } + + offset := reader.nextRevision + for i := offset; i-offset < DOWNLOAD_BATCH_BLOCK_SIZE && i < len(reader.revision.Blocks); i++ { + // TODO: parallel download + blockReader, err := reader.protonDrive.c.GetBlock(reader.ctx, reader.revision.Blocks[i].BareURL, reader.revision.Blocks[i].Token) + if err != nil { + return err + } + defer blockReader.Close() + + err = decryptBlockIntoBuffer(reader.sessionKey, reader.protonDrive.AddrKR, reader.nodeKR, reader.revision.Blocks[i].Hash, reader.revision.Blocks[i].EncSignature, reader.data, blockReader) + if err != nil { + return err + } + + reader.nextRevision = i + 1 + } + + return nil } func (protonDrive *ProtonDrive) UploadFileByReader(ctx context.Context, parentLinkID string, filename string, modTime time.Time, file io.Reader, testParam int) (string, int64, error) { diff --git a/folder.go b/folder.go index 776ae49..f08198b 100644 --- a/folder.go +++ b/folder.go @@ -2,6 +2,7 @@ package proton_api_bridge import ( "context" + "io" "log" "os" "time" @@ -105,11 +106,15 @@ func (protonDrive *ProtonDrive) ListDirectoriesRecursively( log.Println("Downloading", currentPath) defer log.Println("Completes downloading", currentPath) - byteArray, _, err := protonDrive.DownloadFile(ctx, link) + reader, _, _, err := protonDrive.DownloadFile(ctx, link) if err != nil { return err } + byteArray, err := io.ReadAll(reader) + if err != nil { + return err + } err = os.WriteFile("./"+protonDrive.Config.DataFolderName+"/"+currentPath, byteArray, 0777) if err != nil { return err