diff --git a/error.go b/error.go index 9e2bf6d..6c76313 100644 --- a/error.go +++ b/error.go @@ -19,4 +19,6 @@ var ( ErrCantFindDraftRevision = errors.New("can't find a draft revision") ErrWrongUsageOfGetLinkKR = errors.New("internal error for GetLinkKR - nil passed in for link") ErrWrongUsageOfGetLink = errors.New("internal error for getLink - empty linkID passed in") + ErrFileEncryptionReturnDataMissing = errors.New("internal error for file upload - encRetChan returns nil") + ErrFileEncryptionMissingIndex = errors.New("internal error for file upload - data from encRetChan is missing an index") ) diff --git a/file.go b/file.go index 5481cc6..e96b6d6 100644 --- a/file.go +++ b/file.go @@ -456,7 +456,7 @@ func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, paren } func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, newSessionKey *crypto.SessionKey, newNodeKR *crypto.KeyRing, file io.Reader, linkID, revisionID string) ([]byte, int64, error) { - type PendingUploadBlocks struct { + type PendingUploadBlock struct { blockUploadInfo proton.BlockUploadInfo encData []byte } @@ -467,7 +467,7 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n totalFileSize := int64(0) - pendingUploadBlocks := make([]PendingUploadBlocks, 0) + pendingUploadBlocks := make([]*PendingUploadBlock, 0) manifestSignatureData := make([]byte, 0) uploadPendingBlocks := func() error { if len(pendingUploadBlocks) == 0 { @@ -491,24 +491,21 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n return err } - errChan := make(chan error) - uploadBlockWrapper := func(ctx context.Context, errChan chan error, bareURL, token string, block io.Reader) { - // log.Println("Before semaphore") + uploadErrChan := make(chan error) + uploadBlockWrapper := func(ctx context.Context, uploadErrChan chan error, bareURL, token string, block io.Reader) { if err := protonDrive.blockUploadSemaphore.Acquire(ctx, 1); err != nil { - errChan <- err + uploadErrChan <- err } defer protonDrive.blockUploadSemaphore.Release(1) - // log.Println("After semaphore") - // defer log.Println("Release semaphore") - errChan <- protonDrive.c.UploadBlock(ctx, bareURL, token, block) + uploadErrChan <- protonDrive.c.UploadBlock(ctx, bareURL, token, block) } for i := range blockUploadResp { - go uploadBlockWrapper(ctx, errChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData)) + go uploadBlockWrapper(ctx, uploadErrChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData)) } for i := 0; i < len(blockUploadResp); i++ { - err := <-errChan + err := <-uploadErrChan if err != nil { return err } @@ -519,11 +516,109 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n return nil } + type encChanData struct { + hash []byte + pendingBlock *PendingUploadBlock + err error + } + encryptDataBlock := func(encRetChan chan *encChanData, ctx context.Context, idx int, data []byte) { + if err := protonDrive.blockCryptoSemaphore.Acquire(ctx, 1); err != nil { + encRetChan <- &encChanData{ + err: err, + } + } + defer protonDrive.blockCryptoSemaphore.Release(1) + + // encrypt data + dataPlainMessage := crypto.NewPlainMessage(data) + encData, err := newSessionKey.Encrypt(dataPlainMessage) + if err != nil { + encRetChan <- &encChanData{ + err: err, + } + } + + encSignature, err := protonDrive.AddrKR.SignDetachedEncrypted(dataPlainMessage, newNodeKR) + if err != nil { + encRetChan <- &encChanData{ + err: err, + } + } + encSignatureStr, err := encSignature.GetArmored() + if err != nil { + encRetChan <- &encChanData{ + err: err, + } + } + + h := sha256.New() + h.Write(encData) + hash := h.Sum(nil) + base64Hash := base64.StdEncoding.EncodeToString(hash) + if err != nil { + encRetChan <- &encChanData{ + err: err, + } + } + + encRetChan <- &encChanData{ + hash: hash, + pendingBlock: &PendingUploadBlock{ + blockUploadInfo: proton.BlockUploadInfo{ + Index: idx, // iOS drive: BE starts with 1 + Size: int64(len(encData)), + EncSignature: encSignatureStr, + Hash: base64Hash, + }, + encData: encData, + }, + err: nil, + } + } + + checkAndUploadPendingBlocks := func(encRetChan chan *encChanData, nextIndex, pendingBlocks *int) error { + cacheEncChanData := make(map[int]*encChanData) + for i := 0; i < *pendingBlocks; i++ { + tmp := <-encRetChan + + if tmp.err != nil { + return tmp.err + } + + if tmp == nil { + return ErrFileEncryptionReturnDataMissing + } + cacheEncChanData[tmp.pendingBlock.blockUploadInfo.Index] = tmp + } + + for i := 0; i < len(cacheEncChanData); i++ { + if val, ok := cacheEncChanData[*nextIndex]; ok { + manifestSignatureData = append(manifestSignatureData, val.hash...) + + pendingUploadBlocks = append(pendingUploadBlocks, val.pendingBlock) + } else { + return ErrFileEncryptionMissingIndex + } + + *nextIndex++ + } + + err := uploadPendingBlocks() + if err != nil { + return err + } + + *pendingBlocks = 0 + return nil + } + shouldContinue := true + encRetChan := make(chan *encChanData) + nextIdx := 1 + pendingBlocks := 0 for i := 1; shouldContinue; i++ { if (i-1) > 0 && (i-1)%UPLOAD_BATCH_BLOCK_SIZE == 0 { - err := uploadPendingBlocks() - if err != nil { + if err := checkAndUploadPendingBlocks(encRetChan, &nextIdx, &pendingBlocks); err != nil { return nil, 0, err } } @@ -547,43 +642,10 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n data = data[:readBytes] totalFileSize += int64(readBytes) - // encrypt data - dataPlainMessage := crypto.NewPlainMessage(data) - encData, err := newSessionKey.Encrypt(dataPlainMessage) - if err != nil { - return nil, 0, err - } - - encSignature, err := protonDrive.AddrKR.SignDetachedEncrypted(dataPlainMessage, newNodeKR) - if err != nil { - return nil, 0, err - } - encSignatureStr, err := encSignature.GetArmored() - if err != nil { - return nil, 0, err - } - - h := sha256.New() - h.Write(encData) - hash := h.Sum(nil) - base64Hash := base64.StdEncoding.EncodeToString(hash) - if err != nil { - return nil, 0, err - } - manifestSignatureData = append(manifestSignatureData, hash...) - - pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{ - blockUploadInfo: proton.BlockUploadInfo{ - Index: i, // iOS drive: BE starts with 1 - Size: int64(len(encData)), - EncSignature: encSignatureStr, - Hash: base64Hash, - }, - encData: encData, - }) + pendingBlocks += 1 + go encryptDataBlock(encRetChan, ctx, i, data) } - err := uploadPendingBlocks() - if err != nil { + if err := checkAndUploadPendingBlocks(encRetChan, &nextIdx, &pendingBlocks); err != nil { return nil, 0, err }