diff --git a/drive.go b/drive.go index f46f870..7a0758f 100644 --- a/drive.go +++ b/drive.go @@ -3,8 +3,10 @@ package proton_api_bridge import ( "context" "log" + "runtime" "github.com/henrybear327/Proton-API-Bridge/common" + "golang.org/x/sync/semaphore" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/henrybear327/go-proton-api" @@ -27,6 +29,8 @@ type ProtonDrive struct { signatureAddress string linkCache *linkCache + + uploadEncryptionSem *semaphore.Weighted } func NewDefaultConfig() *common.Config { @@ -144,6 +148,8 @@ func NewProtonDrive(ctx context.Context, config *common.Config, authHandler prot signatureAddress: mainShare.Creator, linkCache: newLinkCache(config.DisableLinkCaching), + + uploadEncryptionSem: semaphore.NewWeighted(int64(runtime.GOMAXPROCS(0))), }, credentials, nil } diff --git a/file.go b/file.go index 5e818b8..05fb6f4 100644 --- a/file.go +++ b/file.go @@ -7,9 +7,11 @@ import ( "crypto/sha256" "encoding/base64" "io" + "log" "mime" "os" "path/filepath" + "sync" "time" "github.com/ProtonMail/gopenpgp/v2/crypto" @@ -404,7 +406,58 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n } shouldContinue := true + processData := func(blockIndex int, data []byte, wg *sync.WaitGroup, errorCh chan error) { + defer protonDrive.uploadEncryptionSem.Release(1) + defer wg.Done() + + log.Println("processData start", linkID, blockIndex) + defer log.Println("processData done", linkID, blockIndex) + + // encrypt data + dataPlainMessage := crypto.NewPlainMessage(data) + encData, err := newSessionKey.Encrypt(dataPlainMessage) + if err != nil { + errorCh <- err + return + } + + encSignature, err := protonDrive.AddrKR.SignDetachedEncrypted(dataPlainMessage, newNodeKR) + if err != nil { + errorCh <- err + return + } + encSignatureStr, err := encSignature.GetArmored() + if err != nil { + errorCh <- err + return + } + + h := sha256.New() + h.Write(encData) + hash := h.Sum(nil) + base64Hash := base64.StdEncoding.EncodeToString(hash) + if err != nil { + errorCh <- err + return + } + manifestSignatureData = append(manifestSignatureData, hash...) + + pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{ + blockUploadInfo: proton.BlockUploadInfo{ + Index: blockIndex, // iOS drive: BE starts with 1 + Size: int64(len(encData)), + EncSignature: encSignatureStr, + Hash: base64Hash, + }, + encData: encData, + }) + + errorCh <- nil + } + var wg sync.WaitGroup + errorCh := make(chan error, UPLOAD_BATCH_BLOCK_SIZE) for i := 1; shouldContinue; i++ { + log.Println(linkID, i) // read at most data of size UPLOAD_BLOCK_SIZE data := make([]byte, UPLOAD_BLOCK_SIZE) // FIXME: get block size from the server config instead of hardcoding it readBytes, err := file.Read(data) @@ -424,46 +477,36 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n data = data[:readBytes] totalFileSize += int64(readBytes) - // encrypt data - dataPlainMessage := crypto.NewPlainMessage(data) - encData, err := newSessionKey.Encrypt(dataPlainMessage) + if err := protonDrive.uploadEncryptionSem.Acquire(ctx, 1); err != nil { + return nil, 0, err + } + wg.Add(1) + go processData(i, data, &wg, errorCh) if err != nil { return nil, 0, err } - encSignature, err := protonDrive.AddrKR.SignDetachedEncrypted(dataPlainMessage, newNodeKR) - if err != nil { - return nil, 0, err - } - encSignatureStr, err := encSignature.GetArmored() - if err != nil { - return nil, 0, err - } - - h := sha256.New() - h.Write(encData) - hash := h.Sum(nil) - base64Hash := base64.StdEncoding.EncodeToString(hash) - if err != nil { - return nil, 0, err - } - manifestSignatureData = append(manifestSignatureData, hash...) - - pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{ - blockUploadInfo: proton.BlockUploadInfo{ - Index: i, // iOS drive: BE starts with 1 - Size: int64(len(encData)), - EncSignature: encSignatureStr, - Hash: base64Hash, - }, - encData: encData, - }) - if (i-1) > 0 && (i-1)%UPLOAD_BATCH_BLOCK_SIZE == 0 { + wg.Wait() + close(errorCh) + for err := range errorCh { + if err != nil { + return nil, 0, err + } + } err = uploadPendingBlocks() if err != nil { return nil, 0, err } + + errorCh = make(chan error, UPLOAD_BATCH_BLOCK_SIZE) + } + } + wg.Wait() + close(errorCh) + for err := range errorCh { + if err != nil { + return nil, 0, err } } err := uploadPendingBlocks()