From 12fdfb8b782add6060feadb9058b79bc96e9ea2c Mon Sep 17 00:00:00 2001 From: Ericson Soares Date: Sat, 19 Oct 2024 02:48:27 -0300 Subject: [PATCH] Update sync messages push and pull Also fix scalability issues on sync design --- Cargo.lock | Bin 345122 -> 345145 bytes Cargo.toml | 2 +- .../crates/cloud-services/src/sync/receive.rs | 185 +------- core/crates/cloud-services/src/sync/send.rs | 419 ++--------------- .../src/isolated_file_path_data.rs | 5 +- .../heavy-lifting/src/file_identifier/job.rs | 32 +- .../src/file_identifier/tasks/identifier.rs | 14 +- .../src/file_identifier/tasks/mod.rs | 12 +- core/crates/heavy-lifting/src/indexer/job.rs | 30 +- core/crates/heavy-lifting/src/indexer/mod.rs | 91 ++-- .../heavy-lifting/src/indexer/shallow.rs | 2 +- .../heavy-lifting/src/indexer/tasks/saver.rs | 17 +- .../src/indexer/tasks/updater.rs | 26 +- .../heavy-lifting/src/job_system/report.rs | 2 + .../heavy-lifting/src/media_processor/job.rs | 30 +- core/crates/prisma-helpers/src/lib.rs | 14 + core/crates/sync/src/ingest_utils.rs | 154 +++---- core/crates/sync/src/lib.rs | 7 +- core/crates/sync/tests/lib.rs | 234 ---------- core/crates/sync/tests/mock_instance.rs | 143 ------ core/src/api/files.rs | 83 ++-- core/src/api/labels.rs | 31 +- core/src/api/search/saved.rs | 51 +-- core/src/api/tags.rs | 118 +++-- core/src/library/manager/mod.rs | 1 + core/src/location/manager/watcher/android.rs | 13 +- core/src/location/manager/watcher/ios.rs | 10 +- core/src/location/manager/watcher/linux.rs | 13 +- core/src/location/manager/watcher/macos.rs | 10 +- core/src/location/manager/watcher/mod.rs | 16 +- core/src/location/manager/watcher/utils.rs | 427 ++++++++---------- core/src/location/manager/watcher/windows.rs | 10 +- core/src/location/mod.rs | 295 +++++------- core/src/object/fs/old_copy.rs | 4 +- core/src/object/tag/mod.rs | 4 +- .../object/validation/old_validator_job.rs | 19 +- core/src/old_job/manager.rs | 1 + core/src/old_job/report.rs | 1 + core/src/volume/mod.rs | 95 ++-- crates/crypto/src/cloud/decrypt.rs | 2 +- crates/crypto/src/cloud/secret_key.rs | 2 +- crates/crypto/src/primitives.rs | 4 +- crates/sync-generator/src/model.rs | 41 +- crates/sync-generator/src/sync_data.rs | 139 ++++-- crates/sync/src/crdt.rs | 10 +- crates/sync/src/factory.rs | 57 ++- crates/utils/Cargo.toml | 2 + 47 files changed, 1118 insertions(+), 1760 deletions(-) delete mode 100644 core/crates/sync/tests/lib.rs delete mode 100644 core/crates/sync/tests/mock_instance.rs diff --git a/Cargo.lock b/Cargo.lock index 674aab99ecb23a43ded9898271510f0734fa0183..bbff9ba73223261d4949c704e553fa9abbeae1a0 100644 GIT binary patch delta 105 zcmZ3~A-c0ew4sHug=q`(saXce=E;T@CdnqLhRO(Lnn{Y8S(2fFiGiu5d6JQtiIJ&6 zVp_7fsj2z&fO{+w(|;~z;@&Pho4H1D`uZKrV$%&)G4o9KUct=WzJ3q$_Vs&MirWDN C03%)i delta 90 zcmdnlA-bqTw4sHug=q`(sab|;NhXGg#>pnBNy-SOWumcZijk?2iMf%biKUsPNvfq~ mVzQ-yQHo{aWJe*1>Aw~;ac!5K&0M3{zF`mZ_6>VjI@ Result<(devices::PubId, DateTime), Error> { // FIXME(@fogodev): If we don't have the key hash, we need to fetch it from another device in the group if possible let Some(secret_key) = key_manager.get_key(sync_group_pub_id, &key_hash).await else { return Err(Error::MissingKeyHash); }; - let response = http_client - .get(signed_download_link) - .send() - .await - .map_err(Error::DownloadSyncMessages)? - .error_for_status() - .map_err(Error::ErrorResponseDownloadSyncMessages)?; + debug!( + size = encrypted_messages.len(), + "Received encrypted sync messages collection" + ); - let crdt_ops = if let Some(size) = response.content_length() { - debug!(size, "Received encrypted sync messages collection"); - extract_messages_known_size(response, size, secret_key, original_device_pub_id).await - } else { - debug!("Received encrypted sync messages collection of unknown size"); - extract_messages_unknown_size(response, secret_key, original_device_pub_id).await - }?; + let crdt_ops = decrypt_messages(encrypted_messages, secret_key, original_device_pub_id).await?; assert_eq!( crdt_ops.len(), @@ -285,44 +264,28 @@ async fn handle_single_message( Ok((original_device_pub_id, end_time)) } -#[instrument(skip(response, size, secret_key), err)] -async fn extract_messages_known_size( - response: Response, - size: u64, +#[instrument(skip(encrypted_messages, secret_key), fields(messages_size = %encrypted_messages.len()), err)] +async fn decrypt_messages( + encrypted_messages: Vec, secret_key: SecretKey, devices::PubId(device_pub_id): devices::PubId, ) -> Result, Error> { - let plain_text = if size <= EncryptedBlock::CIPHER_TEXT_SIZE as u64 { - OneShotDecryption::decrypt( - &secret_key, - response - .bytes() - .await - .map_err(Error::ErrorResponseDownloadReadBytesSyncMessages)? - .as_ref() - .into(), - ) - .map_err(Error::Decrypt)? + let plain_text = if encrypted_messages.len() <= EncryptedBlock::CIPHER_TEXT_SIZE { + OneShotDecryption::decrypt(&secret_key, encrypted_messages.as_slice().into()) + .map_err(Error::Decrypt)? } else { - let mut reader = StreamReader::new(response.bytes_stream().map_err(|e| { - error!(?e, "Failed to read sync messages bytes stream"); - io::Error::new(io::ErrorKind::Other, e) - })); + let (nonce, cipher_text) = encrypted_messages.split_at(size_of::()); - let mut nonce = StreamNonce::default(); + let mut plain_text = Vec::with_capacity(cipher_text.len()); - reader - .read_exact(&mut nonce) - .await - .map_err(Error::ReadNonceStreamDecryption)?; - - // TODO: Reimplement using async streaming with serde if it ever gets implemented - - let mut plain_text = vec![]; - - StreamDecryption::decrypt(&secret_key, &nonce, reader, &mut plain_text) - .await - .map_err(Error::Decrypt)?; + StreamDecryption::decrypt( + &secret_key, + nonce.try_into().expect("we split the correct amount"), + cipher_text, + &mut plain_text, + ) + .await + .map_err(Error::Decrypt)?; plain_text }; @@ -332,34 +295,6 @@ async fn extract_messages_known_size( .map_err(Error::DeserializationFailureToPullSyncMessages) } -#[instrument(skip_all, err)] -async fn extract_messages_unknown_size( - response: Response, - secret_key: SecretKey, - devices::PubId(device_pub_id): devices::PubId, -) -> Result, Error> { - let plain_text = match UnknownDownloadKind::new(response).await? { - UnknownDownloadKind::OneShot(buffer) => { - OneShotDecryption::decrypt(&secret_key, buffer.as_slice().into()) - .map_err(Error::Decrypt)? - } - - UnknownDownloadKind::Stream((nonce, reader)) => { - let mut plain_text = vec![]; - - StreamDecryption::decrypt(&secret_key, &nonce, reader, &mut plain_text) - .await - .map_err(Error::Decrypt)?; - - plain_text - } - }; - - rmp_serde::from_slice::(&plain_text) - .map(|compressed_ops| compressed_ops.into_ops(device_pub_id)) - .map_err(Error::DeserializationFailureToPullSyncMessages) -} - #[instrument(skip_all, err)] pub async fn write_cloud_ops_to_db( ops: Vec, @@ -411,73 +346,3 @@ impl LastTimestampKeeper { .map_err(Error::FailedToWriteLastTimestampKeeper) } } - -struct UnknownDownloadSizeStreamer { - stream_reader: Box, - buffer: Vec, - was_read: usize, -} - -enum UnknownDownloadKind { - OneShot(Vec), - Stream((StreamNonce, UnknownDownloadSizeStreamer)), -} - -impl UnknownDownloadKind { - async fn new(response: Response) -> Result { - let mut buffer = Vec::with_capacity(EncryptedBlock::CIPHER_TEXT_SIZE * 2); - - let mut stream = response.bytes_stream(); - - while let Some(res) = stream.next().await { - buffer.extend(res.map_err(Error::ErrorResponseDownloadReadBytesSyncMessages)?); - if buffer.len() > EncryptedBlock::CIPHER_TEXT_SIZE { - break; - } - } - - if buffer.len() < size_of::() { - return Err(Error::IncompleteDownloadBytesSyncMessages); - } - - if buffer.len() <= EncryptedBlock::CIPHER_TEXT_SIZE { - Ok(Self::OneShot(buffer)) - } else { - let nonce_size = size_of::(); - - Ok(Self::Stream(( - StreamNonce::try_from(&buffer[..nonce_size]).expect("passing the right nonce size"), - UnknownDownloadSizeStreamer { - stream_reader: Box::new(StreamReader::new(stream.map_err(|e| { - error!(?e, "Failed to read sync messages bytes stream"); - io::Error::new(io::ErrorKind::Other, e) - }))), - buffer, - was_read: nonce_size, - }, - ))) - } - } -} - -impl AsyncRead for UnknownDownloadSizeStreamer { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if buf.remaining() == 0 { - return Poll::Ready(Ok(())); - } - - if self.was_read == self.buffer.len() { - Pin::new(&mut self.stream_reader).poll_read(cx, buf) - } else { - let len = std::cmp::min(self.buffer.len() - self.was_read, buf.remaining()); - buf.put_slice(&self.buffer[self.was_read..(self.was_read + len)]); - self.was_read += len; - - Poll::Ready(Ok(())) - } - } -} diff --git a/core/crates/cloud-services/src/sync/send.rs b/core/crates/cloud-services/src/sync/send.rs index 2e36b8118..4fd3842da 100644 --- a/core/crates/cloud-services/src/sync/send.rs +++ b/core/crates/cloud-services/src/sync/send.rs @@ -6,7 +6,7 @@ use sd_actors::{Actor, Stopper}; use sd_cloud_schema::{ devices, error::{ClientSideError, NotFoundError}, - sync::{self, groups, messages}, + sync::{groups, messages}, Client, Service, }; use sd_crypto::{ @@ -18,8 +18,7 @@ use sd_utils::{datetime_to_timestamp, timestamp_to_datetime}; use std::{ future::IntoFuture, - num::NonZero, - pin::{pin, Pin}, + pin::pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -27,24 +26,20 @@ use std::{ time::{Duration, UNIX_EPOCH}, }; -use async_stream::try_stream; use chrono::{DateTime, Utc}; -use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use futures_concurrency::future::{Race, TryJoin}; -use quic_rpc::{client::UpdateSink, pattern::bidi_streaming, transport::quinn::QuinnConnection}; -use reqwest_middleware::reqwest::{header, Body}; +use quic_rpc::transport::quinn::QuinnConnection; use tokio::{ - spawn, - sync::{broadcast, oneshot, Notify, Semaphore}, + sync::{broadcast, Notify}, time::sleep, }; -use tracing::{debug, error}; +use tracing::error; use uuid::Uuid; use super::{SyncActors, ONE_MINUTE}; const TEN_SECONDS: Duration = Duration::from_secs(10); -const THIRTY_SECONDS: Duration = Duration::from_secs(30); const MESSAGES_COLLECTION_SIZE: u32 = 100_000; @@ -60,18 +55,6 @@ enum LoopStatus { type LatestTimestamp = NTP64; -type PushResponsesStream = Pin< - Box< - dyn Stream< - Item = Result< - Result, - bidi_streaming::ItemError>, - >, - > + Send - + Sync, - >, ->; - #[derive(Debug)] pub struct Sender { sync_group_pub_id: groups::PubId, @@ -205,17 +188,7 @@ impl Sender { let messages_bytes = rmp_serde::to_vec_named(&compressed_ops) .map_err(Error::SerializationFailureToPushSyncMessages)?; - let plain_text_size = messages_bytes.len(); - let expected_blob_size = if plain_text_size <= EncryptedBlock::PLAIN_TEXT_SIZE { - OneShotEncryption::cipher_text_size(&secret_key, plain_text_size) - } else { - StreamEncryption::cipher_text_size(&secret_key, plain_text_size) - } as u64; - - debug!(?expected_blob_size, ?key_hash, "Preparing sync message"); - - let (mut push_updates, mut push_responses) = self - .cloud_client + self.cloud_client .sync() .messages() .push(messages::push::Request { @@ -228,54 +201,15 @@ impl Sender { device_pub_id: current_device_pub_id, key_hash: key_hash.clone(), operations_count, - start_time, - end_time, - expected_blob_size, + time_range: (start_time, end_time), + encrypted_messages: encrypt_messages( + &secret_key, + &mut self.rng, + messages_bytes, + ) + .await?, }) - .await?; - - let Some(response) = push_responses.next().await else { - return Err(Error::EmptyResponse("push initial response")); - }; - - let messages::push::Response(response_kind) = response??; - - match response_kind { - messages::push::ResponseKind::SinglePresignedUrl(url) => { - upload_to_single_url( - url, - secret_key.clone(), - self.cloud_services.http_client(), - messages_bytes, - &mut self.rng, - ) - .await?; - } - messages::push::ResponseKind::ManyPresignedUrls(urls) => { - upload_to_many_urls( - urls, - secret_key.clone(), - self.cloud_services.http_client().clone(), - messages_bytes, - &mut self.rng, - &mut push_updates, - &mut push_responses, - ) - .await?; - } - messages::push::ResponseKind::Pong => { - return Err(Error::UnexpectedResponse( - "Pong on first messages push request", - )) - } - messages::push::ResponseKind::End => { - return Err(Error::UnexpectedResponse( - "End on first messages push request", - )) - } - } - - finalize_protocol(&mut push_updates, &mut push_responses).await?; + .await??; status = LoopStatus::SentMessages; } @@ -303,8 +237,7 @@ impl Sender { .get_access_token() .await?, group_pub_id: self.sync_group_pub_id, - current_device_pub_id, - kind: messages::get_latest_time::Kind::ForCurrentDevice, + kind: messages::get_latest_time::Kind::ForCurrentDevice(current_device_pub_id), }) .await? { @@ -328,320 +261,44 @@ impl Sender { } } -async fn finalize_protocol( - push_updates: &mut UpdateSink< - Service, - QuinnConnection, - messages::push::RequestUpdate, - sync::Service, - >, - push_responses: &mut PushResponsesStream, -) -> Result<(), Error> { - push_updates - .send(messages::push::RequestUpdate( - messages::push::UpdateKind::End, - )) - .await - .map_err(Error::EndUpdatePushSyncMessages)?; - - let Some(response) = push_responses.next().await else { - return Err(Error::EmptyResponse("push initial response")); - }; - - let messages::push::Response(response_kind) = response??; - - match response_kind { - messages::push::ResponseKind::SinglePresignedUrl(_) - | messages::push::ResponseKind::ManyPresignedUrls(_) => { - return Err(Error::UnexpectedResponse( - "Urls responses on final messages push response", - )) - } - messages::push::ResponseKind::Pong => { - return Err(Error::UnexpectedResponse( - "Pong on final message push response", - )) - } - messages::push::ResponseKind::End => { - /* - Everything is awesome! - */ - } - } - - Ok(()) -} - -async fn upload_to_many_urls( - urls: Vec, - secret_key: SecretKey, - http_client: reqwest_middleware::ClientWithMiddleware, - messages_bytes: Vec, +async fn encrypt_messages( + secret_key: &SecretKey, rng: &mut CryptoRng, - push_updates: &mut UpdateSink< - Service, - QuinnConnection, - messages::push::RequestUpdate, - sync::Service, - >, - push_responses: &mut PushResponsesStream, -) -> Result<(), Error> { - let stop_ping_pong = Arc::new(AtomicBool::new(false)); - let (out_tx, mut out_rx) = oneshot::channel(); - let rng = CryptoRng::from_seed(rng.generate_fixed()); - - let handle = spawn(handle_multipart_upload( - urls, - secret_key, - http_client, - messages_bytes, - rng, - Arc::clone(&stop_ping_pong), - out_tx, - )); - - loop { - if stop_ping_pong.load(Ordering::Acquire) { - break; - } - - if let Err(e) = push_updates - .send(messages::push::RequestUpdate( - messages::push::UpdateKind::Ping, - )) - .await - { - error!(?e, "Failed to send push ping update"); - sleep(TEN_SECONDS).await; - continue; - } - - let Some(response) = push_responses.next().await else { - error!("Empty response from push ping response"); - continue; - }; - - match response { - Ok(Ok(messages::push::Response( - messages::push::ResponseKind::SinglePresignedUrl(_) - | messages::push::ResponseKind::ManyPresignedUrls(_), - ))) => { - unreachable!("can't receive url if we didn't send an initial request") - } - - Ok(Ok(messages::push::Response(messages::push::ResponseKind::Pong))) => { - /* - Everything is awesome! - */ - } - Ok(Ok(messages::push::Response(messages::push::ResponseKind::End))) => { - unreachable!("Can't receive an End if we didn't send an End first"); - } - - Ok(Err(e)) => { - error!(?e, "Error from push ping response"); - sleep(TEN_SECONDS).await; - continue; - } - - Err(e) => { - error!(?e, "Error from push ping response"); - sleep(TEN_SECONDS).await; - continue; - } - } - - if stop_ping_pong.load(Ordering::Acquire) { - break; - } - - sleep(THIRTY_SECONDS).await; - } - - let Ok(out) = out_rx.try_recv() else { - // SAFETY: This try_recv error can only happen if the upload task panicked - // so we're good to unwrap the error. - let e = handle.await.expect_err("upload task panicked"); - error!(?e, "Critical error while uploading sync messages"); - return Err(Error::CriticalErrorWhileUploadingSyncMessages); - }; - - out -} - -async fn handle_multipart_upload( - urls: Vec, - secret_key: SecretKey, - http_client: reqwest_middleware::ClientWithMiddleware, messages_bytes: Vec, - rng: CryptoRng, - stop_ping_pong: Arc, - out_tx: oneshot::Sender>, -) { - async fn inner( - urls: Vec, - secret_key: SecretKey, - http_client: reqwest_middleware::ClientWithMiddleware, - messages_bytes: Vec, - mut rng: CryptoRng, - ) -> Result<(), Error> { - let urls_count = urls.len(); - let message_size = messages_bytes.len(); - let blocks_per_url = message_size / urls_count / EncryptedBlock::PLAIN_TEXT_SIZE; - let cipher_text_size = StreamEncryption::cipher_text_size(&secret_key, message_size); - - let parallel_upload_semaphore = Arc::new(Semaphore::new( - std::thread::available_parallelism() - .map(NonZero::get) - .unwrap_or(1), +) -> Result, Error> { + if messages_bytes.len() <= EncryptedBlock::PLAIN_TEXT_SIZE { + let mut nonce_and_cipher_text = Vec::with_capacity(OneShotEncryption::cipher_text_size( + secret_key, + messages_bytes.len(), )); - // If we're uploading to many URLs, it implies that the message size is bigger than a single - // encryption block, so we always use stream encryption. - - let mut buffers = vec![Vec::with_capacity(cipher_text_size / urls_count); urls_count]; - let (nonce, cipher_stream) = - StreamEncryption::encrypt(&secret_key, messages_bytes.as_slice(), &mut rng); - - buffers[0].extend_from_slice(&nonce); - - let mut cipher_stream = pin!(cipher_stream); - - let mut handles = Vec::with_capacity(urls_count); - - for (idx, (mut buffer, url)) in buffers.into_iter().zip(urls).enumerate() { - for _ in 0..blocks_per_url { - if let Some(cipher_res) = cipher_stream.next().await { - buffer.extend(cipher_res.map_err(Error::Encrypt)?); - } else { - return Err(Error::UnexpectedEndOfStream); - } - } - - handles.push(spawn(upload_part( - idx, - url, - http_client.clone(), - buffer, - Arc::clone(¶llel_upload_semaphore), - ))); - } - - assert!( - cipher_stream.next().await.is_none(), - "Unexpected ciphered bytes still on stream" - ); - - handles.try_join().await.map_err(|e| { - error!(?e, "Error while uploading sync messages"); - Error::CriticalErrorWhileUploadingSyncMessages - })?; - - Ok(()) - } - - let res = inner(urls, secret_key, http_client, messages_bytes, rng).await; - stop_ping_pong.store(true, Ordering::Release); - out_tx - .send(res) - .expect("upload output channel never closes"); -} - -async fn upload_part( - idx: usize, - url: reqwest::Url, - http_client: reqwest_middleware::ClientWithMiddleware, - buffer: Vec, - parallel_upload_semaphore: Arc, -) -> Result<(), Error> { - let _permit = parallel_upload_semaphore - .acquire() - .await - .expect("Semaphore never closes"); - - let response = http_client - .put(url) - .header(header::CONTENT_LENGTH, buffer.len()) - .body(buffer) - .send() - .await - .map_err(Error::UploadSyncMessages)? - .error_for_status() - .map_err(Error::ErrorResponseUploadSyncMessages)?; - - debug!(?response, idx, "Uploaded sync messages part"); - - Ok(()) -} - -async fn upload_to_single_url( - url: reqwest::Url, - secret_key: SecretKey, - http_client: &reqwest_middleware::ClientWithMiddleware, - messages_bytes: Vec, - rng: &mut CryptoRng, -) -> Result<(), Error> { - let (cipher_text_size, body) = if messages_bytes.len() <= EncryptedBlock::PLAIN_TEXT_SIZE { let EncryptedBlock { nonce, cipher_text } = - OneShotEncryption::encrypt(&secret_key, messages_bytes.as_slice(), rng) + OneShotEncryption::encrypt(secret_key, messages_bytes.as_slice(), rng) .map_err(Error::Encrypt)?; - let cipher_text_size = nonce.len() + cipher_text.len(); + nonce_and_cipher_text.extend_from_slice(nonce.as_slice()); + nonce_and_cipher_text.extend(&cipher_text); - let mut body_bytes = Vec::with_capacity(cipher_text_size); - body_bytes.extend_from_slice(nonce.as_slice()); - body_bytes.extend(&cipher_text); - - (cipher_text_size, Body::from(body_bytes)) + Ok(nonce_and_cipher_text) } else { let mut rng = CryptoRng::from_seed(rng.generate_fixed()); - let cipher_text_size = - StreamEncryption::cipher_text_size(&secret_key, messages_bytes.len()); + let mut nonce_and_cipher_text = Vec::with_capacity(StreamEncryption::cipher_text_size( + secret_key, + messages_bytes.len(), + )); - let body_bytes = stream_encryption(secret_key, messages_bytes, &mut rng) - .try_fold( - Vec::with_capacity(cipher_text_size), - |mut body_bytes, ciphered_chunk| async move { - body_bytes.extend(ciphered_chunk); - Ok(body_bytes) - }, - ) - .await?; - - (cipher_text_size, Body::from(body_bytes)) - }; - - http_client - .put(url) - .header(header::CONTENT_LENGTH, cipher_text_size) - .body(body) - .send() - .await - .map_err(Error::UploadSyncMessages)? - .error_for_status() - .map_err(Error::ErrorResponseUploadSyncMessages)?; - - Ok(()) -} - -fn stream_encryption( - secret_key: SecretKey, - messages_bytes: Vec, - rng: &mut CryptoRng, -) -> impl TryStream, Error = Error> + Send + 'static { - let mut rng = CryptoRng::from_seed(rng.generate_fixed()); - - try_stream! { let (nonce, cipher_stream) = - StreamEncryption::encrypt(&secret_key, messages_bytes.as_slice(), &mut rng); + StreamEncryption::encrypt(secret_key, messages_bytes.as_slice(), &mut rng); + + nonce_and_cipher_text.extend_from_slice(nonce.as_slice()); let mut cipher_stream = pin!(cipher_stream); - yield nonce.to_vec(); - - while let Some(res) = cipher_stream.next().await { - yield res.map_err(Error::Encrypt)?; + while let Some(ciphered_chunk) = cipher_stream.try_next().await.map_err(Error::Encrypt)? { + nonce_and_cipher_text.extend(ciphered_chunk); } + + Ok(nonce_and_cipher_text) } } diff --git a/core/crates/file-path-helper/src/isolated_file_path_data.rs b/core/crates/file-path-helper/src/isolated_file_path_data.rs index 3e89cce0f..fe83bbee9 100644 --- a/core/crates/file-path-helper/src/isolated_file_path_data.rs +++ b/core/crates/file-path-helper/src/isolated_file_path_data.rs @@ -2,7 +2,7 @@ use sd_core_prisma_helpers::{ file_path_for_file_identifier, file_path_for_media_processor, file_path_for_object_validator, file_path_to_full_path, file_path_to_handle_custom_uri, file_path_to_handle_p2p_serve_file, file_path_to_isolate, file_path_to_isolate_with_id, file_path_to_isolate_with_pub_id, - file_path_walker, file_path_with_object, + file_path_walker, file_path_watcher_remove, file_path_with_object, }; use sd_prisma::prisma::{file_path, location}; @@ -506,7 +506,8 @@ impl_from_db!( file_path_to_isolate_with_pub_id, file_path_walker, file_path_to_isolate_with_id, - file_path_with_object + file_path_with_object, + file_path_watcher_remove ); impl_from_db_without_location_id!( diff --git a/core/crates/heavy-lifting/src/file_identifier/job.rs b/core/crates/heavy-lifting/src/file_identifier/job.rs index a75bbb2ca..249ea57f2 100644 --- a/core/crates/heavy-lifting/src/file_identifier/job.rs +++ b/core/crates/heavy-lifting/src/file_identifier/job.rs @@ -14,7 +14,11 @@ use crate::{ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId}; -use sd_prisma::prisma::{device, file_path, location, SortOrder}; +use sd_prisma::{ + prisma::{device, file_path, location, SortOrder}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, @@ -267,15 +271,25 @@ impl Job for FileIdentifier { .. } = self; - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set( - LocationScanState::FilesIdentified as i32, - )], + let (sync_param, db_param) = sync_db_not_null_entry!( + LocationScanState::FilesIdentified as i32, + location::scan_state + ); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(file_identifier::Error::from)?; diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs index 22d244cac..125a72713 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs @@ -12,11 +12,11 @@ use sd_prisma::{ prisma::{device, file_path, location, PrismaClient}, prisma_sync, }; -use sd_sync::OperationFactory; +use sd_sync::{sync_db_entry, OperationFactory}; use sd_task_system::{ ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, SerializableTask, Task, TaskId, }; -use sd_utils::{error::FileIOError, msgpack}; +use sd_utils::error::FileIOError; use std::{ collections::HashMap, convert::identity, future::IntoFuture, mem, path::PathBuf, pin::pin, @@ -403,19 +403,17 @@ async fn assign_cas_id_to_file_paths( let (ops, queries) = identified_files .iter() .map(|(pub_id, IdentifiedFile { cas_id, .. })| { + let (sync_param, db_param) = sync_db_entry!(cas_id, file_path::cas_id); + ( sync.shared_update( prisma_sync::file_path::SyncId { pub_id: pub_id.to_db(), }, - file_path::cas_id::NAME, - msgpack!(cas_id), + [sync_param], ), db.file_path() - .update( - file_path::pub_id::equals(pub_id.to_db()), - vec![file_path::cas_id::set(cas_id.into())], - ) + .update(file_path::pub_id::equals(pub_id.to_db()), vec![db_param]) // We don't need any data here, just the id avoids receiving the entire object // as we can't pass an empty select macro call .select(file_path::select!({ id })), diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs index 13e4d7d9f..59f75d0a9 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs @@ -9,7 +9,7 @@ use sd_prisma::{ prisma_sync, }; use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, CRDTOperation, OperationFactory}; -use sd_utils::{chain_optional_iter, msgpack}; +use sd_utils::chain_optional_iter; use std::collections::{HashMap, HashSet}; @@ -47,10 +47,12 @@ fn connect_file_path_to_object<'db>( prisma_sync::file_path::SyncId { pub_id: file_path_pub_id.to_db(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: object_pub_id.to_db(), - }), + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: object_pub_id.to_db(), + }, + file_path::object + )], ), db.file_path() .update( diff --git a/core/crates/heavy-lifting/src/indexer/job.rs b/core/crates/heavy-lifting/src/indexer/job.rs index d910fec81..cf19fbb90 100644 --- a/core/crates/heavy-lifting/src/indexer/job.rs +++ b/core/crates/heavy-lifting/src/indexer/job.rs @@ -16,7 +16,11 @@ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::location_with_indexer_rules; -use sd_prisma::prisma::{device, location}; +use sd_prisma::{ + prisma::{device, location}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, @@ -269,7 +273,7 @@ impl Job for Indexer { .await?; } - update_location_size(location.id, ctx.db(), &ctx).await?; + update_location_size(location.id, location.pub_id.clone(), &ctx).await?; metadata.mean_db_write_time += start_size_update_time.elapsed(); } @@ -287,13 +291,23 @@ impl Job for Indexer { "all tasks must be completed here" ); - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set(LocationScanState::Indexed as i32)], + let (sync_param, db_param) = + sync_db_not_null_entry!(LocationScanState::Indexed as i32, location::scan_state); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(indexer::Error::from)?; diff --git a/core/crates/heavy-lifting/src/indexer/mod.rs b/core/crates/heavy-lifting/src/indexer/mod.rs index 1ad78902b..6880e6d91 100644 --- a/core/crates/heavy-lifting/src/indexer/mod.rs +++ b/core/crates/heavy-lifting/src/indexer/mod.rs @@ -10,11 +10,11 @@ use sd_prisma::{ prisma::{file_path, indexer_rule, location, PrismaClient, SortOrder}, prisma_sync, }; -use sd_sync::OperationFactory; +use sd_sync::{sync_db_entry, OperationFactory}; use sd_utils::{ db::{size_in_bytes_from_db, size_in_bytes_to_db, MissingFieldError}, error::{FileIOError, NonUtf8PathError}, - from_bytes_to_uuid, msgpack, + from_bytes_to_uuid, }; use std::{ @@ -146,22 +146,20 @@ async fn update_directory_sizes( .map(|file_path| { let size_bytes = iso_paths_and_sizes .get(&IsolatedFilePathData::try_from(&file_path)?) - .map(|size| size.to_be_bytes().to_vec()) + .map(|size| size_in_bytes_to_db(*size)) .expect("must be here"); + let (sync_param, db_param) = sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes); + Ok(( sync.shared_update( prisma_sync::file_path::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::size_in_bytes_bytes::NAME, - msgpack!(size_bytes), + [sync_param], ), db.file_path() - .update( - file_path::pub_id::equals(file_path.pub_id), - vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))], - ) + .update(file_path::pub_id::equals(file_path.pub_id), vec![db_param]) .select(file_path::select!({ id })), )) }) @@ -178,35 +176,45 @@ async fn update_directory_sizes( async fn update_location_size( location_id: location::id::Type, - db: &PrismaClient, + location_pub_id: location::pub_id::Type, ctx: &impl OuterContext, ) -> Result<(), Error> { - let total_size = db - .file_path() - .find_many(vec![ - file_path::location_id::equals(Some(location_id)), - file_path::materialized_path::equals(Some("/".to_string())), - ]) - .select(file_path::select!({ size_in_bytes_bytes })) - .exec() - .await? - .into_iter() - .filter_map(|file_path| { - file_path - .size_in_bytes_bytes - .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) - }) - .sum::(); + let db = ctx.db(); + let sync = ctx.sync(); - db.location() - .update( - location::id::equals(location_id), - vec![location::size_in_bytes::set(Some( - total_size.to_be_bytes().to_vec(), - ))], - ) - .exec() - .await?; + let total_size = size_in_bytes_to_db( + db.file_path() + .find_many(vec![ + file_path::location_id::equals(Some(location_id)), + file_path::materialized_path::equals(Some("/".to_string())), + ]) + .select(file_path::select!({ size_in_bytes_bytes })) + .exec() + .await? + .into_iter() + .filter_map(|file_path| { + file_path + .size_in_bytes_bytes + .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) + }) + .sum::(), + ); + + let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location_pub_id, + }, + [sync_param], + ), + db.location() + .update(location::id::equals(location_id), vec![db_param]) + .select(location::select!({ id })), + ) + .await?; ctx.invalidate_query("locations.list"); ctx.invalidate_query("locations.get"); @@ -334,18 +342,19 @@ pub async fn reverse_update_directories_sizes( { let size_bytes = size_in_bytes_to_db(size); + let (sync_param, db_param) = + sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes); + Some(( sync.shared_update( prisma_sync::file_path::SyncId { pub_id: pub_id.clone(), }, - file_path::size_in_bytes_bytes::NAME, - msgpack!(size_bytes), - ), - db.file_path().update( - file_path::pub_id::equals(pub_id), - vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))], + [sync_param], ), + db.file_path() + .update(file_path::pub_id::equals(pub_id), vec![db_param]) + .select(file_path::select!({ id })), )) } else { warn!("Got a missing ancestor for a file_path in the database, ignoring..."); diff --git a/core/crates/heavy-lifting/src/indexer/shallow.rs b/core/crates/heavy-lifting/src/indexer/shallow.rs index 90a22eead..1bc55b556 100644 --- a/core/crates/heavy-lifting/src/indexer/shallow.rs +++ b/core/crates/heavy-lifting/src/indexer/shallow.rs @@ -136,7 +136,7 @@ pub async fn shallow( .await?; } - update_location_size(location.id, db, ctx).await?; + update_location_size(location.id, location.pub_id, ctx).await?; } if indexed_count > 0 || removed_count > 0 { diff --git a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs index 9fbe24554..c5d0951d0 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs @@ -9,10 +9,7 @@ use sd_prisma::{ }; use sd_sync::{sync_db_entry, sync_entry, OperationFactory}; use sd_task_system::{ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId}; -use sd_utils::{ - db::{inode_to_db, size_in_bytes_to_db}, - msgpack, -}; +use sd_utils::db::{inode_to_db, size_in_bytes_to_db}; use std::{sync::Arc, time::Duration}; @@ -121,13 +118,13 @@ impl Task for Saver { new file_paths and they were not identified yet" ); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ + let (sync_params, db_params) = [ ( - ( - location::NAME, - msgpack!(prisma_sync::location::SyncId { + sync_entry!( + prisma_sync::location::SyncId { pub_id: location_pub_id.clone() - }), + }, + location ), location_id::set(Some(*location_id)), ), @@ -152,7 +149,7 @@ impl Task for Saver { ), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); ( sync.shared_create( diff --git a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs index 91eb72899..80cf3d6f4 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs @@ -93,7 +93,7 @@ impl Task for Updater { check_interruption!(interrupter); - let (sync_stuff, paths_to_update) = walked_entries + let (crdt_ops, paths_to_update) = walked_entries .drain(..) .map( |WalkedEntry { @@ -138,18 +138,12 @@ impl Task for Updater { .unzip::<_, _, Vec<_>, Vec<_>>(); ( - sync_params - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: pub_id.to_db(), - }, - field, - value, - ) - }) - .collect::>(), + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: pub_id.to_db(), + }, + sync_params, + ), db.file_path() .update(file_path::pub_id::equals(pub_id.into()), db_params) // selecting id to avoid fetching whole object from database @@ -159,9 +153,7 @@ impl Task for Updater { ) .unzip::<_, _, Vec<_>, Vec<_>>(); - let ops = sync_stuff.into_iter().flatten().collect::>(); - - if ops.is_empty() && paths_to_update.is_empty() { + if crdt_ops.is_empty() && paths_to_update.is_empty() { return Ok(ExecStatus::Done( Output { updated_count: 0, @@ -172,7 +164,7 @@ impl Task for Updater { } let updated = sync - .write_ops(db, (ops, paths_to_update)) + .write_ops(db, (crdt_ops, paths_to_update)) .await .map_err(indexer::Error::from)?; diff --git a/core/crates/heavy-lifting/src/job_system/report.rs b/core/crates/heavy-lifting/src/job_system/report.rs index 3d536e7dd..b747b8195 100644 --- a/core/crates/heavy-lifting/src/job_system/report.rs +++ b/core/crates/heavy-lifting/src/job_system/report.rs @@ -290,6 +290,7 @@ impl Report { .map(|id| job::parent::connect(job::id::equals(id.as_bytes().to_vec())))], ), ) + .select(job::select!({ id })) .exec() .await .map_err(ReportError::Create)?; @@ -318,6 +319,7 @@ impl Report { job::date_completed::set(self.completed_at.map(Into::into)), ], ) + .select(job::select!({ id })) .exec() .await .map_err(ReportError::Update)?; diff --git a/core/crates/heavy-lifting/src/media_processor/job.rs b/core/crates/heavy-lifting/src/media_processor/job.rs index cadeb5f03..fb622e162 100644 --- a/core/crates/heavy-lifting/src/media_processor/job.rs +++ b/core/crates/heavy-lifting/src/media_processor/job.rs @@ -14,7 +14,11 @@ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::file_path_for_media_processor; use sd_file_ext::extensions::Extension; -use sd_prisma::prisma::{location, PrismaClient}; +use sd_prisma::{ + prisma::{location, PrismaClient}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, TaskSystemError, @@ -214,15 +218,23 @@ impl Job for MediaProcessor { .. } = self; - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set( - LocationScanState::Completed as i32, - )], + let (sync_param, db_param) = + sync_db_not_null_entry!(LocationScanState::Completed as i32, location::scan_state); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(media_processor::Error::from)?; diff --git a/core/crates/prisma-helpers/src/lib.rs b/core/crates/prisma-helpers/src/lib.rs index 48a400e65..ee8f11bb9 100644 --- a/core/crates/prisma-helpers/src/lib.rs +++ b/core/crates/prisma-helpers/src/lib.rs @@ -74,6 +74,20 @@ file_path::select!(file_path_for_media_processor { pub_id } }); +file_path::select!(file_path_watcher_remove { + id + pub_id + location_id + materialized_path + is_dir + name + extension + object: select { + id + pub_id + } + +}); file_path::select!(file_path_to_isolate { location_id materialized_path diff --git a/core/crates/sync/src/ingest_utils.rs b/core/crates/sync/src/ingest_utils.rs index 3cc4c8a68..6c77a96b7 100644 --- a/core/crates/sync/src/ingest_utils.rs +++ b/core/crates/sync/src/ingest_utils.rs @@ -1,7 +1,7 @@ use sd_core_prisma_helpers::DevicePubId; use sd_prisma::{ - prisma::{crdt_operation, PrismaClient, SortOrder}, + prisma::{crdt_operation, PrismaClient}, prisma_sync::ModelSyncData, }; use sd_sync::{ @@ -17,6 +17,8 @@ use uuid::Uuid; use super::{db_operation::write_crdt_op_to_db, Error, TimestampPerDevice}; +crdt_operation::select!(crdt_operation_id { id }); + // where the magic happens #[instrument(skip(clock, ops), fields(operations_count = %ops.len()), err)] pub async fn process_crdt_operations( @@ -24,7 +26,7 @@ pub async fn process_crdt_operations( timestamp_per_device: &TimestampPerDevice, db: &PrismaClient, device_pub_id: DevicePubId, - model: ModelId, + model_id: ModelId, record_id: RecordId, mut ops: Vec, ) -> Result<(), Error> { @@ -50,7 +52,7 @@ pub async fn process_crdt_operations( .find(|op| matches!(op.data, CRDTOperationData::Delete)) { trace!("Deleting operation"); - handle_crdt_deletion(db, &device_pub_id, model, record_id, delete_op).await?; + handle_crdt_deletion(db, &device_pub_id, model_id, record_id, delete_op).await?; } // Create + > 0 Update - overwrites the create's data with the updates else if let Some(timestamp) = ops @@ -61,23 +63,22 @@ pub async fn process_crdt_operations( trace!("Create + Updates operations"); // conflict resolution - let delete = db + let delete_count = db .crdt_operation() - .find_first(vec![ - crdt_operation::model::equals(i32::from(model)), + .count(vec![ + crdt_operation::model::equals(i32::from(model_id)), crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), crdt_operation::kind::equals(OperationKind::Delete.to_string()), ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc)) .exec() .await?; - if delete.is_some() { + if delete_count > 0 { debug!("Found a previous delete operation with the same SyncId, will ignore these operations"); return Ok(()); } - handle_crdt_create_and_updates(db, &device_pub_id, model, record_id, ops, timestamp) + handle_crdt_create_and_updates(db, &device_pub_id, model_id, record_id, ops, timestamp) .await?; } // > 0 Update - batches updates with a fake Create op @@ -87,51 +88,57 @@ pub async fn process_crdt_operations( let mut data = BTreeMap::new(); for op in ops.into_iter().rev() { - let CRDTOperationData::Update { field, value } = op.data else { + let CRDTOperationData::Update(fields_and_values) = op.data else { unreachable!("Create + Delete should be filtered out!"); }; - data.insert(field, (value, op.timestamp)); + for (field, value) in fields_and_values { + data.insert(field, (value, op.timestamp)); + } } // conflict resolution - let (create, updates) = db + let (create, newer_updates_count) = db ._batch(( - db.crdt_operation() - .find_first(vec![ - crdt_operation::model::equals(i32::from(model)), - crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), - crdt_operation::kind::equals(OperationKind::Create.to_string()), - ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc)), + db.crdt_operation().count(vec![ + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + crdt_operation::kind::equals(OperationKind::Create.to_string()), + ]), data.iter() .map(|(k, (_, timestamp))| { - Ok(db - .crdt_operation() - .find_first(vec![ - crdt_operation::timestamp::gt({ - #[allow(clippy::cast_possible_wrap)] - // SAFETY: we had to store using i64 due to SQLite limitations - { - timestamp.as_u64() as i64 - } - }), - crdt_operation::model::equals(i32::from(model)), - crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), - crdt_operation::kind::equals(OperationKind::Update(k).to_string()), - ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc))) + Ok(db.crdt_operation().count(vec![ + crdt_operation::timestamp::gt({ + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we had to store using i64 due to SQLite limitations + { + timestamp.as_u64() as i64 + } + }), + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + crdt_operation::kind::contains(format!(":{k}:")), + ])) }) .collect::, Error>>()?, )) .await?; - if create.is_none() { + if create == 0 { warn!("Failed to find a previous create operation with the same SyncId"); return Ok(()); } - handle_crdt_updates(db, &device_pub_id, model, record_id, data, updates).await?; + let keys = data.keys().cloned().collect::>(); + + // remove entries if we possess locally more recent updates for this field + for (update, key) in newer_updates_count.into_iter().zip(keys) { + if update > 0 { + data.remove(&key); + } + } + + handle_crdt_updates(db, &device_pub_id, model_id, record_id, data).await?; } // read the timestamp for the operation's device, or insert one if it doesn't exist @@ -157,24 +164,15 @@ async fn handle_crdt_updates( device_pub_id: &DevicePubId, model_id: ModelId, record_id: rmpv::Value, - mut data: BTreeMap, - updates: Vec>, + data: BTreeMap, ) -> Result<(), Error> { - let keys = data.keys().cloned().collect::>(); let device_pub_id = sd_sync::DevicePubId::from(device_pub_id); - // does the same thing as processing ops one-by-one and returning early if a newer op was found - for (update, key) in updates.into_iter().zip(keys) { - if update.is_some() { - data.remove(&key); - } - } - db._transaction() .with_timeout(30 * 10000) .with_max_wait(30 * 10000) .run(|db| async move { - // fake operation to batch them all at once + // fake operation to batch them all at once, inserting the latest data on appropriate table ModelSyncData::from_op(CRDTOperation { device_pub_id, model_id, @@ -185,35 +183,32 @@ async fn handle_crdt_updates( .map(|(k, (data, _))| (k.clone(), data.clone())) .collect(), ), - }) - .ok_or(Error::InvalidModelId(model_id))? + })? .exec(&db) .await?; - // need to only apply ops that haven't been filtered out - data.into_iter() - .map(|(field, (value, timestamp))| { - let record_id = record_id.clone(); - let db = &db; - - async move { - write_crdt_op_to_db( - &CRDTOperation { - device_pub_id, - model_id, - record_id, - timestamp, - data: CRDTOperationData::Update { field, value }, - }, - db, - ) - .await + let (fields_and_values, latest_timestamp) = data.into_iter().fold( + (BTreeMap::new(), NTP64::default()), + |(mut fields_and_values, mut latest_time_stamp), (field, (value, timestamp))| { + fields_and_values.insert(field, value); + if timestamp > latest_time_stamp { + latest_time_stamp = timestamp; } - }) - .collect::>() - .try_join() - .await - .map(|_| ()) + (fields_and_values, latest_time_stamp) + }, + ); + + write_crdt_op_to_db( + &CRDTOperation { + device_pub_id, + model_id, + record_id, + timestamp: latest_timestamp, + data: CRDTOperationData::Update(fields_and_values), + }, + &db, + ) + .await }) .await } @@ -244,8 +239,11 @@ async fn handle_crdt_create_and_updates( break; } - CRDTOperationData::Update { field, value } => { - data.insert(field.clone(), value.clone()); + CRDTOperationData::Update(fields_and_values) => { + for (field, value) in fields_and_values { + data.insert(field.clone(), value.clone()); + } + applied_ops.push(op); } } @@ -262,8 +260,7 @@ async fn handle_crdt_create_and_updates( record_id: record_id.clone(), timestamp, data: CRDTOperationData::Create(data), - }) - .ok_or(Error::InvalidModelId(model_id))? + })? .exec(&db) .await?; @@ -314,10 +311,7 @@ async fn handle_crdt_deletion( .with_timeout(30 * 10000) .with_max_wait(30 * 10000) .run(|db| async move { - ModelSyncData::from_op(op.clone()) - .ok_or(Error::InvalidModelId(model))? - .exec(&db) - .await?; + ModelSyncData::from_op(op.clone())?.exec(&db).await?; write_crdt_op_to_db(&op, &db).await }) diff --git a/core/crates/sync/src/lib.rs b/core/crates/sync/src/lib.rs index 903fb812b..56822509c 100644 --- a/core/crates/sync/src/lib.rs +++ b/core/crates/sync/src/lib.rs @@ -27,7 +27,10 @@ #![forbid(deprecated_in_future)] #![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)] -use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation}; +use sd_prisma::{ + prisma::{cloud_crdt_operation, crdt_operation}, + prisma_sync, +}; use sd_utils::uuid_to_bytes; use std::{collections::HashMap, sync::Arc}; @@ -66,6 +69,8 @@ pub enum Error { Deserialization(#[from] rmp_serde::decode::Error), #[error("database error: {0}")] Database(#[from] prisma_client_rust::QueryError), + #[error("PrismaSync error: {0}")] + PrismaSync(#[from] prisma_sync::Error), #[error("invalid model id: {0}")] InvalidModelId(ModelId), #[error("tried to write an empty operations list")] diff --git a/core/crates/sync/tests/lib.rs b/core/crates/sync/tests/lib.rs deleted file mode 100644 index 5c9dbf584..000000000 --- a/core/crates/sync/tests/lib.rs +++ /dev/null @@ -1,234 +0,0 @@ -// mod mock_instance; - -// use sd_core_sync::*; - -// use sd_prisma::{prisma::location, prisma_sync}; -// use sd_sync::*; -// use sd_utils::{msgpack, uuid_to_bytes}; - -// use mock_instance::Device; -// use tracing::info; -// use tracing_test::traced_test; -// use uuid::Uuid; - -// const MOCK_LOCATION_NAME: &str = "Location 0"; -// const MOCK_LOCATION_PATH: &str = "/User/Anon/Documents"; - -// async fn write_test_location(instance: &Device) -> location::Data { -// let location_pub_id = Uuid::new_v4(); - -// let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [ -// sync_db_entry!(MOCK_LOCATION_NAME, location::name), -// sync_db_entry!(MOCK_LOCATION_PATH, location::path), -// ] -// .into_iter() -// .unzip(); - -// let location = instance -// .sync -// .write_op( -// &instance.db, -// instance.sync.shared_create( -// prisma_sync::location::SyncId { -// pub_id: uuid_to_bytes(&location_pub_id), -// }, -// sync_ops, -// ), -// instance -// .db -// .location() -// .create(uuid_to_bytes(&location_pub_id), db_ops), -// ) -// .await -// .expect("failed to create mock location"); - -// instance -// .sync -// .write_ops(&instance.db, { -// let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [ -// sync_db_entry!(1024, location::total_capacity), -// sync_db_entry!(512, location::available_capacity), -// ] -// .into_iter() -// .unzip(); - -// ( -// sync_ops -// .into_iter() -// .map(|(k, v)| { -// instance.sync.shared_update( -// prisma_sync::location::SyncId { -// pub_id: uuid_to_bytes(&location_pub_id), -// }, -// k, -// v, -// ) -// }) -// .collect::>(), -// instance -// .db -// .location() -// .update(location::id::equals(location.id), db_ops), -// ) -// }) -// .await -// .expect("failed to create mock location"); - -// location -// } - -// #[tokio::test] -// #[traced_test] -// async fn writes_operations_and_rows_together() -> Result<(), Box> { -// let instance = Device::new(Uuid::new_v4()).await; - -// write_test_location(&instance).await; - -// let operations = instance -// .db -// .crdt_operation() -// .find_many(vec![]) -// .exec() -// .await?; - -// // 1 create, 2 update -// assert_eq!(operations.len(), 3); -// assert_eq!(operations[0].model, prisma_sync::location::MODEL_ID as i32); - -// let out = instance.sync.get_ops(100, vec![]).await?; - -// assert_eq!(out.len(), 3); - -// let locations = instance.db.location().find_many(vec![]).exec().await?; - -// assert_eq!(locations.len(), 1); -// let location = locations.first().unwrap(); -// assert_eq!(location.name.as_deref(), Some(MOCK_LOCATION_NAME)); -// assert_eq!(location.path.as_deref(), Some(MOCK_LOCATION_PATH)); - -// Ok(()) -// } - -// #[tokio::test] -// #[traced_test] -// async fn operations_send_and_ingest() -> Result<(), Box> { -// let instance1 = Device::new(Uuid::new_v4()).await; -// let instance2 = Device::new(Uuid::new_v4()).await; - -// let mut instance2_sync_rx = instance2.sync_rx.resubscribe(); - -// info!("Created instances!"); - -// Device::pair(&instance1, &instance2).await; - -// info!("Paired instances!"); - -// write_test_location(&instance1).await; - -// info!("Created mock location!"); - -// assert!(matches!( -// instance2_sync_rx.recv().await?, -// SyncEvent::Ingested -// )); - -// let out = instance2.sync.get_ops(100, vec![]).await?; - -// assert_locations_equality( -// &instance1.db.location().find_many(vec![]).exec().await?[0], -// &instance2.db.location().find_many(vec![]).exec().await?[0], -// ); - -// assert_eq!(out.len(), 3); - -// instance1.teardown().await; -// instance2.teardown().await; - -// Ok(()) -// } - -// #[tokio::test] -// async fn no_update_after_delete() -> Result<(), Box> { -// let instance1 = Device::new(Uuid::new_v4()).await; -// let instance2 = Device::new(Uuid::new_v4()).await; - -// let mut instance2_sync_rx = instance2.sync_rx.resubscribe(); - -// Device::pair(&instance1, &instance2).await; - -// let location = write_test_location(&instance1).await; - -// assert!(matches!( -// instance2_sync_rx.recv().await?, -// SyncEvent::Ingested -// )); - -// instance2 -// .sync -// .write_op( -// &instance2.db, -// instance2.sync.shared_delete(prisma_sync::location::SyncId { -// pub_id: location.pub_id.clone(), -// }), -// instance2.db.location().delete_many(vec![]), -// ) -// .await?; - -// assert!(matches!( -// instance1.sync_rx.resubscribe().recv().await?, -// SyncEvent::Ingested -// )); - -// instance1 -// .sync -// .write_op( -// &instance1.db, -// instance1.sync.shared_update( -// prisma_sync::location::SyncId { -// pub_id: location.pub_id.clone(), -// }, -// "name", -// msgpack!("New Location"), -// ), -// instance1.db.location().find_many(vec![]), -// ) -// .await?; - -// // one spare update operation that actually gets ignored by instance 2 -// assert_eq!(instance1.db.crdt_operation().count(vec![]).exec().await?, 5); -// assert_eq!(instance2.db.crdt_operation().count(vec![]).exec().await?, 4); - -// assert_eq!(instance1.db.location().count(vec![]).exec().await?, 0); -// // the whole point of the test - the update (which is ingested as an upsert) should be ignored -// assert_eq!(instance2.db.location().count(vec![]).exec().await?, 0); - -// instance1.teardown().await; -// instance2.teardown().await; - -// Ok(()) -// } - -// fn assert_locations_equality(l1: &location::Data, l2: &location::Data) { -// assert_eq!(l1.pub_id, l2.pub_id, "pub id"); -// assert_eq!(l1.name, l2.name, "name"); -// assert_eq!(l1.path, l2.path, "path"); -// assert_eq!(l1.total_capacity, l2.total_capacity, "total capacity"); -// assert_eq!( -// l1.available_capacity, l2.available_capacity, -// "available capacity" -// ); -// assert_eq!(l1.size_in_bytes, l2.size_in_bytes, "size in bytes"); -// assert_eq!(l1.is_archived, l2.is_archived, "is archived"); -// assert_eq!( -// l1.generate_preview_media, l2.generate_preview_media, -// "generate preview media" -// ); -// assert_eq!( -// l1.sync_preview_media, l2.sync_preview_media, -// "sync preview media" -// ); -// assert_eq!(l1.hidden, l2.hidden, "hidden"); -// assert_eq!(l1.date_created, l2.date_created, "date created"); -// assert_eq!(l1.scan_state, l2.scan_state, "scan state"); -// assert_eq!(l1.instance_id, l2.instance_id, "instance id"); -// } diff --git a/core/crates/sync/tests/mock_instance.rs b/core/crates/sync/tests/mock_instance.rs deleted file mode 100644 index 9dd5f1aff..000000000 --- a/core/crates/sync/tests/mock_instance.rs +++ /dev/null @@ -1,143 +0,0 @@ -// use sd_core_sync::*; - -// use sd_prisma::prisma; -// use sd_sync::CompressedCRDTOperationsPerModelPerDevice; - -// use std::sync::{atomic::AtomicBool, Arc}; - -// use tokio::{fs, spawn, sync::broadcast}; -// use tracing::{info, instrument, warn, Instrument}; -// use uuid::Uuid; - -// fn db_path(id: Uuid) -> String { -// format!("/tmp/test-{id}.db") -// } - -// #[derive(Clone)] -// pub struct Device { -// pub pub_id: DevicePubId, -// pub db: Arc, -// pub sync: Arc, -// pub sync_rx: Arc>, -// } - -// impl Device { -// pub async fn new(id: Uuid) -> Arc { -// let url = format!("file:{}", db_path(id)); -// let device_pub_id = DevicePubId::from(id); - -// let db = Arc::new( -// prisma::PrismaClient::_builder() -// .with_url(url.to_string()) -// .build() -// .await -// .unwrap(), -// ); - -// db._db_push().await.unwrap(); - -// db.device() -// .create(device_pub_id.to_db(), vec![]) -// .exec() -// .await -// .unwrap(); - -// // let (sync, sync_rx) = sd_core_sync::SyncManager::new( -// // Arc::clone(&db), -// // &device_pub_id, -// // Arc::new(AtomicBool::new(true)), -// // Default::default(), -// // ) -// // .await -// // .expect("failed to create sync manager"); - -// // Arc::new(Self { -// // pub_id: device_pub_id, -// // db, -// // sync: Arc::new(sync), -// // sync_rx: Arc::new(sync_rx), -// // }) -// } - -// pub async fn teardown(&self) { -// fs::remove_file(db_path(Uuid::from(&self.pub_id))) -// .await -// .unwrap(); -// } - -// pub async fn pair(instance1: &Arc, instance2: &Arc) { -// #[instrument(skip(left, right))] -// async fn half(left: &Arc, right: &Arc, context: &'static str) { -// left.db -// .device() -// .create(right.pub_id.to_db(), vec![]) -// .exec() -// .await -// .unwrap(); - -// spawn({ -// let mut sync_rx_left = left.sync_rx.resubscribe(); -// let right = Arc::clone(right); - -// async move { -// while let Ok(msg) = sync_rx_left.recv().await { -// info!(?msg, "sync_rx_left received message"); -// if matches!(msg, SyncEvent::Created) { -// right -// .sync -// .ingest -// .event_tx -// .send(ingest::Event::Notification) -// .await -// .unwrap(); -// info!("sent notification to instance 2"); -// } -// } -// } -// .in_current_span() -// }); - -// spawn({ -// let left = Arc::clone(left); -// let right = Arc::clone(right); - -// async move { -// while let Ok(msg) = right.sync.ingest.req_rx.recv().await { -// info!(?msg, "right instance received request"); -// match msg { -// ingest::Request::Messages { timestamps, tx } => { -// let messages = left.sync.get_ops(100, timestamps).await.unwrap(); - -// let ingest = &right.sync.ingest; - -// ingest -// .event_tx -// .send(ingest::Event::Messages(ingest::MessagesEvent { -// messages: CompressedCRDTOperationsPerModelPerDevice::new( -// messages, -// ), -// has_more: false, -// device_pub_id: left.pub_id.clone(), -// wait_tx: None, -// })) -// .await -// .unwrap(); - -// if tx.send(()).is_err() { -// warn!("failed to send ack to instance 1"); -// } -// } -// ingest::Request::FinishedIngesting => { -// right.sync.tx.send(SyncEvent::Ingested).unwrap(); -// } -// } -// } -// } -// .in_current_span() -// }); -// } - -// half(instance1, instance2, "instance1 -> instance2").await; -// half(instance2, instance1, "instance2 -> instance1").await; -// } -// } diff --git a/core/src/api/files.rs b/core/src/api/files.rs index 8e0f29992..155fd2884 100644 --- a/core/src/api/files.rs +++ b/core/src/api/files.rs @@ -28,8 +28,8 @@ use sd_prisma::{ prisma::{file_path, location, object}, prisma_sync, }; -use sd_sync::OperationFactory; -use sd_utils::{db::maybe_missing, error::FileIOError, msgpack}; +use sd_sync::{sync_db_entry, sync_db_nullable_entry, sync_entry, OperationFactory}; +use sd_utils::{db::maybe_missing, error::FileIOError}; use std::{ ffi::OsString, @@ -195,19 +195,19 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let (sync_param, db_param) = sync_db_nullable_entry!(args.note, object::note); + sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id, }, - object::note::NAME, - msgpack!(&args.note), - ), - db.object().update( - object::id::equals(args.id), - vec![object::note::set(args.note)], + [sync_param], ), + db.object() + .update(object::id::equals(args.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; @@ -241,19 +241,19 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let (sync_param, db_param) = sync_db_entry!(args.favorite, object::favorite); + sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id, }, - object::favorite::NAME, - msgpack!(&args.favorite), - ), - db.object().update( - object::id::equals(args.id), - vec![object::favorite::set(Some(args.favorite))], + [sync_param], ), + db.object() + .update(object::id::equals(args.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; @@ -346,19 +346,20 @@ pub(crate) fn mount() -> AlphaRouter { let date_accessed = Utc::now().into(); - let (ops, object_ids): (Vec<_>, Vec<_>) = objects + let (ops, object_ids) = objects .into_iter() - .map(|d| { + .map(|object| { ( sync.shared_update( - prisma_sync::object::SyncId { pub_id: d.pub_id }, - object::date_accessed::NAME, - msgpack!(date_accessed), + prisma_sync::object::SyncId { + pub_id: object.pub_id, + }, + [sync_entry!(date_accessed, object::date_accessed)], ), - d.id, + object.id, ) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); if !ops.is_empty() && !object_ids.is_empty() { sync.write_ops( @@ -392,19 +393,20 @@ pub(crate) fn mount() -> AlphaRouter { .exec() .await?; - let (ops, object_ids): (Vec<_>, Vec<_>) = objects + let (ops, object_ids) = objects .into_iter() - .map(|d| { + .map(|object| { ( sync.shared_update( - prisma_sync::object::SyncId { pub_id: d.pub_id }, - object::date_accessed::NAME, - msgpack!(nil), + prisma_sync::object::SyncId { + pub_id: object.pub_id, + }, + [sync_entry!(nil, object::date_accessed)], ), - d.id, + object.id, ) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); if !ops.is_empty() && !object_ids.is_empty() { sync.write_ops( @@ -487,11 +489,32 @@ pub(crate) fn mount() -> AlphaRouter { path = %full_path.display(), "File not found in the file system, will remove from database;", ); - library + + let file_path_pub_id = library .db .file_path() - .delete(file_path::id::equals(args.file_path_ids[0])) + .find_unique(file_path::id::equals(args.file_path_ids[0])) + .select(file_path::select!({ pub_id })) .exec() + .await? + .ok_or(LocationError::FilePath(FilePathError::IdNotFound( + args.file_path_ids[0], + )))? + .pub_id; + + library + .sync + .write_op( + &library.db, + library.sync.shared_delete( + prisma_sync::file_path::SyncId { + pub_id: file_path_pub_id, + }, + ), + library.db.file_path().delete(file_path::id::equals( + args.file_path_ids[0], + )), + ) .await .map_err(LocationError::from)?; diff --git a/core/src/api/labels.rs b/core/src/api/labels.rs index 9aaaf30e3..eed08d8d3 100644 --- a/core/src/api/labels.rs +++ b/core/src/api/labels.rs @@ -116,7 +116,7 @@ pub(crate) fn mount() -> AlphaRouter { .procedure( "delete", R.with2(library()) - .mutation(|(_, library), label_id: i32| async move { + .mutation(|(_, library), label_id: label::id::Type| async move { let Library { db, sync, .. } = library.as_ref(); let label = db @@ -131,6 +131,35 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let delete_ops = db + .label_on_object() + .find_many(vec![label_on_object::label_id::equals(label_id)]) + .select(label_on_object::select!({ object: select { pub_id } })) + .exec() + .await? + .into_iter() + .map(|label_on_object| { + sync.relation_delete(prisma_sync::label_on_object::SyncId { + label: prisma_sync::label::SyncId { + name: label.name.clone(), + }, + object: prisma_sync::object::SyncId { + pub_id: label_on_object.object.pub_id, + }, + }) + }) + .collect::>(); + + sync.write_ops( + db, + ( + delete_ops, + db.label_on_object() + .delete_many(vec![label_on_object::label_id::equals(label_id)]), + ), + ) + .await?; + sync.write_op( db, sync.shared_delete(prisma_sync::label::SyncId { name: label.name }), diff --git a/core/src/api/search/saved.rs b/core/src/api/search/saved.rs index 37dec602e..957474c49 100644 --- a/core/src/api/search/saved.rs +++ b/core/src/api/search/saved.rs @@ -69,7 +69,7 @@ pub(crate) fn mount() -> AlphaRouter { let pub_id = Uuid::now_v7().as_bytes().to_vec(); let date_created: DateTime = Utc::now().into(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = chain_optional_iter( + let (sync_params, db_params) = chain_optional_iter( [ sync_db_entry!(date_created, saved_search::date_created), sync_db_entry!(args.name, saved_search::name), @@ -96,7 +96,7 @@ pub(crate) fn mount() -> AlphaRouter { ], ) .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); sync.write_op( db, @@ -106,7 +106,9 @@ pub(crate) fn mount() -> AlphaRouter { }, sync_params, ), - db.saved_search().create(pub_id, db_params), + db.saved_search() + .create(pub_id, db_params) + .select(saved_search::select!({ id })), ) .await?; @@ -162,7 +164,7 @@ pub(crate) fn mount() -> AlphaRouter { rspc::Error::new(rspc::ErrorCode::NotFound, "search not found".into()) })?; - let (ops, db_params): (Vec<_>, Vec<_>) = chain_optional_iter( + let (sync_params, db_params) = chain_optional_iter( [sync_db_entry!(updated_at, saved_search::date_modified)], [ option_sync_db_entry!(args.name.flatten(), saved_search::name), @@ -173,34 +175,23 @@ pub(crate) fn mount() -> AlphaRouter { ], ) .into_iter() - .map(|((k, v), p)| { - ( - sync.shared_update( - prisma_sync::saved_search::SyncId { - pub_id: search.pub_id.clone(), - }, - k, - v, - ), - p, - ) - }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - if !ops.is_empty() && !db_params.is_empty() { - sync.write_ops( - db, - ( - ops, - db.saved_search() - .update_unchecked(saved_search::id::equals(id), db_params), - ), - ) - .await?; + sync.write_op( + db, + sync.shared_update( + prisma_sync::saved_search::SyncId { + pub_id: search.pub_id.clone(), + }, + sync_params, + ), + db.saved_search() + .update_unchecked(saved_search::id::equals(id), db_params), + ) + .await?; - invalidate_query!(library, "search.saved.list"); - invalidate_query!(library, "search.saved.get"); - } + invalidate_query!(library, "search.saved.list"); + invalidate_query!(library, "search.saved.get"); Ok(()) } diff --git a/core/src/api/tags.rs b/core/src/api/tags.rs index 0d71b848c..0035ea592 100644 --- a/core/src/api/tags.rs +++ b/core/src/api/tags.rs @@ -4,7 +4,7 @@ use sd_prisma::{ prisma::{device, file_path, object, tag, tag_on_object}, prisma_sync, }; -use sd_sync::{option_sync_db_entry, sync_entry, OperationFactory}; +use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, OperationFactory}; use std::collections::BTreeMap; @@ -286,13 +286,17 @@ pub(crate) fn mount() -> AlphaRouter { pub color: Option, } - R.with2(library()) - .mutation(|(_, library), args: TagUpdateArgs| async move { + R.with2(library()).mutation( + |(_, library), TagUpdateArgs { id, name, color }: TagUpdateArgs| async move { + if name.is_none() && color.is_none() { + return Ok(()); + } + let Library { sync, db, .. } = library.as_ref(); let tag = db .tag() - .find_unique(tag::id::equals(args.id)) + .find_unique(tag::id::equals(id)) .select(tag::select!({ pub_id })) .exec() .await? @@ -301,68 +305,88 @@ pub(crate) fn mount() -> AlphaRouter { "Error finding tag in db".into(), ))?; - db.tag() - .update( - tag::id::equals(args.id), - vec![tag::date_modified::set(Some(Utc::now().into()))], - ) - .exec() - .await?; - - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - option_sync_db_entry!(args.name, tag::name), - option_sync_db_entry!(args.color, tag::color), + let (sync_params, db_params) = [ + option_sync_db_entry!(name, tag::name), + option_sync_db_entry!(color, tag::color), + Some(sync_db_entry!(Utc::now(), tag::date_modified)), ] .into_iter() .flatten() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - if sync_params.is_empty() && db_params.is_empty() { - return Ok(()); - } - - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|(k, v)| { - sync.shared_update( - prisma_sync::tag::SyncId { - pub_id: tag.pub_id.clone(), - }, - k, - v, - ) - }) - .collect(), - db.tag().update(tag::id::equals(args.id), db_params), + sync.shared_update( + prisma_sync::tag::SyncId { + pub_id: tag.pub_id.clone(), + }, + sync_params, ), + db.tag() + .update(tag::id::equals(id), db_params) + .select(tag::select!({ id })), ) .await?; invalidate_query!(library, "tags.list"); Ok(()) - }) + }, + ) }) .procedure( "delete", R.with2(library()) - .mutation(|(_, library), tag_id: i32| async move { - library - .db - .tag_on_object() - .delete_many(vec![tag_on_object::tag_id::equals(tag_id)]) - .exec() - .await?; + .mutation(|(_, library), tag_id: tag::id::Type| async move { + let Library { sync, db, .. } = &*library; - library - .db + let tag_pub_id = db .tag() - .delete(tag::id::equals(tag_id)) + .find_unique(tag::id::equals(tag_id)) + .select(tag::select!({ pub_id })) .exec() - .await?; + .await? + .ok_or(rspc::Error::new( + rspc::ErrorCode::NotFound, + "Tag not found".to_string(), + ))? + .pub_id; + + let delete_ops = db + .tag_on_object() + .find_many(vec![tag_on_object::tag_id::equals(tag_id)]) + .select(tag_on_object::select!({ object: select { pub_id } })) + .exec() + .await? + .into_iter() + .map(|tag_on_object| { + sync.relation_delete(prisma_sync::tag_on_object::SyncId { + tag: prisma_sync::tag::SyncId { + pub_id: tag_pub_id.clone(), + }, + object: prisma_sync::object::SyncId { + pub_id: tag_on_object.object.pub_id, + }, + }) + }) + .collect::>(); + + sync.write_ops( + db, + ( + delete_ops, + db.tag_on_object() + .delete_many(vec![tag_on_object::tag_id::equals(tag_id)]), + ), + ) + .await?; + + sync.write_op( + db, + sync.shared_delete(prisma_sync::tag::SyncId { pub_id: tag_pub_id }), + db.tag().delete(tag::id::equals(tag_id)), + ) + .await?; invalidate_query!(library, "tags.list"); diff --git a/core/src/library/manager/mod.rs b/core/src/library/manager/mod.rs index 7649576e6..06aa0b8c9 100644 --- a/core/src/library/manager/mod.rs +++ b/core/src/library/manager/mod.rs @@ -537,6 +537,7 @@ impl Libraries { )), ], ) + .select(instance::select!({ id })) .exec() .await?; } diff --git a/core/src/location/manager/watcher/android.rs b/core/src/location/manager/watcher/android.rs index 01bd8a2a1..723f2e076 100644 --- a/core/src/location/manager/watcher/android.rs +++ b/core/src/location/manager/watcher/android.rs @@ -27,6 +27,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -40,9 +41,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self { + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self + where + Self: Sized, + { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -182,6 +192,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/ios.rs b/core/src/location/manager/watcher/ios.rs index 3a9c91500..25f0a49fd 100644 --- a/core/src/location/manager/watcher/ios.rs +++ b/core/src/location/manager/watcher/ios.rs @@ -33,6 +33,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -48,12 +49,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -183,6 +190,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/linux.rs b/core/src/location/manager/watcher/linux.rs index 0ec459a3c..34d37ed15 100644 --- a/core/src/location/manager/watcher/linux.rs +++ b/core/src/location/manager/watcher/linux.rs @@ -32,6 +32,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -45,9 +46,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self { + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self + where + Self: Sized, + { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -187,6 +197,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/macos.rs b/core/src/location/manager/watcher/macos.rs index 11486cd20..4d3b1ffec 100644 --- a/core/src/location/manager/watcher/macos.rs +++ b/core/src/location/manager/watcher/macos.rs @@ -42,6 +42,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -57,12 +58,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -206,6 +213,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/mod.rs b/core/src/location/manager/watcher/mod.rs index 81b70ef87..d63709740 100644 --- a/core/src/location/manager/watcher/mod.rs +++ b/core/src/location/manager/watcher/mod.rs @@ -4,7 +4,7 @@ use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::{location_ids_and_path, location_with_indexer_rules}; use sd_prisma::prisma::{location, PrismaClient}; -use sd_utils::db::maybe_missing; +use sd_utils::{db::maybe_missing, uuid_to_bytes}; use std::{ collections::HashSet, @@ -76,7 +76,12 @@ const THIRTY_SECONDS: Duration = Duration::from_secs(30); const HUNDRED_MILLIS: Duration = Duration::from_millis(100); trait EventHandler: 'static { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized; @@ -200,7 +205,12 @@ impl LocationWatcher { Stop, } - let mut event_handler = Handler::new(location_id, Arc::clone(&library), Arc::clone(&node)); + let mut event_handler = Handler::new( + location_id, + uuid_to_bytes(&location_pub_id), + Arc::clone(&library), + Arc::clone(&node), + ); let mut last_event_at = Instant::now(); diff --git a/core/src/location/manager/watcher/utils.rs b/core/src/location/manager/watcher/utils.rs index 0adaf9f8c..88b065810 100644 --- a/core/src/location/manager/watcher/utils.rs +++ b/core/src/location/manager/watcher/utils.rs @@ -27,7 +27,9 @@ use sd_core_indexer_rules::{ seed::{GitIgnoreRules, GITIGNORE}, IndexerRuler, RulerDecision, }; -use sd_core_prisma_helpers::{file_path_with_object, object_ids, CasId, ObjectPubId}; +use sd_core_prisma_helpers::{ + file_path_watcher_remove, file_path_with_object, object_ids, CasId, ObjectPubId, +}; use sd_file_ext::{ extensions::{AudioExtension, ImageExtension, VideoExtension}, @@ -37,11 +39,11 @@ use sd_prisma::{ prisma::{device, file_path, location, object}, prisma_sync, }; -use sd_sync::{sync_entry, OperationFactory}; +use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, OperationFactory}; use sd_utils::{ - db::{inode_from_db, inode_to_db, maybe_missing}, + chain_optional_iter, + db::{inode_from_db, inode_to_db, maybe_missing, size_in_bytes_to_db}, error::FileIOError, - msgpack, }; #[cfg(target_family = "unix")] @@ -354,32 +356,32 @@ async fn inner_create_file( let device_pub_id = sync.device_pub_id.to_db(); + let (sync_params, db_params) = [ + sync_db_entry!(date_created, object::date_created), + sync_db_entry!(int_kind, object::kind), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + object::device + ), + object::device::connect(device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + sync.write_op( db, sync.shared_create( prisma_sync::object::SyncId { pub_id: pub_id.to_db(), }, - [ - sync_entry!(date_created, object::date_created), - sync_entry!(int_kind, object::kind), - sync_entry!( - prisma_sync::device::SyncId { - pub_id: device_pub_id.clone() - }, - object::device - ), - ], + sync_params, ), db.object() - .create( - pub_id.into(), - vec![ - object::date_created::set(Some(date_created)), - object::kind::set(Some(int_kind)), - object::device::connect(device::pub_id::equals(device_pub_id)), - ], - ) + .create(pub_id.into(), db_params) .select(object_ids::select()), ) .await? @@ -391,17 +393,21 @@ async fn inner_create_file( prisma_sync::location::SyncId { pub_id: created_file.pub_id.clone(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: object_pub_id.clone() - }), - ), - db.file_path().update( - file_path::pub_id::equals(created_file.pub_id.clone()), - vec![file_path::object::connect(object::pub_id::equals( - object_pub_id.clone(), - ))], + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: object_pub_id.clone() + }, + file_path::object + )], ), + db.file_path() + .update( + file_path::pub_id::equals(created_file.pub_id.clone()), + vec![file_path::object::connect(object::pub_id::equals( + object_pub_id.clone(), + ))], + ) + .select(file_path::select!({ id })), ) .await?; @@ -590,34 +596,22 @@ async fn inner_update_file( let is_hidden = path_is_hidden(full_path, &fs_metadata); if file_path.cas_id.as_deref() != cas_id.as_ref().map(CasId::as_str) { - let (sync_params, db_params): (Vec<_>, Vec<_>) = { - use file_path::*; - + let (sync_params, db_params) = chain_optional_iter( [ - ( - (cas_id::NAME, msgpack!(file_path.cas_id)), - Some(cas_id::set(file_path.cas_id.clone())), + sync_db_entry!( + size_in_bytes_to_db(fs_metadata.len()), + file_path::size_in_bytes_bytes ), - ( - ( - size_in_bytes_bytes::NAME, - msgpack!(fs_metadata.len().to_be_bytes().to_vec()), - ), - Some(size_in_bytes_bytes::set(Some( - fs_metadata.len().to_be_bytes().to_vec(), - ))), + sync_db_entry!( + DateTime::::from(fs_metadata.modified_or_now()), + file_path::date_modified ), - { - let date = DateTime::::from(fs_metadata.modified_or_now()).into(); - - ( - (date_modified::NAME, msgpack!(date)), - Some(date_modified::set(Some(date))), - ) - }, - { - // TODO: Should this be a skip rather than a null-set? - let checksum = if file_path.integrity_checksum.is_some() { + ], + [ + option_sync_db_entry!(file_path.cas_id.clone(), file_path::cas_id), + option_sync_db_entry!( + if file_path.integrity_checksum.is_some() { + // TODO: Should this be a skip rather than a null-set? // If a checksum was already computed, we need to recompute it Some( file_checksum(full_path) @@ -626,68 +620,39 @@ async fn inner_update_file( ) } else { None - }; - - ( - (integrity_checksum::NAME, msgpack!(checksum)), - Some(integrity_checksum::set(checksum)), - ) - }, - { - if current_inode != inode { - ( - (inode::NAME, msgpack!(inode)), - Some(inode::set(Some(inode_to_db(inode)))), - ) - } else { - ((inode::NAME, msgpack!(nil)), None) - } - }, - { - if is_hidden != file_path.hidden.unwrap_or_default() { - ( - (hidden::NAME, msgpack!(inode)), - Some(hidden::set(Some(is_hidden))), - ) - } else { - ((hidden::NAME, msgpack!(nil)), None) - } - }, - ] - .into_iter() - .filter_map(|(sync_param, maybe_db_param)| { - maybe_db_param.map(|db_param| (sync_param, db_param)) - }) - .unzip() - }; - - let ops = sync_params - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), }, - field, - value, - ) - }) - .collect::>(); - - if !ops.is_empty() && !db_params.is_empty() { - // file content changed - sync.write_ops( - db, - ( - ops, - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - db_params, - ), + file_path::integrity_checksum ), - ) - .await?; - } + option_sync_db_entry!( + (current_inode != inode).then(|| inode_to_db(inode)), + file_path::inode + ), + option_sync_db_entry!( + (is_hidden != file_path.hidden.unwrap_or_default()).then_some(is_hidden), + file_path::hidden + ), + ], + ) + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + // file content changed + sync.write_op( + db, + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + sync_params, + ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + db_params, + ) + .select(file_path::select!({ id })), + ) + .await?; if let Some(ref object) = file_path.object { let int_kind = kind as i32; @@ -699,19 +664,18 @@ async fn inner_update_file( .await? == 1 { if object.kind.map(|k| k != int_kind).unwrap_or_default() { + let (sync_param, db_param) = sync_db_entry!(int_kind, object::kind); sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id.clone(), }, - object::kind::NAME, - msgpack!(int_kind), - ), - db.object().update( - object::id::equals(object.id), - vec![object::kind::set(Some(int_kind))], + [sync_param], ), + db.object() + .update(object::id::equals(object.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; } @@ -722,31 +686,31 @@ async fn inner_update_file( let device_pub_id = sync.device_pub_id.to_db(); + let (sync_params, db_params) = [ + sync_db_entry!(date_created, object::date_created), + sync_db_entry!(int_kind, object::kind), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + object::device + ), + object::device::connect(device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + sync.write_op( db, sync.shared_create( prisma_sync::object::SyncId { pub_id: pub_id.to_db(), }, - [ - sync_entry!(date_created, object::date_created), - sync_entry!(int_kind, object::kind), - sync_entry!( - prisma_sync::device::SyncId { - pub_id: device_pub_id.clone() - }, - object::device - ), - ], - ), - db.object().create( - pub_id.to_db(), - vec![ - object::date_created::set(Some(date_created)), - object::kind::set(Some(int_kind)), - object::device::connect(device::pub_id::equals(device_pub_id)), - ], + sync_params, ), + db.object().create(pub_id.to_db(), db_params), ) .await?; @@ -756,17 +720,21 @@ async fn inner_update_file( prisma_sync::location::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: pub_id.to_db() - }), - ), - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::object::connect(object::pub_id::equals( - pub_id.into(), - ))], + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: pub_id.to_db() + }, + file_path::object + )], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![file_path::object::connect(object::pub_id::equals( + pub_id.into(), + ))], + ) + .select(file_path::select!({ id })), ) .await?; } @@ -874,21 +842,22 @@ async fn inner_update_file( invalidate_query!(library, "search.paths"); invalidate_query!(library, "search.objects"); } else if is_hidden != file_path.hidden.unwrap_or_default() { - sync.write_ops( + let (sync_param, db_param) = sync_db_entry!(is_hidden, file_path::hidden); + + sync.write_op( db, - ( - vec![sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), - }, - file_path::hidden::NAME, - msgpack!(is_hidden), - )], - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::hidden::set(Some(is_hidden))], - ), + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + [sync_param], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![db_param], + ) + .select(file_path::select!({ id })), ) .await?; @@ -972,7 +941,7 @@ pub(super) async fn rename( .await?; let total_paths_count = paths.len(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = paths + let (sync_params, db_params) = paths .into_iter() .filter_map(|path| path.materialized_path.map(|mp| (path.id, path.pub_id, mp))) .map(|(id, pub_id, mp)| { @@ -981,19 +950,20 @@ pub(super) async fn rename( &format!("{}/{}/", new_parts.materialized_path, new_parts.name), ); + let (sync_param, db_param) = + sync_db_entry!(new_path, file_path::materialized_path); + ( sync.shared_update( sd_prisma::prisma_sync::file_path::SyncId { pub_id }, - file_path::materialized_path::NAME, - msgpack!(&new_path), - ), - db.file_path().update( - file_path::id::equals(id), - vec![file_path::materialized_path::set(Some(new_path))], + [sync_param], ), + db.file_path() + .update(file_path::id::equals(id), vec![db_param]) + .select(file_path::select!({ id })), ) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); if !sync_params.is_empty() && !db_params.is_empty() { sync.write_ops(db, (sync_params, db_params)).await?; @@ -1002,65 +972,38 @@ pub(super) async fn rename( trace!(%total_paths_count, "Updated file_paths;"); } - let is_hidden = path_is_hidden(new_path, &new_path_metadata); - - let date_modified = DateTime::::from(new_path_metadata.modified_or_now()).into(); - - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - ( - ( - file_path::materialized_path::NAME, - msgpack!(new_path_materialized_str), - ), - file_path::materialized_path::set(Some(new_path_materialized_str)), + let (sync_params, db_params) = [ + sync_db_entry!(new_path_materialized_str, file_path::materialized_path), + sync_db_entry!(new_parts.name.to_string(), file_path::name), + sync_db_entry!(new_parts.extension.to_string(), file_path::extension), + sync_db_entry!( + DateTime::::from(new_path_metadata.modified_or_now()), + file_path::date_modified ), - ( - (file_path::name::NAME, msgpack!(new_parts.name)), - file_path::name::set(Some(new_parts.name.to_string())), - ), - ( - (file_path::extension::NAME, msgpack!(new_parts.extension)), - file_path::extension::set(Some(new_parts.extension.to_string())), - ), - ( - (file_path::date_modified::NAME, msgpack!(&date_modified)), - file_path::date_modified::set(Some(date_modified)), - ), - ( - (file_path::hidden::NAME, msgpack!(is_hidden)), - file_path::hidden::set(Some(is_hidden)), + sync_db_entry!( + path_is_hidden(new_path, &new_path_metadata), + file_path::hidden ), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - let ops = sync_params - .into_iter() - .map(|(k, v)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), - }, - k, - v, - ) - }) - .collect::>(); + sync.write_op( + db, + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + sync_params, + ), + db.file_path() + .update(file_path::pub_id::equals(file_path.pub_id), db_params) + .select(file_path::select!({ id })), + ) + .await?; - if !ops.is_empty() && !db_params.is_empty() { - sync.write_ops( - db, - ( - ops, - db.file_path() - .update(file_path::pub_id::equals(file_path.pub_id), db_params), - ), - ) - .await?; - - invalidate_query!(library, "search.paths"); - invalidate_query!(library, "search.objects"); - } + invalidate_query!(library, "search.paths"); + invalidate_query!(library, "search.objects"); } Ok(()) @@ -1084,19 +1027,20 @@ pub(super) async fn remove( &location_path, full_path, )?) + .select(file_path_watcher_remove::select()) .exec() .await? else { return Ok(()); }; - remove_by_file_path(location_id, full_path, &file_path, library).await + remove_by_file_path(location_id, full_path, file_path, library).await } async fn remove_by_file_path( location_id: location::id::Type, path: impl AsRef + Send, - file_path: &file_path::Data, + file_path: file_path_watcher_remove::Data, library: &Library, ) -> Result<(), LocationManagerError> { // check file still exists on disk @@ -1120,28 +1064,42 @@ async fn remove_by_file_path( delete_directory( library, location_id, - Some(&IsolatedFilePathData::try_from(file_path)?), + Some(&IsolatedFilePathData::try_from(&file_path)?), ) .await?; } else { sync.write_op( db, sync.shared_delete(prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), + pub_id: file_path.pub_id, }), db.file_path().delete(file_path::id::equals(file_path.id)), ) .await?; - if let Some(object_id) = file_path.object_id { - db.object() - .delete_many(vec![ - object::id::equals(object_id), + if let Some(object) = file_path.object { + // If this object doesn't have any other file paths, delete it + if db + .object() + .count(vec![ + object::id::equals(object.id), // https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#none object::file_paths::none(vec![]), ]) .exec() + .await? == 1 + { + sync.write_op( + db, + sync.shared_delete(prisma_sync::object::SyncId { + pub_id: object.pub_id, + }), + db.object() + .delete(object::id::equals(object.id)) + .select(object::select!({ id })), + ) .await?; + } } } } @@ -1210,6 +1168,7 @@ pub(super) async fn recalculate_directories_size( candidates: &mut HashMap, buffer: &mut Vec<(PathBuf, Instant)>, location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: &Library, ) -> Result<(), LocationManagerError> { let mut location_path_cache = None; @@ -1268,7 +1227,7 @@ pub(super) async fn recalculate_directories_size( } if should_update_location_size { - update_location_size(location_id, library).await?; + update_location_size(location_id, location_pub_id, library).await?; } if should_invalidate { diff --git a/core/src/location/manager/watcher/windows.rs b/core/src/location/manager/watcher/windows.rs index a9b24c54c..bd85693e8 100644 --- a/core/src/location/manager/watcher/windows.rs +++ b/core/src/location/manager/watcher/windows.rs @@ -39,6 +39,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -54,12 +55,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -277,6 +284,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/mod.rs b/core/src/location/mod.rs index a1fd20073..a4a998995 100644 --- a/core/src/location/mod.rs +++ b/core/src/location/mod.rs @@ -18,9 +18,9 @@ use sd_prisma::{ }; use sd_sync::*; use sd_utils::{ - db::{maybe_missing, size_in_bytes_to_db, MissingFieldError}, + db::{maybe_missing, size_in_bytes_from_db, size_in_bytes_to_db}, error::{FileIOError, NonUtf8PathError}, - msgpack, uuid_to_bytes, + uuid_to_bytes, }; use std::{ @@ -304,63 +304,36 @@ impl LocationUpdateArgs { let name = self.name.clone(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - self.name - .filter(|name| location.name.as_ref() != Some(name)) - .map(|v| { - ( - (location::name::NAME, msgpack!(v)), - location::name::set(Some(v)), - ) - }), - self.generate_preview_media.map(|v| { - ( - (location::generate_preview_media::NAME, msgpack!(v)), - location::generate_preview_media::set(Some(v)), - ) - }), - self.sync_preview_media.map(|v| { - ( - (location::sync_preview_media::NAME, msgpack!(v)), - location::sync_preview_media::set(Some(v)), - ) - }), - self.hidden.map(|v| { - ( - (location::hidden::NAME, msgpack!(v)), - location::hidden::set(Some(v)), - ) - }), - self.path.clone().map(|v| { - ( - (location::path::NAME, msgpack!(v)), - location::path::set(Some(v)), - ) - }), + let (sync_params, db_params) = [ + option_sync_db_entry!( + self.name + .filter(|name| location.name.as_ref() != Some(name)), + location::name + ), + option_sync_db_entry!( + self.generate_preview_media, + location::generate_preview_media + ), + option_sync_db_entry!(self.sync_preview_media, location::sync_preview_media), + option_sync_db_entry!(self.hidden, location::hidden), + option_sync_db_entry!(self.path.clone(), location::path), ] .into_iter() .flatten() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); if !sync_params.is_empty() { - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|p| { - sync.shared_update( - prisma_sync::location::SyncId { - pub_id: location.pub_id.clone(), - }, - p.0, - p.1, - ) - }) - .collect(), - db.location() - .update(location::id::equals(self.id), db_params), + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + sync_params, ), + db.location() + .update(location::id::equals(self.id), db_params) + .select(location::select!({ id })), ) .await?; @@ -651,33 +624,25 @@ pub async fn relink_location( .map(str::to_string) .ok_or_else(|| NonUtf8PathError(location_path.into()))?; - sync.write_op( - db, - sync.shared_update( - prisma_sync::location::SyncId { - pub_id: pub_id.clone(), - }, - location::path::NAME, - msgpack!(path), - ), - db.location().update( - location::pub_id::equals(pub_id.clone()), - vec![location::path::set(Some(path))], - ), - ) - .await?; + let (sync_param, db_param) = sync_db_entry!(path, location::path); - let location_id = db - .location() - .find_unique(location::pub_id::equals(pub_id)) - .select(location::select!({ id })) - .exec() + let location_id = sync + .write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: pub_id.clone(), + }, + [sync_param], + ), + db.location() + .update(location::pub_id::equals(pub_id.clone()), vec![db_param]) + .select(location::select!({ id })), + ) .await? - .ok_or_else(|| { - LocationError::MissingField(MissingFieldError::new("missing id of location")) - })?; + .id; - Ok(location_id.id) + Ok(location_id) } #[derive(Debug)] @@ -1002,45 +967,44 @@ async fn check_nested_location( #[instrument(skip_all, err)] pub async fn update_location_size( location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: &Library, -) -> Result<(), QueryError> { - let Library { db, .. } = library; +) -> Result<(), sd_core_sync::Error> { + let Library { db, sync, .. } = library; - let total_size = db - .file_path() - .find_many(vec![ - file_path::location_id::equals(Some(location_id)), - file_path::materialized_path::equals(Some("/".to_string())), - ]) - .select(file_path::select!({ size_in_bytes_bytes })) - .exec() - .await? - .into_iter() - .filter_map(|file_path| { - file_path.size_in_bytes_bytes.map(|size_in_bytes_bytes| { - u64::from_be_bytes([ - size_in_bytes_bytes[0], - size_in_bytes_bytes[1], - size_in_bytes_bytes[2], - size_in_bytes_bytes[3], - size_in_bytes_bytes[4], - size_in_bytes_bytes[5], - size_in_bytes_bytes[6], - size_in_bytes_bytes[7], - ]) + let total_size = size_in_bytes_to_db( + db.file_path() + .find_many(vec![ + file_path::location_id::equals(Some(location_id)), + file_path::materialized_path::equals(Some("/".to_string())), + ]) + .select(file_path::select!({ size_in_bytes_bytes })) + .exec() + .await? + .into_iter() + .filter_map(|file_path| { + file_path + .size_in_bytes_bytes + .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) }) - }) - .sum::(); + .sum::(), + ); - db.location() - .update( - location::id::equals(location_id), - vec![location::size_in_bytes::set(Some( - total_size.to_be_bytes().to_vec(), - ))], - ) - .exec() - .await?; + let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location_pub_id, + }, + [sync_param], + ), + db.location() + .update(location::id::equals(location_id), vec![db_param]) + .select(location::select!({ id })), + ) + .await?; invalidate_query!(library, "locations.list"); invalidate_query!(library, "locations.get"); @@ -1100,69 +1064,60 @@ pub async fn create_file_path( location_id, ))?; - let (sync_params, db_params): (Vec<_>, Vec<_>) = { - use file_path::{ - cas_id, date_created, date_indexed, date_modified, device, extension, hidden, inode, - is_dir, location, materialized_path, name, size_in_bytes_bytes, - }; + let device_pub_id = sync.device_pub_id.to_db(); - let device_pub_id = sync.device_pub_id.to_db(); - - [ - ( - sync_entry!( - prisma_sync::location::SyncId { - pub_id: location.pub_id - }, - location - ), - location::connect(prisma::location::id::equals(location.id)), + let (sync_params, db_params) = [ + ( + sync_entry!( + prisma_sync::location::SyncId { + pub_id: location.pub_id + }, + file_path::location ), - ( - sync_entry!(cas_id, cas_id), - cas_id::set(cas_id.map(Into::into)), + file_path::location::connect(prisma::location::id::equals(location.id)), + ), + ( + sync_entry!(cas_id, file_path::cas_id), + file_path::cas_id::set(cas_id.map(Into::into)), + ), + sync_db_entry!(materialized_path, file_path::materialized_path), + sync_db_entry!(name, file_path::name), + sync_db_entry!(extension, file_path::extension), + sync_db_entry!( + size_in_bytes_to_db(metadata.size_in_bytes), + file_path::size_in_bytes_bytes + ), + sync_db_entry!(inode_to_db(metadata.inode), file_path::inode), + sync_db_entry!(is_dir, file_path::is_dir), + sync_db_entry!(metadata.created_at, file_path::date_created), + sync_db_entry!(metadata.modified_at, file_path::date_modified), + sync_db_entry!(indexed_at, file_path::date_indexed), + sync_db_entry!(metadata.hidden, file_path::hidden), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + file_path::device ), - sync_db_entry!(materialized_path, materialized_path), - sync_db_entry!(name, name), - sync_db_entry!(extension, extension), - sync_db_entry!( - size_in_bytes_to_db(metadata.size_in_bytes), - size_in_bytes_bytes - ), - sync_db_entry!(inode_to_db(metadata.inode), inode), - sync_db_entry!(is_dir, is_dir), - sync_db_entry!(metadata.created_at, date_created), - sync_db_entry!(metadata.modified_at, date_modified), - sync_db_entry!(indexed_at, date_indexed), - sync_db_entry!(metadata.hidden, hidden), - ( - sync_entry!( - prisma_sync::device::SyncId { - pub_id: device_pub_id.clone() - }, - device - ), - device::connect(prisma::device::pub_id::equals(device_pub_id)), - ), - ] - .into_iter() - .unzip() - }; + file_path::device::connect(prisma::device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); let pub_id = sd_utils::uuid_to_bytes(&Uuid::now_v7()); - let created_path = sync - .write_op( - db, - sync.shared_create( - prisma_sync::file_path::SyncId { - pub_id: pub_id.clone(), - }, - sync_params, - ), - db.file_path().create(pub_id, db_params), - ) - .await?; - - Ok(created_path) + sync.write_op( + db, + sync.shared_create( + prisma_sync::file_path::SyncId { + pub_id: pub_id.clone(), + }, + sync_params, + ), + db.file_path().create(pub_id, db_params), + ) + .await + .map_err(Into::into) } diff --git a/core/src/object/fs/old_copy.rs b/core/src/object/fs/old_copy.rs index 8b760b920..2d7b0fb70 100644 --- a/core/src/object/fs/old_copy.rs +++ b/core/src/object/fs/old_copy.rs @@ -323,8 +323,8 @@ impl StatefulJob for OldFileCopierJobInit { .await?; dirs.extend(more_dirs); - let (dir_source_file_data, dir_target_full_path): (Vec<_>, Vec<_>) = - dirs.into_iter().unzip(); + let (dir_source_file_data, dir_target_full_path) = + dirs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let step_files = dir_source_file_data .into_iter() diff --git a/core/src/object/tag/mod.rs b/core/src/object/tag/mod.rs index 98238462b..34b609a83 100644 --- a/core/src/object/tag/mod.rs +++ b/core/src/object/tag/mod.rs @@ -23,14 +23,14 @@ impl TagCreateArgs { ) -> Result { let pub_id = Uuid::now_v7().as_bytes().to_vec(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ + let (sync_params, db_params) = [ sync_db_entry!(self.name, tag::name), sync_db_entry!(self.color, tag::color), sync_db_entry!(false, tag::is_hidden), sync_db_entry!(Utc::now(), tag::date_created), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); sync.write_op( db, diff --git a/core/src/object/validation/old_validator_job.rs b/core/src/object/validation/old_validator_job.rs index d90fc56cb..7ddd42938 100644 --- a/core/src/object/validation/old_validator_job.rs +++ b/core/src/object/validation/old_validator_job.rs @@ -15,8 +15,8 @@ use sd_prisma::{ prisma::{file_path, location}, prisma_sync, }; -use sd_sync::OperationFactory; -use sd_utils::{db::maybe_missing, error::FileIOError, msgpack}; +use sd_sync::{sync_db_entry, OperationFactory}; +use sd_utils::{db::maybe_missing, error::FileIOError}; use std::{ hash::{Hash, Hasher}, @@ -157,19 +157,22 @@ impl StatefulJob for OldObjectValidatorJobInit { .await .map_err(|e| ValidatorError::FileIO(FileIOError::from((full_path, e))))?; + let (sync_param, db_param) = sync_db_entry!(checksum, file_path::integrity_checksum); + sync.write_op( db, sync.shared_update( prisma_sync::file_path::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::integrity_checksum::NAME, - msgpack!(&checksum), - ), - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::integrity_checksum::set(Some(checksum))], + [sync_param], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![db_param], + ) + .select(file_path::select!({ id })), ) .await?; } diff --git a/core/src/old_job/manager.rs b/core/src/old_job/manager.rs index f47164759..c9e5cc892 100644 --- a/core/src/old_job/manager.rs +++ b/core/src/old_job/manager.rs @@ -320,6 +320,7 @@ impl OldJobs { job::id::equals(job.id.as_bytes().to_vec()), vec![job::status::set(Some(JobStatus::Canceled as i32))], ) + .select(job::select!({ id })) .exec() .await?; } diff --git a/core/src/old_job/report.rs b/core/src/old_job/report.rs index ed40df23d..af7333267 100644 --- a/core/src/old_job/report.rs +++ b/core/src/old_job/report.rs @@ -395,6 +395,7 @@ impl OldJobReport { job::date_completed::set(self.completed_at.map(Into::into)), ], ) + .select(job::select!({ id })) .exec() .await?; Ok(()) diff --git a/core/src/volume/mod.rs b/core/src/volume/mod.rs index 50d9d1b15..9519d639a 100644 --- a/core/src/volume/mod.rs +++ b/core/src/volume/mod.rs @@ -7,7 +7,7 @@ use sd_prisma::{ prisma::{device, storage_statistics, PrismaClient}, prisma_sync, }; -use sd_sync::{sync_entry, OperationFactory}; +use sd_sync::{sync_db_not_null_entry, sync_entry, OperationFactory}; use sd_utils::uuid_to_bytes; use std::{ @@ -531,67 +531,66 @@ async fn update_storage_statistics( .map(|s| s.pub_id); if let Some(storage_statistics_pub_id) = storage_statistics_pub_id { - sync.write_ops( - db, - ( - [ - sync_entry!(total_capacity, storage_statistics::total_capacity), - sync_entry!(available_capacity, storage_statistics::available_capacity), - ] - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::storage_statistics::SyncId { - pub_id: storage_statistics_pub_id.clone(), - }, - field, - value, - ) - }) - .collect(), - db.storage_statistics() - .update( - storage_statistics::pub_id::equals(storage_statistics_pub_id), - vec![ - storage_statistics::total_capacity::set(total_capacity as i64), - storage_statistics::available_capacity::set(available_capacity as i64), - ], - ) - // We don't need any data here, just the id avoids receiving the entire object - // as we can't pass an empty select macro call - .select(storage_statistics::select!({ id })), + let (sync_params, db_params) = [ + sync_db_not_null_entry!(total_capacity as i64, storage_statistics::total_capacity), + sync_db_not_null_entry!( + available_capacity as i64, + storage_statistics::available_capacity ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::storage_statistics::SyncId { + pub_id: storage_statistics_pub_id.clone(), + }, + sync_params, + ), + db.storage_statistics() + .update( + storage_statistics::pub_id::equals(storage_statistics_pub_id), + db_params, + ) + // We don't need any data here, just the id avoids receiving the entire object + // as we can't pass an empty select macro call + .select(storage_statistics::select!({ id })), ) .await?; } else { let new_storage_statistics_id = uuid_to_bytes(&Uuid::now_v7()); + let (sync_params, db_params) = [ + sync_db_not_null_entry!(total_capacity as i64, storage_statistics::total_capacity), + sync_db_not_null_entry!( + available_capacity as i64, + storage_statistics::available_capacity + ), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + storage_statistics::device + ), + storage_statistics::device::connect(device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + sync.write_op( db, sync.shared_create( prisma_sync::storage_statistics::SyncId { pub_id: new_storage_statistics_id.clone(), }, - [ - sync_entry!(total_capacity, storage_statistics::total_capacity), - sync_entry!(available_capacity, storage_statistics::available_capacity), - sync_entry!( - prisma_sync::device::SyncId { - pub_id: device_pub_id.clone() - }, - storage_statistics::device - ), - ], + sync_params, ), db.storage_statistics() - .create( - new_storage_statistics_id, - vec![ - storage_statistics::total_capacity::set(total_capacity as i64), - storage_statistics::available_capacity::set(available_capacity as i64), - storage_statistics::device::connect(device::pub_id::equals(device_pub_id)), - ], - ) + .create(new_storage_statistics_id, db_params) // We don't need any data here, just the id avoids receiving the entire object // as we can't pass an empty select macro call .select(storage_statistics::select!({ id })), diff --git a/crates/crypto/src/cloud/decrypt.rs b/crates/crypto/src/cloud/decrypt.rs index c45a99110..94913f64b 100644 --- a/crates/crypto/src/cloud/decrypt.rs +++ b/crates/crypto/src/cloud/decrypt.rs @@ -31,7 +31,7 @@ impl OneShotDecryption for SecretKey { EncryptedBlockRef { nonce, cipher_text }: EncryptedBlockRef<'_>, ) -> Result, Error> { XChaCha20Poly1305::new(&self.0) - .decrypt(&nonce, cipher_text) + .decrypt(nonce, cipher_text) .map_err(|aead::Error| Error::Decrypt) } diff --git a/crates/crypto/src/cloud/secret_key.rs b/crates/crypto/src/cloud/secret_key.rs index c1df94f9f..2477684ad 100644 --- a/crates/crypto/src/cloud/secret_key.rs +++ b/crates/crypto/src/cloud/secret_key.rs @@ -191,7 +191,7 @@ mod tests { let EncryptedBlock { nonce, cipher_text } = key.encrypt(message, &mut rng).unwrap(); let mut bytes = Vec::with_capacity(nonce.len() + cipher_text.len()); - bytes.extend_from_slice(&nonce); + bytes.extend_from_slice(nonce.as_slice()); bytes.extend(cipher_text); assert_eq!( diff --git a/crates/crypto/src/primitives.rs b/crates/crypto/src/primitives.rs index a37981a01..1d8335fc1 100644 --- a/crates/crypto/src/primitives.rs +++ b/crates/crypto/src/primitives.rs @@ -16,7 +16,7 @@ pub struct EncryptedBlock { } pub struct EncryptedBlockRef<'e> { - pub nonce: OneShotNonce, + pub nonce: &'e OneShotNonce, pub cipher_text: &'e [u8], } @@ -25,7 +25,7 @@ impl<'e> From<&'e [u8]> for EncryptedBlockRef<'e> { let (nonce, cipher_text) = cipher_text.split_at(size_of::()); Self { - nonce: OneShotNonce::try_from(nonce).expect("we split the correct amount"), + nonce: nonce.try_into().expect("we split the correct amount"), cipher_text, } } diff --git a/crates/sync-generator/src/model.rs b/crates/sync-generator/src/model.rs index 767c1d820..e171634b8 100644 --- a/crates/sync-generator/src/model.rs +++ b/crates/sync-generator/src/model.rs @@ -46,7 +46,7 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { RefinedFieldWalker::Scalar(scalar_field) => { (!scalar_field.is_in_required_relation()).then(|| { quote! { - #model_name_snake::#field_name_snake::set(::rmpv::ext::from_value(val).unwrap()), + #model_name_snake::#field_name_snake::set(::rmpv::ext::from_value(val)?), } }) } @@ -59,11 +59,19 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { |i| { if i.count() == 1 { Some(quote! {{ - let val: std::collections::HashMap = ::rmpv::ext::from_value(val).unwrap(); - let val = val.into_iter().next().unwrap(); + + let (field, value) = ::rmpv + ::ext + ::from_value::>(val)? + .into_iter() + .next() + .ok_or(Error::MissingRelationData { + field: field.to_string(), + model: #relation_model_name_snake::NAME.to_string() + })?; #model_name_snake::#field_name_snake::connect( - #relation_model_name_snake::UniqueWhereParam::deserialize(&val.0, val.1).unwrap() + #relation_model_name_snake::UniqueWhereParam::deserialize(&field, value)? ) }}) } else { @@ -81,10 +89,13 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { } else { quote! { impl #model_name_snake::SetParam { - pub fn deserialize(field: &str, val: ::rmpv::Value) -> Option { - Some(match field { + pub fn deserialize(field: &str, val: ::rmpv::Value) -> Result { + Ok(match field { #(#field_matches)* - _ => return None + _ => return Err(Error::FieldNotFound { + field: field.to_string(), + model: #model_name_snake::NAME.to_string(), + }), }) } } @@ -97,9 +108,12 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { Module::new( model.name(), quote! { - use super::prisma::*; + use super::Error; + use prisma_client_rust::scalar_types::*; + use super::prisma::*; + #sync_id #set_param_impl @@ -172,7 +186,7 @@ fn process_unique_params(model: Walker<'_, ModelId>, model_name_snake: &Ident) - Some(quote!(#model_name_snake::#field_name_snake::NAME => #model_name_snake::#field_name_snake::equals( - ::rmpv::ext::from_value(val).unwrap() + ::rmpv::ext::from_value(val)? ), )) } @@ -185,10 +199,13 @@ fn process_unique_params(model: Walker<'_, ModelId>, model_name_snake: &Ident) - } else { quote! { impl #model_name_snake::UniqueWhereParam { - pub fn deserialize(field: &str, val: ::rmpv::Value) -> Option { - Some(match field { + pub fn deserialize(field: &str, val: ::rmpv::Value) -> Result { + Ok(match field { #(#field_matches)* - _ => return None + _ => return Err(Error::FieldNotFound { + field: field.to_string(), + model: #model_name_snake::NAME.to_string(), + }) }) } } diff --git a/crates/sync-generator/src/sync_data.rs b/crates/sync-generator/src/sync_data.rs index 9e9fdd937..e8ee713e6 100644 --- a/crates/sync-generator/src/sync_data.rs +++ b/crates/sync-generator/src/sync_data.rs @@ -7,7 +7,7 @@ use prisma_models::walkers::{FieldWalker, ScalarFieldWalker}; use crate::{ModelSyncType, ModelWithSyncType}; pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { - let (variants, matches): (Vec<_>, Vec<_>) = models + let (variants, matches) = models .iter() .filter_map(|(model, sync_type)| { let model_name_snake = snake_ident(model.name()); @@ -26,12 +26,12 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { quote!(#model_name_pascal(#model_name_snake::SyncId, sd_sync::CRDTOperationData)), quote! { #model_name_snake::MODEL_ID => - Self::#model_name_pascal(rmpv::ext::from_value(op.record_id).ok()?, op.data) + Self::#model_name_pascal(rmpv::ext::from_value(op.record_id)?, op.data) }, ) }) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); let exec_matches = models.iter().filter_map(|(model, sync_type)| { let model_name_pascal = pascal_ident(model.name()); @@ -54,20 +54,22 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { }) }); + let error_enum = declare_error_enum(); + quote! { pub enum ModelSyncData { #(#variants),* } impl ModelSyncData { - pub fn from_op(op: sd_sync::CRDTOperation) -> Option { - Some(match op.model_id { + pub fn from_op(op: sd_sync::CRDTOperation) -> Result { + Ok(match op.model_id { #(#matches),*, - _ => return None + _ => return Err(Error::InvalidModelId(op.model_id)), }) } - pub async fn exec(self, db: &prisma::PrismaClient) -> prisma_client_rust::Result<()> { + pub async fn exec(self, db: &prisma::PrismaClient) -> Result<(), Error> { match self { #(#exec_matches),* } @@ -75,6 +77,69 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { Ok(()) } } + + #error_enum + } +} + +fn declare_error_enum() -> TokenStream { + quote! { + #[derive(Debug)] + pub enum Error { + Rmpv(rmpv::ext::Error), + RmpSerialize(rmp_serde::encode::Error), + Prisma(prisma_client_rust::QueryError), + InvalidModelId(sd_sync::ModelId), + FieldNotFound { field: String, model: String }, + MissingRelationData { field: String, model: String }, + RelatedEntryNotFound { field: String, model: String }, + } + + impl From for Error { + fn from(e: rmpv::ext::Error) -> Self { + Self::Rmpv(e) + } + } + + impl From for Error { + fn from(e: rmp_serde::encode::Error) -> Self { + Self::RmpSerialize(e) + } + } + + impl From for Error { + fn from(e: prisma_client_rust::QueryError) -> Self { + Self::Prisma(e) + } + } + + impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Rmpv(e) => write!(f, "Failed to serialize or deserialize rmpv data: {e}"), + Self::RmpSerialize(e) => write!(f, "Failed to serialize rmp data: {e}"), + Self::Prisma(e) => write!(f, "Prisma error: {e}"), + Self::InvalidModelId(id) => write!(f, "Invalid model id: {id}"), + Self::FieldNotFound { field, model } => { + write!(f, "Field '{field}' not found in model '{model}'") + } + Self::MissingRelationData { field, model } => { + write!( + f, + "Field '{field}' missing relation data in model '{model}'" + ) + } + Self::RelatedEntryNotFound { field, model } => { + write!( + f, + "Related entry for field '{field}' not found in table '{model}'" + ) + } + } + } + } + + impl std::error::Error for Error {} } } @@ -103,6 +168,7 @@ fn handle_crdt_ops_relation( .and_then(|(_m, sync)| sync.as_ref()) .map(|sync| snake_ident(sync.sync_id()[0].name())) .expect("missing sync id field name for relation"); + let item_model_name_snake = snake_ident(item.related_model().name()); let item_field_name_snake = snake_ident(item.name()); @@ -155,11 +221,15 @@ fn handle_crdt_ops_relation( vec![], ) .exec() - .await - .ok(); + .await?; }, - sd_sync::CRDTOperationData::Update { field, value } => { - let data = vec![prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap()]; + + sd_sync::CRDTOperationData::Update(data) => { + let data = data.into_iter() + .map(|(field, value)| { + prisma::#model_name_snake::SetParam::deserialize(&field, value) + }) + .collect::, _>>()?; db.#model_name_snake() .upsert( @@ -171,15 +241,14 @@ fn handle_crdt_ops_relation( data, ) .exec() - .await - .ok(); + .await?; }, + sd_sync::CRDTOperationData::Delete => { db.#model_name_snake() .delete(id) .exec() - .await - .ok(); + .await?; }, } } @@ -198,8 +267,10 @@ fn handle_crdt_ops_shared( .expect("missing fields") .next() .expect("empty fields"); + let id_name_snake = snake_ident(scalar_field.name()); let field_name_snake = snake_ident(rel.name()); + let opposite_model_name_snake = snake_ident( rel.opposite_relation_field() .expect("missing opposite relation field") @@ -211,12 +282,16 @@ fn handle_crdt_ops_shared( id.#field_name_snake.pub_id.clone() )); + let pub_id_field = format!("{field_name_snake}::pub_id"); + let rel_fetch = quote! { let rel = db.#opposite_model_name_snake() .find_unique(#relation_equals_condition) .exec() - .await? - .unwrap(); + .await?.ok_or_else(|| Error::RelatedEntryNotFound { + field: #pub_id_field.to_string(), + model: prisma::#opposite_model_name_snake::NAME.to_string(), + })?; }; ( @@ -226,6 +301,7 @@ fn handle_crdt_ops_shared( relation_equals_condition, ) } + RefinedFieldWalker::Scalar(s) => { let field_name_snake = snake_ident(s.name()); let thing = quote!(id.#field_name_snake.clone()); @@ -238,24 +314,12 @@ fn handle_crdt_ops_shared( #get_id match data { - sd_sync::CRDTOperationData::Create(data) => { - let data: Vec<_> = data.into_iter().map(|(field, value)| { - prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap() - }).collect(); - - db.#model_name_snake() - .upsert( - prisma::#model_name_snake::#id_name_snake::equals(#equals_value), - prisma::#model_name_snake::create(#create_id, data.clone()), - data - ) - .exec() - .await?; - }, - sd_sync::CRDTOperationData::Update { field, value } => { - let data = vec![ - prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap() - ]; + sd_sync::CRDTOperationData::Create(data) | sd_sync::CRDTOperationData::Update(data) => { + let data = data.into_iter() + .map(|(field, value)| { + prisma::#model_name_snake::SetParam::deserialize(&field, value) + }) + .collect::, _>>()?; db.#model_name_snake() .upsert( @@ -266,6 +330,7 @@ fn handle_crdt_ops_shared( .exec() .await?; }, + sd_sync::CRDTOperationData::Delete => { db.#model_name_snake() .delete(prisma::#model_name_snake::#id_name_snake::equals(#equals_value)) @@ -275,8 +340,8 @@ fn handle_crdt_ops_shared( db.crdt_operation() .delete_many(vec![ prisma::crdt_operation::model::equals(#model_id as i32), - prisma::crdt_operation::record_id::equals(rmp_serde::to_vec(&id).unwrap()), - prisma::crdt_operation::kind::equals(sd_sync::OperationKind::Create.to_string()) + prisma::crdt_operation::record_id::equals(rmp_serde::to_vec(&id)?), + prisma::crdt_operation::kind::equals(sd_sync::OperationKind::Create.to_string()), ]) .exec() .await?; diff --git a/crates/sync/src/crdt.rs b/crates/sync/src/crdt.rs index 13eda3ffa..3cbdf23d2 100644 --- a/crates/sync/src/crdt.rs +++ b/crates/sync/src/crdt.rs @@ -7,7 +7,7 @@ use uhlc::NTP64; pub enum OperationKind<'a> { Create, - Update(&'a str), + Update(Vec<&'a str>), Delete, } @@ -15,7 +15,7 @@ impl fmt::Display for OperationKind<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { OperationKind::Create => write!(f, "c"), - OperationKind::Update(field) => write!(f, "u:{field}"), + OperationKind::Update(fields) => write!(f, "u:{}:", fields.join(":")), OperationKind::Delete => write!(f, "d"), } } @@ -26,7 +26,7 @@ pub enum CRDTOperationData { #[serde(rename = "c")] Create(BTreeMap), #[serde(rename = "u")] - Update { field: String, value: rmpv::Value }, + Update(BTreeMap), #[serde(rename = "d")] Delete, } @@ -41,7 +41,9 @@ impl CRDTOperationData { pub fn as_kind(&self) -> OperationKind<'_> { match self { Self::Create(_) => OperationKind::Create, - Self::Update { field, .. } => OperationKind::Update(field), + Self::Update(fields_and_values) => { + OperationKind::Update(fields_and_values.keys().map(String::as_str).collect()) + } Self::Delete => OperationKind::Delete, } } diff --git a/crates/sync/src/factory.rs b/crates/sync/src/factory.rs index 9fed9a52f..7c73f8b5f 100644 --- a/crates/sync/src/factory.rs +++ b/crates/sync/src/factory.rs @@ -46,15 +46,16 @@ pub trait OperationFactory { fn shared_update( &self, id: impl SyncId, - field: impl Into, - value: rmpv::Value, + values: impl IntoIterator + 'static, ) -> CRDTOperation { self.new_op( &id, - CRDTOperationData::Update { - field: field.into(), - value, - }, + CRDTOperationData::Update( + values + .into_iter() + .map(|(name, value)| (name.to_string(), value)) + .collect(), + ), ) } @@ -77,20 +78,23 @@ pub trait OperationFactory { ), ) } + fn relation_update( &self, id: impl RelationSyncId, - field: impl Into, - value: rmpv::Value, + values: impl IntoIterator + 'static, ) -> CRDTOperation { self.new_op( &id, - CRDTOperationData::Update { - field: field.into(), - value, - }, + CRDTOperationData::Update( + values + .into_iter() + .map(|(name, value)| (name.to_string(), value)) + .collect(), + ), ) } + fn relation_delete( &self, id: impl RelationSyncId, @@ -101,9 +105,14 @@ pub trait OperationFactory { #[macro_export] macro_rules! sync_entry { + (nil, $($prisma_column_module:tt)+) => { + ($($prisma_column_module)+::NAME, ::sd_utils::msgpack!(nil)) + }; + ($value:expr, $($prisma_column_module:tt)+) => { ($($prisma_column_module)+::NAME, ::sd_utils::msgpack!($value)) - } + }; + } #[macro_export] @@ -124,6 +133,28 @@ macro_rules! sync_db_entry { }} } +#[macro_export] +macro_rules! sync_db_nullable_entry { + ($value:expr, $($prisma_column_module:tt)+) => {{ + let value = $value.into(); + ( + $crate::sync_entry!(&value, $($prisma_column_module)+), + $($prisma_column_module)+::set(value) + ) + }} +} + +#[macro_export] +macro_rules! sync_db_not_null_entry { + ($value:expr, $($prisma_column_module:tt)+) => {{ + let value = $value.into(); + ( + $crate::sync_entry!(&value, $($prisma_column_module)+), + $($prisma_column_module)+::set(value) + ) + }} +} + #[macro_export] macro_rules! option_sync_db_entry { ($value:expr, $($prisma_column_module:tt)+) => { diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 45891b344..2960a21c4 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -14,6 +14,8 @@ sd-prisma = { path = "../prisma" } # Workspace dependencies chrono = { workspace = true } prisma-client-rust = { workspace = true } +rmp-serde = { workspace = true } +rmpv = { workspace = true } rspc = { workspace = true, features = ["unstable"] } thiserror = { workspace = true } tracing = { workspace = true }