Merge pull request #2633 from spacedriveapp/eng-1828-migration-to-new-cloud-api-system

[ENG-1828] Migrate to new Cloud Services API
This commit is contained in:
Lynx
2024-10-30 18:50:04 -07:00
committed by GitHub
261 changed files with 13665 additions and 9049 deletions

BIN
Cargo.lock generated
View File

Binary file not shown.

View File

@@ -19,6 +19,9 @@ repository = "https://github.com/spacedriveapp/spacedrive"
rust-version = "1.81"
[workspace.dependencies]
# First party dependencies
sd-cloud-schema = { git = "https://github.com/spacedriveapp/cloud-services-schema", rev = "bbc69c5cb2" }
# Third party dependencies used by one or more of our crates
async-channel = "2.3"
async-stream = "0.3.6"
@@ -26,23 +29,25 @@ async-trait = "0.1.83"
axum = "0.7.7"
axum-extra = "0.9.4"
base64 = "0.22.1"
blake3 = "1.5"
blake3 = "1.5.4"
bytes = "1.7.1" # Update blocked by hyper
chrono = "0.4.38"
ed25519-dalek = "2.1"
flume = "0.11.0"
futures = "0.3.31"
futures-concurrency = "7.6"
globset = "0.4.15"
http = "1.1"
hyper = "1.5"
image = "0.24.9" # Update blocked due to https://github.com/image-rs/image/issues/2230
image = "0.25.4"
itertools = "0.13.0"
lending-stream = "1.0"
libc = "0.2"
libc = "0.2.159"
mimalloc = "0.1.43"
normpath = "1.3"
pin-project-lite = "0.2.14"
rand = "0.9.0-alpha.2"
regex = "1"
regex = "1.11"
reqwest = { version = "0.12.8", default-features = false }
rmp = "0.8.14"
rmp-serde = "1.3"
@@ -62,7 +67,8 @@ tracing-subscriber = "0.3.18"
tracing-test = "0.2.5"
uhlc = "0.8.0" # Must follow version used by specta
uuid = "1.10" # Must follow version used by specta
webp = "0.2.6" # Update blocked by image
webp = "0.3.0"
zeroize = "1.8"
[workspace.dependencies.rspc]
git = "https://github.com/spacedriveapp/rspc.git"

View File

@@ -20,26 +20,28 @@
"@sd/ui": "workspace:*",
"@t3-oss/env-core": "^0.7.1",
"@tanstack/react-query": "^5.59",
"@tauri-apps/api": "=2.0.2",
"@tauri-apps/plugin-dialog": "2.0.0",
"@tauri-apps/api": "=2.0.3",
"@tauri-apps/plugin-dialog": "2.0.1",
"@tauri-apps/plugin-http": "2.0.1",
"@tauri-apps/plugin-os": "2.0.0",
"@tauri-apps/plugin-shell": "2.0.0",
"@tauri-apps/plugin-shell": "2.0.1",
"consistent-hash": "^1.2.2",
"immer": "^10.0.3",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-router-dom": "=6.20.1",
"sonner": "^1.0.3"
"sonner": "^1.0.3",
"supertokens-web-js": "^0.13.0"
},
"devDependencies": {
"@sd/config": "workspace:*",
"@sentry/vite-plugin": "^2.16.0",
"@tauri-apps/cli": "2.0.1",
"@tauri-apps/cli": "2.0.4",
"@types/react": "^18.2.67",
"@types/react-dom": "^18.2.22",
"sass": "^1.72.0",
"typescript": "^5.6.2",
"vite": "^5.2.0",
"vite-tsconfig-paths": "^4.3.2"
"vite": "^5.4.9",
"vite-tsconfig-paths": "^5.0.1"
}
}

View File

@@ -35,12 +35,15 @@ uuid = { workspace = true, features = ["serde"] }
# Specific Desktop dependencies
# WARNING: Do NOT enable default features, as that vendors dbus (see below)
opener = { version = "0.7.1", features = ["reveal"], default-features = false }
specta-typescript = "=0.0.7"
tauri-plugin-dialog = "=2.0.2"
tauri-plugin-os = "=2.0.1"
tauri-plugin-shell = "=2.0.2"
tauri-plugin-updater = "=2.0.2"
opener = { version = "0.7.1", features = ["reveal"], default-features = false }
specta-typescript = "=0.0.7"
tauri-plugin-clipboard-manager = "=2.0.1"
tauri-plugin-deep-link = "=2.0.1"
tauri-plugin-dialog = "=2.0.3"
tauri-plugin-http = "=2.0.3"
tauri-plugin-os = "=2.0.1"
tauri-plugin-shell = "=2.0.2"
tauri-plugin-updater = "=2.0.2"
# memory allocator
mimalloc = { workspace = true }

View File

@@ -17,6 +17,7 @@
"dialog:allow-open",
"dialog:allow-save",
"dialog:allow-confirm",
"deep-link:default",
"os:allow-os-type",
"core:window:allow-close",
"core:window:allow-create",
@@ -24,6 +25,32 @@
"core:window:allow-minimize",
"core:window:allow-toggle-maximize",
"core:window:allow-start-dragging",
"core:webview:allow-internal-toggle-devtools"
"core:webview:allow-internal-toggle-devtools",
{
"identifier": "http:default",
"allow": [
{
"url": "http://ipc.localhost"
},
{
"url": "http://asset.localhost"
},
{
"url": "http://localhost:8001"
},
{
"url": "http://tauri.localhost"
},
{
"url": "http://localhost:9420"
},
{
"url": "https://auth.spacedrive.com"
},
{
"url": "https://plausible.io"
}
]
}
]
}

View File

@@ -14,13 +14,13 @@ use sd_core::{Node, NodeError};
use sd_fda::DiskAccess;
use serde::{Deserialize, Serialize};
use specta_typescript::Typescript;
use tauri::Emitter;
use tauri::{async_runtime::block_on, webview::PlatformWebview, AppHandle, Manager, WindowEvent};
use tauri::{Emitter, Listener};
use tauri_plugins::{sd_error_plugin, sd_server_plugin};
use tauri_specta::{collect_events, Builder};
use tokio::task::block_in_place;
use tokio::time::sleep;
use tracing::error;
use tracing::{debug, error};
mod file;
mod menu;
@@ -179,7 +179,11 @@ pub enum DragAndDropEvent {
Cancelled,
}
const CLIENT_ID: &str = "2abb241e-40b8-4517-a3e3-5594375c8fbb";
#[derive(Debug, Clone, Serialize, Deserialize, specta::Type, tauri_specta::Event)]
#[serde(rename_all = "camelCase")]
pub struct DeepLinkEvent {
data: String,
}
#[tokio::main]
async fn main() -> tauri::Result<()> {
@@ -221,9 +225,20 @@ async fn main() -> tauri::Result<()> {
tauri::Builder::default()
.invoke_handler(builder.invoke_handler())
.plugin(tauri_plugin_deep_link::init())
.setup(move |app| {
// We need a the app handle to determine the data directory now.
// This means all the setup code has to be within `setup`, however it doesn't support async so we `block_on`.
let handle = app.handle().clone();
app.listen("deep-link://new-url", move |event| {
let deep_link_event = DeepLinkEvent {
data: event.payload().to_string(),
};
println!("Deep link event={:?}", deep_link_event);
handle.emit("deeplink", deep_link_event).unwrap();
});
block_in_place(|| {
block_on(async move {
builder.mount_events(app);
@@ -239,10 +254,7 @@ async fn main() -> tauri::Result<()> {
// The `_guard` must be assigned to variable for flushing remaining logs on main exit through Drop
let (_guard, result) = match Node::init_logger(&data_dir) {
Ok(guard) => (
Some(guard),
Node::new(data_dir, sd_core::Env::new(CLIENT_ID)).await,
),
Ok(guard) => (Some(guard), Node::new(data_dir).await),
Err(err) => (None, Err(NodeError::Logger(err))),
};
@@ -256,7 +268,7 @@ async fn main() -> tauri::Result<()> {
}
};
let should_clear_localstorage = node.libraries.get_all().await.is_empty();
let should_clear_local_storage = node.libraries.get_all().await.is_empty();
handle.plugin(rspc::integrations::tauri::plugin(router, {
let node = node.clone();
@@ -266,8 +278,8 @@ async fn main() -> tauri::Result<()> {
handle.manage(node.clone());
handle.windows().iter().for_each(|(_, window)| {
if should_clear_localstorage {
println!("cleaning localStorage");
if should_clear_local_storage {
debug!("cleaning localStorage");
for webview in window.webviews() {
webview.eval("localStorage.clear();").ok();
}
@@ -344,6 +356,7 @@ async fn main() -> tauri::Result<()> {
.plugin(tauri_plugin_dialog::init())
.plugin(tauri_plugin_os::init())
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_http::init())
// TODO: Bring back Tauri Plugin Window State - it was buggy so we removed it.
.plugin(tauri_plugin_updater::Builder::new().build())
.plugin(updater::plugin())

View File

@@ -1,5 +1,5 @@
{
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.2/core/tauri-config-schema/schema.json",
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.8/crates/tauri-cli/tauri.config.schema.json",
"productName": "Spacedrive",
"identifier": "com.spacedrive.desktop",
"build": {
@@ -36,7 +36,12 @@
}
],
"security": {
"csp": "default-src webkit-pdfjs-viewer: asset: https://asset.localhost blob: data: filesystem: ws: wss: http: https: tauri: 'unsafe-eval' 'unsafe-inline' 'self' img-src: 'self'"
"csp": {
"default-src": "'self' webkit-pdfjs-viewer: asset: http://asset.localhost blob: data: filesystem: http: https: tauri:",
"connect-src": "'self' ipc: http://ipc.localhost ws: wss: http: https: tauri:",
"img-src": "'self' asset: http://asset.localhost blob: data: filesystem: http: https: tauri:",
"style-src": "'self' 'unsafe-inline' http: https: tauri:"
}
}
},
"bundle": {
@@ -100,6 +105,12 @@
"endpoints": [
"https://spacedrive.com/api/releases/tauri/{{version}}/{{target}}/{{arch}}"
]
},
"deep-link": {
"mobile": [],
"desktop": {
"schemes": ["spacedrive"]
}
}
}
}

View File

@@ -3,9 +3,10 @@ import { QueryClientProvider } from '@tanstack/react-query';
import { listen } from '@tauri-apps/api/event';
import { PropsWithChildren, startTransition, useEffect, useMemo, useRef, useState } from 'react';
import { createPortal } from 'react-dom';
import { RspcProvider } from '@sd/client';
import { RspcProvider, useBridgeMutation } from '@sd/client';
import {
createRoutes,
DeeplinkEvent,
ErrorPage,
KeybindEvent,
PlatformProvider,
@@ -17,14 +18,11 @@ import { RouteTitleContext } from '@sd/interface/hooks/useRouteTitle';
import '@sd/ui/style/style.scss';
import { useLocale } from '@sd/interface/hooks';
import { commands } from './commands';
import { platform } from './platform';
import { queryClient } from './query';
import { createMemoryRouterWithHistory } from './router';
import { createUpdater } from './updater';
import SuperTokens from 'supertokens-web-js';
import EmailPassword from 'supertokens-web-js/recipe/emailpassword';
import Passwordless from 'supertokens-web-js/recipe/passwordless';
import Session from 'supertokens-web-js/recipe/session';
import ThirdParty from 'supertokens-web-js/recipe/thirdparty';
// TODO: Bring this back once upstream is fixed up.
// const client = hooks.createClient({
// links: [
@@ -34,6 +32,32 @@ import { createUpdater } from './updater';
// tauriLink()
// ]
// });
import getCookieHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/cookieHandler';
import getWindowHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/windowHandler';
import { useLocale } from '@sd/interface/hooks';
import { AUTH_SERVER_URL, getTokens } from '@sd/interface/util';
import { commands } from './commands';
import { platform } from './platform';
import { queryClient } from './query';
import { createMemoryRouterWithHistory } from './router';
import { createUpdater } from './updater';
SuperTokens.init({
appInfo: {
apiDomain: AUTH_SERVER_URL,
apiBasePath: '/api/auth',
appName: 'Spacedrive Auth Service'
},
cookieHandler: getCookieHandler,
windowHandler: getWindowHandler,
recipeList: [
Session.init({ tokenTransferMethod: 'header' }),
EmailPassword.init(),
ThirdParty.init(),
Passwordless.init()
]
});
const startupError = (window as any).__SD_ERROR__ as string | undefined;
@@ -41,15 +65,31 @@ export default function App() {
useEffect(() => {
// This tells Tauri to show the current window because it's finished loading
commands.appReady();
// .then(() => {
// if (import.meta.env.PROD) window.fetch = fetch;
// });
}, []);
useEffect(() => {
const keybindListener = listen('keybind', (input) => {
document.dispatchEvent(new KeybindEvent(input.payload as string));
});
const deeplinkListener = listen('deeplink', async (data) => {
const payload = (data.payload as any).data as string;
if (!payload) return;
const json = JSON.parse(payload)[0];
if (!json) return;
//json output: "spacedrive://-/URL"
if (typeof json !== 'string') return;
if (!json.startsWith('spacedrive://-')) return;
const url = (json as string).split('://-/')[1];
if (!url) return;
document.dispatchEvent(new DeeplinkEvent(url));
});
return () => {
keybindListener.then((unlisten) => unlisten());
deeplinkListener.then((unlisten) => unlisten());
};
}, []);
@@ -79,6 +119,15 @@ type RedirectPath = { pathname: string; search: string | undefined };
function AppInner() {
const [tabs, setTabs] = useState(() => [createTab()]);
const [selectedTabIndex, setSelectedTabIndex] = useState(0);
const tokens = getTokens();
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
useEffect(() => {
// If the access token and/or refresh token are missing, we need to skip the cloud bootstrap
if (tokens.accessToken.length === 0 || tokens.refreshToken.length === 0) return;
cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const selectedTab = tabs[selectedTabIndex]!;

View File

@@ -5,7 +5,8 @@
"declarationDir": "dist",
"paths": {
"~/*": ["./src/*"]
}
},
"moduleResolution": "bundler"
},
"include": ["src"],
"references": [

View File

@@ -19,18 +19,6 @@ const Pre: FC<{ children: React.ReactNode }> = ({ children }) => {
return (
<div ref={textInput} className="relative">
{/* <button
aria-label="Copy code"
type="button"
className="absolute right-2 top-2 z-10 rounded-md bg-app-box p-3 text-white/60 transition-colors duration-200 ease-in-out hover:bg-app-darkBox"
onClick={onCopy}
>
{copied ? (
<Check size={18} className="text-green-400" />
) : (
<Copy size={18} className="text-white opacity-70" />
)}
</button> */}
<Button
size="md"
rounding="both"

View File

@@ -312,6 +312,7 @@ buck-out/
.fakebuckversion
### ReactNative.Xcode Stack ###
ios/
## Xcode 8 and earlier
@@ -319,6 +320,7 @@ buck-out/
.gradle
**/build/
!src/**/build/
android/
# Ignore Gradle GUI config
gradle-app.setting

View File

@@ -15,17 +15,17 @@ err() {
if [ -z "${HOME:-}" ]; then
case "$(uname)" in
"Darwin")
HOME="$(CDPATH='' cd -- "$(osascript -e 'set output to (POSIX path of (path to home folder))')" && pwd -P)"
;;
"Linux")
HOME="$(CDPATH='' cd -- "$(getent passwd "$(id -un)" | cut -d: -f6)" && pwd -P)"
;;
*)
err "Your OS ($(uname)) is not supported by this script." \
'We would welcome a PR or some help adding your OS to this script.' \
'https://github.com/spacedriveapp/spacedrive/issues'
;;
"Darwin")
HOME="$(CDPATH='' cd -- "$(osascript -e 'set output to (POSIX path of (path to home folder))')" && pwd -P)"
;;
"Linux")
HOME="$(CDPATH='' cd -- "$(getent passwd "$(id -un)" | cut -d: -f6)" && pwd -P)"
;;
*)
err "Your OS ($(uname)) is not supported by this script." \
'We would welcome a PR or some help adding your OS to this script.' \
'https://github.com/spacedriveapp/spacedrive/issues'
;;
esac
export HOME
@@ -47,18 +47,19 @@ export PATH="${CARGO_HOME:-"${HOME}/.cargo"}/bin:$PATH"
if [ "${CI:-}" = "true" ]; then
# TODO: This need to be adjusted for future mobile release CI
case "$(uname -m)" in
"arm64" | "aarch64")
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
;;
"x86_64")
ANDROID_BUILD_TARGET_LIST="x86_64"
;;
*)
err 'Unsupported architecture for CI build.'
;;
"arm64" | "aarch64")
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
;;
"x86_64")
ANDROID_BUILD_TARGET_LIST="x86_64"
;;
*)
err 'Unsupported architecture for CI build.'
;;
esac
else
ANDROID_BUILD_TARGET_LIST="arm64-v8a armeabi-v7a x86_64"
# ANDROID_BUILD_TARGET_LIST="arm64-v8a armeabi-v7a x86_64"
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
fi
# Configure build targets CLI arg for `cargo ndk`
@@ -69,4 +70,5 @@ for _target in $ANDROID_BUILD_TARGET_LIST; do
done
cd "${__dirname}/crate"
cargo ndk --platform 34 "$@" -o "$OUTPUT_DIRECTORY" build --release
cargo ndk --platform 34 "$@" -o "$OUTPUT_DIRECTORY" build
# \ --release

View File

@@ -37,9 +37,7 @@ pub extern "system" fn Java_com_spacedrive_core_SDCoreModule_registerCoreEventLi
if let Err(err) = result {
// TODO: Send rspc error or something here so we can show this in the UI.
// TODO: Maybe reinitialise the core cause it could be in an invalid state?
println!(
"Error in Java_com_spacedrive_core_SDCoreModule_registerCoreEventListener: {err:?}"
);
error!("Error in Java_com_spacedrive_core_SDCoreModule_registerCoreEventListener: {err:?}");
}
}

View File

@@ -7,7 +7,6 @@ license.workspace = true
repository.workspace = true
rust-version.workspace = true
# Spacedrive Sub-crates
[target.'cfg(target_os = "ios")'.dependencies]
sd-core = { default-features = false, features = [

View File

@@ -34,8 +34,6 @@ pub static SUBSCRIPTIONS: LazyLock<
pub static EVENT_SENDER: OnceLock<mpsc::Sender<Response>> = OnceLock::new();
pub const CLIENT_ID: &str = "d068776a-05b6-4aaa-9001-4d01734e1944";
pub struct MobileSender<'a> {
resp: &'a mut Option<Response>,
}
@@ -76,7 +74,7 @@ pub fn handle_core_msg(
None => {
let _guard = Node::init_logger(&data_dir);
let new_node = match Node::new(data_dir, sd_core::Env::new(CLIENT_ID)).await {
let new_node = match Node::new(data_dir).await {
Ok(node) => node,
Err(e) => {
error!(?e, "Failed to initialize node;");

View File

@@ -37,7 +37,8 @@ Pod::Spec.new do |s|
ffmpeg_frameworks = [
"-framework AudioToolbox",
"-framework VideoToolbox",
"-framework AVFoundation"
"-framework AVFoundation",
"-framework SystemConfiguration",
].join(' ')
s.xcconfig = {

View File

@@ -48,10 +48,10 @@ mkdir -p "$TARGET_DIRECTORY"
TARGET_DIRECTORY="$(CDPATH='' cd -- "$TARGET_DIRECTORY" && pwd -P)"
TARGET_CONFIG=debug
if [ "${CONFIGURATION:-}" = "Release" ]; then
set -- --release
TARGET_CONFIG=release
fi
# if [ "${CONFIGURATION:-}" = "Release" ]; then
# set -- --release
# TARGET_CONFIG=release
# fi
trap 'if [ -e "${CARGO_CONFIG}.bak" ]; then mv "${CARGO_CONFIG}.bak" "$CARGO_CONFIG"; fi' EXIT
@@ -59,21 +59,21 @@ trap 'if [ -e "${CARGO_CONFIG}.bak" ]; then mv "${CARGO_CONFIG}.bak" "$CARGO_CON
RUST_PATH="${CARGO_HOME:-"${HOME}/.cargo"}/bin:$(brew --prefix)/bin:$(env -i /bin/bash --noprofile --norc -c 'echo $PATH')"
if [ "${PLATFORM_NAME:-}" = "iphonesimulator" ]; then
case "$(uname -m)" in
"arm64" | "aarch64") # M series
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios-sim\" }|" "$CARGO_CONFIG"
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target aarch64-apple-ios-sim "$@"
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/aarch64-apple-ios-sim/${TARGET_CONFIG}/libsd_mobile_ios.a"
symlink_libs "${DEPS}/aarch64-apple-ios-sim/lib" "$TARGET_DIRECTORY"
;;
"x86_64") # Intel
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/x86_64-apple-ios\" }|" "$CARGO_CONFIG"
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target x86_64-apple-ios "$@"
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/x86_64-apple-ios/${TARGET_CONFIG}/libsd_mobile_ios.a"
symlink_libs "${DEPS}/x86_64-apple-ios/lib" "$TARGET_DIRECTORY"
;;
*)
err 'Unsupported architecture.'
;;
"arm64" | "aarch64") # M series
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios-sim\" }|" "$CARGO_CONFIG"
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target aarch64-apple-ios-sim "$@"
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/aarch64-apple-ios-sim/${TARGET_CONFIG}/libsd_mobile_ios.a"
symlink_libs "${DEPS}/aarch64-apple-ios-sim/lib" "$TARGET_DIRECTORY"
;;
"x86_64") # Intel
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/x86_64-apple-ios\" }|" "$CARGO_CONFIG"
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target x86_64-apple-ios "$@"
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/x86_64-apple-ios/${TARGET_CONFIG}/libsd_mobile_ios.a"
symlink_libs "${DEPS}/x86_64-apple-ios/lib" "$TARGET_DIRECTORY"
;;
*)
err 'Unsupported architecture.'
;;
esac
else
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios\" }|" "$CARGO_CONFIG"

View File

@@ -6,7 +6,7 @@
"private": true,
"scripts": {
"start": "expo start --dev-client",
"android": "expo run:android",
"android": "java -version 2>&1 | grep -q 'version \"17' && expo run:android || echo 'Java version 17 is required to be the running version. Please switch or uninstall the current version (Version '$(java -version 2>&1 | awk -F '\"' '/version/ {print $2}')').' && exit 1",
"ios": "expo run:ios",
"prebuild": "expo prebuild",
"xcode": "open ios/Spacedrive.xcworkspace",
@@ -22,11 +22,10 @@
"@gorhom/bottom-sheet": "^4.6.1",
"@hookform/resolvers": "^3.1.0",
"@spacedrive/rspc-client": "github:spacedriveapp/rspc#path:packages/client&6a77167495",
"@spacedrive/rspc-react": "github:spacedriveapp/rspc#path:packages/react&6a77167495",
"@react-native-async-storage/async-storage": "~1.23.1",
"@react-native-masked-view/masked-view": "^0.3.1",
"@react-navigation/bottom-tabs": "^6.5.19",
"@react-navigation/drawer": "^6.6.14",
"@react-navigation/drawer": "^6.6.15",
"@react-navigation/native": "^6.1.16",
"@react-navigation/native-stack": "^6.9.25",
"@sd/assets": "workspace:*",
@@ -37,12 +36,12 @@
"class-variance-authority": "^0.7.0",
"dayjs": "^1.11.10",
"event-target-polyfill": "^0.0.4",
"expo": "~51.0.28",
"expo-av": "^14.0.6",
"expo": "~51.0.32",
"expo-av": "^14.0.7",
"expo-blur": "^13.0.2",
"expo-build-properties": "~0.12.5",
"expo-haptics": "~13.0.1",
"expo-image": "^1.12.13",
"expo-image": "^1.12.15",
"expo-linking": "~6.3.1",
"expo-media-library": "~16.0.4",
"expo-splash-screen": "~0.27.5",
@@ -71,6 +70,7 @@
"react-native-wheel-color-picker": "^1.2.0",
"rive-react-native": "^6.2.3",
"solid-js": "^1.8.8",
"supertokens-react-native": "^5.1.2",
"twrnc": "^4.1.0",
"use-count-up": "^3.0.1",
"use-debounce": "^9.0.4",

View File

@@ -17,6 +17,7 @@ import { Alert, LogBox, Permission, PermissionsAndroid, Platform } from 'react-n
import { GestureHandlerRootView } from 'react-native-gesture-handler';
import { MenuProvider } from 'react-native-popup-menu';
import { SafeAreaProvider } from 'react-native-safe-area-context';
import SuperTokens from 'supertokens-react-native';
import { useSnapshot } from 'valtio';
import {
ClientContextProvider,
@@ -24,7 +25,9 @@ import {
LibraryContextProvider,
P2PContextProvider,
RspcProvider,
useBridgeMutation,
useBridgeQuery,
useBridgeSubscription,
useClientContext,
useInvalidateQuery,
usePlausibleEvent,
@@ -33,12 +36,13 @@ import {
} from '@sd/client';
import { GlobalModals } from './components/modal/GlobalModals';
import { Toast, toastConfig } from './components/primitive/Toast';
import { toast, Toast, toastConfig } from './components/primitive/Toast';
import { useTheme } from './hooks/useTheme';
import { changeTwTheme, tw } from './lib/tailwind';
import RootNavigator from './navigation';
import OnboardingNavigator from './navigation/OnboardingNavigator';
import { P2P } from './screens/p2p/P2P';
import { AUTH_SERVER_URL } from './utils';
import { currentLibraryStore } from './utils/nav';
LogBox.ignoreLogs(['Sending `onAnimatedValueUpdate` with no listeners registered.']);
@@ -129,6 +133,41 @@ function AppContainer() {
useInvalidateQuery();
const { id } = useSnapshot(currentLibraryStore);
const userResponse = useBridgeMutation('cloud.userResponse');
useBridgeSubscription(['cloud.listenCloudServicesNotifications'], {
onData: (d) => {
console.log('Received cloud service notification', d);
switch (d.kind) {
case 'ReceivedJoinSyncGroupRequest':
// WARNING: This is a debug solution to accept the device into the sync group. THIS SHOULD NOT MAKE IT TO PRODUCTION
userResponse.mutate({
kind: 'AcceptDeviceInSyncGroup',
data: {
ticket: d.data.ticket,
accepted: {
id: d.data.sync_group.library.pub_id,
name: d.data.sync_group.library.name,
description: null
}
}
});
// TODO: Move the code above into the dialog below (@Rocky43007)
// dialogManager.create((dp) => (
// <RequestAddDialog
// device_model={'MacBookPro'}
// device_name={"Arnab's Macbook"}
// library_name={"Arnab's Library"}
// {...dp}
// />
// ));
break;
default:
toast.info(`Cloud Service Notification: ${d.kind}`);
break;
}
}
});
return (
<SafeAreaProvider style={tw`flex-1 bg-black`}>
@@ -156,6 +195,10 @@ export default function App() {
useEffect(() => {
global.Intl = require('intl');
require('intl/locale-data/jsonp/en'); //TODO(@Rocky43007): Setup a way to import all the languages we support, once we add localization on mobile.
SuperTokens.init({
apiDomain: AUTH_SERVER_URL,
apiBasePath: '/api/auth'
});
SplashScreen.hideAsync();
if (Platform.OS === 'android') {
(async () => {

View File

@@ -8,12 +8,13 @@ import { tw, twStyle } from '~/lib/tailwind';
type Props = {
route?: RouteProp<any, any>; // supporting title from the options object of navigation
navBack?: boolean; // whether to show the back icon
navBackTo?: string; // route to go back to
search?: boolean; // whether to show the search icon
title?: string; // in some cases - we want to override the route title
};
// Default header with search bar and button to open drawer
export default function Header({ route, navBack, title, search = false }: Props) {
export default function Header({ route, navBack, title, navBackTo, search = false }: Props) {
const navigation = useNavigation<DrawerNavigationHelpers>();
const headerHeight = useSafeAreaInsets().top;
const isAndroid = Platform.OS === 'android';
@@ -28,7 +29,13 @@ export default function Header({ route, navBack, title, search = false }: Props)
<View style={tw`w-full flex-row items-center justify-between`}>
<View style={tw`flex-row items-center gap-3`}>
{navBack ? (
<Pressable hitSlop={24} onPress={() => navigation.goBack()}>
<Pressable
hitSlop={24}
onPress={() => {
if (navBackTo) return navigation.navigate(navBackTo);
navigation.goBack();
}}
>
<ArrowLeft size={24} color={tw.color('ink')} />
</Pressable>
) : (

View File

@@ -2,13 +2,7 @@ import { BottomSheetFlatList } from '@gorhom/bottom-sheet';
import { NavigationProp, useNavigation } from '@react-navigation/native';
import { forwardRef } from 'react';
import { ActivityIndicator, Text, View } from 'react-native';
import {
CloudLibrary,
useBridgeMutation,
useBridgeQuery,
useClientContext,
useRspcContext
} from '@sd/client';
import { useBridgeMutation, useBridgeQuery, useClientContext, useRspcContext } from '@sd/client';
import { Modal, ModalRef } from '~/components/layout/Modal';
import { Button } from '~/components/primitive/Button';
import useForwardedRef from '~/hooks/useForwardedRef';
@@ -25,9 +19,9 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
const { libraries } = useClientContext();
const cloudLibraries = useBridgeQuery(['cloud.library.list']);
const cloudLibraries = useBridgeQuery(['cloud.libraries.list', true]);
const cloudLibrariesData = cloudLibraries.data?.filter(
(cloudLibrary) => !libraries.data?.find((l) => l.uuid === cloudLibrary.uuid)
(cloudLibrary) => !libraries.data?.find((l) => l.uuid === cloudLibrary.pub_id)
);
return (
@@ -63,11 +57,11 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
description="No cloud libraries available to join"
/>
}
keyExtractor={(item) => item.uuid}
keyExtractor={(item) => item.pub_id}
showsVerticalScrollIndicator={false}
renderItem={({ item }) => (
<CloudLibraryCard
data={item}
// data={item}
navigation={navigation}
modalRef={modalRef}
/>
@@ -81,38 +75,37 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
});
interface Props {
data: CloudLibrary;
// data: CloudLibrary;
modalRef: React.RefObject<ModalRef>;
navigation: NavigationProp<RootStackParamList>;
}
const CloudLibraryCard = ({ data, modalRef, navigation }: Props) => {
const CloudLibraryCard = ({ modalRef, navigation }: Props) => {
const rspc = useRspcContext().queryClient;
const joinLibrary = useBridgeMutation(['cloud.library.join']);
// const joinLibrary = useBridgeMutation(['cloud.library.join']);
return (
<View
key={data.uuid}
style={tw`flex flex-row items-center justify-between gap-2 rounded-md border border-app-box bg-app p-2`}
>
<Text numberOfLines={1} style={tw`max-w-[80%] text-sm font-bold text-ink`}>
{data.name}
{'BOB'}
</Text>
<Button
size="sm"
variant="accent"
disabled={joinLibrary.isPending}
// disabled={joinLibrary.isPending}
onPress={async () => {
const library = await joinLibrary.mutateAsync(data.uuid);
// const library = await joinLibrary.mutateAsync(data.uuid);
rspc.setQueryData(['library.list'], (libraries: any) => {
// The invalidation system beat us to it
if ((libraries || []).find((l: any) => l.uuid === library.uuid))
return libraries;
// rspc.setQueryData(['library.list'], (libraries: any) => {
// // The invalidation system beat us to it
// if ((libraries || []).find((l: any) => l.uuid === library.uuid))
// return libraries;
return [...(libraries || []), library];
});
// return [...(libraries || []), library];
// });
currentLibraryStore.id = library.uuid;
// currentLibraryStore.id = library.uuid;
navigation.navigate('Root', {
screen: 'Home',
@@ -128,9 +121,10 @@ const CloudLibraryCard = ({ data, modalRef, navigation }: Props) => {
}}
>
<Text style={tw`text-sm font-medium text-white`}>
{joinLibrary.isPending && joinLibrary.variables === data.uuid
{/* {joinLibrary.isPending && joinLibrary.variables === data.uuid
? 'Joining...'
: 'Join'}
: 'Join'} */}
THIS FILE NEEDS TO BE UPDATED TO USE THE NEW LIBRARY SYSTEM IN THE FUTURE
</Text>
</Button>
</View>

View File

@@ -0,0 +1,66 @@
import { ArrowRight } from 'phosphor-react-native';
import React, { forwardRef } from 'react';
import { Text, View } from 'react-native';
import { HardwareModel } from '@sd/client';
import { Icon } from '~/components/icons/Icon';
import { Modal, ModalRef } from '~/components/layout/Modal';
import { hardwareModelToIcon } from '~/components/overview/Devices';
import { Button } from '~/components/primitive/Button';
import useForwardedRef from '~/hooks/useForwardedRef';
import { twStyle } from '~/lib/tailwind';
interface Props {
device_name: string;
device_model: HardwareModel;
library_name: string;
}
const JoinRequestModal = forwardRef<ModalRef, Props>((props, ref) => {
const modalRef = useForwardedRef(ref);
return (
<Modal ref={modalRef} snapPoints={['36']} title="Sync request">
<View style={twStyle('px-6')}>
<Text style={twStyle('mx-auto mt-2 text-center text-ink-dull')}>
A device is requesting to join one of your libraries. Please review the device
and the library it is requesting to join below.
</Text>
<View style={twStyle('my-7 flex-row items-center justify-center gap-10')}>
<View style={twStyle('flex flex-col items-center justify-center gap-2')}>
<Icon
// once backend endpoint is populated need to check if this is working correctly i.e fetching correct icons for devices
name={hardwareModelToIcon(props.device_model)}
alt="Device icon"
size={48}
/>
<Text style={twStyle('text-sm font-bold text-ink')}>
{props.device_name}
</Text>
</View>
<ArrowRight weight="bold" color="#ABACBA" size={18} />
{/* library */}
<View style={twStyle('flex flex-col items-center justify-center gap-2')}>
<Icon
// once backend endpoint is populated need to check if this is working correctly i.e fetching correct icons for devices
name={'Book'}
alt="Device icon"
size={48}
/>
<Text style={twStyle('text-sm font-bold text-ink')}>
{props.library_name}
</Text>
</View>
</View>
<View style={twStyle('mx-auto flex-row justify-center gap-5')}>
<Button style={twStyle('flex-1')} variant="gray">
<Text style={twStyle('font-bold text-ink-dull')}>Cancel</Text>
</Button>
<Button style={twStyle('flex-1')} variant="accent">
<Text style={twStyle('font-bold text-ink')}>Accept</Text>
</Button>
</View>
</View>
</Modal>
);
});
export default JoinRequestModal;

View File

@@ -5,8 +5,9 @@ import React, { useEffect, useState } from 'react';
import { Platform, Text, View } from 'react-native';
import DeviceInfo from 'react-native-device-info';
import { ScrollView } from 'react-native-gesture-handler';
import { HardwareModel, NodeState, StatisticsResponse } from '@sd/client';
import { HardwareModel, NodeState, StatisticsResponse, useBridgeQuery } from '@sd/client';
import { tw, twStyle } from '~/lib/tailwind';
import { getTokens } from '~/utils';
import Fade from '../layout/Fade';
import { Button } from '../primitive/Button';
@@ -44,6 +45,23 @@ const Devices = ({ node, stats }: Props) => {
Omit<RNFS.FSInfoResultT, 'totalSpaceEx' | 'freeSpaceEx'>
>({ freeSpace: 0, totalSpace: 0 });
const [deviceName, setDeviceName] = useState<string>('');
const [accessToken, setAccessToken] = useState<string>('');
useEffect(() => {
(async () => {
const at = await getTokens();
setAccessToken(at.accessToken);
})();
}, []);
const devices = useBridgeQuery(['cloud.devices.list']);
// Refetch devices every 10 seconds
useEffect(() => {
const interval = setInterval(async () => {
await devices.refetch();
}, 10000);
return () => clearInterval(interval);
}, []);
useEffect(() => {
const getFSInfo = async () => {
@@ -74,7 +92,7 @@ const Devices = ({ node, stats }: Props) => {
}, [node]);
return (
<OverviewSection title="Devices" count={node ? 1 : 0}>
<OverviewSection title="Devices" count={node ? 1 + (devices.data?.length ?? 0) : 0}>
<View>
<Fade height={'100%'} width={30} color="black">
<ScrollView
@@ -93,6 +111,18 @@ const Devices = ({ node, stats }: Props) => {
connectionType={null}
/>
)}
{devices.data?.map((device) => (
<StatCard
key={device.pub_id}
name={device.name}
// TODO (Optional): Use Brand Type for Different Android Models/iOS Models using DeviceInfo.getBrand()
icon={hardwareModelToIcon(device.hardware_model)}
totalSpace={'0'}
freeSpace={'0'}
color="#0362FF"
connectionType={'cloud'}
/>
))}
<NewCard
icons={['Laptop', 'Server', 'SilverBox', 'Tablet']}
text="Spacedrive works best on all your devices."

View File

@@ -4,6 +4,8 @@ import { CompositeScreenProps } from '@react-navigation/native';
import { createNativeStackNavigator, NativeStackScreenProps } from '@react-navigation/native-stack';
import Header from '~/components/header/Header';
import SearchHeader from '~/components/header/SearchHeader';
import AccountLogin from '~/screens/settings/client/AccountSettings/AccountLogin';
import AccountProfile from '~/screens/settings/client/AccountSettings/AccountProfile';
import AppearanceSettingsScreen from '~/screens/settings/client/AppearanceSettings';
import ExtensionsSettingsScreen from '~/screens/settings/client/ExtensionsSettings';
import GeneralSettingsScreen from '~/screens/settings/client/GeneralSettings';
@@ -12,12 +14,10 @@ import PrivacySettingsScreen from '~/screens/settings/client/PrivacySettings';
import AboutScreen from '~/screens/settings/info/About';
import DebugScreen from '~/screens/settings/info/Debug';
import SupportScreen from '~/screens/settings/info/Support';
import CloudSettings from '~/screens/settings/library/CloudSettings/CloudSettings';
import EditLocationSettingsScreen from '~/screens/settings/library/EditLocationSettings';
import LibraryGeneralSettingsScreen from '~/screens/settings/library/LibraryGeneralSettings';
import LocationSettingsScreen from '~/screens/settings/library/LocationSettings';
import NodesSettingsScreen from '~/screens/settings/library/NodesSettings';
import SyncSettingsScreen from '~/screens/settings/library/SyncSettings';
import TagsSettingsScreen from '~/screens/settings/library/TagsSettings';
import SettingsScreen from '~/screens/settings/Settings';
@@ -46,6 +46,16 @@ export default function SettingsStack() {
component={GeneralSettingsScreen}
options={{ header: () => <Header navBack title="General" /> }}
/>
<Stack.Screen
name="AccountLogin"
component={AccountLogin}
options={{ header: () => <Header navBackTo="Settings" navBack title="Account" /> }}
/>
<Stack.Screen
name="AccountProfile"
component={AccountProfile}
options={{ header: () => <Header navBackTo="Settings" navBack title="Account" /> }}
/>
<Stack.Screen
name="LibrarySettings"
component={LibrarySettingsScreen}
@@ -94,16 +104,6 @@ export default function SettingsStack() {
component={TagsSettingsScreen}
options={{ header: () => <Header navBack title="Tags" /> }}
/>
<Stack.Screen
name="SyncSettings"
component={SyncSettingsScreen}
options={{ header: () => <Header navBack title="Sync" /> }}
/>
<Stack.Screen
name="CloudSettings"
component={CloudSettings}
options={{ header: () => <Header navBack title="Cloud" /> }}
/>
{/* <Stack.Screen
name="KeysSettings"
component={KeysSettingsScreen}
@@ -134,6 +134,8 @@ export type SettingsStackParamList = {
Settings: undefined;
// Client
GeneralSettings: undefined;
AccountLogin: undefined;
AccountProfile: undefined;
LibrarySettings: undefined;
AppearanceSettings: undefined;
PrivacySettings: undefined;

View File

@@ -12,9 +12,9 @@ import {
PuzzlePiece,
ShareNetwork,
ShieldCheck,
TagSimple
TagSimple,
UserCircle
} from 'phosphor-react-native';
import React from 'react';
import { Platform, SectionList, Text, TouchableWithoutFeedback, View } from 'react-native';
import { DebugState, useDebugState, useDebugStateEnabler, useLibraryQuery } from '@sd/client';
import ScreenContainer from '~/components/layout/ScreenContainer';
@@ -22,6 +22,7 @@ import { SettingsItem } from '~/components/settings/SettingsItem';
import { useEnableDrawer } from '~/hooks/useEnableDrawer';
import { tw, twStyle } from '~/lib/tailwind';
import { SettingsStackParamList, SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
import { useUserStore } from '~/stores/userStore';
type SectionType = {
title: string;
@@ -34,7 +35,10 @@ type SectionType = {
}[];
};
const sections: (debugState: DebugState) => SectionType[] = (debugState) => [
const sections: (
debugState: DebugState,
userInfo: ReturnType<typeof useUserStore>['userInfo']
) => SectionType[] = (debugState, userInfo) => [
{
title: 'Client',
data: [
@@ -44,6 +48,21 @@ const sections: (debugState: DebugState) => SectionType[] = (debugState) => [
title: 'General',
rounded: 'top'
},
...(userInfo
? ([
{
icon: UserCircle,
navigateTo: 'AccountProfile',
title: 'Account'
}
] as const)
: ([
{
icon: UserCircle,
navigateTo: 'AccountLogin',
title: 'Account'
}
] as const)),
{
icon: Books,
navigateTo: 'LibrarySettings',
@@ -158,12 +177,16 @@ function renderSectionHeader({ section }: { section: { title: string } }) {
export default function SettingsScreen({ navigation }: SettingsStackScreenProps<'Settings'>) {
const debugState = useDebugState();
const syncEnabled = useLibraryQuery(['sync.enabled']);
const userInfo = useUserStore().userInfo;
// Enables the drawer from react-navigation
useEnableDrawer();
return (
<ScreenContainer tabHeight={false} style={tw`gap-0 px-5 py-0`}>
<SectionList
contentContainerStyle={tw`py-6`}
sections={sections(debugState)}
sections={sections(debugState, userInfo)}
renderItem={({ item }) => (
<SettingsItem
syncEnabled={syncEnabled.data}

View File

@@ -0,0 +1,181 @@
import { MotiView } from 'moti';
import { AppleLogo, GithubLogo, GoogleLogo, IconProps } from 'phosphor-react-native';
import { useEffect, useState } from 'react';
import { Text, View } from 'react-native';
import { LinearTransition } from 'react-native-reanimated';
import Card from '~/components/layout/Card';
import ScreenContainer from '~/components/layout/ScreenContainer';
import { Button } from '~/components/primitive/Button';
import { tw, twStyle } from '~/lib/tailwind';
import { getUserStore, useUserStore } from '~/stores/userStore';
import { AUTH_SERVER_URL } from '~/utils';
import Login from './Login';
import Register from './Register';
const AccountTabs = ['Login', 'Register'] as const;
type SocialLogin = {
name: 'Github' | 'Google' | 'Apple';
icon: React.FC<IconProps>;
};
const SocialLogins: SocialLogin[] = [
{ name: 'Github', icon: GithubLogo },
{ name: 'Google', icon: GoogleLogo },
{ name: 'Apple', icon: AppleLogo }
];
const AccountLogin = () => {
const [activeTab, setActiveTab] = useState<'Login' | 'Register'>('Login');
const userInfo = useUserStore().userInfo;
useEffect(() => {
if (userInfo) return; //no need to check if user info is already present
async function _() {
const user_data = await fetch(`${AUTH_SERVER_URL}/api/user`, {
method: 'GET'
});
const data = await user_data.json();
if (data.message !== 'unauthorised') {
getUserStore().userInfo = data;
}
}
_();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
// FIXME: Currently opens in App.
// const socialLoginHandlers = (name: SocialLogin['name']) => {
// return {
// Github: async () => {
// try {
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
// thirdPartyId: 'github',
// // This is where Github should redirect the user back after login or error.
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/github'
// });
// // we redirect the user to Github for auth.
// window.location.assign(authUrl);
// } catch (err: any) {
// if (err.isSuperTokensGeneralError === true) {
// // this may be a custom error message sent from the API by you.
// toast.error(err.message);
// } else {
// toast.error('Oops! Something went wrong.');
// }
// }
// },
// Google: async () => {
// try {
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
// thirdPartyId: 'google',
// // This is where Google should redirect the user back after login or error.
// // This URL goes on the Google's dashboard as well.
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/google'
// });
// /*
// Example value of authUrl: https://accounts.google.com/o/oauth2/v2/auth/oauthchooseaccount?scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&access_type=offline&include_granted_scopes=true&response_type=code&client_id=1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com&state=5a489996a28cafc83ddff&redirect_uri=https%3A%2F%2Fsupertokens.io%2Fdev%2Foauth%2Fredirect-to-app&flowName=GeneralOAuthFlow
// */
// // we redirect the user to google for auth.
// window.location.assign(authUrl);
// } catch (err: any) {
// if (err.isSuperTokensGeneralError === true) {
// // this may be a custom error message sent from the API by you.
// toast.error(err.message);
// } else {
// toast.error('Oops! Something went wrong.');
// }
// }
// },
// Apple: async () => {
// try {
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
// thirdPartyId: 'apple',
// // This is where Apple should redirect the user back after login or error.
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/apple'
// });
// // we redirect the user to Apple for auth.
// window.location.assign(authUrl);
// } catch (err: any) {
// if (err.isSuperTokensGeneralError === true) {
// // this may be a custom error message sent from the API by you.
// toast.error(err.message);
// } else {
// toast.error('Oops! Something went wrong.');
// }
// }
// }
// }[name]();
// };
return (
<ScreenContainer scrollview={false} style={tw`gap-2 px-6`}>
<View style={tw`flex flex-col justify-between gap-5 lg:flex-row`}>
<Card style={tw`relative flex w-full flex-col items-center justify-center`}>
<View style={tw`flex w-full flex-row gap-x-1.5`}>
{AccountTabs.map((text) => (
<Button
key={text}
onPress={() => {
setActiveTab(text);
}}
style={twStyle(
'relative flex-1 border-b border-app-line/50 p-2 text-center',
text === 'Login' ? 'rounded-tl-md' : 'rounded-tr-md'
)}
>
<Text
style={twStyle(
'relative z-10 text-sm',
text === activeTab ? 'font-bold text-ink' : 'text-ink-faint'
)}
>
{text}
</Text>
{text === activeTab && (
<MotiView
animate={{
borderRadius: text === 'Login' ? 0.3 : 0
}}
layout={LinearTransition.duration(200)}
style={tw`absolute inset-x-0 top-0 z-0 bg-app-line/60`}
/>
)}
</Button>
))}
</View>
<View style={tw`mt-3 flex w-full flex-col justify-center gap-1.5`}>
{activeTab === 'Login' ? <Login /> : <Register />}
{/* Disabled for now */}
{/* <View style={tw`flex items-center w-full gap-3 my-2`}>
<Divider />
<Text style={tw`text-xs text-ink-faint`}>OR</Text>
<Divider />
</View>
<View style={tw`flex justify-center gap-3`}>
{SocialLogins.map((social) => (
<Button
variant="outline"
onPress={async () => await socialLoginHandlers(social.name)}
key={social.name}
style={tw`p-3 border rounded-full border-app-line bg-app-input`}
>
<social.icon style={tw`text-white`} weight="bold" />
</Button>
))}
</View> */}
</View>
</Card>
</View>
</ScreenContainer>
);
};
export default AccountLogin;

View File

@@ -0,0 +1,179 @@
import { useNavigation } from '@react-navigation/native';
import { Envelope } from 'phosphor-react-native';
import { useEffect, useState } from 'react';
import { Text, View } from 'react-native';
import {
SyncStatus,
useBridgeMutation,
useBridgeQuery,
useLibraryMutation,
useLibrarySubscription
} from '@sd/client';
import Card from '~/components/layout/Card';
import ScreenContainer from '~/components/layout/ScreenContainer';
import { Button } from '~/components/primitive/Button';
import { tw, twStyle } from '~/lib/tailwind';
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
import { getUserStore, useUserStore } from '~/stores/userStore';
import { AUTH_SERVER_URL, getTokens } from '~/utils';
const AccountProfile = () => {
const userInfo = useUserStore().userInfo;
const emailName = userInfo ? userInfo.email.split('@')[0] : '';
const capitalizedEmailName = (emailName?.charAt(0).toUpperCase() ?? '') + emailName?.slice(1);
const navigator = useNavigation<SettingsStackScreenProps<'AccountLogin'>['navigation']>();
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
const devices = useBridgeQuery(['cloud.devices.list']);
const addLibraryToCloud = useLibraryMutation('cloud.libraries.create');
const listLibraries = useBridgeQuery(['cloud.libraries.list', true]);
const createSyncGroup = useLibraryMutation('cloud.syncGroups.create');
const listSyncGroups = useBridgeQuery(['cloud.syncGroups.list']);
const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join');
const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']);
const [{ accessToken, refreshToken }, setTokens] = useState<{
accessToken: string;
refreshToken: string;
}>({
accessToken: '',
refreshToken: ''
});
useEffect(() => {
(async () => {
const { accessToken, refreshToken } = await getTokens();
setTokens({ accessToken, refreshToken });
})();
}, []);
const [syncStatus, setSyncStatus] = useState<SyncStatus | null>(null);
useLibrarySubscription(['sync.active'], {
onData: (data) => {
console.log('sync activity', data);
setSyncStatus(data);
}
});
async function signOut() {
await fetch(`${AUTH_SERVER_URL}/api/auth/signout`, {
method: 'POST'
});
navigator.navigate('AccountLogin');
getUserStore().userInfo = undefined;
}
return (
<ScreenContainer scrollview={false} style={tw`gap-2 px-6`}>
<View style={tw`flex flex-col justify-between gap-5 lg:flex-row`}>
<Card
style={tw`relative flex w-full flex-col items-center justify-center lg:max-w-[320px]`}
>
<View style={tw`w-full`}>
<Text style={tw`mx-auto mt-3 text-lg text-white`}>
Welcome{' '}
<Text style={tw`font-bold text-white`}>{capitalizedEmailName}</Text>
</Text>
<Card
style={tw`mt-4 flex-row items-center gap-2 overflow-hidden border-app-inputborder bg-app-input`}
>
<Envelope weight="fill" size={20} color="white" />
<Text numberOfLines={1} style={tw`max-w-[90%] text-white`}>
{userInfo ? userInfo.email : ''}
</Text>
</Card>
<Button variant="danger" style={tw`mt-3`} onPress={signOut}>
<Text style={tw`font-bold text-white`}>Sign out</Text>
</Button>
</View>
</Card>
{/* Sync activity */}
<View style={tw`mt-5 flex flex-col`}>
<Text style={tw`mb-2 text-md font-semibold`}>Sync Activity</Text>
<View style={tw`flex flex-row gap-2`}>
{Object.keys(syncStatus ?? {}).map((status, index) => (
<Card key={index} style="flex w-full items-center p-4">
<View
style={twStyle(
'mr-2 size-[15px] rounded-full bg-app-box',
syncStatus?.[status as keyof SyncStatus]
? 'bg-accent'
: 'bg-app-input'
)}
/>
<Text style={tw`text-sm font-semibold`}>{status}</Text>
</Card>
))}
</View>
</View>
{/* Automatically list libraries */}
<View style={tw`mt-5 flex flex-col gap-3`}>
<Text style={tw`text-md font-semibold text-white`}>Cloud Libraries</Text>
{listLibraries.data?.map((library) => (
<Card key={library.pub_id} style={tw`p-41 w-full`}>
<Text style={tw`text-sm font-semibold text-white`}>{library.name}</Text>
</Card>
)) || <Text style={tw`text-white`}>No libraries found.</Text>}
</View>
{/* Debug buttons */}
<Card style={tw`flex gap-2 text-white`}>
<Button
variant="gray"
onPress={async () => {
cloudBootstrap.mutate([accessToken.trim(), refreshToken.trim()]);
}}
>
<Text style={tw`text-white`}>Start Cloud Bootstrap</Text>
</Button>
<Button
variant="gray"
onPress={async () => {
addLibraryToCloud.mutate(null);
}}
>
<Text style={tw`text-white`}>Add Library to Cloud</Text>
</Button>
<Button
variant="gray"
onPress={async () => {
createSyncGroup.mutate(null);
}}
>
<Text style={tw`text-white`}>Create Sync Group</Text>
</Button>
</Card>
<View style={tw`mt-5 flex flex-col gap-3 text-white`}>
<Text style={tw`text-md font-semibold`}>Library Sync Groups</Text>
{listSyncGroups.data?.map((group) => (
<Card key={group.pub_id} style="w-full p-4">
<Text style={tw`text-sm font-semibold text-white`}>
{group.library.name}
</Text>
<Button
style={tw`mt-2`}
onPress={async () => {
if (!currentDevice.data) await currentDevice.refetch();
if (currentDevice.data && devices.data) {
requestJoinSyncGroup.mutate({
asking_device: currentDevice.data,
sync_group: {
devices: devices.data,
...group
}
});
}
}}
>
<Text style={tw`text-white`}>Join Sync Group</Text>
</Button>
</Card>
)) || <Text style={tw`text-white`}>No sync groups found.</Text>}
</View>
</View>
</ScreenContainer>
);
};
export default AccountProfile;

View File

@@ -0,0 +1,195 @@
import AsyncStorage from '@react-native-async-storage/async-storage';
import { useNavigation } from '@react-navigation/native';
import { RSPCError } from '@spacedrive/rspc-client';
import { UseMutationResult } from '@tanstack/react-query';
import { useState } from 'react';
import { Controller } from 'react-hook-form';
import { Text, View } from 'react-native';
import { z } from 'zod';
import { useBridgeMutation, useZodForm } from '@sd/client';
import { Button } from '~/components/primitive/Button';
import { Input } from '~/components/primitive/Input';
import { toast } from '~/components/primitive/Toast';
import { tw, twStyle } from '~/lib/tailwind';
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
import { getUserStore } from '~/stores/userStore';
import { AUTH_SERVER_URL } from '~/utils';
import ShowPassword from './ShowPassword';
const LoginSchema = z.object({
email: z.string().email({
message: 'Email is required'
}),
password: z.string().min(6, {
message: 'Password must be at least 6 characters'
})
});
const Login = () => {
const [showPassword, setShowPassword] = useState(false);
const form = useZodForm({
schema: LoginSchema,
defaultValues: {
email: '',
password: ''
}
});
const updateUserStore = getUserStore();
const navigator = useNavigation<SettingsStackScreenProps<'AccountProfile'>['navigation']>();
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
return (
<View>
<View style={tw`flex flex-col gap-1.5`}>
<Controller
control={form.control}
name="email"
render={({ field }) => (
<View style={tw`relative flex items-start`}>
<Input
{...field}
placeholder="Email"
style={twStyle(
`w-full`,
form.formState.errors.email && 'border-red-500'
)}
onChangeText={field.onChange}
/>
{form.formState.errors.email && (
<Text style={tw`my-1 text-xs text-red-500`}>
{form.formState.errors.email.message}
</Text>
)}
</View>
)}
/>
<Controller
control={form.control}
name="password"
render={({ field }) => (
<View style={tw`relative flex items-start`}>
<Input
{...field}
placeholder="Password"
style={twStyle(
`w-full`,
form.formState.errors.password && 'border-red-500'
)}
onChangeText={field.onChange}
secureTextEntry={!showPassword}
/>
{form.formState.errors.password && (
<Text style={tw`my-1 text-xs text-red-500`}>
{form.formState.errors.password.message}
</Text>
)}
<ShowPassword
showPassword={showPassword}
setShowPassword={setShowPassword}
/>
</View>
)}
/>
<Button
style={tw`mx-auto mt-2 w-full`}
variant="accent"
onPress={form.handleSubmit(async (data) => {
await signInClicked(
data.email,
data.password,
navigator,
cloudBootstrap,
updateUserStore
);
})}
disabled={form.formState.isSubmitting}
>
<Text style={tw`font-bold text-white`}>Submit</Text>
</Button>
</View>
</View>
);
};
async function signInClicked(
email: string,
password: string,
navigator: SettingsStackScreenProps<'AccountProfile'>['navigation'],
cloudBootstrap: UseMutationResult<null, RSPCError, [string, string], unknown>, // Cloud bootstrap mutation
updateUserStore: ReturnType<typeof getUserStore>
) {
try {
const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signin`, {
method: 'POST',
headers: {
'Content-Type': 'application/json; charset=utf-8'
},
body: JSON.stringify({
formFields: [
{
id: 'email',
value: email
},
{
id: 'password',
value: password
}
]
})
});
const response: {
status: string;
reason?: string;
user?: {
id: string;
email: string;
timeJoined: number;
tenantIds: string[];
};
} = await req.json();
if (response.status === 'FIELD_ERROR') {
// response.reason?.forEach((formField) => {
// if (formField.id === 'email') {
// // Email validation failed (for example incorrect email syntax).
// toast.error(formField.error);
// }
// });
console.error('Field error: ', response.reason);
} else if (response.status === 'WRONG_CREDENTIALS_ERROR') {
toast.error('Email & password combination is incorrect.');
} else if (response.status === 'SIGN_IN_NOT_ALLOWED') {
// the reason string is a user friendly message
// about what went wrong. It can also contain a support code which users
// can tell you so you know why their sign in was not allowed.
toast.error(response.reason!);
} else {
// sign in successful. The session tokens are automatically handled by
// the frontend SDK.
cloudBootstrap.mutate([
req.headers.get('st-access-token')!,
req.headers.get('st-refresh-token')!
]);
toast.success('Sign in successful');
// Update the user store with the user info
updateUserStore.userInfo = response.user;
// Save the access token to AsyncStorage, because SuperTokens doesn't store it correctly. Thanks to the React Native SDK.
await AsyncStorage.setItem('access_token', req.headers.get('st-access-token')!);
await AsyncStorage.setItem('refresh_token', req.headers.get('st-refresh-token')!);
// Refresh the page to show the user is logged in
navigator.navigate('AccountProfile');
}
} catch (err: any) {
if (err.isSuperTokensGeneralError === true) {
// this may be a custom error message sent from the API by you.
toast.error(err.message);
} else {
console.error(err);
toast.error('Oops! Something went wrong.');
}
}
}
export default Login;

View File

@@ -0,0 +1,191 @@
import { zodResolver } from '@hookform/resolvers/zod';
import { useNavigation } from '@react-navigation/native';
import { useState } from 'react';
import { Controller, useForm } from 'react-hook-form';
import { Text, View } from 'react-native';
import { z } from 'zod';
import { Button } from '~/components/primitive/Button';
import { Input } from '~/components/primitive/Input';
import { toast } from '~/components/primitive/Toast';
import { tw, twStyle } from '~/lib/tailwind';
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
import { AUTH_SERVER_URL } from '~/utils';
import ShowPassword from './ShowPassword';
const RegisterSchema = z
.object({
email: z.string().email({
message: 'Email is required'
}),
password: z.string().min(6, {
message: 'Password must be at least 6 characters'
}),
confirmPassword: z.string().min(6, {
message: 'Password must be at least 6 characters'
})
})
.refine((data) => data.password === data.confirmPassword, {
message: 'Passwords do not match',
path: ['confirmPassword']
});
type RegisterType = z.infer<typeof RegisterSchema>;
const Register = () => {
const [showPassword, setShowPassword] = useState(false);
// useZodForm seems to be out-dated or needs
//fixing as it does not support the schema using zod.refine
const form = useForm<RegisterType>({
resolver: zodResolver(RegisterSchema),
defaultValues: {
email: '',
password: '',
confirmPassword: ''
}
});
const navigator = useNavigation<SettingsStackScreenProps<'AccountProfile'>['navigation']>();
return (
<View style={tw`flex flex-col gap-1.5`}>
<Controller
control={form.control}
name="email"
render={({ field }) => (
<Input
{...field}
style={twStyle(`w-full`, form.formState.errors.email && 'border-red-500')}
placeholder="Email"
onChangeText={field.onChange}
/>
)}
/>
{form.formState.errors.email && (
<Text style={tw`text-xs text-red-500`}>{form.formState.errors.email.message}</Text>
)}
<Controller
control={form.control}
name="password"
render={({ field }) => (
<View style={tw`relative flex items-center justify-center`}>
<Input
{...field}
placeholder="Password"
style={twStyle(
`w-full`,
form.formState.errors.password && 'border-red-500'
)}
onChangeText={field.onChange}
secureTextEntry={!showPassword}
/>
</View>
)}
/>
{form.formState.errors.password && (
<Text style={tw`text-xs text-red-500`}>
{form.formState.errors.password.message}
</Text>
)}
<Controller
control={form.control}
name="confirmPassword"
render={({ field }) => (
<View style={tw`relative flex items-start`}>
<Input
{...field}
placeholder="Confirm Password"
style={twStyle(
`w-full`,
form.formState.errors.confirmPassword && 'border-red-500'
)}
onChangeText={field.onChange}
secureTextEntry={!showPassword}
/>
{form.formState.errors.confirmPassword && (
<Text style={tw`my-1 text-xs text-red-500`}>
{form.formState.errors.confirmPassword.message}
</Text>
)}
<ShowPassword
showPassword={showPassword}
setShowPassword={setShowPassword}
plural={true}
/>
</View>
)}
/>
<Button
style={tw`mx-auto mt-2 w-full`}
variant="accent"
onPress={form.handleSubmit(
async (data) => await signUpClicked(data.email, data.password, navigator)
)}
disabled={form.formState.isSubmitting}
>
<Text style={tw`font-bold text-white`}>Submit</Text>
</Button>
</View>
);
};
async function signUpClicked(
email: string,
password: string,
navigator: SettingsStackScreenProps<'AccountProfile'>['navigation']
) {
try {
const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signup`, {
method: 'POST',
headers: {
'Content-Type': 'application/json; charset=utf-8'
},
body: JSON.stringify({
formFields: [
{
id: 'email',
value: email
},
{
id: 'password',
value: password
}
]
})
});
const response: {
status: string;
reason?: string;
user?: {
id: string;
email: string;
timeJoined: number;
tenantIds: string[];
};
} = await req.json();
if (response.status === 'FIELD_ERROR') {
// one of the input formFields failed validaiton
console.error('Field error: ', response.reason);
} else if (response.status === 'SIGN_UP_NOT_ALLOWED') {
// the reason string is a user friendly message
// about what went wrong. It can also contain a support code which users
// can tell you so you know why their sign up was not allowed.
toast.error(response.reason!);
} else {
// sign up successful. The session tokens are automatically handled by
// the frontend SDK.
toast.success('Sign up successful');
navigator.navigate('AccountProfile');
}
} catch (err: any) {
if (err.isSuperTokensGeneralError === true) {
// this may be a custom error message sent from the API by you.
toast.error(err.message);
} else {
console.error(err);
toast.error('Oops! Something went wrong.');
}
}
}
export default Register;

View File

@@ -0,0 +1,29 @@
import { Eye, EyeClosed } from 'phosphor-react-native';
import { Text } from 'react-native';
import { Button } from '~/components/primitive/Button';
import { tw } from '~/lib/tailwind';
interface Props {
showPassword: boolean;
setShowPassword: (value: boolean) => void;
plural?: boolean;
}
const ShowPassword = ({ showPassword, setShowPassword, plural }: Props) => {
return (
<Button
variant="gray"
style={tw`mt-1.5 flex w-full flex-row items-center justify-center gap-2`}
onPressIn={() => setShowPassword(!showPassword)}
>
{!showPassword ? (
<EyeClosed size={14} color="white" />
) : (
<Eye size={14} color="white" />
)}
<Text style={tw`font-bold text-ink`}>Show Password{plural ? 's' : ''}</Text>
</Button>
);
};
export default ShowPassword;

View File

@@ -1,26 +1,52 @@
import { useQueryClient } from '@tanstack/react-query';
import React from 'react';
import { Text, View } from 'react-native';
import {
auth,
toggleFeatureFlag,
useBridgeMutation,
useBridgeQuery,
useDebugState,
useFeatureFlags
useFeatureFlags,
useLibraryMutation
} from '@sd/client';
import Card from '~/components/layout/Card';
import { Button } from '~/components/primitive/Button';
import { tw } from '~/lib/tailwind';
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
import { getTokens } from '~/utils';
const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
const debugState = useDebugState();
const featureFlags = useFeatureFlags();
const origin = useBridgeQuery(['cloud.getApiOrigin']);
const setOrigin = useBridgeMutation(['cloud.setApiOrigin']);
const [tokens, setTokens] = React.useState({ accessToken: '', refreshToken: '' });
const accessToken = tokens.accessToken;
const refreshToken = tokens.refreshToken;
// const origin = useBridgeQuery(['cloud.getApiOrigin']);
// const setOrigin = useBridgeMutation(['cloud.setApiOrigin']);
const queryClient = useQueryClient();
React.useEffect(() => {
async function _() {
const _a = await getTokens();
setTokens({ accessToken: _a.accessToken, refreshToken: _a.refreshToken });
}
_();
}, []);
const cloudBootstrap = useBridgeMutation(['cloud.bootstrap']);
const addLibraryToCloud = useLibraryMutation('cloud.libraries.create');
const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join');
const getGroup = useBridgeQuery([
'cloud.syncGroups.get',
{
pub_id: '01924497-a1be-76e3-b62f-9582ea15463a',
// pub_id: '01924a25-966b-7c00-a582-9eed3aadd2cd',
kind: 'WithDevices'
}
]);
// console.log(getGroup.data);
const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']);
// console.log('Current Device: ', currentDevice.data);
const createSyncGroup = useLibraryMutation('cloud.syncGroups.create');
// const queryClient = useQueryClient();
return (
<View style={tw`flex-1 p-4`}>
@@ -31,7 +57,7 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
</Button>
<Text style={tw`text-ink`}>{JSON.stringify(featureFlags)}</Text>
<Text style={tw`text-ink`}>{JSON.stringify(debugState)}</Text>
<Button
{/* <Button
onPress={() => {
navigation.popToTop();
navigation.replace('Settings');
@@ -39,13 +65,13 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
}}
>
<Text style={tw`text-ink`}>Disable Debug Mode</Text>
</Button>
<Button
</Button> */}
{/* <Button
onPress={() => {
const url =
origin.data === 'https://app.spacedrive.com'
origin.data === 'https://api.spacedrive.com'
? 'http://localhost:3000'
: 'https://app.spacedrive.com';
: 'https://api.spacedrive.com';
setOrigin.mutateAsync(url).then(async () => {
await auth.logout();
await queryClient.invalidateQueries();
@@ -53,7 +79,7 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
}}
>
<Text style={tw`text-ink`}>Toggle API Route ({origin.data})</Text>
</Button>
</Button> */}
<Button
onPress={() => {
navigation.popToTop();
@@ -64,12 +90,53 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
>
<Text style={tw`text-ink`}>Go to Backfill Waiting Page</Text>
</Button>
<Button
{/* <Button
onPress={async () => {
await auth.logout();
}}
>
<Text style={tw`text-ink`}>Logout</Text>
</Button> */}
<Button
onPress={async () => {
const tokens = await getTokens();
cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]);
}}
>
<Text style={tw`text-ink`}>Cloud Bootstrap</Text>
</Button>
<Button
onPress={async () => {
addLibraryToCloud.mutate(null);
}}
>
<Text style={tw`text-ink`}>Add Library to Cloud</Text>
</Button>
<Button
onPress={async () => {
createSyncGroup.mutate(null);
}}
>
<Text style={tw`text-ink`}>Create Sync Group</Text>
</Button>
<Button
onPress={async () => {
if (
currentDevice.data &&
getGroup.data &&
getGroup.data.kind === 'WithDevices'
) {
currentDevice.refetch();
console.log('Current Device: ', currentDevice.data);
console.log('Get Group: ', getGroup.data.data);
requestJoinSyncGroup.mutate({
sync_group: getGroup.data.data,
asking_device: currentDevice.data
});
}
}}
>
<Text style={tw`text-ink`}>Request Join Sync Group</Text>
</Button>
</Card>
</View>

View File

@@ -1,130 +0,0 @@
import { useMemo } from 'react';
import { ActivityIndicator, FlatList, Text, View } from 'react-native';
import { useLibraryContext, useLibraryMutation, useLibraryQuery } from '@sd/client';
import { Icon } from '~/components/icons/Icon';
import Card from '~/components/layout/Card';
import Empty from '~/components/layout/Empty';
import ScreenContainer from '~/components/layout/ScreenContainer';
import VirtualizedListWrapper from '~/components/layout/VirtualizedListWrapper';
import { Button } from '~/components/primitive/Button';
import { Divider } from '~/components/primitive/Divider';
import { styled, tw, twStyle } from '~/lib/tailwind';
import { useAuthStateSnapshot } from '~/stores/auth';
import Instance from './Instance';
import Library from './Library';
import Login from './Login';
import ThisInstance from './ThisInstance';
export const InfoBox = styled(View, 'rounded-md border gap-1 border-app bg-transparent p-2');
const CloudSettings = () => {
return (
<ScreenContainer scrollview={false} style={tw`gap-0 px-6 py-0`}>
<AuthSensitiveChild />
</ScreenContainer>
);
};
const AuthSensitiveChild = () => {
const authState = useAuthStateSnapshot();
if (authState.status === 'loggedIn') return <Authenticated />;
if (authState.status === 'notLoggedIn' || authState.status === 'loggingIn') return <Login />;
return null;
};
const Authenticated = () => {
const { library } = useLibraryContext();
const cloudLibrary = useLibraryQuery(['cloud.library.get'], { retry: false });
const createLibrary = useLibraryMutation(['cloud.library.create']);
const cloudInstances = useMemo(
() =>
cloudLibrary.data?.instances.filter(
(instance) => instance.uuid !== library.instance_id
),
[cloudLibrary.data, library.instance_id]
);
if (cloudLibrary.isLoading) {
return (
<View style={tw`flex-1 items-center justify-center`}>
<ActivityIndicator size="small" />
</View>
);
}
return (
<ScreenContainer
scrollview={Boolean(cloudLibrary.data)}
style={tw`gap-0`}
tabHeight={false}
>
{cloudLibrary.data ? (
<View style={tw`flex-col items-start gap-5`}>
<Library cloudLibrary={cloudLibrary.data} />
<ThisInstance cloudLibrary={cloudLibrary.data} />
<Card style={tw`w-full`}>
<View style={tw`flex-row items-center gap-2`}>
<View
style={tw`self-start rounded border border-app-lightborder bg-app-highlight px-1.5 py-[2px]`}
>
<Text style={tw`text-xs font-semibold text-ink`}>
{cloudInstances?.length}
</Text>
</View>
<Text style={tw`font-semibold text-ink`}>Instances</Text>
</View>
<Divider style={tw`mb-4 mt-2`} />
<VirtualizedListWrapper
scrollEnabled={false}
contentContainerStyle={tw`flex-1`}
horizontal
>
<FlatList
data={cloudInstances}
scrollEnabled={false}
ListEmptyComponent={
<Empty textStyle={tw`my-0`} description="No instances found" />
}
contentContainerStyle={twStyle(
cloudInstances?.length === 0 && 'flex-row'
)}
showsHorizontalScrollIndicator={false}
ItemSeparatorComponent={() => <View style={tw`h-2`} />}
renderItem={({ item }) => <Instance data={item} />}
keyExtractor={(item) => item.id}
numColumns={1}
/>
</VirtualizedListWrapper>
</Card>
</View>
) : (
<View style={tw`flex-1 justify-center`}>
<Card style={tw`relative p-6`}>
<Icon style={tw`mx-auto mb-2`} name="CloudSync" size={64} />
<Text style={tw`mx-auto text-center text-sm text-ink`}>
Uploading your library to the cloud will allow you to access your
library from other devices using your account & importing.
</Text>
<Button
variant={'accent'}
style={tw`mx-auto mt-4 max-w-[82%]`}
disabled={createLibrary.isPending}
onPress={async () => await createLibrary.mutateAsync(null)}
>
{createLibrary.isPending ? (
<Text style={tw`text-ink`}>Connecting library...</Text>
) : (
<Text style={tw`font-medium text-ink`}>Connect library</Text>
)}
</Button>
</Card>
</View>
)}
</ScreenContainer>
);
};
export default CloudSettings;

View File

@@ -1,64 +0,0 @@
import { Text, View } from 'react-native';
import { CloudInstance, HardwareModel } from '@sd/client';
import { Icon } from '~/components/icons/Icon';
import { hardwareModelToIcon } from '~/components/overview/Devices';
import { tw } from '~/lib/tailwind';
import { InfoBox } from './CloudSettings';
interface Props {
data: CloudInstance;
}
const Instance = ({ data }: Props) => {
return (
<InfoBox style={tw`w-full gap-2`}>
<View>
<View style={tw`mx-auto my-2`}>
<Icon
name={
hardwareModelToIcon(data.metadata.device_model as HardwareModel) as any
}
size={60}
/>
</View>
<Text
numberOfLines={1}
style={tw`mb-3 px-1 text-center text-sm font-medium text-ink`}
>
{data.metadata.name}
</Text>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>Id:</Text>
<Text numberOfLines={1} style={tw`max-w-[250px] text-ink-dull`}>
{data.id}
</Text>
</View>
</InfoBox>
</View>
<View>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>UUID:</Text>
<Text numberOfLines={1} style={tw`max-w-[85%] text-ink-dull`}>
{data.uuid}
</Text>
</View>
</InfoBox>
</View>
<View>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>Public key:</Text>
<Text numberOfLines={1} style={tw`max-w-3/4 text-ink-dull`}>
{data.identity}
</Text>
</View>
</InfoBox>
</View>
</InfoBox>
);
};
export default Instance;

View File

@@ -1,66 +0,0 @@
import { CheckCircle, XCircle } from 'phosphor-react-native';
import { useMemo } from 'react';
import { Text, View } from 'react-native';
import { CloudLibrary, useLibraryContext, useLibraryMutation } from '@sd/client';
import Card from '~/components/layout/Card';
import { Button } from '~/components/primitive/Button';
import { Divider } from '~/components/primitive/Divider';
import { SettingsTitle } from '~/components/settings/SettingsContainer';
import { tw } from '~/lib/tailwind';
import { logout, useAuthStateSnapshot } from '~/stores/auth';
import { InfoBox } from './CloudSettings';
interface LibraryProps {
cloudLibrary?: CloudLibrary;
}
const Library = ({ cloudLibrary }: LibraryProps) => {
const authState = useAuthStateSnapshot();
const { library } = useLibraryContext();
const syncLibrary = useLibraryMutation(['cloud.library.sync']);
const thisInstance = useMemo(
() => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id),
[cloudLibrary, library.instance_id]
);
return (
<Card style={tw`w-full`}>
<View style={tw`flex-row items-center justify-between`}>
<Text style={tw`font-medium text-ink`}>Library</Text>
{authState.status === 'loggedIn' && (
<Button variant="gray" size="sm" onPress={logout}>
<Text style={tw`text-xs font-semibold text-ink`}>Logout</Text>
</Button>
)}
</View>
<Divider style={tw`mb-4 mt-2`} />
<SettingsTitle style={tw`mb-2`}>Name</SettingsTitle>
<InfoBox>
<Text style={tw`text-ink`}>{cloudLibrary?.name}</Text>
</InfoBox>
<Button
disabled={syncLibrary.isPending || thisInstance !== undefined}
variant="gray"
onPress={() => syncLibrary.mutate(null)}
style={tw`mt-2 flex-row gap-1 py-2`}
>
{thisInstance ? (
<CheckCircle size={16} weight="fill" color={tw.color('green-400')} />
) : (
<XCircle
style={tw`rounded-full`}
size={16}
weight="fill"
color={tw.color('red-500')}
/>
)}
<Text style={tw`text-sm font-semibold text-ink`}>
{thisInstance !== undefined ? 'Library synced' : 'Library not synced'}
</Text>
</Button>
</Card>
);
};
export default Library;

View File

@@ -1,45 +0,0 @@
import { Text, View } from 'react-native';
import { Icon } from '~/components/icons/Icon';
import Card from '~/components/layout/Card';
import { Button } from '~/components/primitive/Button';
import { tw } from '~/lib/tailwind';
import { cancel, login, useAuthStateSnapshot } from '~/stores/auth';
const Login = () => {
const authState = useAuthStateSnapshot();
const buttonText = {
notLoggedIn: 'Login',
loggingIn: 'Cancel'
};
return (
<View style={tw`flex-1 flex-col items-center justify-center gap-2`}>
<Card style={tw`w-full items-center justify-center gap-2 p-6`}>
<View style={tw`flex-col items-center gap-2`}>
<Icon name="CloudSync" size={64} />
<Text style={tw`text-center text-sm text-ink`}>
Cloud Sync will upload your library to the cloud so you can access your
library from other devices by importing it from the cloud.
</Text>
</View>
{(authState.status === 'notLoggedIn' || authState.status === 'loggingIn') && (
<Button
variant="accent"
style={tw`mx-auto mt-4 max-w-[50%]`}
onPress={async (e) => {
e.preventDefault();
if (authState.status === 'loggingIn') {
await cancel();
} else {
await login();
}
}}
>
<Text style={tw`font-medium text-ink`}>{buttonText[authState.status]}</Text>
</Button>
)}
</Card>
</View>
);
};
export default Login;

View File

@@ -1,76 +0,0 @@
import { useMemo } from 'react';
import { Text, View } from 'react-native';
import { CloudLibrary, HardwareModel, useLibraryContext } from '@sd/client';
import { Icon } from '~/components/icons/Icon';
import Card from '~/components/layout/Card';
import { hardwareModelToIcon } from '~/components/overview/Devices';
import { Divider } from '~/components/primitive/Divider';
import { tw } from '~/lib/tailwind';
import { InfoBox } from './CloudSettings';
interface ThisInstanceProps {
cloudLibrary?: CloudLibrary;
}
const ThisInstance = ({ cloudLibrary }: ThisInstanceProps) => {
const { library } = useLibraryContext();
const thisInstance = useMemo(
() => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id),
[cloudLibrary, library.instance_id]
);
if (!thisInstance) return null;
return (
<Card style={tw`w-full gap-2`}>
<View>
<Text style={tw`mb-1 font-semibold text-ink`}>This Instance</Text>
<Divider />
</View>
<View style={tw`mx-auto my-2 items-center`}>
<Icon
name={
hardwareModelToIcon(
thisInstance.metadata.device_model as HardwareModel
) as any
}
size={60}
/>
<Text numberOfLines={1} style={tw`px-1 font-semibold text-ink`}>
{thisInstance.metadata.name}
</Text>
</View>
<View>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>Id:</Text>
<Text style={tw`max-w-[250px] text-ink-dull`}>{thisInstance.id}</Text>
</View>
</InfoBox>
</View>
<View>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>UUID:</Text>
<Text numberOfLines={1} style={tw`max-w-[85%] text-ink-dull`}>
{thisInstance.uuid}
</Text>
</View>
</InfoBox>
</View>
<View>
<InfoBox>
<View style={tw`flex-row items-center gap-1`}>
<Text style={tw`text-sm font-medium text-ink`}>Publc Key:</Text>
<Text numberOfLines={1} style={tw`max-w-3/4 text-ink-dull`}>
{thisInstance.identity}
</Text>
</View>
</InfoBox>
</View>
</Card>
);
};
export default ThisInstance;

View File

@@ -1,158 +0,0 @@
import { useIsFocused } from '@react-navigation/native';
import { inferSubscriptionResult } from '@spacedrive/rspc-client';
import { MotiView } from 'moti';
import { Circle } from 'phosphor-react-native';
import React, { useEffect, useRef, useState } from 'react';
import { Text, View } from 'react-native';
import {
Procedures,
useLibraryMutation,
useLibraryQuery,
useLibrarySubscription
} from '@sd/client';
import { Icon } from '~/components/icons/Icon';
import Card from '~/components/layout/Card';
import { ModalRef } from '~/components/layout/Modal';
import ScreenContainer from '~/components/layout/ScreenContainer';
import CloudModal from '~/components/modal/cloud/CloudModal';
import { Button } from '~/components/primitive/Button';
import { tw } from '~/lib/tailwind';
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
const SyncSettingsScreen = ({ navigation }: SettingsStackScreenProps<'SyncSettings'>) => {
const syncEnabled = useLibraryQuery(['sync.enabled']);
const [data, setData] = useState<inferSubscriptionResult<Procedures, 'library.actors'>>({});
const modalRef = useRef<ModalRef>(null);
const [startBackfill, setStart] = useState(false);
const pageFocused = useIsFocused();
const [showCloudModal, setShowCloudModal] = useState(false);
useLibrarySubscription(['library.actors'], { onData: setData });
useEffect(() => {
if (startBackfill === true) {
navigation.navigate('BackfillWaitingStack', {
screen: 'BackfillWaiting'
});
setTimeout(() => setShowCloudModal(true), 1000);
}
}, [startBackfill, navigation]);
useEffect(() => {
if (pageFocused && showCloudModal) modalRef.current?.present();
return () => {
if (showCloudModal) setShowCloudModal(false);
};
}, [pageFocused, showCloudModal]);
return (
<ScreenContainer scrollview={false} style={tw`gap-0 px-6`}>
{syncEnabled.data === false ? (
<View style={tw`flex-1 justify-center`}>
<Card style={tw`relative flex-col items-center gap-5 p-6`}>
<View style={tw`flex-col items-center gap-2`}>
<Icon name="Sync" size={72} style={tw`mb-2`} />
<Text style={tw`text-center leading-5 text-ink`}>
With Sync, you can share your library with other devices using P2P
technology.
</Text>
<Text style={tw`text-center leading-5 text-ink`}>
Additionally, allowing you to enable Cloud services to upload your
library to the cloud, making it accessible on any of your devices.
</Text>
</View>
<Button
variant={'accent'}
style={tw`mx-auto max-w-[82%]`}
onPress={() => setStart(true)}
>
<Text style={tw`font-medium text-white`}>Start</Text>
</Button>
</Card>
</View>
) : (
<View style={tw`flex-row flex-wrap gap-2`}>
{Object.keys(data).map((key) => {
return (
<Card style={tw`w-[48%]`} key={key}>
<OnlineIndicator online={data[key] ?? false} />
<Text
key={key}
style={tw`mb-3 mt-1 flex-col items-center justify-center text-left text-xs text-white`}
>
{key}
</Text>
{data[key] ? <StopButton name={key} /> : <StartButton name={key} />}
</Card>
);
})}
</View>
)}
<CloudModal ref={modalRef} />
</ScreenContainer>
);
};
export default SyncSettingsScreen;
function OnlineIndicator({ online }: { online: boolean }) {
const size = 6;
return (
<View
style={tw`mb-1 h-6 w-6 items-center justify-center rounded-full border border-app-inputborder bg-app-input p-2`}
>
{online ? (
<View style={tw`relative items-center justify-center`}>
<MotiView
from={{ scale: 0, opacity: 1 }}
animate={{ scale: 3, opacity: 0 }}
transition={{
type: 'timing',
duration: 1500,
loop: true,
repeatReverse: false,
delay: 1000
}}
style={tw`absolute z-10 h-2 w-2 items-center justify-center rounded-full bg-green-500`}
/>
<View style={tw`h-2 w-2 rounded-full bg-green-500`} />
</View>
) : (
<Circle size={size} color={tw.color('red-400')} weight="fill" />
)}
</View>
);
}
function StartButton({ name }: { name: string }) {
const startActor = useLibraryMutation(['library.startActor']);
return (
<Button
variant="accent"
size="sm"
disabled={startActor.isPending}
onPress={() => startActor.mutate(name)}
>
<Text style={tw`text-xs font-medium text-ink`}>
{startActor.isPending ? 'Starting' : 'Start'}
</Text>
</Button>
);
}
function StopButton({ name }: { name: string }) {
const stopActor = useLibraryMutation(['library.stopActor']);
return (
<Button
variant="accent"
size="sm"
disabled={stopActor.isPending}
onPress={() => stopActor.mutate(name)}
>
<Text style={tw`text-xs font-medium text-ink`}>
{stopActor.isPending ? 'Stopping' : 'Stop'}
</Text>
</Button>
);
}

View File

@@ -18,16 +18,16 @@ export function useAuthStateSnapshot() {
return useSolidStore(store).state;
}
nonLibraryClient
.query(['auth.me'])
.then(() => (store.state = { status: 'loggedIn' }))
.catch((e) => {
if (e instanceof RSPCError && e.code === 401) {
// TODO: handle error?
console.error('error', e);
}
store.state = { status: 'notLoggedIn' };
});
// nonLibraryClient
// .query(['auth.me'])
// .then(() => (store.state = { status: 'loggedIn' }))
// .catch((e) => {
// if (e instanceof RSPCError && e.code === 401) {
// // TODO: handle error?
// console.error('error', e);
// }
// store.state = { status: 'notLoggedIn' };
// });
type CallbackStatus = 'success' | { error: string } | 'cancel';
const loginCallbacks = new Set<(status: CallbackStatus) => void>();
@@ -41,29 +41,29 @@ export function login() {
store.state = { status: 'loggingIn' };
let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], {
onData(data) {
if (data === 'Complete') {
loginCallbacks.forEach((cb) => cb('success'));
} else if ('Error' in data) {
console.error('[auth] error: ', data.Error);
onError(data.Error);
} else {
console.log('[auth] verification url: ', data.Start.verification_url_complete);
Promise.resolve()
.then(() => Linking.openURL(data.Start.verification_url_complete))
.then(
(res) => {
authCleanup = res;
},
(e) => onError(e.message)
);
}
},
onError(e) {
onError(e.message);
}
});
// let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], {
// onData(data) {
// if (data === 'Complete') {
// loginCallbacks.forEach((cb) => cb('success'));
// } else if ('Error' in data) {
// console.error('[auth] error: ', data.Error);
// onError(data.Error);
// } else {
// console.log('[auth] verification url: ', data.Start.verification_url_complete);
// Promise.resolve()
// .then(() => Linking.openURL(data.Start.verification_url_complete))
// .then(
// (res) => {
// authCleanup = res;
// },
// (e) => onError(e.message)
// );
// }
// },
// onError(e) {
// onError(e.message);
// }
// });
return new Promise<void>((res, rej) => {
const cb = async (status: CallbackStatus) => {
@@ -71,7 +71,7 @@ export function login() {
if (status === 'success') {
store.state = { status: 'loggedIn' };
nonLibraryClient.query(['auth.me']);
// nonLibraryClient.query(['auth.me']);
res();
} else {
store.state = { status: 'notLoggedIn' };
@@ -88,8 +88,8 @@ export function set_logged_in() {
export function logout() {
store.state = { status: 'loggingOut' };
nonLibraryClient.mutation(['auth.logout']);
nonLibraryClient.query(['auth.me']);
// nonLibraryClient.mutation(['auth.logout']);
// nonLibraryClient.query(['auth.me']);
store.state = { status: 'notLoggedIn' };
}

View File

@@ -0,0 +1,22 @@
import { proxy, useSnapshot } from 'valtio';
export type User = {
id: string;
email: string;
timeJoined: number;
tenantIds: string[];
};
const state = {
userInfo: undefined as User | undefined
};
const store = proxy({
...state
});
// for reading
export const useUserStore = () => useSnapshot(store);
// for writing
export const getUserStore = () => store;

View File

@@ -0,0 +1,13 @@
import AsyncStorage from '@react-native-async-storage/async-storage';
export async function getTokens() {
const fetchedToken = await AsyncStorage.getItem('access_token');
const fetchedRefreshToken = await AsyncStorage.getItem('refresh_token');
return {
accessToken: fetchedToken ?? '',
refreshToken: fetchedRefreshToken ?? ''
};
}
// export const AUTH_SERVER_URL = __DEV__ ? 'http://localhost:9420' : 'https://auth.spacedrive.com';
export const AUTH_SERVER_URL = 'https://auth.spacedrive.com';

View File

@@ -144,19 +144,7 @@ async fn main() {
let state = AppState { auth };
let (node, router) = match Node::new(
data_dir,
sd_core::Env {
api_url: tokio::sync::Mutex::new(
std::env::var("SD_API_URL")
.unwrap_or_else(|_| "https://app.spacedrive.com".to_string()),
),
client_id: std::env::var("SD_CLIENT_ID")
.unwrap_or_else(|_| "04701823-a498-406e-aef9-22081c1dae34".to_string()),
},
)
.await
{
let (node, router) = match Node::new(data_dir).await {
Ok(d) => d,
Err(e) => {
panic!("{}", e.to_string())

View File

@@ -30,6 +30,6 @@
"storybook": "^8.0.1",
"tailwindcss": "^3.4.10",
"typescript": "^5.6.2",
"vite": "^5.2.0"
"vite": "^5.4.9"
}
}

View File

@@ -41,7 +41,7 @@
"rollup-plugin-visualizer": "^5.12.0",
"start-server-and-test": "^2.0.3",
"typescript": "^5.6.2",
"vite": "^5.2.0",
"vite-tsconfig-paths": "^4.3.2"
"vite": "^5.4.9",
"vite-tsconfig-paths": "^5.0.1"
}
}

View File

@@ -20,6 +20,7 @@ heif = ["sd-images/heif"]
[dependencies]
# Inner Core Sub-crates
sd-core-cloud-services = { path = "./crates/cloud-services" }
sd-core-file-path-helper = { path = "./crates/file-path-helper" }
sd-core-heavy-lifting = { path = "./crates/heavy-lifting" }
sd-core-indexer-rules = { path = "./crates/indexer-rules" }
@@ -29,7 +30,8 @@ sd-core-sync = { path = "./crates/sync" }
# Spacedrive Sub-crates
sd-actors = { path = "../crates/actors" }
sd-ai = { path = "../crates/ai", optional = true }
sd-cloud-api = { path = "../crates/cloud-api" }
sd-crypto = { path = "../crates/crypto" }
sd-ffmpeg = { path = "../crates/ffmpeg", optional = true }
sd-file-ext = { path = "../crates/file-ext" }
sd-images = { path = "../crates/images", features = ["rspc", "serde", "specta"] }
sd-media-metadata = { path = "../crates/media-metadata" }
@@ -49,6 +51,7 @@ async-trait = { workspace = true }
axum = { workspace = true, features = ["ws"] }
base64 = { workspace = true }
blake3 = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
futures = { workspace = true }
futures-concurrency = { workspace = true }
@@ -64,6 +67,7 @@ reqwest = { workspace = true, features = ["json", "native-tls-vendor
rmp-serde = { workspace = true }
rmpv = { workspace = true }
rspc = { workspace = true, features = ["alpha", "axum", "chrono", "unstable", "uuid"] }
sd-cloud-schema = { workspace = true }
serde = { workspace = true, features = ["derive", "rc"] }
serde_json = { workspace = true }
specta = { workspace = true }
@@ -75,12 +79,11 @@ tokio-stream = { workspace = true, features = ["fs"] }
tokio-util = { workspace = true, features = ["io"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
uuid = { workspace = true, features = ["serde", "v4"] }
uuid = { workspace = true, features = ["serde", "v4", "v7"] }
# Specific Core dependencies
async-recursion = "1.1"
base91 = "0.1.0"
bytes = "1.6"
ctor = "0.2.8"
directories = "5.0"
flate2 = "1.0"
@@ -98,6 +101,7 @@ sysinfo = "0.29.11" # Update blocked
tar = "0.4.41"
tower-service = "0.3.2"
tracing-appender = "0.2.3"
whoami = "1.5.2"
[dependencies.tokio]
features = ["io-util", "macros", "process", "rt-multi-thread", "sync", "time"]

View File

@@ -0,0 +1,55 @@
[package]
name = "sd-core-cloud-services"
version = "0.1.0"
edition = "2021"
[dependencies]
# Core Spacedrive Sub-crates
sd-core-sync = { path = "../sync" }
# Spacedrive Sub-crates
sd-actors = { path = "../../../crates/actors" }
sd-cloud-schema = { workspace = true }
sd-crypto = { path = "../../../crates/crypto" }
sd-prisma = { path = "../../../crates/prisma" }
sd-utils = { path = "../../../crates/utils" }
# Workspace dependencies
async-stream = { workspace = true }
base64 = { workspace = true }
blake3 = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
flume = { workspace = true }
futures = { workspace = true }
futures-concurrency = { workspace = true }
rmp-serde = { workspace = true }
rspc = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
specta = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "time"] }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true, features = ["serde"] }
zeroize = { workspace = true }
# External dependencies
anyhow = "1.0.86"
dashmap = "6.1.0"
iroh-net = { version = "0.27", features = ["discovery-local-network", "iroh-relay"] }
paste = "=1.0.15"
quic-rpc = { version = "0.12.1", features = ["quinn-transport"] }
quinn = { package = "iroh-quinn", version = "0.11" }
# Using whatever version of reqwest that reqwest-middleware uses, just putting here to enable some features
reqwest = { version = "0.12", features = ["json", "native-tls-vendored", "stream"] }
reqwest-middleware = { version = "0.3", features = ["json"] }
reqwest-retry = "0.6"
rustls = { version = "=0.23.15", default-features = false, features = ["brotli", "ring", "std"] }
rustls-platform-verifier = "0.3.3"
[dev-dependencies]
tokio = { workspace = true, features = ["rt", "sync", "time"] }

View File

@@ -0,0 +1,358 @@
use crate::p2p::{NotifyUser, UserResponse};
use sd_cloud_schema::{Client, Service, ServicesALPN};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use futures::Stream;
use iroh_net::relay::RelayUrl;
use quic_rpc::{transport::quinn::QuinnConnection, RpcClient};
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint};
use reqwest::{IntoUrl, Url};
use reqwest_middleware::{reqwest, ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::sync::{Mutex, RwLock};
use tracing::warn;
use super::{
error::Error, key_manager::KeyManager, p2p::CloudP2P, token_refresher::TokenRefresher,
};
#[derive(Debug, Default, Clone)]
enum ClientState {
#[default]
NotConnected,
Connected(Client<QuinnConnection<Service>, Service>),
}
/// Cloud services are a optional feature that allows you to interact with the cloud services
/// of Spacedrive.
/// They're optional in two different ways:
/// - The cloud services depends on a user being logged in with our server.
/// - The user being connected to the internet to begin with.
///
/// As we don't want to force the user to be connected to the internet, we have to make sure
/// that core can always operate without the cloud services.
#[derive(Debug)]
pub struct CloudServices {
client_state: Arc<RwLock<ClientState>>,
get_cloud_api_address: Url,
http_client: ClientWithMiddleware,
domain_name: String,
pub cloud_p2p_dns_origin_name: String,
pub cloud_p2p_relay_url: RelayUrl,
pub cloud_p2p_dns_pkarr_url: Url,
pub token_refresher: TokenRefresher,
key_manager: Arc<RwLock<Option<Arc<KeyManager>>>>,
cloud_p2p: Arc<RwLock<Option<Arc<CloudP2P>>>>,
pub(crate) notify_user_tx: flume::Sender<NotifyUser>,
notify_user_rx: flume::Receiver<NotifyUser>,
user_response_tx: flume::Sender<UserResponse>,
pub(crate) user_response_rx: flume::Receiver<UserResponse>,
pub has_bootstrapped: Arc<Mutex<bool>>,
}
impl CloudServices {
/// Creates a new cloud services client that can be used to interact with the cloud services.
/// The client will try to connect to the cloud services on a best effort basis, as the user
/// might not be connected to the internet.
/// If the client fails to connect, it will try again the next time it's used.
pub async fn new(
get_cloud_api_address: impl IntoUrl + Send,
cloud_p2p_relay_url: impl IntoUrl + Send,
cloud_p2p_dns_pkarr_url: impl IntoUrl + Send,
cloud_p2p_dns_origin_name: String,
domain_name: String,
) -> Result<Self, Error> {
let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3));
#[cfg(not(debug_assertions))]
{
http_client_builder = http_client_builder.https_only(true);
}
let cloud_p2p_relay_url = cloud_p2p_relay_url
.into_url()
.map_err(Error::InvalidUrl)?
.into();
let cloud_p2p_dns_pkarr_url = cloud_p2p_dns_pkarr_url
.into_url()
.map_err(Error::InvalidUrl)?;
let http_client =
ClientBuilder::new(http_client_builder.build().map_err(Error::HttpClientInit)?)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder().build_with_max_retries(3),
))
.build();
let get_cloud_api_address = get_cloud_api_address
.into_url()
.map_err(Error::InvalidUrl)?;
let client_state = match Self::init_client(
&http_client,
get_cloud_api_address.clone(),
domain_name.clone(),
)
.await
{
Ok(client) => Arc::new(RwLock::new(ClientState::Connected(client))),
Err(e) => {
warn!(
?e,
"Failed to initialize cloud services client; \
This is a best effort and we will continue in Not Connected mode"
);
Arc::new(RwLock::new(ClientState::NotConnected))
}
};
let (notify_user_tx, notify_user_rx) = flume::bounded(16);
let (user_response_tx, user_response_rx) = flume::bounded(16);
Ok(Self {
client_state,
token_refresher: TokenRefresher::new(
http_client.clone(),
get_cloud_api_address.clone(),
),
get_cloud_api_address,
http_client,
cloud_p2p_dns_origin_name,
cloud_p2p_relay_url,
cloud_p2p_dns_pkarr_url,
domain_name,
key_manager: Arc::default(),
cloud_p2p: Arc::default(),
notify_user_tx,
notify_user_rx,
user_response_tx,
user_response_rx,
has_bootstrapped: Arc::default(),
})
}
pub fn stream_user_notifications(&self) -> impl Stream<Item = NotifyUser> + '_ {
self.notify_user_rx.stream()
}
#[must_use]
pub const fn http_client(&self) -> &ClientWithMiddleware {
&self.http_client
}
/// Send back a user response to the Cloud P2P actor
///
/// # Panics
/// Will panic if the channel is closed, which should never happen
pub async fn send_user_response(&self, response: UserResponse) {
self.user_response_tx
.send_async(response)
.await
.expect("user response channel must never close");
}
async fn init_client(
http_client: &ClientWithMiddleware,
get_cloud_api_address: Url,
domain_name: String,
) -> Result<Client<QuinnConnection<Service>, Service>, Error> {
let cloud_api_address = http_client
.get(get_cloud_api_address)
.send()
.await
.map_err(Error::FailedToRequestApiAddress)?
.error_for_status()
.map_err(Error::AuthServerError)?
.text()
.await
.map_err(Error::FailedToExtractApiAddress)?
.parse::<SocketAddr>()?;
let mut crypto_config = {
#[cfg(debug_assertions)]
{
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
{
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA1,
rustls::SignatureScheme::ECDSA_SHA1_Legacy,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth()
}
#[cfg(not(debug_assertions))]
{
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.dangerous()
.with_custom_certificate_verifier(Arc::new(
rustls_platform_verifier::Verifier::new(),
))
.with_no_client_auth()
}
};
crypto_config
.alpn_protocols
.extend([ServicesALPN::LATEST.to_vec()]);
let client_config = ClientConfig::new(Arc::new(
QuicClientConfig::try_from(crypto_config)
.expect("misconfigured TLS client config, this is a bug and should crash"),
));
let mut endpoint = Endpoint::client("[::]:0".parse().expect("hardcoded address"))
.map_err(Error::FailedToCreateEndpoint)?;
endpoint.set_default_client_config(client_config);
// TODO(@fogodev): It's possible that we can't keep the connection alive all the time,
// and need to use single shot connections. I will only be sure when we have
// actually battle-tested the cloud services in core.
Ok(Client::new(RpcClient::new(QuinnConnection::new(
endpoint,
cloud_api_address,
domain_name,
))))
}
/// Returns a client to the cloud services.
///
/// If the client is not connected, it will try to connect to the cloud services.
/// Available routes documented in
/// [`sd_cloud_schema::Service`](https://github.com/spacedriveapp/cloud-services-schema).
pub async fn client(&self) -> Result<Client<QuinnConnection<Service>, Service>, Error> {
if let ClientState::Connected(client) = { self.client_state.read().await.clone() } {
return Ok(client);
}
// If we're not connected, we need to try to connect.
let client = Self::init_client(
&self.http_client,
self.get_cloud_api_address.clone(),
self.domain_name.clone(),
)
.await?;
*self.client_state.write().await = ClientState::Connected(client.clone());
Ok(client)
}
pub async fn set_key_manager(&self, key_manager: KeyManager) {
self.key_manager
.write()
.await
.replace(Arc::new(key_manager));
}
pub async fn key_manager(&self) -> Result<Arc<KeyManager>, Error> {
self.key_manager
.read()
.await
.as_ref()
.map_or(Err(Error::KeyManagerNotInitialized), |key_manager| {
Ok(Arc::clone(key_manager))
})
}
pub async fn set_cloud_p2p(&self, cloud_p2p: CloudP2P) {
self.cloud_p2p.write().await.replace(Arc::new(cloud_p2p));
}
pub async fn cloud_p2p(&self) -> Result<Arc<CloudP2P>, Error> {
self.cloud_p2p
.read()
.await
.as_ref()
.map_or(Err(Error::CloudP2PNotInitialized), |cloud_p2p| {
Ok(Arc::clone(cloud_p2p))
})
}
}
#[cfg(test)]
mod tests {
use sd_cloud_schema::{auth, devices};
use super::*;
#[ignore]
#[tokio::test]
async fn test_client() {
let response = CloudServices::new(
"http://localhost:9420/cloud-api-address",
"http://relay.localhost:9999/",
"http://pkarr.localhost:9999/",
"dns.localhost:9999".to_string(),
"localhost".to_string(),
)
.await
.unwrap()
.client()
.await
.unwrap()
.devices()
.list(devices::list::Request {
access_token: auth::AccessToken("invalid".to_string()),
})
.await
.unwrap();
assert!(matches!(
response,
Err(sd_cloud_schema::Error::Client(
sd_cloud_schema::error::ClientSideError::Unauthorized
))
));
}
}

View File

@@ -0,0 +1,170 @@
use sd_cloud_schema::{cloud_p2p, sync::groups, Service};
use sd_utils::error::FileIOError;
use std::{io, net::AddrParseError};
use quic_rpc::{
pattern::{bidi_streaming, rpc, server_streaming},
transport::quinn::QuinnConnection,
};
#[derive(thiserror::Error, Debug)]
pub enum Error {
// Setup errors
#[error("Couldn't parse Cloud Services API address URL: {0}")]
InvalidUrl(reqwest::Error),
#[error("Failed to parse Cloud Services API address URL")]
FailedToParseRelayUrl,
#[error("Failed to initialize http client: {0}")]
HttpClientInit(reqwest::Error),
#[error("Failed to request Cloud Services API address from Auth Server route: {0}")]
FailedToRequestApiAddress(reqwest_middleware::Error),
#[error("Auth Server's Cloud Services API address route returned an error: {0}")]
AuthServerError(reqwest::Error),
#[error(
"Failed to extract response body from Auth Server's Cloud Services API address route: {0}"
)]
FailedToExtractApiAddress(reqwest::Error),
#[error("Failed to parse auth server's Cloud Services API address: {0}")]
FailedToParseApiAddress(#[from] AddrParseError),
#[error("Failed to create endpoint: {0}")]
FailedToCreateEndpoint(io::Error),
// Token refresher errors
#[error("Invalid token format, missing claims")]
MissingClaims,
#[error("Failed to decode access token data: {0}")]
DecodeAccessTokenData(#[from] base64::DecodeError),
#[error("Failed to deserialize access token json data: {0}")]
DeserializeAccessTokenData(#[from] serde_json::Error),
#[error("Token expired")]
TokenExpired,
#[error("Failed to request refresh token: {0}")]
RefreshTokenRequest(reqwest_middleware::Error),
#[error("Missing tokens on refresh response")]
MissingTokensOnRefreshResponse,
#[error("Failed to parse token header value to string: {0}")]
FailedToParseTokenHeaderValueToString(#[from] reqwest::header::ToStrError),
// Key Manager errors
#[error("Failed to handle File on KeyManager: {0}")]
FileIO(#[from] FileIOError),
#[error("Failed to handle key store serialization: {0}")]
KeyStoreSerialization(rmp_serde::encode::Error),
#[error("Failed to handle key store deserialization: {0}")]
KeyStoreDeserialization(rmp_serde::decode::Error),
#[error("Key store encryption related error: {{context: \"{context}\", source: {source}}}")]
KeyStoreCrypto {
#[source]
source: sd_crypto::Error,
context: &'static str,
},
#[error("Key manager not initialized")]
KeyManagerNotInitialized,
// Cloud P2P errors
#[error("Failed to create Cloud P2P endpoint: {0}")]
CreateCloudP2PEndpoint(anyhow::Error),
#[error("Failed to connect to Cloud P2P node: {0}")]
ConnectToCloudP2PNode(anyhow::Error),
#[error("Communication error with Cloud P2P node: {0}")]
CloudP2PRpcCommunication(#[from] rpc::Error<QuinnConnection<cloud_p2p::Service>>),
#[error("Cloud P2P not initialized")]
CloudP2PNotInitialized,
#[error("Failed to initialize LocalSwarmDiscovery: {0}")]
LocalSwarmDiscoveryInit(anyhow::Error),
#[error("Failed to initialize DhtDiscovery: {0}")]
DhtDiscoveryInit(anyhow::Error),
// Communication errors
#[error("Failed to communicate with RPC backend: {0}")]
RpcCommunication(#[from] rpc::Error<QuinnConnection<Service>>),
#[error("Failed to communicate with Server Streaming RPC backend: {0}")]
ServerStreamCommunication(#[from] server_streaming::Error<QuinnConnection<Service>>),
#[error("Failed to receive next response from Server Streaming RPC backend: {0}")]
ServerStreamRecv(#[from] server_streaming::ItemError<QuinnConnection<Service>>),
#[error("Failed to communicate with Bidi Streaming RPC backend: {0}")]
BidiStreamCommunication(#[from] bidi_streaming::Error<QuinnConnection<Service>>),
#[error("Failed to receive next response from Bidi Streaming RPC backend: {0}")]
BidiStreamRecv(#[from] bidi_streaming::ItemError<QuinnConnection<Service>>),
#[error("Error from backend: {0}")]
Backend(#[from] sd_cloud_schema::Error),
#[error("Failed to get access token from refresher: {0}")]
GetToken(#[from] GetTokenError),
#[error("Unexpected empty response from backend, context: {0}")]
EmptyResponse(&'static str),
#[error("Unexpected response from backend, context: {0}")]
UnexpectedResponse(&'static str),
// Sync error
#[error("Sync error: {0}")]
Sync(#[from] sd_core_sync::Error),
#[error("Tried to sync messages with a group without having needed key")]
MissingSyncGroupKey(groups::PubId),
#[error("Failed to encrypt sync messages: {0}")]
Encrypt(sd_crypto::Error),
#[error("Failed to decrypt sync messages: {0}")]
Decrypt(sd_crypto::Error),
#[error("Failed to upload sync messages: {0}")]
UploadSyncMessages(reqwest_middleware::Error),
#[error("Failed to download sync messages: {0}")]
DownloadSyncMessages(reqwest_middleware::Error),
#[error("Received an error response from uploading sync messages: {0}")]
ErrorResponseUploadSyncMessages(reqwest::Error),
#[error("Received an error response from downloading sync messages: {0}")]
ErrorResponseDownloadSyncMessages(reqwest::Error),
#[error(
"Received an error response from downloading sync messages while reading its bytes: {0}"
)]
ErrorResponseDownloadReadBytesSyncMessages(reqwest::Error),
#[error("Critical error while uploading sync messages")]
CriticalErrorWhileUploadingSyncMessages,
#[error("Failed to send End update to push sync messages")]
EndUpdatePushSyncMessages(io::Error),
#[error("Unexpected end of stream while encrypting sync messages")]
UnexpectedEndOfStream,
#[error("Failed to create directory to store timestamp keeper files")]
FailedToCreateTimestampKeepersDirectory(io::Error),
#[error("Failed to read last timestamp keeper for pulling sync messages: {0}")]
FailedToReadLastTimestampKeeper(io::Error),
#[error("Failed to handle last timestamp keeper serialization: {0}")]
LastTimestampKeeperSerialization(rmp_serde::encode::Error),
#[error("Failed to handle last timestamp keeper deserialization: {0}")]
LastTimestampKeeperDeserialization(rmp_serde::decode::Error),
#[error("Failed to write last timestamp keeper for pulling sync messages: {0}")]
FailedToWriteLastTimestampKeeper(io::Error),
#[error("Sync messages download and decrypt task panicked")]
SyncMessagesDownloadAndDecryptTaskPanicked,
#[error("Serialization failure to push sync messages: {0}")]
SerializationFailureToPushSyncMessages(rmp_serde::encode::Error),
#[error("Deserialization failure to pull sync messages: {0}")]
DeserializationFailureToPullSyncMessages(rmp_serde::decode::Error),
#[error("Read nonce stream decryption: {0}")]
ReadNonceStreamDecryption(io::Error),
#[error("Incomplete download bytes sync messages")]
IncompleteDownloadBytesSyncMessages,
// Temporary errors
#[error("Device missing secret key for decrypting sync messages")]
MissingKeyHash,
}
#[derive(thiserror::Error, Debug)]
pub enum GetTokenError {
#[error("Token refresher not initialized")]
RefresherNotInitialized,
#[error("Token refresher failed to refresh and need to be initialized again")]
FailedToRefresh,
}
impl From<Error> for rspc::Error {
fn from(e: Error) -> Self {
Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e)
}
}
impl From<GetTokenError> for rspc::Error {
fn from(e: GetTokenError) -> Self {
Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e)
}
}

View File

@@ -0,0 +1,331 @@
use crate::Error;
use sd_cloud_schema::{
sync::{groups, KeyHash},
NodeId, SecretKey as IrohSecretKey,
};
use sd_crypto::{
cloud::{decrypt, encrypt, secret_key::SecretKey},
primitives::{EncryptedBlock, OneShotNonce, StreamNonce},
CryptoRng,
};
use sd_utils::error::FileIOError;
use tracing::debug;
use std::{
collections::{BTreeMap, VecDeque},
fs::Metadata,
path::PathBuf,
pin::pin,
};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufWriter},
};
use zeroize::{Zeroize, ZeroizeOnDrop};
type KeyStack = VecDeque<(KeyHash, SecretKey)>;
#[derive(Serialize, Deserialize)]
pub struct KeyStore {
iroh_secret_key: IrohSecretKey,
keys: BTreeMap<groups::PubId, KeyStack>,
}
impl KeyStore {
pub const fn new(iroh_secret_key: IrohSecretKey) -> Self {
Self {
iroh_secret_key,
keys: BTreeMap::new(),
}
}
pub fn add_key(&mut self, group_pub_id: groups::PubId, key: SecretKey) {
self.keys.entry(group_pub_id).or_default().push_front((
KeyHash(blake3::hash(key.as_ref()).to_hex().to_string()),
key,
));
}
pub fn add_key_with_hash(
&mut self,
group_pub_id: groups::PubId,
key: SecretKey,
key_hash: KeyHash,
) {
debug!(
key_hash = key_hash.0,
?group_pub_id,
"Added single cloud sync key to key manager"
);
self.keys
.entry(group_pub_id)
.or_default()
.push_front((key_hash, key));
}
pub fn add_many_keys(
&mut self,
group_pub_id: groups::PubId,
keys: impl IntoIterator<Item = SecretKey, IntoIter = impl DoubleEndedIterator<Item = SecretKey>>,
) {
let group_entry = self.keys.entry(group_pub_id).or_default();
// We reverse the secret keys as a implementation detail to
// keep the keys in the same order as they were added as a stack
for key in keys.into_iter().rev() {
let key_hash = blake3::hash(key.as_ref()).to_hex().to_string();
debug!(
key_hash,
?group_pub_id,
"Added cloud sync key to key manager"
);
group_entry.push_front((KeyHash(key_hash), key));
}
}
pub fn remove_group(&mut self, group_pub_id: groups::PubId) {
self.keys.remove(&group_pub_id);
}
pub fn iroh_secret_key(&self) -> IrohSecretKey {
self.iroh_secret_key.clone()
}
pub fn node_id(&self) -> NodeId {
self.iroh_secret_key.public()
}
pub fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option<SecretKey> {
self.keys.get(&group_pub_id).and_then(|group| {
group
.iter()
.find_map(|(key_hash, key)| (key_hash == hash).then(|| key.clone()))
})
}
pub fn get_latest_key(&self, group_pub_id: groups::PubId) -> Option<(KeyHash, SecretKey)> {
self.keys
.get(&group_pub_id)
.and_then(|group| group.front().cloned())
}
pub fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec<SecretKey> {
self.keys
.get(&group_pub_id)
.map(|group| group.iter().map(|(_key_hash, key)| key.clone()).collect())
.unwrap_or_default()
}
pub async fn encrypt(
&self,
key: &SecretKey,
rng: &mut CryptoRng,
keys_file_path: &PathBuf,
) -> Result<(), Error> {
let plain_text_bytes =
rmp_serde::to_vec_named(self).map_err(Error::KeyStoreSerialization)?;
let mut file = BufWriter::with_capacity(
EncryptedBlock::CIPHER_TEXT_SIZE,
fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&keys_file_path)
.await
.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to open space keys file to encrypt",
))
})?,
);
if plain_text_bytes.len() < EncryptedBlock::PLAIN_TEXT_SIZE {
use encrypt::OneShotEncryption;
let EncryptedBlock { nonce, cipher_text } = key
.encrypt(&plain_text_bytes, rng)
.map_err(|e| Error::KeyStoreCrypto {
source: e,
context: "Failed to oneshot encrypt key store",
})?;
file.write_all(nonce.as_slice()).await.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to write space keys file oneshot nonce",
))
})?;
file.write_all(cipher_text.as_slice()).await.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to write space keys file oneshot cipher text",
))
})?;
} else {
use encrypt::StreamEncryption;
let (nonce, stream) = key.encrypt(plain_text_bytes.as_slice(), rng);
file.write_all(nonce.as_slice()).await.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to write space keys file stream nonce",
))
})?;
let mut stream = pin!(stream);
while let Some(res) = stream.next().await {
file.write_all(&res.map_err(|e| Error::KeyStoreCrypto {
source: e,
context: "Failed to stream encrypt key store",
})?)
.await
.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to write space keys file stream cipher text",
))
})?;
}
};
file.flush().await.map_err(|e| {
FileIOError::from((&keys_file_path, e, "Failed to flush space keys file")).into()
})
}
pub async fn decrypt(
key: &SecretKey,
metadata: Metadata,
keys_file_path: &PathBuf,
) -> Result<Self, Error> {
let mut file = fs::File::open(&keys_file_path).await.map_err(|e| {
FileIOError::from((
keys_file_path,
e,
"Failed to open space keys file to decrypt",
))
})?;
let usize_file_len =
usize::try_from(metadata.len()).expect("Failed to convert metadata length to usize");
let key_store_bytes =
if usize_file_len <= EncryptedBlock::CIPHER_TEXT_SIZE + size_of::<OneShotNonce>() {
use decrypt::OneShotDecryption;
let mut nonce = OneShotNonce::default();
file.read_exact(&mut nonce).await.map_err(|e| {
FileIOError::from((
keys_file_path,
e,
"Failed to read space keys file oneshot nonce",
))
})?;
let mut cipher_text = vec![0u8; usize_file_len - size_of::<OneShotNonce>()];
file.read_exact(&mut cipher_text).await.map_err(|e| {
FileIOError::from((
keys_file_path,
e,
"Failed to read space keys file oneshot cipher text",
))
})?;
key.decrypt_owned(&EncryptedBlock { nonce, cipher_text })
.map_err(|e| Error::KeyStoreCrypto {
source: e,
context: "Failed to oneshot decrypt space keys file",
})?
} else {
use decrypt::StreamDecryption;
let mut nonce = StreamNonce::default();
let mut key_store_bytes = Vec::with_capacity(
(usize_file_len - size_of::<StreamNonce>()) / EncryptedBlock::CIPHER_TEXT_SIZE
* EncryptedBlock::PLAIN_TEXT_SIZE,
);
file.read_exact(&mut nonce).await.map_err(|e| {
FileIOError::from((
keys_file_path,
e,
"Failed to read space keys file stream nonce",
))
})?;
key.decrypt(&nonce, &mut file, &mut key_store_bytes)
.await
.map_err(|e| Error::KeyStoreCrypto {
source: e,
context: "Failed to stream decrypt space keys file",
})?;
key_store_bytes
};
let this = rmp_serde::from_slice::<Self>(&key_store_bytes)
.map_err(Error::KeyStoreDeserialization)?;
#[cfg(debug_assertions)]
{
use std::fmt::Write;
let mut key_hashes_log = String::new();
this.keys.iter().for_each(|(group_pub_id, key_stack)| {
writeln!(
key_hashes_log,
"Group: {group_pub_id:?} => KeyHashes: {:?}",
key_stack
.iter()
.map(|(KeyHash(key_hash), _)| key_hash)
.collect::<Vec<_>>()
)
.expect("Failed to write to key hashes log");
});
tracing::info!("Loaded key hashes: {key_hashes_log}");
}
Ok(this)
}
}
/// Zeroize our secret keys and scrambles up iroh's secret key that doesn't implement zeroize
impl Zeroize for KeyStore {
fn zeroize(&mut self) {
self.iroh_secret_key = IrohSecretKey::generate();
self.keys.values_mut().for_each(|group| {
group
.iter_mut()
.map(|(_key_hash, key)| key)
.for_each(Zeroize::zeroize);
});
self.keys = BTreeMap::new();
}
}
impl Drop for KeyStore {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for KeyStore {}

View File

@@ -0,0 +1,183 @@
use crate::Error;
use sd_cloud_schema::{
sync::{groups, KeyHash},
NodeId, SecretKey as IrohSecretKey,
};
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng};
use sd_utils::error::FileIOError;
use std::{
fmt,
path::{Path, PathBuf},
};
use tokio::{fs, sync::RwLock};
mod key_store;
use key_store::KeyStore;
const KEY_FILE_NAME: &str = "space.keys";
pub struct KeyManager {
master_key: SecretKey,
keys_file_path: PathBuf,
store: RwLock<KeyStore>,
}
impl KeyManager {
pub async fn new(
master_key: SecretKey,
iroh_secret_key: IrohSecretKey,
data_directory: impl AsRef<Path> + Send,
rng: &mut CryptoRng,
) -> Result<Self, Error> {
async fn inner(
master_key: SecretKey,
iroh_secret_key: IrohSecretKey,
keys_file_path: PathBuf,
rng: &mut CryptoRng,
) -> Result<KeyManager, Error> {
let store = KeyStore::new(iroh_secret_key);
store.encrypt(&master_key, rng, &keys_file_path).await?;
Ok(KeyManager {
master_key,
keys_file_path,
store: RwLock::new(store),
})
}
inner(
master_key,
iroh_secret_key,
data_directory.as_ref().join(KEY_FILE_NAME),
rng,
)
.await
}
pub async fn load(
master_key: SecretKey,
data_directory: impl AsRef<Path> + Send,
) -> Result<Self, Error> {
async fn inner(
master_key: SecretKey,
keys_file_path: PathBuf,
) -> Result<KeyManager, Error> {
Ok(KeyManager {
store: RwLock::new(
KeyStore::decrypt(
&master_key,
fs::metadata(&keys_file_path).await.map_err(|e| {
FileIOError::from((
&keys_file_path,
e,
"Failed to read space keys file",
))
})?,
&keys_file_path,
)
.await?,
),
master_key,
keys_file_path,
})
}
inner(master_key, data_directory.as_ref().join(KEY_FILE_NAME)).await
}
pub async fn iroh_secret_key(&self) -> IrohSecretKey {
self.store.read().await.iroh_secret_key()
}
pub async fn node_id(&self) -> NodeId {
self.store.read().await.node_id()
}
pub async fn add_key(
&self,
group_pub_id: groups::PubId,
key: SecretKey,
rng: &mut CryptoRng,
) -> Result<(), Error> {
let mut store = self.store.write().await;
store.add_key(group_pub_id, key);
// Keeping the write lock here, this way we ensure that we can't corrupt the file
store
.encrypt(&self.master_key, rng, &self.keys_file_path)
.await
}
pub async fn add_key_with_hash(
&self,
group_pub_id: groups::PubId,
key: SecretKey,
key_hash: KeyHash,
rng: &mut CryptoRng,
) -> Result<(), Error> {
let mut store = self.store.write().await;
store.add_key_with_hash(group_pub_id, key, key_hash);
// Keeping the write lock here, this way we ensure that we can't corrupt the file
store
.encrypt(&self.master_key, rng, &self.keys_file_path)
.await
}
pub async fn remove_group(
&self,
group_pub_id: groups::PubId,
rng: &mut CryptoRng,
) -> Result<(), Error> {
let mut store = self.store.write().await;
store.remove_group(group_pub_id);
// Keeping the write lock here, this way we ensure that we can't corrupt the file
store
.encrypt(&self.master_key, rng, &self.keys_file_path)
.await
}
pub async fn add_many_keys(
&self,
group_pub_id: groups::PubId,
keys: impl IntoIterator<
Item = SecretKey,
IntoIter = impl DoubleEndedIterator<Item = SecretKey> + Send,
> + Send,
rng: &mut CryptoRng,
) -> Result<(), Error> {
let mut store = self.store.write().await;
store.add_many_keys(group_pub_id, keys);
// Keeping the write lock here, this way we ensure that we can't corrupt the file
store
.encrypt(&self.master_key, rng, &self.keys_file_path)
.await
}
pub async fn get_latest_key(
&self,
group_pub_id: groups::PubId,
) -> Option<(KeyHash, SecretKey)> {
self.store.read().await.get_latest_key(group_pub_id)
}
pub async fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option<SecretKey> {
self.store.read().await.get_key(group_pub_id, hash)
}
pub async fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec<SecretKey> {
self.store.read().await.get_group_keys(group_pub_id)
}
}
impl fmt::Debug for KeyManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KeyManager")
.field("master_key", &"[REDACTED]")
.field("keys_file_path", &self.keys_file_path)
.field("store", &"[REDACTED]")
.finish()
}
}

View File

@@ -0,0 +1,55 @@
#![recursion_limit = "256"]
#![warn(
clippy::all,
clippy::pedantic,
clippy::correctness,
clippy::perf,
clippy::style,
clippy::suspicious,
clippy::complexity,
clippy::nursery,
clippy::unwrap_used,
unused_qualifications,
rust_2018_idioms,
trivial_casts,
trivial_numeric_casts,
unused_allocation,
clippy::unnecessary_cast,
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::dbg_macro,
clippy::deprecated_cfg_attr,
clippy::separated_literal_suffix,
deprecated
)]
#![forbid(deprecated_in_future)]
#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)]
mod error;
mod client;
mod key_manager;
mod p2p;
mod sync;
mod token_refresher;
pub use client::CloudServices;
pub use error::{Error, GetTokenError};
pub use key_manager::KeyManager;
pub use p2p::{
CloudP2P, JoinSyncGroupResponse, JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse,
};
pub use sync::{
declare_actors as declare_cloud_sync, SyncActors as CloudSyncActors,
SyncActorsState as CloudSyncActorsState,
};
// Re-exports
pub use quic_rpc::transport::quinn::QuinnConnection;
// Export URL for the auth server
pub const AUTH_SERVER_URL: &str = "https://auth.spacedrive.com";
// pub const AUTH_SERVER_URL: &str = "http://localhost:9420";

View File

@@ -0,0 +1,242 @@
use crate::{sync::ReceiveAndIngestNotifiers, CloudServices, Error};
use sd_cloud_schema::{
cloud_p2p::{authorize_new_device_in_sync_group, CloudP2PALPN, CloudP2PError},
devices::{self, Device},
libraries,
sync::groups::{self, GroupWithDevices},
SecretKey as IrohSecretKey,
};
use sd_crypto::{CryptoRng, SeedableRng};
use std::{sync::Arc, time::Duration};
use iroh_net::{
discovery::{
dns::DnsDiscovery, local_swarm_discovery::LocalSwarmDiscovery, pkarr::dht::DhtDiscovery,
ConcurrentDiscovery, Discovery,
},
relay::{RelayMap, RelayMode, RelayUrl},
Endpoint, NodeId,
};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use tokio::{spawn, sync::oneshot, time::sleep};
use tracing::{debug, error, warn};
mod new_sync_messages_notifier;
mod runner;
use runner::Runner;
#[derive(Debug)]
pub struct JoinedLibraryCreateArgs {
pub pub_id: libraries::PubId,
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)]
#[serde(transparent)]
#[repr(transparent)]
#[specta(rename = "CloudP2PTicket")]
pub struct Ticket(u64);
#[derive(Debug, Serialize, specta::Type)]
#[serde(tag = "kind", content = "data")]
#[specta(rename = "CloudP2PNotifyUser")]
pub enum NotifyUser {
ReceivedJoinSyncGroupRequest {
ticket: Ticket,
asking_device: Device,
sync_group: GroupWithDevices,
},
ReceivedJoinSyncGroupResponse {
response: JoinSyncGroupResponse,
sync_group: GroupWithDevices,
},
SendingJoinSyncGroupResponseError {
error: JoinSyncGroupError,
sync_group: GroupWithDevices,
},
TimedOutJoinRequest {
device: Device,
succeeded: bool,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)]
pub enum JoinSyncGroupError {
Communication,
InternalServer,
Auth,
}
#[derive(Debug, Serialize, specta::Type)]
pub enum JoinSyncGroupResponse {
Accepted { authorizor_device: Device },
Failed(CloudP2PError),
CriticalError,
}
#[derive(Debug, Clone, Serialize, Deserialize, specta::Type)]
pub struct BasicLibraryCreationArgs {
pub id: libraries::PubId,
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Deserialize, specta::Type)]
#[serde(tag = "kind", content = "data")]
#[specta(rename = "CloudP2PUserResponse")]
pub enum UserResponse {
AcceptDeviceInSyncGroup {
ticket: Ticket,
accepted: Option<BasicLibraryCreationArgs>,
},
}
#[derive(Debug, Clone)]
pub struct CloudP2P {
msgs_tx: flume::Sender<runner::Message>,
}
impl CloudP2P {
pub async fn new(
current_device_pub_id: devices::PubId,
cloud_services: &CloudServices,
mut rng: CryptoRng,
iroh_secret_key: IrohSecretKey,
dns_origin_domain: String,
dns_pkarr_url: Url,
relay_url: RelayUrl,
) -> Result<Self, Error> {
let dht_discovery = DhtDiscovery::builder()
.secret_key(iroh_secret_key.clone())
.pkarr_relay(dns_pkarr_url)
.build()
.map_err(Error::DhtDiscoveryInit)?;
let endpoint = Endpoint::builder()
.alpns(vec![CloudP2PALPN::LATEST.to_vec()])
.discovery(Box::new(ConcurrentDiscovery::from_services(vec![
Box::new(DnsDiscovery::new(dns_origin_domain)),
Box::new(
LocalSwarmDiscovery::new(iroh_secret_key.public())
.map_err(Error::LocalSwarmDiscoveryInit)?,
),
Box::new(dht_discovery.clone()),
])))
.secret_key(iroh_secret_key)
.relay_mode(RelayMode::Custom(RelayMap::from_url(relay_url)))
.bind()
.await
.map_err(Error::CreateCloudP2PEndpoint)?;
spawn({
let endpoint = endpoint.clone();
async move {
loop {
let Ok(node_addr) = endpoint.node_addr().await.map_err(|e| {
warn!(?e, "Failed to get direct addresses to force publish on DHT");
}) else {
sleep(Duration::from_secs(5)).await;
continue;
};
debug!("Force publishing peer on DHT");
return dht_discovery.publish(&node_addr.info);
}
}
});
let (msgs_tx, msgs_rx) = flume::bounded(16);
spawn({
let runner = Runner::new(
current_device_pub_id,
cloud_services,
msgs_tx.clone(),
endpoint,
)
.await?;
let user_response_rx = cloud_services.user_response_rx.clone();
async move {
// All cloned runners share a single state with internal mutability
while let Err(e) = spawn(runner.clone().run(
msgs_rx.clone(),
user_response_rx.clone(),
CryptoRng::from_seed(rng.generate_fixed()),
))
.await
{
if e.is_panic() {
error!("Cloud P2P runner panicked");
} else {
break;
}
}
}
});
Ok(Self { msgs_tx })
}
/// Requests the device with the given connection ID asking for permission to the current device
/// to join the sync group
///
/// # Panics
/// Will panic if the actor channel is closed, which should never happen
pub async fn request_join_sync_group(
&self,
devices_in_group: Vec<(devices::PubId, NodeId)>,
req: authorize_new_device_in_sync_group::Request,
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
) {
self.msgs_tx
.send_async(runner::Message::Request(runner::Request::JoinSyncGroup {
req,
devices_in_group,
tx,
}))
.await
.expect("Channel closed");
}
/// Register a notifier for the desired sync group, which will notify the receiver actor when
/// new sync messages arrive through cloud p2p notification requests.
///
/// # Panics
/// Will panic if the actor channel is closed, which should never happen
pub async fn register_sync_messages_receiver_notifier(
&self,
sync_group_pub_id: groups::PubId,
notifier: Arc<ReceiveAndIngestNotifiers>,
) {
self.msgs_tx
.send_async(runner::Message::RegisterSyncMessageNotifier((
sync_group_pub_id,
notifier,
)))
.await
.expect("Channel closed");
}
/// Emit a notification that new sync messages were sent to cloud, so other devices should pull
/// them as soon as possible.
///
/// # Panics
/// Will panic if the actor channel is closed, which should never happen
pub async fn notify_new_sync_messages(&self, group_pub_id: groups::PubId) {
self.msgs_tx
.send_async(runner::Message::NotifyPeersSyncMessages(group_pub_id))
.await
.expect("Channel closed");
}
}
impl Drop for CloudP2P {
fn drop(&mut self) {
self.msgs_tx.send(runner::Message::Stop).ok();
}
}

View File

@@ -0,0 +1,156 @@
use crate::{token_refresher::TokenRefresher, Error};
use sd_cloud_schema::{
cloud_p2p::{Client, CloudP2PALPN, Service},
devices,
sync::groups,
};
use std::time::Duration;
use futures_concurrency::future::Join;
use iroh_net::{Endpoint, NodeId};
use quic_rpc::{transport::quinn::QuinnConnection, RpcClient};
use tokio::time::Instant;
use tracing::{debug, error, instrument, warn};
use super::runner::Message;
const CACHED_MAX_DURATION: Duration = Duration::from_secs(60 * 5);
pub async fn dispatch_notifier(
group_pub_id: groups::PubId,
device_pub_id: devices::PubId,
devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>,
msgs_tx: flume::Sender<Message>,
cloud_services: sd_cloud_schema::Client<
QuinnConnection<sd_cloud_schema::Service>,
sd_cloud_schema::Service,
>,
token_refresher: TokenRefresher,
endpoint: Endpoint,
) {
match notify_peers(
group_pub_id,
device_pub_id,
devices,
cloud_services,
token_refresher,
endpoint,
)
.await
{
Ok((true, devices)) => {
if msgs_tx
.send_async(Message::UpdateCachedDevices((group_pub_id, devices)))
.await
.is_err()
{
warn!("Failed to send update cached devices message to update cached devices");
}
}
Ok((false, _)) => {}
Err(e) => {
error!(?e, "Failed to notify peers");
}
}
}
#[instrument(skip(cloud_services, token_refresher, endpoint))]
async fn notify_peers(
group_pub_id: groups::PubId,
device_pub_id: devices::PubId,
devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>,
cloud_services: sd_cloud_schema::Client<
QuinnConnection<sd_cloud_schema::Service>,
sd_cloud_schema::Service,
>,
token_refresher: TokenRefresher,
endpoint: Endpoint,
) -> Result<(bool, Vec<(devices::PubId, NodeId)>), Error> {
let (devices, update_cache) = match devices {
Some((when, devices)) if when.elapsed() < CACHED_MAX_DURATION => (devices, false),
_ => {
debug!("Fetching devices connection ids for group");
let groups::get::Response(groups::get::ResponseKind::DevicesConnectionIds(devices)) =
cloud_services
.sync()
.groups()
.get(groups::get::Request {
access_token: token_refresher.get_access_token().await?,
pub_id: group_pub_id,
kind: groups::get::RequestKind::DevicesConnectionIds,
})
.await??
else {
unreachable!("Only DevicesConnectionIds response is expected, as we requested it");
};
(devices, true)
}
};
send_notifications(group_pub_id, device_pub_id, &devices, &endpoint).await;
Ok((update_cache, devices))
}
async fn send_notifications(
group_pub_id: groups::PubId,
device_pub_id: devices::PubId,
devices: &[(devices::PubId, NodeId)],
endpoint: &Endpoint,
) {
devices
.iter()
.filter(|(peer_device_pub_id, _)| *peer_device_pub_id != device_pub_id)
.map(|(peer_device_pub_id, connection_id)| async move {
if let Err(e) =
connect_and_send_notification(group_pub_id, device_pub_id, connection_id, endpoint)
.await
{
// Using just a debug log here because we don't want to spam the logs with
// every single notification failure, as this is more a nice to have feature than a
// critical one
debug!(?e, %peer_device_pub_id, "Failed to send new sync messages notification to peer");
} else {
debug!(%peer_device_pub_id, "Sent new sync messages notification to peer");
}
})
.collect::<Vec<_>>()
.join()
.await;
}
async fn connect_and_send_notification(
group_pub_id: groups::PubId,
device_pub_id: devices::PubId,
connection_id: &NodeId,
endpoint: &Endpoint,
) -> Result<(), Error> {
let client = Client::new(RpcClient::new(QuinnConnection::<Service>::from_connection(
endpoint
.connect(*connection_id, CloudP2PALPN::LATEST)
.await
.map_err(Error::ConnectToCloudP2PNode)?,
)));
if let Err(e) = client
.notify_new_sync_messages(
sd_cloud_schema::cloud_p2p::notify_new_sync_messages::Request {
sync_group_pub_id: group_pub_id,
device_pub_id,
},
)
.await?
{
warn!(
?e,
"This route shouldn't return an error, it's just a notification",
);
};
Ok(())
}

View File

@@ -0,0 +1,651 @@
use crate::{
p2p::JoinSyncGroupError, sync::ReceiveAndIngestNotifiers, token_refresher::TokenRefresher,
CloudServices, Error, KeyManager,
};
use sd_cloud_schema::{
cloud_p2p::{
self, authorize_new_device_in_sync_group, notify_new_sync_messages, Client, CloudP2PALPN,
CloudP2PError, Service,
},
devices::{self, Device},
sync::groups,
};
use sd_crypto::{CryptoRng, SeedableRng};
use std::{
collections::HashMap,
pin::pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use dashmap::DashMap;
use flume::SendError;
use futures::StreamExt;
use futures_concurrency::stream::Merge;
use iroh_net::{Endpoint, NodeId};
use quic_rpc::{
server::{Accepting, RpcChannel, RpcServerError},
transport::quinn::{QuinnConnection, QuinnServerEndpoint},
RpcClient, RpcServer,
};
use tokio::{
spawn,
sync::{oneshot, Mutex},
task::JoinHandle,
time::{interval, Instant, MissedTickBehavior},
};
use tokio_stream::wrappers::IntervalStream;
use tracing::{debug, error, warn};
use super::{
new_sync_messages_notifier::dispatch_notifier, BasicLibraryCreationArgs, JoinSyncGroupResponse,
JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse,
};
const TEN_SECONDS: Duration = Duration::from_secs(10);
const FIVE_MINUTES: Duration = Duration::from_secs(60 * 5);
#[allow(clippy::large_enum_variant)] // Ignoring because the enum Stop variant will only happen a single time ever
pub enum Message {
Request(Request),
RegisterSyncMessageNotifier((groups::PubId, Arc<ReceiveAndIngestNotifiers>)),
NotifyPeersSyncMessages(groups::PubId),
UpdateCachedDevices((groups::PubId, Vec<(devices::PubId, NodeId)>)),
Stop,
}
pub enum Request {
JoinSyncGroup {
req: authorize_new_device_in_sync_group::Request,
devices_in_group: Vec<(devices::PubId, NodeId)>,
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
},
}
/// We use internal mutability here, but don't worry because there will always be a single
/// [`Runner`] running at a time, so the lock is never contended
pub struct Runner {
current_device_pub_id: devices::PubId,
token_refresher: TokenRefresher,
cloud_services: sd_cloud_schema::Client<
QuinnConnection<sd_cloud_schema::Service>,
sd_cloud_schema::Service,
>,
msgs_tx: flume::Sender<Message>,
endpoint: Endpoint,
key_manager: Arc<KeyManager>,
ticketer: Arc<AtomicU64>,
notify_user_tx: flume::Sender<NotifyUser>,
sync_messages_receiver_notifiers_map:
Arc<DashMap<groups::PubId, Arc<ReceiveAndIngestNotifiers>>>,
pending_sync_group_join_requests: Arc<Mutex<HashMap<Ticket, PendingSyncGroupJoin>>>,
cached_devices_per_group: HashMap<groups::PubId, (Instant, Vec<(devices::PubId, NodeId)>)>,
timeout_checker_buffer: Vec<(Ticket, PendingSyncGroupJoin)>,
}
impl Clone for Runner {
fn clone(&self) -> Self {
Self {
current_device_pub_id: self.current_device_pub_id,
token_refresher: self.token_refresher.clone(),
cloud_services: self.cloud_services.clone(),
msgs_tx: self.msgs_tx.clone(),
endpoint: self.endpoint.clone(),
key_manager: Arc::clone(&self.key_manager),
ticketer: Arc::clone(&self.ticketer),
notify_user_tx: self.notify_user_tx.clone(),
sync_messages_receiver_notifiers_map: Arc::clone(
&self.sync_messages_receiver_notifiers_map,
),
pending_sync_group_join_requests: Arc::clone(&self.pending_sync_group_join_requests),
// Just cache the devices and their node_ids per group
cached_devices_per_group: HashMap::new(),
// This one is a temporary buffer only used for timeout checker
timeout_checker_buffer: vec![],
}
}
}
struct PendingSyncGroupJoin {
channel: RpcChannel<Service, QuinnServerEndpoint<Service>>,
request: authorize_new_device_in_sync_group::Request,
this_device: Device,
since: Instant,
}
impl Runner {
pub async fn new(
current_device_pub_id: devices::PubId,
cloud_services: &CloudServices,
msgs_tx: flume::Sender<Message>,
endpoint: Endpoint,
) -> Result<Self, Error> {
Ok(Self {
current_device_pub_id,
token_refresher: cloud_services.token_refresher.clone(),
cloud_services: cloud_services.client().await?,
msgs_tx,
endpoint,
key_manager: cloud_services.key_manager().await?,
ticketer: Arc::default(),
notify_user_tx: cloud_services.notify_user_tx.clone(),
sync_messages_receiver_notifiers_map: Arc::default(),
pending_sync_group_join_requests: Arc::default(),
cached_devices_per_group: HashMap::new(),
timeout_checker_buffer: vec![],
})
}
pub async fn run(
mut self,
msgs_rx: flume::Receiver<Message>,
user_response_rx: flume::Receiver<UserResponse>,
mut rng: CryptoRng,
) {
// Ignoring because this is only used internally and I think that boxing will be more expensive than wasting
// some extra bytes for smaller variants
#[allow(clippy::large_enum_variant)]
enum StreamMessage {
AcceptResult(
Result<
Accepting<Service, QuinnServerEndpoint<Service>>,
RpcServerError<QuinnServerEndpoint<Service>>,
>,
),
Message(Message),
UserResponse(UserResponse),
Tick,
}
let mut ticker = interval(TEN_SECONDS);
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
// FIXME(@fogodev): Update this function to use iroh-net transport instead of quinn
// when it's implemented
let (server, server_handle) = setup_server_endpoint(self.endpoint.clone());
let mut msg_stream = pin!((
async_stream::stream! {
loop {
yield StreamMessage::AcceptResult(server.accept().await);
}
},
msgs_rx.stream().map(StreamMessage::Message),
user_response_rx.stream().map(StreamMessage::UserResponse),
IntervalStream::new(ticker).map(|_| StreamMessage::Tick),
)
.merge());
while let Some(msg) = msg_stream.next().await {
match msg {
StreamMessage::AcceptResult(Ok(accepting)) => {
let Ok((request, channel)) = accepting.read_first().await.map_err(|e| {
error!(?e, "Failed to read first request from a new connection;");
}) else {
continue;
};
self.handle_request(request, channel).await;
}
StreamMessage::AcceptResult(Err(e)) => {
// TODO(@fogodev): Maybe report this error to the user on a toast?
error!(?e, "Error accepting connection;");
}
StreamMessage::Message(Message::Request(Request::JoinSyncGroup {
req,
devices_in_group,
tx,
})) => self.dispatch_join_requests(req, devices_in_group, &mut rng, tx),
StreamMessage::Message(Message::RegisterSyncMessageNotifier((
group_pub_id,
notifier,
))) => {
self.sync_messages_receiver_notifiers_map
.insert(group_pub_id, notifier);
}
StreamMessage::Message(Message::NotifyPeersSyncMessages(group_pub_id)) => {
spawn(dispatch_notifier(
group_pub_id,
self.current_device_pub_id,
self.cached_devices_per_group.get(&group_pub_id).cloned(),
self.msgs_tx.clone(),
self.cloud_services.clone(),
self.token_refresher.clone(),
self.endpoint.clone(),
));
}
StreamMessage::Message(Message::UpdateCachedDevices((
group_pub_id,
devices_connections_ids,
))) => {
self.cached_devices_per_group
.insert(group_pub_id, (Instant::now(), devices_connections_ids));
}
StreamMessage::UserResponse(UserResponse::AcceptDeviceInSyncGroup {
ticket,
accepted,
}) => {
self.handle_join_response(ticket, accepted).await;
}
StreamMessage::Tick => self.tick().await,
StreamMessage::Message(Message::Stop) => {
server_handle.abort();
break;
}
}
}
}
fn dispatch_join_requests(
&self,
req: authorize_new_device_in_sync_group::Request,
devices_in_group: Vec<(devices::PubId, NodeId)>,
rng: &mut CryptoRng,
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
) {
async fn inner(
key_manager: Arc<KeyManager>,
endpoint: Endpoint,
mut rng: CryptoRng,
req: authorize_new_device_in_sync_group::Request,
devices_in_group: Vec<(devices::PubId, NodeId)>,
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
) -> Result<JoinSyncGroupResponse, Error> {
let group_pub_id = req.sync_group.pub_id;
loop {
let client =
match connect_to_first_available_client(&endpoint, &devices_in_group).await {
Ok(client) => client,
Err(e) => {
return Ok(JoinSyncGroupResponse::Failed(e));
}
};
match client
.authorize_new_device_in_sync_group(req.clone())
.await?
{
Ok(authorize_new_device_in_sync_group::Response {
authorizor_device,
keys,
library_pub_id,
library_name,
library_description,
}) => {
debug!(
device_pub_id = %authorizor_device.pub_id,
%group_pub_id,
keys_count = keys.len(),
%library_pub_id,
library_name,
"Received join sync group response"
);
key_manager
.add_many_keys(
group_pub_id,
keys.into_iter().map(|key| {
key.as_slice()
.try_into()
.expect("critical error, backend has invalid secret keys")
}),
&mut rng,
)
.await?;
if tx
.send(JoinedLibraryCreateArgs {
pub_id: library_pub_id,
name: library_name,
description: library_description,
})
.is_err()
{
error!("Failed to handle library creation locally from received library data");
return Ok(JoinSyncGroupResponse::CriticalError);
}
return Ok(JoinSyncGroupResponse::Accepted { authorizor_device });
}
// In case of timeout, we will try again
Err(CloudP2PError::TimedOut) => continue,
Err(e) => return Ok(JoinSyncGroupResponse::Failed(e)),
}
}
}
spawn({
let endpoint = self.endpoint.clone();
let notify_user_tx = self.notify_user_tx.clone();
let key_manager = Arc::clone(&self.key_manager);
let rng = CryptoRng::from_seed(rng.generate_fixed());
async move {
let sync_group = req.sync_group.clone();
if let Err(SendError(response)) = notify_user_tx
.send_async(NotifyUser::ReceivedJoinSyncGroupResponse {
response: inner(key_manager, endpoint, rng, req, devices_in_group, tx)
.await
.unwrap_or_else(|e| {
error!(
?e,
"Failed to issue authorize new device in sync group request;"
);
JoinSyncGroupResponse::CriticalError
}),
sync_group,
})
.await
{
error!(?response, "Failed to send response to user;");
}
}
});
}
async fn handle_request(
&self,
request: cloud_p2p::Request,
channel: RpcChannel<Service, QuinnServerEndpoint<Service>>,
) {
match request {
cloud_p2p::Request::AuthorizeNewDeviceInSyncGroup(
authorize_new_device_in_sync_group::Request {
sync_group,
asking_device,
},
) => {
let ticket = Ticket(self.ticketer.fetch_add(1, Ordering::Relaxed));
let this_device = sync_group
.devices
.iter()
.find(|device| device.pub_id == self.current_device_pub_id)
.expect(
"current device must be in the sync group, otherwise we wouldn't be here",
)
.clone();
self.notify_user_tx
.send_async(NotifyUser::ReceivedJoinSyncGroupRequest {
ticket,
asking_device: asking_device.clone(),
sync_group: sync_group.clone(),
})
.await
.expect("notify_user_tx must never closes!");
self.pending_sync_group_join_requests.lock().await.insert(
ticket,
PendingSyncGroupJoin {
channel,
request: authorize_new_device_in_sync_group::Request {
sync_group,
asking_device,
},
this_device,
since: Instant::now(),
},
);
}
cloud_p2p::Request::NotifyNewSyncMessages(req) => {
if let Err(e) = channel
.rpc(
req,
(),
|(),
notify_new_sync_messages::Request {
sync_group_pub_id,
device_pub_id,
}| async move {
debug!(%sync_group_pub_id, %device_pub_id, "Received new sync messages notification");
if let Some(notifier) = self
.sync_messages_receiver_notifiers_map
.get(&sync_group_pub_id)
{
notifier.notify_receiver();
} else {
warn!("Received new sync messages notification for unknown sync group");
}
Ok(notify_new_sync_messages::Response)
},
)
.await
{
error!(
?e,
"Failed to reply to new sync messages notification request"
);
}
}
}
}
async fn handle_join_response(
&self,
ticket: Ticket,
accepted: Option<BasicLibraryCreationArgs>,
) {
let Some(PendingSyncGroupJoin {
channel,
request,
this_device,
..
}) = self
.pending_sync_group_join_requests
.lock()
.await
.remove(&ticket)
else {
warn!("Received join response for unknown ticket; We probably timed out this request already");
return;
};
let sync_group = request.sync_group.clone();
let asking_device_pub_id = request.asking_device.pub_id;
let was_accepted = accepted.is_some();
let response = if let Some(BasicLibraryCreationArgs {
id: library_pub_id,
name: library_name,
description: library_description,
}) = accepted
{
Ok(authorize_new_device_in_sync_group::Response {
authorizor_device: this_device,
keys: self
.key_manager
.get_group_keys(request.sync_group.pub_id)
.await
.into_iter()
.map(Into::into)
.collect(),
library_pub_id,
library_name,
library_description,
})
} else {
Err(CloudP2PError::Rejected)
};
if let Err(e) = channel
.rpc(request, (), |(), _req| async move { response })
.await
{
error!(?e, "Failed to send response to user;");
self.notify_join_error(sync_group, JoinSyncGroupError::Communication)
.await;
return;
}
if was_accepted {
let Ok(access_token) = self
.token_refresher
.get_access_token()
.await
.map_err(|e| error!(?e, "Failed to get access token;"))
else {
self.notify_join_error(sync_group, JoinSyncGroupError::Auth)
.await;
return;
};
match self
.cloud_services
.sync()
.groups()
.reply_join_request(groups::reply_join_request::Request {
access_token,
group_pub_id: sync_group.pub_id,
authorized_device_pub_id: asking_device_pub_id,
authorizor_device_pub_id: self.current_device_pub_id,
})
.await
{
Ok(Ok(groups::reply_join_request::Response)) => {
// Everything is Awesome!
}
Ok(Err(e)) => {
error!(?e, "Failed to reply to join request");
self.notify_join_error(sync_group, JoinSyncGroupError::InternalServer)
.await;
}
Err(e) => {
error!(?e, "Failed to send reply to join request");
self.notify_join_error(sync_group, JoinSyncGroupError::Communication)
.await;
}
}
}
}
async fn notify_join_error(
&self,
sync_group: groups::GroupWithDevices,
error: JoinSyncGroupError,
) {
self.notify_user_tx
.send_async(NotifyUser::SendingJoinSyncGroupResponseError { error, sync_group })
.await
.expect("notify_user_tx must never closes!");
}
async fn tick(&mut self) {
self.timeout_checker_buffer.clear();
let mut pending_sync_group_join_requests =
self.pending_sync_group_join_requests.lock().await;
for (ticket, pending_sync_group_join) in pending_sync_group_join_requests.drain() {
if pending_sync_group_join.since.elapsed() > FIVE_MINUTES {
let PendingSyncGroupJoin {
channel, request, ..
} = pending_sync_group_join;
let asking_device = request.asking_device.clone();
let notify_message = match channel
.rpc(request, (), |(), _req| async move {
Err(CloudP2PError::TimedOut)
})
.await
{
Ok(()) => NotifyUser::TimedOutJoinRequest {
device: asking_device,
succeeded: true,
},
Err(e) => {
error!(?e, "Failed to send timed out response to user;");
NotifyUser::TimedOutJoinRequest {
device: asking_device,
succeeded: false,
}
}
};
self.notify_user_tx
.send_async(notify_message)
.await
.expect("notify_user_tx must never closes!");
} else {
self.timeout_checker_buffer
.push((ticket, pending_sync_group_join));
}
}
pending_sync_group_join_requests.extend(self.timeout_checker_buffer.drain(..));
}
}
async fn connect_to_first_available_client(
endpoint: &Endpoint,
devices_in_group: &[(devices::PubId, NodeId)],
) -> Result<Client<QuinnConnection<Service>, Service>, CloudP2PError> {
for (device_pub_id, device_connection_id) in devices_in_group {
if let Ok(connection) = endpoint
.connect(*device_connection_id, CloudP2PALPN::LATEST)
.await
.map_err(
|e| error!(?e, %device_pub_id, "Failed to connect to authorizor device candidate"),
) {
debug!(%device_pub_id, "Connected to authorizor device candidate");
return Ok(Client::new(RpcClient::new(
QuinnConnection::<Service>::from_connection(connection),
)));
}
}
Err(CloudP2PError::UnableToConnect)
}
fn setup_server_endpoint(
endpoint: Endpoint,
) -> (
RpcServer<Service, QuinnServerEndpoint<Service>>,
JoinHandle<()>,
) {
let local_addr = {
let (ipv4_addr, maybe_ipv6_addr) = endpoint.bound_sockets();
// Trying to give preference to IPv6 addresses because it's 2024
maybe_ipv6_addr.unwrap_or(ipv4_addr)
};
let (connections_tx, connections_rx) = flume::bounded(16);
(
RpcServer::new(QuinnServerEndpoint::<Service>::handle_connections(
connections_rx,
local_addr,
)),
spawn(async move {
while let Some(connecting) = endpoint.accept().await {
if let Ok(connection) = connecting.await.map_err(|e| {
warn!(?e, "Cloud P2P failed to accept connection");
}) {
if connections_tx.send_async(connection).await.is_err() {
warn!("Connection receiver dropped");
break;
}
}
}
}),
)
}

View File

@@ -0,0 +1,121 @@
use crate::Error;
use sd_core_sync::SyncManager;
use sd_actors::{Actor, Stopper};
use std::{
future::IntoFuture,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use futures::FutureExt;
use futures_concurrency::future::Race;
use tokio::{
sync::Notify,
time::{sleep, Instant},
};
use tracing::{debug, error};
use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE};
/// Responsible for taking sync operations received from the cloud,
/// and applying them to the local database via the sync system's ingest actor.
pub struct Ingester {
sync: SyncManager,
notifiers: Arc<ReceiveAndIngestNotifiers>,
active: Arc<AtomicBool>,
active_notify: Arc<Notify>,
}
impl Actor<SyncActors> for Ingester {
const IDENTIFIER: SyncActors = SyncActors::Ingester;
async fn run(&mut self, stop: Stopper) {
enum Race {
Notified,
Stopped,
}
loop {
self.active.store(true, Ordering::Relaxed);
self.active_notify.notify_waiters();
if let Err(e) = self.run_loop_iteration().await {
error!(?e, "Error during cloud sync ingester actor iteration");
sleep(ONE_MINUTE).await;
continue;
}
self.active.store(false, Ordering::Relaxed);
self.active_notify.notify_waiters();
if matches!(
(
self.notifiers
.wait_notification_to_ingest()
.map(|()| Race::Notified),
stop.into_future().map(|()| Race::Stopped),
)
.race()
.await,
Race::Stopped
) {
break;
}
}
}
}
impl Ingester {
pub const fn new(
sync: SyncManager,
notifiers: Arc<ReceiveAndIngestNotifiers>,
active: Arc<AtomicBool>,
active_notify: Arc<Notify>,
) -> Self {
Self {
sync,
notifiers,
active,
active_notify,
}
}
async fn run_loop_iteration(&self) -> Result<(), Error> {
let start = Instant::now();
let operations_to_ingest_count = self
.sync
.db
.cloud_crdt_operation()
.count(vec![])
.exec()
.await
.map_err(sd_core_sync::Error::from)?;
if operations_to_ingest_count == 0 {
debug!("Nothing to ingest, early finishing ingester loop");
return Ok(());
}
debug!(
operations_to_ingest_count,
"Starting sync messages cloud ingestion loop"
);
let ingested_count = self.sync.ingest_ops().await?;
debug!(
ingested_count,
"Finished sync messages cloud ingestion loop in {:?}",
start.elapsed()
);
Ok(())
}
}

View File

@@ -0,0 +1,136 @@
use crate::{CloudServices, Error};
use sd_core_sync::SyncManager;
use sd_actors::{ActorsCollection, IntoActor};
use sd_cloud_schema::sync::groups;
use sd_crypto::CryptoRng;
use std::{
fmt,
path::Path,
sync::{atomic::AtomicBool, Arc},
time::Duration,
};
use futures_concurrency::future::TryJoin;
use tokio::sync::Notify;
mod ingest;
mod receive;
mod send;
use ingest::Ingester;
use receive::Receiver;
use send::Sender;
const ONE_MINUTE: Duration = Duration::from_secs(60);
#[derive(Default)]
pub struct SyncActorsState {
pub send_active: Arc<AtomicBool>,
pub receive_active: Arc<AtomicBool>,
pub ingest_active: Arc<AtomicBool>,
pub state_change_notifier: Arc<Notify>,
receiver_and_ingester_notifiers: Arc<ReceiveAndIngestNotifiers>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, specta::Type)]
#[specta(rename = "CloudSyncActors")]
pub enum SyncActors {
Ingester,
Sender,
Receiver,
}
impl fmt::Display for SyncActors {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Ingester => write!(f, "Cloud Sync Ingester"),
Self::Sender => write!(f, "Cloud Sync Sender"),
Self::Receiver => write!(f, "Cloud Sync Receiver"),
}
}
}
#[derive(Debug, Default)]
pub struct ReceiveAndIngestNotifiers {
ingester: Notify,
receiver: Notify,
}
impl ReceiveAndIngestNotifiers {
pub fn notify_receiver(&self) {
self.receiver.notify_one();
}
async fn wait_notification_to_receive(&self) {
self.receiver.notified().await;
}
fn notify_ingester(&self) {
self.ingester.notify_one();
}
async fn wait_notification_to_ingest(&self) {
self.ingester.notified().await;
}
}
pub async fn declare_actors(
data_dir: Box<Path>,
cloud_services: Arc<CloudServices>,
actors: &ActorsCollection<SyncActors>,
actors_state: &SyncActorsState,
sync_group_pub_id: groups::PubId,
sync: SyncManager,
rng: CryptoRng,
) -> Result<Arc<ReceiveAndIngestNotifiers>, Error> {
let (sender, receiver) = (
Sender::new(
sync_group_pub_id,
sync.clone(),
Arc::clone(&cloud_services),
Arc::clone(&actors_state.send_active),
Arc::clone(&actors_state.state_change_notifier),
rng,
),
Receiver::new(
data_dir,
sync_group_pub_id,
cloud_services.clone(),
sync.clone(),
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
Arc::clone(&actors_state.receive_active),
Arc::clone(&actors_state.state_change_notifier),
),
)
.try_join()
.await?;
let ingester = Ingester::new(
sync,
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
Arc::clone(&actors_state.ingest_active),
Arc::clone(&actors_state.state_change_notifier),
);
actors
.declare_many_boxed([
sender.into_actor(),
receiver.into_actor(),
ingester.into_actor(),
])
.await;
cloud_services
.cloud_p2p()
.await?
.register_sync_messages_receiver_notifier(
sync_group_pub_id,
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
)
.await;
Ok(Arc::clone(&actors_state.receiver_and_ingester_notifiers))
}

View File

@@ -0,0 +1,356 @@
use crate::{CloudServices, Error, KeyManager};
use sd_cloud_schema::{
devices,
sync::{
groups,
messages::{pull, MessagesCollection},
},
Client, Service,
};
use sd_core_sync::{
cloud_crdt_op_db, CRDTOperation, CompressedCRDTOperationsPerModel, SyncManager,
};
use sd_actors::{Actor, Stopper};
use sd_crypto::{
cloud::{OneShotDecryption, SecretKey, StreamDecryption},
primitives::{EncryptedBlock, StreamNonce},
};
use sd_prisma::prisma::PrismaClient;
use std::{
collections::{hash_map::Entry, HashMap},
future::IntoFuture,
path::Path,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use chrono::{DateTime, Utc};
use futures::{FutureExt, StreamExt};
use futures_concurrency::future::{Race, TryJoin};
use quic_rpc::transport::quinn::QuinnConnection;
use serde::{Deserialize, Serialize};
use tokio::{fs, io, sync::Notify, time::sleep};
use tracing::{debug, error, instrument, warn};
use uuid::Uuid;
use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE};
const CLOUD_SYNC_DATA_KEEPER_DIRECTORY: &str = "cloud_sync_data_keeper";
/// Responsible for downloading sync operations from the cloud to be processed by the ingester
pub struct Receiver {
keeper: LastTimestampKeeper,
sync_group_pub_id: groups::PubId,
device_pub_id: devices::PubId,
cloud_services: Arc<CloudServices>,
cloud_client: Client<QuinnConnection<Service>>,
key_manager: Arc<KeyManager>,
sync: SyncManager,
notifiers: Arc<ReceiveAndIngestNotifiers>,
active: Arc<AtomicBool>,
active_notifier: Arc<Notify>,
}
impl Actor<SyncActors> for Receiver {
const IDENTIFIER: SyncActors = SyncActors::Receiver;
async fn run(&mut self, stop: Stopper) {
enum Race {
Continue,
Stop,
}
loop {
self.active.store(true, Ordering::Relaxed);
self.active_notifier.notify_waiters();
let res = self.run_loop_iteration().await;
self.active.store(false, Ordering::Relaxed);
if let Err(e) = res {
error!(?e, "Error during cloud sync receiver actor iteration");
sleep(ONE_MINUTE).await;
continue;
}
self.active_notifier.notify_waiters();
if matches!(
(
sleep(ONE_MINUTE).map(|()| Race::Continue),
self.notifiers
.wait_notification_to_receive()
.map(|()| Race::Continue),
stop.into_future().map(|()| Race::Stop),
)
.race()
.await,
Race::Stop
) {
break;
}
}
}
}
impl Receiver {
pub async fn new(
data_dir: impl AsRef<Path> + Send,
sync_group_pub_id: groups::PubId,
cloud_services: Arc<CloudServices>,
sync: SyncManager,
notifiers: Arc<ReceiveAndIngestNotifiers>,
active: Arc<AtomicBool>,
active_notify: Arc<Notify>,
) -> Result<Self, Error> {
let (keeper, cloud_client, key_manager) = (
LastTimestampKeeper::load(data_dir.as_ref(), sync_group_pub_id),
cloud_services.client(),
cloud_services.key_manager(),
)
.try_join()
.await?;
Ok(Self {
keeper,
sync_group_pub_id,
device_pub_id: devices::PubId(Uuid::from(&sync.device_pub_id)),
cloud_services,
cloud_client,
key_manager,
sync,
notifiers,
active,
active_notifier: active_notify,
})
}
async fn run_loop_iteration(&mut self) -> Result<(), Error> {
let mut responses_stream = self
.cloud_client
.sync()
.messages()
.pull(pull::Request {
access_token: self
.cloud_services
.token_refresher
.get_access_token()
.await?,
group_pub_id: self.sync_group_pub_id,
current_device_pub_id: self.device_pub_id,
start_time_per_device: self
.keeper
.timestamps
.iter()
.map(|(device_pub_id, timestamp)| (*device_pub_id, *timestamp))
.collect(),
})
.await?;
while let Some(new_messages_res) = responses_stream.next().await {
let pull::Response(new_messages) = new_messages_res??;
if new_messages.is_empty() {
break;
}
self.handle_new_messages(new_messages).await?;
}
debug!("Finished sync messages receiver actor iteration");
self.keeper.save().await
}
async fn handle_new_messages(
&mut self,
new_messages: Vec<MessagesCollection>,
) -> Result<(), Error> {
debug!(
new_messages_collections_count = new_messages.len(),
start_time = ?new_messages.first().map(|c| c.start_time),
end_time = ?new_messages.first().map(|c| c.end_time),
"Handling new sync messages collections",
);
for message in new_messages.into_iter().filter(|message| {
if message.original_device_pub_id == self.device_pub_id {
warn!("Received sync message from the current device, need to check backend, this is a bug!");
false
} else {
true
}
}) {
debug!(
new_messages_count = message.operations_count,
start_time = ?message.start_time,
end_time = ?message.end_time,
"Handling new sync messages",
);
let (device_pub_id, timestamp) = handle_single_message(
self.sync_group_pub_id,
message,
&self.key_manager,
&self.sync,
)
.await?;
match self.keeper.timestamps.entry(device_pub_id) {
Entry::Occupied(mut entry) => {
if entry.get() < &timestamp {
*entry.get_mut() = timestamp;
}
}
Entry::Vacant(entry) => {
entry.insert(timestamp);
}
}
// To ingest after each sync message collection is received, we MUST download and
// store the messages SEQUENTIALLY, otherwise we might ingest messages out of order
// due to parallel downloads
self.notifiers.notify_ingester();
}
Ok(())
}
}
#[instrument(
skip_all,
fields(%sync_group_pub_id, %original_device_pub_id, operations_count, ?key_hash, %end_time),
)]
async fn handle_single_message(
sync_group_pub_id: groups::PubId,
MessagesCollection {
original_device_pub_id,
end_time,
operations_count,
key_hash,
encrypted_messages,
..
}: MessagesCollection,
key_manager: &KeyManager,
sync: &SyncManager,
) -> Result<(devices::PubId, DateTime<Utc>), Error> {
// FIXME(@fogodev): If we don't have the key hash, we need to fetch it from another device in the group if possible
let Some(secret_key) = key_manager.get_key(sync_group_pub_id, &key_hash).await else {
return Err(Error::MissingKeyHash);
};
debug!(
size = encrypted_messages.len(),
"Received encrypted sync messages collection"
);
let crdt_ops = decrypt_messages(encrypted_messages, secret_key, original_device_pub_id).await?;
assert_eq!(
crdt_ops.len(),
operations_count as usize,
"Sync messages count mismatch"
);
write_cloud_ops_to_db(crdt_ops, &sync.db).await?;
Ok((original_device_pub_id, end_time))
}
#[instrument(skip(encrypted_messages, secret_key), fields(messages_size = %encrypted_messages.len()), err)]
async fn decrypt_messages(
encrypted_messages: Vec<u8>,
secret_key: SecretKey,
devices::PubId(device_pub_id): devices::PubId,
) -> Result<Vec<CRDTOperation>, Error> {
let plain_text = if encrypted_messages.len() <= EncryptedBlock::CIPHER_TEXT_SIZE {
OneShotDecryption::decrypt(&secret_key, encrypted_messages.as_slice().into())
.map_err(Error::Decrypt)?
} else {
let (nonce, cipher_text) = encrypted_messages.split_at(size_of::<StreamNonce>());
let mut plain_text = Vec::with_capacity(cipher_text.len());
StreamDecryption::decrypt(
&secret_key,
nonce.try_into().expect("we split the correct amount"),
cipher_text,
&mut plain_text,
)
.await
.map_err(Error::Decrypt)?;
plain_text
};
rmp_serde::from_slice::<CompressedCRDTOperationsPerModel>(&plain_text)
.map(|compressed_ops| compressed_ops.into_ops(device_pub_id))
.map_err(Error::DeserializationFailureToPullSyncMessages)
}
#[instrument(skip_all, err)]
pub async fn write_cloud_ops_to_db(
ops: Vec<CRDTOperation>,
db: &PrismaClient,
) -> Result<(), sd_core_sync::Error> {
db._batch(
ops.into_iter()
.map(|op| cloud_crdt_op_db(&op).map(|op| op.to_query(db)))
.collect::<Result<Vec<_>, _>>()?,
)
.await?;
Ok(())
}
#[derive(Serialize, Deserialize, Debug)]
struct LastTimestampKeeper {
timestamps: HashMap<devices::PubId, DateTime<Utc>>,
file_path: Box<Path>,
}
impl LastTimestampKeeper {
async fn load(data_dir: &Path, sync_group_pub_id: groups::PubId) -> Result<Self, Error> {
let cloud_sync_data_directory = data_dir.join(CLOUD_SYNC_DATA_KEEPER_DIRECTORY);
fs::create_dir_all(&cloud_sync_data_directory)
.await
.map_err(Error::FailedToCreateTimestampKeepersDirectory)?;
let file_path = cloud_sync_data_directory
.join(format!("{sync_group_pub_id}.bin"))
.into_boxed_path();
match fs::read(&file_path).await {
Ok(bytes) => Ok(Self {
timestamps: rmp_serde::from_slice(&bytes)
.map_err(Error::LastTimestampKeeperDeserialization)?,
file_path,
}),
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(Self {
timestamps: HashMap::new(),
file_path,
}),
Err(e) => Err(Error::FailedToReadLastTimestampKeeper(e)),
}
}
async fn save(&self) -> Result<(), Error> {
fs::write(
&self.file_path,
&rmp_serde::to_vec_named(&self.timestamps)
.map_err(Error::LastTimestampKeeperSerialization)?,
)
.await
.map_err(Error::FailedToWriteLastTimestampKeeper)
}
}

View File

@@ -0,0 +1,337 @@
use crate::{CloudServices, Error, KeyManager};
use sd_core_sync::{CompressedCRDTOperationsPerModelPerDevice, SyncEvent, SyncManager, NTP64};
use sd_actors::{Actor, Stopper};
use sd_cloud_schema::{
devices,
error::{ClientSideError, NotFoundError},
sync::{groups, messages},
Client, Service,
};
use sd_crypto::{
cloud::{OneShotEncryption, SecretKey, StreamEncryption},
primitives::EncryptedBlock,
CryptoRng, SeedableRng,
};
use sd_utils::{datetime_to_timestamp, timestamp_to_datetime};
use std::{
future::IntoFuture,
pin::pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, UNIX_EPOCH},
};
use chrono::{DateTime, Utc};
use futures::{FutureExt, StreamExt, TryStreamExt};
use futures_concurrency::future::{Race, TryJoin};
use quic_rpc::transport::quinn::QuinnConnection;
use tokio::{
sync::{broadcast, Notify},
time::sleep,
};
use tracing::{debug, error};
use uuid::Uuid;
use super::{SyncActors, ONE_MINUTE};
const TEN_SECONDS: Duration = Duration::from_secs(10);
const MESSAGES_COLLECTION_SIZE: u32 = 10_000;
enum RaceNotifiedOrStopped {
Notified,
Stopped,
}
enum LoopStatus {
SentMessages,
Idle,
}
type LatestTimestamp = NTP64;
#[derive(Debug)]
pub struct Sender {
sync_group_pub_id: groups::PubId,
sync: SyncManager,
cloud_services: Arc<CloudServices>,
cloud_client: Client<QuinnConnection<Service>>,
key_manager: Arc<KeyManager>,
is_active: Arc<AtomicBool>,
state_notify: Arc<Notify>,
rng: CryptoRng,
maybe_latest_timestamp: Option<LatestTimestamp>,
}
impl Actor<SyncActors> for Sender {
const IDENTIFIER: SyncActors = SyncActors::Sender;
async fn run(&mut self, stop: Stopper) {
loop {
self.is_active.store(true, Ordering::Relaxed);
self.state_notify.notify_waiters();
let res = self.run_loop_iteration().await;
self.is_active.store(false, Ordering::Relaxed);
match res {
Ok(LoopStatus::SentMessages) => {
if let Ok(cloud_p2p) = self.cloud_services.cloud_p2p().await.map_err(|e| {
error!(?e, "Failed to get cloud p2p client on sender actor");
}) {
cloud_p2p
.notify_new_sync_messages(self.sync_group_pub_id)
.await;
}
}
Ok(LoopStatus::Idle) => {}
Err(e) => {
error!(?e, "Error during cloud sync sender actor iteration");
sleep(ONE_MINUTE).await;
continue;
}
}
self.state_notify.notify_waiters();
if matches!(
(
// recreate subscription each time so that existing messages are dropped
wait_notification(self.sync.subscribe()),
stop.into_future().map(|()| RaceNotifiedOrStopped::Stopped),
)
.race()
.await,
RaceNotifiedOrStopped::Stopped
) {
break;
}
sleep(TEN_SECONDS).await;
}
}
}
impl Sender {
pub async fn new(
sync_group_pub_id: groups::PubId,
sync: SyncManager,
cloud_services: Arc<CloudServices>,
is_active: Arc<AtomicBool>,
state_notify: Arc<Notify>,
rng: CryptoRng,
) -> Result<Self, Error> {
let (cloud_client, key_manager) = (cloud_services.client(), cloud_services.key_manager())
.try_join()
.await?;
Ok(Self {
sync_group_pub_id,
sync,
cloud_services,
cloud_client,
key_manager,
is_active,
state_notify,
rng,
maybe_latest_timestamp: None,
})
}
async fn run_loop_iteration(&mut self) -> Result<LoopStatus, Error> {
debug!("Starting cloud sender actor loop iteration");
let current_device_pub_id = devices::PubId(Uuid::from(&self.sync.device_pub_id));
let (key_hash, secret_key) = self
.key_manager
.get_latest_key(self.sync_group_pub_id)
.await
.ok_or(Error::MissingSyncGroupKey(self.sync_group_pub_id))?;
let current_latest_timestamp = self.get_latest_timestamp(current_device_pub_id).await?;
let mut crdt_ops_stream = pin!(self.sync.stream_device_ops(
&self.sync.device_pub_id,
MESSAGES_COLLECTION_SIZE,
current_latest_timestamp
));
let mut status = LoopStatus::Idle;
let mut new_latest_timestamp = current_latest_timestamp;
debug!(
chunk_size = MESSAGES_COLLECTION_SIZE,
"Trying to fetch chunk of sync messages from the database"
);
while let Some(ops_res) = crdt_ops_stream.next().await {
let ops = ops_res?;
let (Some(first), Some(last)) = (ops.first(), ops.last()) else {
break;
};
debug!("Got first and last sync messages");
#[allow(clippy::cast_possible_truncation)]
let operations_count = ops.len() as u32;
debug!(operations_count, "Got chunk of sync messages");
new_latest_timestamp = last.timestamp;
let start_time = timestamp_to_datetime(first.timestamp);
let end_time = timestamp_to_datetime(last.timestamp);
// Ignoring this device_pub_id here as we already know it
let (_device_pub_id, compressed_ops) =
CompressedCRDTOperationsPerModelPerDevice::new_single_device(ops);
let messages_bytes = rmp_serde::to_vec_named(&compressed_ops)
.map_err(Error::SerializationFailureToPushSyncMessages)?;
let encrypted_messages =
encrypt_messages(&secret_key, &mut self.rng, messages_bytes).await?;
let encrypted_messages_size = encrypted_messages.len();
debug!(
operations_count,
encrypted_messages_size, "Sending sync messages to cloud",
);
self.cloud_client
.sync()
.messages()
.push(messages::push::Request {
access_token: self
.cloud_services
.token_refresher
.get_access_token()
.await?,
group_pub_id: self.sync_group_pub_id,
device_pub_id: current_device_pub_id,
key_hash: key_hash.clone(),
operations_count,
time_range: (start_time, end_time),
encrypted_messages,
})
.await??;
debug!(
operations_count,
encrypted_messages_size, "Sent sync messages to cloud",
);
status = LoopStatus::SentMessages;
}
self.maybe_latest_timestamp = Some(new_latest_timestamp);
debug!("Finished cloud sender actor loop iteration");
Ok(status)
}
async fn get_latest_timestamp(
&self,
current_device_pub_id: devices::PubId,
) -> Result<LatestTimestamp, Error> {
if let Some(latest_timestamp) = &self.maybe_latest_timestamp {
Ok(*latest_timestamp)
} else {
let latest_time = match self
.cloud_client
.sync()
.messages()
.get_latest_time(messages::get_latest_time::Request {
access_token: self
.cloud_services
.token_refresher
.get_access_token()
.await?,
group_pub_id: self.sync_group_pub_id,
kind: messages::get_latest_time::Kind::ForCurrentDevice(current_device_pub_id),
})
.await?
{
Ok(messages::get_latest_time::Response {
latest_time,
latest_device_pub_id,
}) => {
assert_eq!(latest_device_pub_id, current_device_pub_id);
latest_time
}
Err(sd_cloud_schema::Error::Client(ClientSideError::NotFound(
NotFoundError::LatestSyncMessageTime,
))) => DateTime::<Utc>::from(UNIX_EPOCH),
Err(e) => return Err(e.into()),
};
Ok(datetime_to_timestamp(latest_time))
}
}
}
async fn encrypt_messages(
secret_key: &SecretKey,
rng: &mut CryptoRng,
messages_bytes: Vec<u8>,
) -> Result<Vec<u8>, Error> {
if messages_bytes.len() <= EncryptedBlock::PLAIN_TEXT_SIZE {
let mut nonce_and_cipher_text = Vec::with_capacity(OneShotEncryption::cipher_text_size(
secret_key,
messages_bytes.len(),
));
let EncryptedBlock { nonce, cipher_text } =
OneShotEncryption::encrypt(secret_key, messages_bytes.as_slice(), rng)
.map_err(Error::Encrypt)?;
nonce_and_cipher_text.extend_from_slice(nonce.as_slice());
nonce_and_cipher_text.extend(&cipher_text);
Ok(nonce_and_cipher_text)
} else {
let mut rng = CryptoRng::from_seed(rng.generate_fixed());
let mut nonce_and_cipher_text = Vec::with_capacity(StreamEncryption::cipher_text_size(
secret_key,
messages_bytes.len(),
));
let (nonce, cipher_stream) =
StreamEncryption::encrypt(secret_key, messages_bytes.as_slice(), &mut rng);
nonce_and_cipher_text.extend_from_slice(nonce.as_slice());
let mut cipher_stream = pin!(cipher_stream);
while let Some(ciphered_chunk) = cipher_stream.try_next().await.map_err(Error::Encrypt)? {
nonce_and_cipher_text.extend(ciphered_chunk);
}
Ok(nonce_and_cipher_text)
}
}
async fn wait_notification(mut rx: broadcast::Receiver<SyncEvent>) -> RaceNotifiedOrStopped {
// wait until Created message comes in
loop {
if matches!(rx.recv().await, Ok(SyncEvent::Created)) {
break;
};
}
RaceNotifiedOrStopped::Notified
}

View File

@@ -0,0 +1,468 @@
use sd_cloud_schema::auth::{AccessToken, RefreshToken};
use std::{pin::pin, time::Duration};
use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
use chrono::{DateTime, Utc};
use futures::StreamExt;
use futures_concurrency::stream::Merge;
use reqwest::Url;
use reqwest_middleware::{reqwest::header, ClientWithMiddleware};
use tokio::{
spawn,
sync::oneshot,
time::{interval, sleep, MissedTickBehavior},
};
use tokio_stream::wrappers::IntervalStream;
use tracing::{error, warn};
use super::{Error, GetTokenError};
const ONE_MINUTE: Duration = Duration::from_secs(60);
const TEN_SECONDS: Duration = Duration::from_secs(10);
enum Message {
Init(
(
AccessToken,
RefreshToken,
oneshot::Sender<Result<(), Error>>,
),
),
CheckInitialization(oneshot::Sender<Result<(), GetTokenError>>),
RequestToken(oneshot::Sender<Result<AccessToken, GetTokenError>>),
RefreshTime,
Tick,
}
#[derive(Debug, Clone)]
pub struct TokenRefresher {
tx: flume::Sender<Message>,
}
impl TokenRefresher {
pub(crate) fn new(http_client: ClientWithMiddleware, auth_server_url: Url) -> Self {
let (tx, rx) = flume::bounded(8);
spawn(async move {
let refresh_url = auth_server_url
.join("/api/auth/session/refresh")
.expect("hardcoded refresh url path");
while let Err(e) = spawn(Runner::run(
http_client.clone(),
refresh_url.clone(),
rx.clone(),
))
.await
{
if e.is_panic() {
if let Some(msg) = e.into_panic().downcast_ref::<&str>() {
error!(?msg, "Panic in request handler!");
} else {
error!("Some unknown panic in request handler!");
}
}
}
});
Self { tx }
}
pub async fn init(
&self,
access_token: AccessToken,
refresh_token: RefreshToken,
) -> Result<(), Error> {
let (tx, rx) = oneshot::channel();
self.tx
.send_async(Message::Init((access_token, refresh_token, tx)))
.await
.expect("Token refresher channel closed");
rx.await.expect("Token refresher channel closed")
}
pub async fn check_initialization(&self) -> Result<(), GetTokenError> {
let (tx, rx) = oneshot::channel();
self.tx
.send_async(Message::CheckInitialization(tx))
.await
.expect("Token refresher channel closed");
rx.await.expect("Token refresher channel closed")
}
pub async fn get_access_token(&self) -> Result<AccessToken, GetTokenError> {
let (tx, rx) = oneshot::channel();
self.tx
.send_async(Message::RequestToken(tx))
.await
.expect("Token refresher channel closed");
rx.await.expect("Token refresher channel closed")
}
}
struct Runner {
initialized: bool,
http_client: ClientWithMiddleware,
refresh_url: Url,
current_token: Option<AccessToken>,
current_refresh_token: Option<RefreshToken>,
token_decoding_buffer: Vec<u8>,
refresh_tx: flume::Sender<Message>,
}
impl Runner {
async fn run(
http_client: ClientWithMiddleware,
refresh_url: Url,
msgs_rx: flume::Receiver<Message>,
) {
let (refresh_tx, refresh_rx) = flume::bounded(1);
let mut ticker = interval(TEN_SECONDS);
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
let mut msg_stream = pin!((
msgs_rx.into_stream(),
refresh_rx.into_stream(),
IntervalStream::new(ticker).map(|_| Message::Tick)
)
.merge());
let mut runner = Self {
initialized: false,
http_client,
refresh_url,
current_token: None,
current_refresh_token: None,
token_decoding_buffer: Vec::new(),
refresh_tx,
};
while let Some(msg) = msg_stream.next().await {
match msg {
Message::Init((access_token, refresh_token, ack)) => {
if ack
.send(runner.init(access_token, refresh_token).await)
.is_err()
{
error!("Failed to send init token refresher response, receiver dropped;");
}
}
Message::CheckInitialization(ack) => runner.check_initialization(ack),
Message::RequestToken(ack) => runner.reply_token(ack),
Message::RefreshTime => {
if let Err(e) = runner.refresh().await {
error!(?e, "Failed to refresh token: {e}");
}
}
Message::Tick => runner.tick().await,
}
}
}
async fn init(
&mut self,
access_token: AccessToken,
refresh_token: RefreshToken,
) -> Result<(), Error> {
let access_token_duration =
Self::extract_access_token_duration(&mut self.token_decoding_buffer, &access_token)?;
self.initialized = true;
self.current_token = Some(access_token);
self.current_refresh_token = Some(refresh_token);
// If the token has an expiration smaller than a minute, we need to refresh it immediately.
if access_token_duration < ONE_MINUTE {
self.refresh_tx
.send_async(Message::RefreshTime)
.await
.expect("refresh channel never closes");
} else {
// This task will be mostly parked waiting a sleep
spawn(Self::schedule_refresh(
self.refresh_tx.clone(),
access_token_duration - ONE_MINUTE,
));
}
Ok(())
}
fn reply_token(&self, ack: oneshot::Sender<Result<AccessToken, GetTokenError>>) {
if ack
.send(self.current_token.clone().ok_or({
if self.initialized {
GetTokenError::FailedToRefresh
} else {
GetTokenError::RefresherNotInitialized
}
}))
.is_err()
{
warn!("Failed to send access token response, receiver dropped;");
}
}
async fn refresh(&mut self) -> Result<(), Error> {
let RefreshToken(refresh_token) = self
.current_refresh_token
.clone()
.expect("refresh token is set otherwise we wouldn't be here");
let response = self
.http_client
.post(self.refresh_url.clone())
.header("rid", "session")
.header(header::AUTHORIZATION, format!("Bearer {refresh_token}"))
.send()
.await
.map_err(Error::RefreshTokenRequest)?
.error_for_status()
.map_err(Error::AuthServerError)?;
if let (Some(access_token), Some(refresh_token)) = (
response.headers().get("st-access-token"),
response.headers().get("st-refresh-token"),
) {
// Only set values if we can parse both of them to strings
let (access_token, refresh_token) = (
Self::token_header_value_to_string(access_token)?,
Self::token_header_value_to_string(refresh_token)?,
);
self.current_token = Some(AccessToken(access_token));
self.current_refresh_token = Some(RefreshToken(refresh_token));
} else {
return Err(Error::MissingTokensOnRefreshResponse);
}
Ok(())
}
fn extract_access_token_duration(
token_decoding_buffer: &mut Vec<u8>,
AccessToken(token): &AccessToken,
) -> Result<Duration, Error> {
#[derive(serde::Deserialize)]
struct Token {
#[serde(with = "chrono::serde::ts_seconds")]
exp: DateTime<Utc>,
}
token_decoding_buffer.clear();
// The format of a JWT token is simple:
// "<base64-encoded header>.<base64-encoded claims>.<signature>"
BASE64_URL_SAFE_NO_PAD.decode_vec(
token.split('.').nth(1).ok_or(Error::MissingClaims)?,
token_decoding_buffer,
)?;
serde_json::from_slice::<Token>(token_decoding_buffer)?
.exp
.signed_duration_since(Utc::now())
.to_std()
.map_err(|_| Error::TokenExpired)
}
async fn schedule_refresh(refresh_tx: flume::Sender<Message>, wait_time: Duration) {
sleep(wait_time).await;
refresh_tx
.send_async(Message::RefreshTime)
.await
.expect("Refresh channel closed");
}
fn token_header_value_to_string(token: &header::HeaderValue) -> Result<String, Error> {
token.to_str().map(str::to_string).map_err(Into::into)
}
fn check_initialization(&self, ack: oneshot::Sender<Result<(), GetTokenError>>) {
if ack
.send(if self.initialized {
Ok(())
} else {
Err(GetTokenError::RefresherNotInitialized)
})
.is_err()
{
warn!("Failed to send access token response, receiver dropped;");
}
}
/// This method is a safeguard to make sure we try to keep refreshing tokens even if they
/// already expired, as the refresh token has a bigger expiration than the access token.
async fn tick(&mut self) {
if let Some(access_token) = &self.current_token {
if matches!(
Self::extract_access_token_duration(&mut self.token_decoding_buffer, access_token),
Err(Error::TokenExpired)
) {
if let Err(e) = self.refresh().await {
error!(?e, "Failed to refresh expired token on tick method;");
}
}
}
}
}
/// This test is here for documentation purposes only, they are not meant to be run.
/// They're just examples of how to sign-up/sign-in and refresh tokens
#[cfg(test)]
mod tests {
use reqwest::header;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde_json::json;
use crate::AUTH_SERVER_URL;
use super::*;
async fn get_tokens() -> (AccessToken, RefreshToken) {
let client = reqwest::Client::new();
let req_body = json!({
"formFields": [
{
"id": "email",
"value": "johndoe@gmail.com"
},
{
"id": "password",
"value": "testPass123"
}
]
});
let response = client
.post(format!("{AUTH_SERVER_URL}/api/auth/public/signup"))
.header("rid", "emailpassword")
.header("st-auth-mode", "header")
.json(&req_body)
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
if let (Some(access_token), Some(refresh_token)) = (
response.headers().get("st-access-token"),
response.headers().get("st-refresh-token"),
) {
(
AccessToken(access_token.to_str().unwrap().to_string()),
RefreshToken(refresh_token.to_str().unwrap().to_string()),
)
} else {
let response = client
.post(format!("{AUTH_SERVER_URL}/api/auth/public/signin"))
.header("rid", "emailpassword")
.header("st-auth-mode", "header")
.json(&req_body)
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
(
AccessToken(
response
.headers()
.get("st-access-token")
.unwrap()
.to_str()
.unwrap()
.to_string(),
),
RefreshToken(
response
.headers()
.get("st-refresh-token")
.unwrap()
.to_str()
.unwrap()
.to_string(),
),
)
}
}
#[ignore = "Documentation only"]
#[tokio::test]
async fn test_refresh_token() {
let (AccessToken(access_token), RefreshToken(refresh_token)) = get_tokens().await;
let client = reqwest::Client::new();
let response = client
.post(format!("{AUTH_SERVER_URL}/api/auth/session/refresh"))
.header("rid", "session")
.header(header::AUTHORIZATION, format!("Bearer {refresh_token}"))
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
assert_ne!(
response
.headers()
.get("st-access-token")
.unwrap()
.to_str()
.unwrap(),
access_token.as_str()
);
assert_ne!(
response
.headers()
.get("st-refresh-token")
.unwrap()
.to_str()
.unwrap(),
refresh_token.as_str()
);
}
#[ignore = "Needs an actual SuperTokens auth server running"]
#[tokio::test]
async fn test_refresher_runner() {
let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3));
let http_client = ClientBuilder::new(http_client_builder.build().unwrap())
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder().build_with_max_retries(3),
))
.build();
let (refresh_tx, _refresh_rx) = flume::bounded(1);
let mut runner = Runner {
initialized: false,
http_client,
refresh_url: Url::parse(&format!("{AUTH_SERVER_URL}/api/auth/session/refresh"))
.unwrap(),
current_token: None,
current_refresh_token: None,
token_decoding_buffer: Vec::new(),
refresh_tx,
};
let (access_token, refresh_token) = get_tokens().await;
runner.init(access_token, refresh_token).await.unwrap();
runner.refresh().await.unwrap();
}
}

View File

@@ -2,7 +2,7 @@ use sd_core_prisma_helpers::{
file_path_for_file_identifier, file_path_for_media_processor, file_path_for_object_validator,
file_path_to_full_path, file_path_to_handle_custom_uri, file_path_to_handle_p2p_serve_file,
file_path_to_isolate, file_path_to_isolate_with_id, file_path_to_isolate_with_pub_id,
file_path_walker, file_path_with_object,
file_path_walker, file_path_watcher_remove, file_path_with_object,
};
use sd_prisma::prisma::{file_path, location};
@@ -506,7 +506,8 @@ impl_from_db!(
file_path_to_isolate_with_pub_id,
file_path_walker,
file_path_to_isolate_with_id,
file_path_with_object
file_path_with_object,
file_path_watcher_remove
);
impl_from_db_without_location_id!(

View File

@@ -14,7 +14,11 @@ use crate::{
use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId};
use sd_prisma::prisma::{file_path, location, SortOrder};
use sd_prisma::{
prisma::{device, file_path, location, SortOrder},
prisma_sync,
};
use sd_sync::{sync_db_not_null_entry, OperationFactory};
use sd_task_system::{
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
TaskOutput, TaskStatus,
@@ -128,14 +132,14 @@ impl Job for FileIdentifier {
match task_kind {
TaskKind::Identifier => tasks::Identifier::deserialize(
&task_bytes,
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
(Arc::clone(ctx.db()), ctx.sync().clone()),
)
.await
.map(IntoTask::into_task),
TaskKind::ObjectProcessor => tasks::ObjectProcessor::deserialize(
&task_bytes,
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
(Arc::clone(ctx.db()), ctx.sync().clone()),
)
.await
.map(IntoTask::into_task),
@@ -173,8 +177,21 @@ impl Job for FileIdentifier {
) -> Result<ReturnStatus, Error> {
let mut pending_running_tasks = FuturesUnordered::new();
let device_pub_id = &ctx.sync().device_pub_id;
let device_id = ctx
.db()
.device()
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
.exec()
.await
.map_err(file_identifier::Error::from)?
.ok_or(file_identifier::Error::DeviceNotFound(
device_pub_id.clone(),
))?
.id;
match self
.init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher)
.init_or_resume(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
.await
{
Ok(()) => { /* Everything is awesome! */ }
@@ -201,7 +218,7 @@ impl Job for FileIdentifier {
match task {
Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => {
match self
.process_task_output(task_id, out, &ctx, &dispatcher)
.process_task_output(task_id, out, &ctx, device_id, &dispatcher)
.await
{
Ok(tasks) => pending_running_tasks.extend(tasks),
@@ -254,15 +271,25 @@ impl Job for FileIdentifier {
..
} = self;
ctx.db()
.location()
.update(
location::id::equals(location.id),
vec![location::scan_state::set(
LocationScanState::FilesIdentified as i32,
)],
let (sync_param, db_param) = sync_db_not_null_entry!(
LocationScanState::FilesIdentified as i32,
location::scan_state
);
ctx.sync()
.write_op(
ctx.db(),
ctx.sync().shared_update(
prisma_sync::location::SyncId {
pub_id: location.pub_id.clone(),
},
[sync_param],
),
ctx.db()
.location()
.update(location::id::equals(location.id), vec![db_param])
.select(location::select!({ id })),
)
.exec()
.await
.map_err(file_identifier::Error::from)?;
@@ -302,6 +329,7 @@ impl FileIdentifier {
&mut self,
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
// if we don't have any pending task, then this is a fresh job
@@ -335,6 +363,7 @@ impl FileIdentifier {
.as_ref()
.unwrap_or(&location_root_iso_file_path),
ctx,
device_id,
dispatcher,
pending_running_tasks,
)
@@ -345,8 +374,9 @@ impl FileIdentifier {
self.last_orphan_file_path_id = None;
self.dispatch_deep_identifier_tasks(
&maybe_sub_iso_file_path,
maybe_sub_iso_file_path.as_ref(),
ctx,
device_id,
dispatcher,
pending_running_tasks,
)
@@ -378,6 +408,7 @@ impl FileIdentifier {
.as_ref()
.unwrap_or(&location_root_iso_file_path),
ctx,
device_id,
dispatcher,
pending_running_tasks,
)
@@ -388,8 +419,9 @@ impl FileIdentifier {
self.last_orphan_file_path_id = None;
self.dispatch_deep_identifier_tasks(
&maybe_sub_iso_file_path,
maybe_sub_iso_file_path.as_ref(),
ctx,
device_id,
dispatcher,
pending_running_tasks,
)
@@ -401,8 +433,9 @@ impl FileIdentifier {
Phase::SearchingOrphans => {
self.dispatch_deep_identifier_tasks(
&maybe_sub_iso_file_path,
maybe_sub_iso_file_path.as_ref(),
ctx,
device_id,
dispatcher,
pending_running_tasks,
)
@@ -447,6 +480,7 @@ impl FileIdentifier {
task_id: TaskId,
any_task_output: Box<dyn AnyTaskOutput>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Result<Vec<TaskHandle<Error>>, DispatcherError> {
if any_task_output.is::<identifier::Output>() {
@@ -457,6 +491,7 @@ impl FileIdentifier {
.downcast::<identifier::Output>()
.expect("just checked"),
ctx,
device_id,
dispatcher,
)
.await;
@@ -501,6 +536,7 @@ impl FileIdentifier {
errors,
}: identifier::Output,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Result<Vec<TaskHandle<Error>>, DispatcherError> {
self.metadata.mean_extract_metadata_time += extract_metadata_time;
@@ -548,6 +584,7 @@ impl FileIdentifier {
let (tasks_count, res) = match dispatch_object_processor_tasks(
self.file_paths_accumulator.drain(),
ctx,
device_id,
dispatcher,
false,
)
@@ -636,6 +673,7 @@ impl FileIdentifier {
&mut self,
sub_iso_file_path: &IsolatedFilePathData<'static>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
pending_running_tasks: &FuturesUnordered<TaskHandle<Error>>,
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
@@ -702,7 +740,8 @@ impl FileIdentifier {
orphan_paths,
true,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
))
.await?,
);
@@ -713,8 +752,9 @@ impl FileIdentifier {
async fn dispatch_deep_identifier_tasks<OuterCtx: OuterContext>(
&mut self,
maybe_sub_iso_file_path: &Option<IsolatedFilePathData<'static>>,
maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'static>>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
pending_running_tasks: &FuturesUnordered<TaskHandle<Error>>,
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
@@ -785,7 +825,8 @@ impl FileIdentifier {
orphan_paths,
false,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
))
.await?,
);

View File

@@ -2,9 +2,10 @@ use crate::{utils::sub_path, OuterContext};
use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData};
use sd_core_prisma_helpers::CasId;
use sd_core_sync::DevicePubId;
use sd_file_ext::{extensions::Extension, kind::ObjectKind};
use sd_prisma::prisma::{file_path, location};
use sd_prisma::prisma::{device, file_path, location};
use sd_task_system::{TaskDispatcher, TaskHandle};
use sd_utils::{db::MissingFieldError, error::FileIOError};
@@ -41,6 +42,8 @@ const CHUNK_SIZE: usize = 100;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("device not found: <device_pub_id='{0}'")]
DeviceNotFound(DevicePubId),
#[error("missing field on database: {0}")]
MissingField(#[from] MissingFieldError),
#[error("failed to deserialized stored tasks for job resume: {0}")]
@@ -173,7 +176,7 @@ fn orphan_path_filters_shallow(
fn orphan_path_filters_deep(
location_id: location::id::Type,
file_path_id: Option<file_path::id::Type>,
maybe_sub_iso_file_path: &Option<IsolatedFilePathData<'_>>,
maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'_>>,
) -> Vec<file_path::WhereParam> {
sd_utils::chain_optional_iter(
[
@@ -197,6 +200,7 @@ fn orphan_path_filters_deep(
async fn dispatch_object_processor_tasks<Iter, Dispatcher>(
file_paths_by_cas_id: Iter,
ctx: &impl OuterContext,
device_id: device::id::Type,
dispatcher: &Dispatcher,
with_priority: bool,
) -> Result<Vec<TaskHandle<crate::Error>>, Dispatcher::DispatchError>
@@ -217,7 +221,8 @@ where
.dispatch(tasks::ObjectProcessor::new(
HashMap::from([(cas_id, objects_to_create_or_link)]),
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
with_priority,
))
.await?,
@@ -239,7 +244,8 @@ where
.dispatch(tasks::ObjectProcessor::new(
mem::take(&mut current_batch),
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
with_priority,
))
.await?,
@@ -256,7 +262,8 @@ where
.dispatch(tasks::ObjectProcessor::new(
current_batch,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
with_priority,
))
.await?,

View File

@@ -6,7 +6,7 @@ use crate::{
use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_prisma_helpers::file_path_for_file_identifier;
use sd_prisma::prisma::{file_path, location, SortOrder};
use sd_prisma::prisma::{device, file_path, location, SortOrder};
use sd_task_system::{
BaseTaskDispatcher, CancelTaskOnDrop, TaskDispatcher, TaskHandle, TaskOutput, TaskStatus,
};
@@ -66,6 +66,19 @@ pub async fn shallow(
Ok,
)?;
let device_pub_id = &ctx.sync().device_pub_id;
let device_id = ctx
.db()
.device()
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
.exec()
.await
.map_err(file_identifier::Error::from)?
.ok_or(file_identifier::Error::DeviceNotFound(
device_pub_id.clone(),
))?
.id;
let mut orphans_count = 0;
let mut last_orphan_file_path_id = None;
@@ -103,7 +116,8 @@ pub async fn shallow(
orphan_paths,
true,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
))
.await
else {
@@ -119,13 +133,14 @@ pub async fn shallow(
return Ok(vec![]);
}
process_tasks(identifier_tasks, dispatcher, ctx).await
process_tasks(identifier_tasks, dispatcher, ctx, device_id).await
}
async fn process_tasks(
identifier_tasks: Vec<TaskHandle<Error>>,
dispatcher: &BaseTaskDispatcher<Error>,
ctx: &impl OuterContext,
device_id: device::id::Type,
) -> Result<Vec<NonCriticalError>, Error> {
let total_identifier_tasks = identifier_tasks.len();
@@ -169,6 +184,7 @@ async fn process_tasks(
let Ok(tasks) = dispatch_object_processor_tasks(
file_paths_accumulator.drain(),
ctx,
device_id,
dispatcher,
true,
)

View File

@@ -5,18 +5,18 @@ use crate::{
use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId, FilePathPubId};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_file_ext::kind::ObjectKind;
use sd_prisma::{
prisma::{file_path, location, PrismaClient},
prisma::{device, file_path, location, PrismaClient},
prisma_sync,
};
use sd_sync::OperationFactory;
use sd_sync::{sync_db_entry, OperationFactory};
use sd_task_system::{
ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, SerializableTask, Task, TaskId,
};
use sd_utils::{error::FileIOError, msgpack};
use sd_utils::error::FileIOError;
use std::{
collections::HashMap, convert::identity, future::IntoFuture, mem, path::PathBuf, pin::pin,
@@ -64,6 +64,7 @@ pub struct Identifier {
file_paths_by_id: HashMap<FilePathPubId, file_path_for_file_identifier::Data>,
// Inner state
device_id: device::id::Type,
identified_files: HashMap<FilePathPubId, IdentifiedFile>,
file_paths_without_cas_id: Vec<FilePathToCreateOrLinkObject>,
@@ -72,7 +73,7 @@ pub struct Identifier {
// Dependencies
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
}
/// Output from the `[Identifier]` task
@@ -135,6 +136,7 @@ impl Task<Error> for Identifier {
let Self {
location,
location_path,
device_id,
file_paths_by_id,
file_paths_without_cas_id,
identified_files,
@@ -255,6 +257,7 @@ impl Task<Error> for Identifier {
file_paths_without_cas_id.drain(..),
&self.db,
&self.sync,
*device_id,
),
)
.try_join()
@@ -301,6 +304,7 @@ impl Task<Error> for Identifier {
file_paths_without_cas_id.drain(..),
&self.db,
&self.sync,
*device_id,
)
.await?;
@@ -324,7 +328,8 @@ impl Identifier {
file_paths: Vec<file_path_for_file_identifier::Data>,
with_priority: bool,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
device_id: device::id::Type,
) -> Self {
let mut output = Output::default();
@@ -377,6 +382,7 @@ impl Identifier {
id: TaskId::new_v4(),
location,
location_path,
device_id,
identified_files: HashMap::with_capacity(file_paths_count - directories_count),
file_paths_without_cas_id,
file_paths_by_id,
@@ -394,33 +400,31 @@ async fn assign_cas_id_to_file_paths(
db: &PrismaClient,
sync: &SyncManager,
) -> Result<(), file_identifier::Error> {
// Assign cas_id to each file path
sync.write_ops(
db,
identified_files
.iter()
.map(|(pub_id, IdentifiedFile { cas_id, .. })| {
(
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: pub_id.to_db(),
},
file_path::cas_id::NAME,
msgpack!(cas_id),
),
db.file_path()
.update(
file_path::pub_id::equals(pub_id.to_db()),
vec![file_path::cas_id::set(cas_id.into())],
)
// We don't need any data here, just the id avoids receiving the entire object
// as we can't pass an empty select macro call
.select(file_path::select!({ id })),
)
})
.unzip::<_, _, _, Vec<_>>(),
)
.await?;
let (ops, queries) = identified_files
.iter()
.map(|(pub_id, IdentifiedFile { cas_id, .. })| {
let (sync_param, db_param) = sync_db_entry!(cas_id, file_path::cas_id);
(
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: pub_id.to_db(),
},
[sync_param],
),
db.file_path()
.update(file_path::pub_id::equals(pub_id.to_db()), vec![db_param])
// We don't need any data here, just the id avoids receiving the entire object
// as we can't pass an empty select macro call
.select(file_path::select!({ id })),
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
if !ops.is_empty() && !queries.is_empty() {
// Assign cas_id to each file path
sync.write_ops(db, (ops, queries)).await?;
}
Ok(())
}
@@ -500,6 +504,7 @@ struct SaveState {
id: TaskId,
location: Arc<location::Data>,
location_path: Arc<PathBuf>,
device_id: device::id::Type,
file_paths_by_id: HashMap<FilePathPubId, file_path_for_file_identifier::Data>,
identified_files: HashMap<FilePathPubId, IdentifiedFile>,
file_paths_without_cas_id: Vec<FilePathToCreateOrLinkObject>,
@@ -512,13 +517,14 @@ impl SerializableTask<Error> for Identifier {
type DeserializeError = rmp_serde::decode::Error;
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
let Self {
id,
location,
location_path,
device_id,
file_paths_by_id,
identified_files,
file_paths_without_cas_id,
@@ -530,6 +536,7 @@ impl SerializableTask<Error> for Identifier {
id,
location,
location_path,
device_id,
file_paths_by_id,
identified_files,
file_paths_without_cas_id,
@@ -547,6 +554,7 @@ impl SerializableTask<Error> for Identifier {
id,
location,
location_path,
device_id,
file_paths_by_id,
identified_files,
file_paths_without_cas_id,
@@ -558,6 +566,7 @@ impl SerializableTask<Error> for Identifier {
location,
location_path,
file_paths_by_id,
device_id,
identified_files,
file_paths_without_cas_id,
output,

View File

@@ -1,15 +1,15 @@
use crate::file_identifier;
use sd_core_prisma_helpers::{file_path_id, FilePathPubId, ObjectPubId};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_file_ext::kind::ObjectKind;
use sd_prisma::{
prisma::{file_path, object, PrismaClient},
prisma::{device, file_path, object, PrismaClient},
prisma_sync,
};
use sd_sync::{CRDTOperation, OperationFactory};
use sd_utils::msgpack;
use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, CRDTOperation, OperationFactory};
use sd_utils::chain_optional_iter;
use std::collections::{HashMap, HashSet};
@@ -47,10 +47,12 @@ fn connect_file_path_to_object<'db>(
prisma_sync::file_path::SyncId {
pub_id: file_path_pub_id.to_db(),
},
file_path::object::NAME,
msgpack!(prisma_sync::object::SyncId {
pub_id: object_pub_id.to_db(),
}),
[sync_entry!(
prisma_sync::object::SyncId {
pub_id: object_pub_id.to_db(),
},
file_path::object
)],
),
db.file_path()
.update(
@@ -69,6 +71,7 @@ async fn create_objects_and_update_file_paths(
files_and_kinds: impl IntoIterator<Item = FilePathToCreateOrLinkObject> + Send,
db: &PrismaClient,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<HashMap<file_path::id::Type, ObjectPubId>, file_identifier::Error> {
trace!("Preparing objects");
let (object_create_args, file_path_args) = files_and_kinds
@@ -84,16 +87,23 @@ async fn create_objects_and_update_file_paths(
let kind = kind as i32;
let (sync_params, db_params) = [
(
(object::date_created::NAME, msgpack!(created_at)),
object::date_created::set(created_at),
),
(
(object::kind::NAME, msgpack!(kind)),
object::kind::set(Some(kind)),
),
]
let device_pub_id = sync.device_pub_id.to_db();
let (sync_params, db_params) = chain_optional_iter(
[
(
sync_entry!(
prisma_sync::device::SyncId {
pub_id: device_pub_id,
},
object::device
),
object::device_id::set(Some(device_id)),
),
sync_db_entry!(kind, object::kind),
],
[option_sync_db_entry!(created_at, object::date_created)],
)
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
@@ -121,51 +131,57 @@ async fn create_objects_and_update_file_paths(
.unzip::<_, _, HashMap<_, _>, Vec<_>>(
);
trace!(
new_objects_count = object_create_args.len(),
"Creating new Objects!;",
);
let new_objects_count = object_create_args.len();
if new_objects_count > 0 {
trace!(new_objects_count, "Creating new Objects!;",);
// create new object records with assembled values
let created_objects_count = sync
.write_ops(db, {
let (sync, db_params) = object_create_args
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
(
sync.into_iter().flatten().collect(),
db.object().create_many(db_params),
)
})
.await?;
trace!(%created_objects_count, "Created new Objects;");
if created_objects_count > 0 {
trace!("Updating file paths with created objects");
let updated_file_path_ids = sync
.write_ops(
db,
file_path_update_args
// create new object records with assembled values
let created_objects_count = sync
.write_ops(db, {
let (sync, db_params) = object_create_args
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>(),
)
.await
.map(|file_paths| {
file_paths
.into_iter()
.map(|file_path_id::Data { id }| id)
.collect::<HashSet<_>>()
})?;
.unzip::<_, _, Vec<_>, Vec<_>>();
object_pub_id_by_file_path_id
.retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id));
(sync, db.object().create_many(db_params))
})
.await?;
Ok(object_pub_id_by_file_path_id)
trace!(%created_objects_count, "Created new Objects;");
if created_objects_count > 0 {
let file_paths_to_update_count = file_path_update_args.len();
if file_paths_to_update_count > 0 {
trace!(
file_paths_to_update_count,
"Updating file paths with created objects"
);
let updated_file_path_ids = sync
.write_ops(
db,
file_path_update_args
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>(),
)
.await
.map(|file_paths| {
file_paths
.into_iter()
.map(|file_path_id::Data { id }| id)
.collect::<HashSet<_>>()
})?;
object_pub_id_by_file_path_id
.retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id));
}
Ok(object_pub_id_by_file_path_id)
} else {
trace!("No objects created, skipping file path updates");
Ok(HashMap::new())
}
} else {
trace!("No objects created, skipping file path updates");
trace!("No objects to create, skipping file path updates");
Ok(HashMap::new())
}
}

View File

@@ -1,9 +1,9 @@
use crate::{file_identifier, Error};
use sd_core_prisma_helpers::{file_path_id, object_for_file_identifier, CasId, ObjectPubId};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::prisma::{file_path, object, PrismaClient};
use sd_prisma::prisma::{device, file_path, object, PrismaClient};
use sd_task_system::{
check_interruption, ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId,
};
@@ -29,13 +29,14 @@ pub struct ObjectProcessor {
// Inner state
stage: Stage,
device_id: device::id::Type,
// Out collector
output: Output,
// Dependencies
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -93,6 +94,7 @@ impl Task<Error> for ObjectProcessor {
let Self {
db,
sync,
device_id,
file_paths_by_cas_id,
stage,
output:
@@ -167,8 +169,13 @@ impl Task<Error> for ObjectProcessor {
);
let start = Instant::now();
let (more_file_paths_with_new_object, more_linked_objects_count) =
assign_objects_to_duplicated_orphans(file_paths_by_cas_id, db, sync)
.await?;
assign_objects_to_duplicated_orphans(
file_paths_by_cas_id,
db,
sync,
*device_id,
)
.await?;
*create_object_time = start.elapsed();
file_path_ids_with_new_object.extend(more_file_paths_with_new_object);
*linked_objects_count += more_linked_objects_count;
@@ -194,7 +201,8 @@ impl ObjectProcessor {
pub fn new(
file_paths_by_cas_id: HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
device_id: device::id::Type,
with_priority: bool,
) -> Self {
Self {
@@ -202,6 +210,7 @@ impl ObjectProcessor {
db,
sync,
file_paths_by_cas_id,
device_id,
stage: Stage::Starting,
output: Output::default(),
with_priority,
@@ -270,45 +279,44 @@ async fn assign_existing_objects_to_file_paths(
db: &PrismaClient,
sync: &SyncManager,
) -> Result<Vec<file_path::id::Type>, file_identifier::Error> {
sync.write_ops(
db,
objects_by_cas_id
.iter()
.flat_map(|(cas_id, object_pub_id)| {
file_paths_by_cas_id
.remove(cas_id)
.map(|file_paths| {
file_paths.into_iter().map(
|FilePathToCreateOrLinkObject {
file_path_pub_id, ..
}| {
connect_file_path_to_object(
&file_path_pub_id,
object_pub_id,
db,
sync,
)
},
)
})
.expect("must be here")
})
.unzip::<_, _, Vec<_>, Vec<_>>(),
)
.await
.map(|file_paths| {
file_paths
.into_iter()
.map(|file_path_id::Data { id }| id)
.collect()
})
.map_err(Into::into)
let (ops, queries) = objects_by_cas_id
.iter()
.flat_map(|(cas_id, object_pub_id)| {
file_paths_by_cas_id
.remove(cas_id)
.map(|file_paths| {
file_paths.into_iter().map(
|FilePathToCreateOrLinkObject {
file_path_pub_id, ..
}| {
connect_file_path_to_object(&file_path_pub_id, object_pub_id, db, sync)
},
)
})
.expect("must be here")
})
.unzip::<_, _, Vec<_>, Vec<_>>();
if ops.is_empty() && queries.is_empty() {
return Ok(vec![]);
}
sync.write_ops(db, (ops, queries))
.await
.map(|file_paths| {
file_paths
.into_iter()
.map(|file_path_id::Data { id }| id)
.collect()
})
.map_err(Into::into)
}
async fn assign_objects_to_duplicated_orphans(
file_paths_by_cas_id: &mut HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
db: &PrismaClient,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(Vec<file_path::id::Type>, u64), file_identifier::Error> {
// at least 1 file path per cas_id
let mut selected_file_paths = Vec::with_capacity(file_paths_by_cas_id.len());
@@ -327,7 +335,7 @@ async fn assign_objects_to_duplicated_orphans(
});
let (mut file_paths_with_new_object, objects_by_cas_id) =
create_objects_and_update_file_paths(selected_file_paths, db, sync)
create_objects_and_update_file_paths(selected_file_paths, db, sync, device_id)
.await?
.into_iter()
.map(|(file_path_id, object_pub_id)| {
@@ -365,6 +373,7 @@ async fn assign_objects_to_duplicated_orphans(
pub struct SaveState {
id: TaskId,
file_paths_by_cas_id: HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
device_id: device::id::Type,
stage: Stage,
output: Output,
with_priority: bool,
@@ -375,12 +384,13 @@ impl SerializableTask<Error> for ObjectProcessor {
type DeserializeError = rmp_serde::decode::Error;
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
let Self {
id,
file_paths_by_cas_id,
device_id,
stage,
output,
with_priority,
@@ -390,6 +400,7 @@ impl SerializableTask<Error> for ObjectProcessor {
rmp_serde::to_vec_named(&SaveState {
id,
file_paths_by_cas_id,
device_id,
stage,
output,
with_priority,
@@ -404,6 +415,7 @@ impl SerializableTask<Error> for ObjectProcessor {
|SaveState {
id,
file_paths_by_cas_id,
device_id,
stage,
output,
with_priority,
@@ -412,6 +424,7 @@ impl SerializableTask<Error> for ObjectProcessor {
with_priority,
file_paths_by_cas_id,
stage,
device_id,
output,
db,
sync,

View File

@@ -16,7 +16,11 @@ use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_indexer_rules::{IndexerRule, IndexerRuler};
use sd_core_prisma_helpers::location_with_indexer_rules;
use sd_prisma::prisma::location;
use sd_prisma::{
prisma::{device, location},
prisma_sync,
};
use sd_sync::{sync_db_not_null_entry, OperationFactory};
use sd_task_system::{
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
TaskOutput, TaskStatus,
@@ -116,13 +120,13 @@ impl Job for Indexer {
TaskKind::Save => tasks::Saver::deserialize(
&task_bytes,
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
(Arc::clone(ctx.db()), ctx.sync().clone()),
)
.await
.map(IntoTask::into_task),
TaskKind::Update => tasks::Updater::deserialize(
&task_bytes,
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
(Arc::clone(ctx.db()), ctx.sync().clone()),
)
.await
.map(IntoTask::into_task),
@@ -161,6 +165,17 @@ impl Job for Indexer {
) -> Result<ReturnStatus, Error> {
let mut pending_running_tasks = FuturesUnordered::new();
let device_pub_id = &ctx.sync().device_pub_id;
let device_id = ctx
.db()
.device()
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
.exec()
.await
.map_err(indexer::Error::from)?
.ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))?
.id;
match self
.init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher)
.await
@@ -191,21 +206,26 @@ impl Job for Indexer {
}
if let Some(res) = self
.process_handles(&mut pending_running_tasks, &ctx, &dispatcher)
.process_handles(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
.await
{
return res;
}
if let Some(res) = self
.dispatch_last_save_and_update_tasks(&mut pending_running_tasks, &ctx, &dispatcher)
.dispatch_last_save_and_update_tasks(
&mut pending_running_tasks,
&ctx,
device_id,
&dispatcher,
)
.await
{
return res;
}
if let Some(res) = self
.index_pending_ancestors(&mut pending_running_tasks, &ctx, &dispatcher)
.index_pending_ancestors(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
.await
{
return res;
@@ -253,7 +273,7 @@ impl Job for Indexer {
.await?;
}
update_location_size(location.id, ctx.db(), &ctx).await?;
update_location_size(location.id, location.pub_id.clone(), &ctx).await?;
metadata.mean_db_write_time += start_size_update_time.elapsed();
}
@@ -271,13 +291,23 @@ impl Job for Indexer {
"all tasks must be completed here"
);
ctx.db()
.location()
.update(
location::id::equals(location.id),
vec![location::scan_state::set(LocationScanState::Indexed as i32)],
let (sync_param, db_param) =
sync_db_not_null_entry!(LocationScanState::Indexed as i32, location::scan_state);
ctx.sync()
.write_op(
ctx.db(),
ctx.sync().shared_update(
prisma_sync::location::SyncId {
pub_id: location.pub_id.clone(),
},
[sync_param],
),
ctx.db()
.location()
.update(location::id::equals(location.id), vec![db_param])
.select(location::select!({ id })),
)
.exec()
.await
.map_err(indexer::Error::from)?;
@@ -338,6 +368,7 @@ impl Indexer {
task_id: TaskId,
any_task_output: Box<dyn AnyTaskOutput>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Result<Vec<TaskHandle<Error>>, JobErrorOrDispatcherError<indexer::Error>> {
self.metadata.completed_tasks += 1;
@@ -349,6 +380,7 @@ impl Indexer {
.downcast::<walker::Output<WalkerDBProxy, IsoFilePathFactory>>()
.expect("just checked"),
ctx,
device_id,
dispatcher,
)
.await;
@@ -403,6 +435,7 @@ impl Indexer {
..
}: walker::Output<WalkerDBProxy, IsoFilePathFactory>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Result<Vec<TaskHandle<Error>>, JobErrorOrDispatcherError<indexer::Error>> {
self.metadata.mean_scan_read_time += scan_time;
@@ -465,7 +498,7 @@ impl Indexer {
self.metadata.mean_db_write_time += db_delete_time.elapsed();
}
let (save_tasks, update_tasks) =
self.prepare_save_and_update_tasks(to_create, to_update, ctx);
self.prepare_save_and_update_tasks(to_create, to_update, ctx, device_id);
ctx.progress(vec![
ProgressUpdate::TaskCount(self.metadata.total_tasks),
@@ -552,13 +585,14 @@ impl Indexer {
&mut self,
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Option<Result<ReturnStatus, Error>> {
while let Some(task) = pending_running_tasks.next().await {
match task {
Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => {
match self
.process_task_output(task_id, out, ctx, dispatcher)
.process_task_output(task_id, out, ctx, device_id, dispatcher)
.await
{
Ok(more_handles) => pending_running_tasks.extend(more_handles),
@@ -666,6 +700,7 @@ impl Indexer {
&mut self,
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Option<Result<ReturnStatus, Error>> {
if !self.to_create_buffer.is_empty() || !self.to_update_buffer.is_empty() {
@@ -687,7 +722,8 @@ impl Indexer {
self.location.pub_id.clone(),
self.to_create_buffer.drain(..).collect(),
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
)
.into_task(),
);
@@ -707,7 +743,7 @@ impl Indexer {
tasks::Updater::new_deep(
self.to_update_buffer.drain(..).collect(),
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
)
.into_task(),
);
@@ -726,7 +762,7 @@ impl Indexer {
}
}
self.process_handles(pending_running_tasks, ctx, dispatcher)
self.process_handles(pending_running_tasks, ctx, device_id, dispatcher)
.await
} else {
None
@@ -737,6 +773,7 @@ impl Indexer {
&mut self,
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
dispatcher: &JobTaskDispatcher,
) -> Option<Result<ReturnStatus, Error>> {
if self.ancestors_needing_indexing.is_empty() {
@@ -759,7 +796,8 @@ impl Indexer {
self.location.pub_id.clone(),
chunked_saves,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
)
})
.collect::<Vec<_>>();
@@ -776,7 +814,7 @@ impl Indexer {
}
}
self.process_handles(pending_running_tasks, ctx, dispatcher)
self.process_handles(pending_running_tasks, ctx, device_id, dispatcher)
.await
}
@@ -785,6 +823,7 @@ impl Indexer {
to_create: Vec<WalkedEntry>,
to_update: Vec<WalkedEntry>,
ctx: &impl JobContext<OuterCtx>,
device_id: device::id::Type,
) -> (Vec<tasks::Saver>, Vec<tasks::Updater>) {
if self.processing_first_directory {
// If we are processing the first directory, we dispatch shallow tasks with higher priority
@@ -806,7 +845,8 @@ impl Indexer {
self.location.pub_id.clone(),
chunked_saves,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
)
})
.collect::<Vec<_>>();
@@ -824,7 +864,7 @@ impl Indexer {
tasks::Updater::new_shallow(
chunked_updates,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
)
})
.collect::<Vec<_>>();
@@ -851,7 +891,8 @@ impl Indexer {
self.location.pub_id.clone(),
chunked_saves,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
device_id,
));
}
save_tasks
@@ -878,7 +919,7 @@ impl Indexer {
update_tasks.push(tasks::Updater::new_deep(
chunked_updates,
Arc::clone(ctx.db()),
Arc::clone(ctx.sync()),
ctx.sync().clone(),
));
}
update_tasks

View File

@@ -4,17 +4,17 @@ use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData};
use sd_core_prisma_helpers::{
file_path_pub_and_cas_ids, file_path_to_isolate_with_pub_id, file_path_walker,
};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::{DevicePubId, SyncManager};
use sd_prisma::{
prisma::{file_path, indexer_rule, location, PrismaClient, SortOrder},
prisma_sync,
};
use sd_sync::OperationFactory;
use sd_sync::{sync_db_entry, OperationFactory};
use sd_utils::{
db::{size_in_bytes_from_db, size_in_bytes_to_db, MissingFieldError},
error::{FileIOError, NonUtf8PathError},
from_bytes_to_uuid, msgpack,
from_bytes_to_uuid,
};
use std::{
@@ -50,6 +50,8 @@ pub enum Error {
IndexerRuleNotFound(indexer_rule::id::Type),
#[error(transparent)]
SubPath(#[from] sub_path::Error),
#[error("device not found: <device_pub_id='{0}'")]
DeviceNotFound(DevicePubId),
// Internal Errors
#[error("database error: {0}")]
@@ -136,7 +138,7 @@ async fn update_directory_sizes(
db: &PrismaClient,
sync: &SyncManager,
) -> Result<(), Error> {
let to_sync_and_update = db
let (ops, queries) = db
._batch(chunk_db_queries(iso_paths_and_sizes.keys(), db))
.await?
.into_iter()
@@ -144,22 +146,20 @@ async fn update_directory_sizes(
.map(|file_path| {
let size_bytes = iso_paths_and_sizes
.get(&IsolatedFilePathData::try_from(&file_path)?)
.map(|size| size.to_be_bytes().to_vec())
.map(|size| size_in_bytes_to_db(*size))
.expect("must be here");
let (sync_param, db_param) = sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes);
Ok((
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: file_path.pub_id.clone(),
},
file_path::size_in_bytes_bytes::NAME,
msgpack!(size_bytes),
[sync_param],
),
db.file_path()
.update(
file_path::pub_id::equals(file_path.pub_id),
vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))],
)
.update(file_path::pub_id::equals(file_path.pub_id), vec![db_param])
.select(file_path::select!({ id })),
))
})
@@ -167,42 +167,54 @@ async fn update_directory_sizes(
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
sync.write_ops(db, to_sync_and_update).await?;
if !ops.is_empty() && !queries.is_empty() {
sync.write_ops(db, (ops, queries)).await?;
}
Ok(())
}
async fn update_location_size(
location_id: location::id::Type,
db: &PrismaClient,
location_pub_id: location::pub_id::Type,
ctx: &impl OuterContext,
) -> Result<(), Error> {
let total_size = db
.file_path()
.find_many(vec![
file_path::location_id::equals(Some(location_id)),
file_path::materialized_path::equals(Some("/".to_string())),
])
.select(file_path::select!({ size_in_bytes_bytes }))
.exec()
.await?
.into_iter()
.filter_map(|file_path| {
file_path
.size_in_bytes_bytes
.map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes))
})
.sum::<u64>();
let db = ctx.db();
let sync = ctx.sync();
db.location()
.update(
location::id::equals(location_id),
vec![location::size_in_bytes::set(Some(
total_size.to_be_bytes().to_vec(),
))],
)
.exec()
.await?;
let total_size = size_in_bytes_to_db(
db.file_path()
.find_many(vec![
file_path::location_id::equals(Some(location_id)),
file_path::materialized_path::equals(Some("/".to_string())),
])
.select(file_path::select!({ size_in_bytes_bytes }))
.exec()
.await?
.into_iter()
.filter_map(|file_path| {
file_path
.size_in_bytes_bytes
.map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes))
})
.sum::<u64>(),
);
let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes);
sync.write_op(
db,
sync.shared_update(
prisma_sync::location::SyncId {
pub_id: location_pub_id,
},
[sync_param],
),
db.location()
.update(location::id::equals(location_id), vec![db_param])
.select(location::select!({ id })),
)
.await?;
ctx.invalidate_query("locations.list");
ctx.invalidate_query("locations.get");
@@ -213,7 +225,7 @@ async fn update_location_size(
async fn remove_non_existing_file_paths(
to_remove: Vec<file_path_pub_and_cas_ids::Data>,
db: &PrismaClient,
sync: &sd_core_sync::Manager,
sync: &SyncManager,
) -> Result<u64, Error> {
#[allow(clippy::cast_sign_loss)]
let (sync_params, db_params): (Vec<_>, Vec<_>) = to_remove
@@ -228,6 +240,10 @@ async fn remove_non_existing_file_paths(
})
.unzip();
if sync_params.is_empty() {
return Ok(0);
}
sync.write_ops(
db,
(
@@ -318,7 +334,7 @@ pub async fn reverse_update_directories_sizes(
)
.await?;
let to_sync_and_update = ancestors
let (sync_ops, update_queries) = ancestors
.into_values()
.filter_map(|materialized_path| {
if let Some((pub_id, size)) =
@@ -326,18 +342,19 @@ pub async fn reverse_update_directories_sizes(
{
let size_bytes = size_in_bytes_to_db(size);
let (sync_param, db_param) =
sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes);
Some((
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: pub_id.clone(),
},
file_path::size_in_bytes_bytes::NAME,
msgpack!(size_bytes),
),
db.file_path().update(
file_path::pub_id::equals(pub_id),
vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))],
[sync_param],
),
db.file_path()
.update(file_path::pub_id::equals(pub_id), vec![db_param])
.select(file_path::select!({ id })),
))
} else {
warn!("Got a missing ancestor for a file_path in the database, ignoring...");
@@ -346,7 +363,9 @@ pub async fn reverse_update_directories_sizes(
})
.unzip::<_, _, Vec<_>, Vec<_>>();
sync.write_ops(db, to_sync_and_update).await?;
if !sync_ops.is_empty() && !update_queries.is_empty() {
sync.write_ops(db, (sync_ops, update_queries)).await?;
}
Ok(())
}

View File

@@ -4,9 +4,9 @@ use crate::{
use sd_core_indexer_rules::{IndexerRule, IndexerRuler};
use sd_core_prisma_helpers::location_with_indexer_rules;
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::prisma::PrismaClient;
use sd_prisma::prisma::{device, PrismaClient};
use sd_task_system::{BaseTaskDispatcher, CancelTaskOnDrop, IntoTask, TaskDispatcher, TaskOutput};
use sd_utils::db::maybe_missing;
@@ -62,6 +62,17 @@ pub async fn shallow(
.await?,
);
let device_pub_id = &ctx.sync().device_pub_id;
let device_id = ctx
.db()
.device()
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
.exec()
.await
.map_err(indexer::Error::from)?
.ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))?
.id;
let Some(walker::Output {
to_create,
to_update,
@@ -96,7 +107,8 @@ pub async fn shallow(
to_create,
to_update,
Arc::clone(db),
Arc::clone(sync),
sync.clone(),
device_id,
dispatcher,
)
.await?
@@ -124,7 +136,7 @@ pub async fn shallow(
.await?;
}
update_location_size(location.id, db, ctx).await?;
update_location_size(location.id, location.pub_id, ctx).await?;
}
if indexed_count > 0 || removed_count > 0 {
@@ -203,7 +215,8 @@ async fn save_and_update(
to_create: Vec<WalkedEntry>,
to_update: Vec<WalkedEntry>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
device_id: device::id::Type,
dispatcher: &BaseTaskDispatcher<Error>,
) -> Result<Option<Metadata>, Error> {
let save_and_update_tasks = to_create
@@ -216,7 +229,8 @@ async fn save_and_update(
location.pub_id.clone(),
chunk.collect::<Vec<_>>(),
Arc::clone(&db),
Arc::clone(&sync),
sync.clone(),
device_id,
)
})
.map(IntoTask::into_task)
@@ -229,7 +243,7 @@ async fn save_and_update(
tasks::Updater::new_shallow(
chunk.collect::<Vec<_>>(),
Arc::clone(&db),
Arc::clone(&sync),
sync.clone(),
)
})
.map(IntoTask::into_task),

View File

@@ -1,18 +1,15 @@
use crate::{indexer, Error};
use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::{
prisma::{file_path, location, PrismaClient},
prisma::{device, file_path, location, PrismaClient},
prisma_sync,
};
use sd_sync::{sync_db_entry, OperationFactory};
use sd_sync::{sync_db_entry, sync_entry, OperationFactory};
use sd_task_system::{ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId};
use sd_utils::{
db::{inode_to_db, size_in_bytes_to_db},
msgpack,
};
use sd_utils::db::{inode_to_db, size_in_bytes_to_db};
use std::{sync::Arc, time::Duration};
@@ -32,11 +29,12 @@ pub struct Saver {
// Received input args
location_id: location::id::Type,
location_pub_id: location::pub_id::Type,
device_id: device::id::Type,
walked_entries: Vec<WalkedEntry>,
// Dependencies
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
}
/// [`Save`] Task output
@@ -73,8 +71,9 @@ impl Task<Error> for Saver {
#[allow(clippy::blocks_in_conditions)] // Due to `err` on `instrument` macro above
async fn run(&mut self, _: &Interrupter) -> Result<ExecStatus, Error> {
use file_path::{
create_unchecked, date_created, date_indexed, date_modified, extension, hidden, inode,
is_dir, location, location_id, materialized_path, name, size_in_bytes_bytes,
create_unchecked, date_created, date_indexed, date_modified, device, device_id,
extension, hidden, inode, is_dir, location, location_id, materialized_path, name,
size_in_bytes_bytes,
};
let start_time = Instant::now();
@@ -82,13 +81,14 @@ impl Task<Error> for Saver {
let Self {
location_id,
location_pub_id,
device_id,
walked_entries,
db,
sync,
..
} = self;
let (sync_stuff, paths): (Vec<_>, Vec<_>) = walked_entries
let (create_crdt_ops, paths): (Vec<_>, Vec<_>) = walked_entries
.drain(..)
.map(
|WalkedEntry {
@@ -118,13 +118,13 @@ impl Task<Error> for Saver {
new file_paths and they were not identified yet"
);
let (sync_params, db_params): (Vec<_>, Vec<_>) = [
let (sync_params, db_params) = [
(
(
location::NAME,
msgpack!(prisma_sync::location::SyncId {
sync_entry!(
prisma_sync::location::SyncId {
pub_id: location_pub_id.clone()
}),
},
location
),
location_id::set(Some(*location_id)),
),
@@ -138,9 +138,18 @@ impl Task<Error> for Saver {
sync_db_entry!(modified_at, date_modified),
sync_db_entry!(Utc::now(), date_indexed),
sync_db_entry!(hidden, hidden),
(
sync_entry!(
prisma_sync::device::SyncId {
pub_id: sync.device_pub_id.to_db(),
},
device
),
device_id::set(Some(*device_id)),
),
]
.into_iter()
.unzip();
.unzip::<_, _, Vec<_>, Vec<_>>();
(
sync.shared_create(
@@ -155,12 +164,22 @@ impl Task<Error> for Saver {
)
.unzip();
if create_crdt_ops.is_empty() && paths.is_empty() {
return Ok(ExecStatus::Done(
Output {
saved_count: 0,
save_duration: Duration::ZERO,
}
.into_output(),
));
}
#[allow(clippy::cast_sign_loss)]
let saved_count = sync
.write_ops(
db,
(
sync_stuff.into_iter().flatten().collect(),
create_crdt_ops,
db.file_path().create_many(paths).skip_duplicates(),
),
)
@@ -188,12 +207,14 @@ impl Saver {
location_pub_id: location::pub_id::Type,
walked_entries: Vec<WalkedEntry>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
device_id: device::id::Type,
) -> Self {
Self {
id: TaskId::new_v4(),
location_id,
location_pub_id,
device_id,
walked_entries,
db,
sync,
@@ -207,12 +228,14 @@ impl Saver {
location_pub_id: location::pub_id::Type,
walked_entries: Vec<WalkedEntry>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
device_id: device::id::Type,
) -> Self {
Self {
id: TaskId::new_v4(),
location_id,
location_pub_id,
device_id,
walked_entries,
db,
sync,
@@ -228,6 +251,7 @@ struct SaveState {
location_id: location::id::Type,
location_pub_id: location::pub_id::Type,
device_id: device::id::Type,
walked_entries: Vec<WalkedEntry>,
}
@@ -236,7 +260,7 @@ impl SerializableTask<Error> for Saver {
type DeserializeError = rmp_serde::decode::Error;
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
let Self {
@@ -244,6 +268,7 @@ impl SerializableTask<Error> for Saver {
is_shallow,
location_id,
location_pub_id,
device_id,
walked_entries,
..
} = self;
@@ -252,6 +277,7 @@ impl SerializableTask<Error> for Saver {
is_shallow,
location_id,
location_pub_id,
device_id,
walked_entries,
})
}
@@ -266,12 +292,14 @@ impl SerializableTask<Error> for Saver {
is_shallow,
location_id,
location_pub_id,
device_id,
walked_entries,
}| Self {
id,
is_shallow,
location_id,
location_pub_id,
device_id,
walked_entries,
db,
sync,

View File

@@ -1,7 +1,7 @@
use crate::{indexer, Error};
use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::{
prisma::{file_path, object, PrismaClient},
@@ -39,7 +39,7 @@ pub struct Updater {
// Dependencies
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
}
/// [`Update`] Task output
@@ -93,7 +93,7 @@ impl Task<Error> for Updater {
check_interruption!(interrupter);
let (sync_stuff, paths_to_update) = walked_entries
let (crdt_ops, paths_to_update) = walked_entries
.drain(..)
.map(
|WalkedEntry {
@@ -138,18 +138,12 @@ impl Task<Error> for Updater {
.unzip::<_, _, Vec<_>, Vec<_>>();
(
sync_params
.into_iter()
.map(|(field, value)| {
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: pub_id.to_db(),
},
field,
value,
)
})
.collect::<Vec<_>>(),
sync.shared_update(
prisma_sync::file_path::SyncId {
pub_id: pub_id.to_db(),
},
sync_params,
),
db.file_path()
.update(file_path::pub_id::equals(pub_id.into()), db_params)
// selecting id to avoid fetching whole object from database
@@ -159,11 +153,18 @@ impl Task<Error> for Updater {
)
.unzip::<_, _, Vec<_>, Vec<_>>();
if crdt_ops.is_empty() && paths_to_update.is_empty() {
return Ok(ExecStatus::Done(
Output {
updated_count: 0,
update_duration: Duration::ZERO,
}
.into_output(),
));
}
let updated = sync
.write_ops(
db,
(sync_stuff.into_iter().flatten().collect(), paths_to_update),
)
.write_ops(db, (crdt_ops, paths_to_update))
.await
.map_err(indexer::Error::from)?;
@@ -186,7 +187,7 @@ impl Updater {
pub fn new_deep(
walked_entries: Vec<WalkedEntry>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
) -> Self {
Self {
id: TaskId::new_v4(),
@@ -202,7 +203,7 @@ impl Updater {
pub fn new_shallow(
walked_entries: Vec<WalkedEntry>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
) -> Self {
Self {
id: TaskId::new_v4(),
@@ -264,7 +265,7 @@ impl SerializableTask<Error> for Updater {
type DeserializeError = rmp_serde::decode::Error;
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
let Self {

View File

@@ -1,6 +1,6 @@
use crate::{Error, NonCriticalError, UpdateEvent};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::prisma::PrismaClient;
use sd_task_system::{
@@ -98,7 +98,7 @@ impl ProgressUpdate {
pub trait OuterContext: Send + Sync + Clone + 'static {
fn id(&self) -> Uuid;
fn db(&self) -> &Arc<PrismaClient>;
fn sync(&self) -> &Arc<SyncManager>;
fn sync(&self) -> &SyncManager;
fn invalidate_query(&self, query: &'static str);
fn query_invalidator(&self) -> impl Fn(&'static str) + Send + Sync;
fn report_update(&self, update: UpdateEvent);
@@ -158,7 +158,7 @@ where
JobCtx: JobContext<OuterCtx>,
{
fn into_job(self) -> Box<dyn DynJob<OuterCtx, JobCtx>> {
let id = JobId::new_v4();
let id = JobId::now_v7();
Box::new(JobHolder {
id,
@@ -333,7 +333,7 @@ where
}
pub fn new(job: J) -> Self {
let id = JobId::new_v4();
let id = JobId::now_v7();
Self {
id,
job,

View File

@@ -290,6 +290,7 @@ impl Report {
.map(|id| job::parent::connect(job::id::equals(id.as_bytes().to_vec())))],
),
)
.select(job::select!({ id }))
.exec()
.await
.map_err(ReportError::Create)?;
@@ -300,7 +301,7 @@ impl Report {
Ok(())
}
pub async fn update(&mut self, db: &PrismaClient) -> Result<(), ReportError> {
pub async fn update(&self, db: &PrismaClient) -> Result<(), ReportError> {
db.job()
.update(
job::id::equals(self.id.as_bytes().to_vec()),
@@ -318,6 +319,7 @@ impl Report {
job::date_completed::set(self.completed_at.map(Into::into)),
],
)
.select(job::select!({ id }))
.exec()
.await
.map_err(ReportError::Update)?;

View File

@@ -313,7 +313,7 @@ impl<OuterCtx: OuterContext, JobCtx: JobContext<OuterCtx>> JobSystemRunner<Outer
Ok(Some(serialized_job)) => {
let name = {
let db = handle.ctx.db();
let mut report = handle.ctx.report_mut().await;
let report = handle.ctx.report().await;
if let Err(e) = report.update(db).await {
error!(?e, "Failed to update report on job shutdown;");
}

View File

@@ -1,15 +1,15 @@
use crate::media_processor::{self, media_data_extractor};
use sd_core_prisma_helpers::ObjectPubId;
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::{DevicePubId, SyncManager};
use sd_file_ext::extensions::{Extension, ImageExtension, ALL_IMAGE_EXTENSIONS};
use sd_media_metadata::ExifMetadata;
use sd_prisma::{
prisma::{exif_data, object, PrismaClient},
prisma::{device, exif_data, object, PrismaClient},
prisma_sync,
};
use sd_sync::{option_sync_db_entry, OperationFactory};
use sd_sync::{option_sync_db_entry, sync_entry, OperationFactory};
use sd_utils::chain_optional_iter;
use std::{path::Path, sync::LazyLock};
@@ -51,9 +51,20 @@ fn to_query(
exif_version,
}: ExifMetadata,
object_id: exif_data::object_id::Type,
device_pub_id: &DevicePubId,
) -> (Vec<(&'static str, rmpv::Value)>, exif_data::Create) {
let device_pub_id = device_pub_id.to_db();
let (sync_params, db_params) = chain_optional_iter(
[],
[(
sync_entry!(
prisma_sync::device::SyncId {
pub_id: device_pub_id.clone()
},
exif_data::device
),
exif_data::device::connect(device::pub_id::equals(device_pub_id)),
)],
[
option_sync_db_entry!(
serde_json::to_vec(&camera_data).ok(),
@@ -109,24 +120,22 @@ pub async fn save(
exif_datas
.into_iter()
.map(|(exif_data, object_id, object_pub_id)| async move {
let (sync_params, create) = to_query(exif_data, object_id);
let (sync_params, create) = to_query(exif_data, object_id, &sync.device_pub_id);
let db_params = create._params.clone();
sync.write_ops(
sync.write_op(
db,
(
sync.shared_create(
prisma_sync::exif_data::SyncId {
object: prisma_sync::object::SyncId {
pub_id: object_pub_id.into(),
},
sync.shared_create(
prisma_sync::exif_data::SyncId {
object: prisma_sync::object::SyncId {
pub_id: object_pub_id.into(),
},
sync_params,
),
db.exif_data()
.upsert(exif_data::object_id::equals(object_id), create, db_params)
.select(exif_data::select!({ id })),
},
sync_params,
),
db.exif_data()
.upsert(exif_data::object_id::equals(object_id), create, db_params)
.select(exif_data::select!({ id })),
)
.await
})

View File

@@ -14,7 +14,11 @@ use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_prisma_helpers::file_path_for_media_processor;
use sd_file_ext::extensions::Extension;
use sd_prisma::prisma::{location, PrismaClient};
use sd_prisma::{
prisma::{location, PrismaClient},
prisma_sync,
};
use sd_sync::{sync_db_not_null_entry, OperationFactory};
use sd_task_system::{
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
TaskOutput, TaskStatus, TaskSystemError,
@@ -125,7 +129,7 @@ impl Job for MediaProcessor {
TaskKind::MediaDataExtractor => {
tasks::MediaDataExtractor::deserialize(
&task_bytes,
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
(Arc::clone(ctx.db()), ctx.sync().clone()),
)
.await
.map(IntoTask::into_task)
@@ -214,15 +218,23 @@ impl Job for MediaProcessor {
..
} = self;
ctx.db()
.location()
.update(
location::id::equals(location.id),
vec![location::scan_state::set(
LocationScanState::Completed as i32,
)],
let (sync_param, db_param) =
sync_db_not_null_entry!(LocationScanState::Completed as i32, location::scan_state);
ctx.sync()
.write_op(
ctx.db(),
ctx.sync().shared_update(
prisma_sync::location::SyncId {
pub_id: location.pub_id.clone(),
},
[sync_param],
),
ctx.db()
.location()
.update(location::id::equals(location.id), vec![db_param])
.select(location::select!({ id })),
)
.exec()
.await
.map_err(media_processor::Error::from)?;
@@ -632,7 +644,7 @@ impl MediaProcessor {
parent_iso_file_path.location_id(),
Arc::clone(&self.location_path),
Arc::clone(db),
Arc::clone(sync),
sync.clone(),
)
})
.map(IntoTask::into_task)
@@ -648,7 +660,7 @@ impl MediaProcessor {
parent_iso_file_path.location_id(),
Arc::clone(&self.location_path),
Arc::clone(db),
Arc::clone(sync),
sync.clone(),
)
})
.map(IntoTask::into_task),

View File

@@ -4,7 +4,7 @@ use crate::{
};
use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_prisma::prisma::{location, PrismaClient};
use sd_task_system::{
@@ -154,7 +154,7 @@ pub async fn shallow(
async fn dispatch_media_data_extractor_tasks(
db: &Arc<PrismaClient>,
sync: &Arc<SyncManager>,
sync: &SyncManager,
parent_iso_file_path: &IsolatedFilePathData<'_>,
location_path: &Arc<PathBuf>,
dispatcher: &BaseTaskDispatcher<Error>,
@@ -185,7 +185,7 @@ async fn dispatch_media_data_extractor_tasks(
parent_iso_file_path.location_id(),
Arc::clone(location_path),
Arc::clone(db),
Arc::clone(sync),
sync.clone(),
)
})
.map(IntoTask::into_task)
@@ -201,7 +201,7 @@ async fn dispatch_media_data_extractor_tasks(
parent_iso_file_path.location_id(),
Arc::clone(location_path),
Arc::clone(db),
Arc::clone(sync),
sync.clone(),
)
})
.map(IntoTask::into_task),
@@ -220,7 +220,7 @@ async fn dispatch_media_data_extractor_tasks(
async fn dispatch_thumbnailer_tasks(
parent_iso_file_path: &IsolatedFilePathData<'_>,
should_regenerate: bool,
location_path: &PathBuf,
location_path: &Path,
dispatcher: &BaseTaskDispatcher<Error>,
ctx: &impl OuterContext,
) -> Result<Vec<TaskHandle<Error>>, Error> {

View File

@@ -8,7 +8,7 @@ use crate::{
use sd_core_file_path_helper::IsolatedFilePathData;
use sd_core_prisma_helpers::{file_path_for_media_processor, ObjectPubId};
use sd_core_sync::Manager as SyncManager;
use sd_core_sync::SyncManager;
use sd_media_metadata::{ExifMetadata, FFmpegMetadata};
use sd_prisma::prisma::{exif_data, ffmpeg_data, file_path, location, object, PrismaClient};
@@ -69,7 +69,7 @@ pub struct MediaDataExtractor {
// Dependencies
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -275,7 +275,7 @@ impl MediaDataExtractor {
location_id: location::id::Type,
location_path: Arc<PathBuf>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
) -> Self {
let mut output = Output::default();
@@ -316,7 +316,7 @@ impl MediaDataExtractor {
location_id: location::id::Type,
location_path: Arc<PathBuf>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
) -> Self {
Self::new(Kind::Exif, file_paths, location_id, location_path, db, sync)
}
@@ -327,7 +327,7 @@ impl MediaDataExtractor {
location_id: location::id::Type,
location_path: Arc<PathBuf>,
db: Arc<PrismaClient>,
sync: Arc<SyncManager>,
sync: SyncManager,
) -> Self {
Self::new(
Kind::FFmpeg,
@@ -550,7 +550,7 @@ impl SerializableTask<Error> for MediaDataExtractor {
type DeserializeError = rmp_serde::decode::Error;
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
let Self {

View File

@@ -60,7 +60,7 @@ impl<'de> Deserialize<'de> for RulePerKind {
struct FieldsVisitor;
impl<'de> de::Visitor<'de> for FieldsVisitor {
impl de::Visitor<'_> for FieldsVisitor {
type Value = Fields;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {

View File

@@ -9,8 +9,9 @@ repository.workspace = true
[dependencies]
# Spacedrive Sub-crates
sd-prisma = { path = "../../../crates/prisma" }
sd-utils = { path = "../../../crates/utils" }
sd-cloud-schema = { workspace = true }
sd-prisma = { path = "../../../crates/prisma" }
sd-utils = { path = "../../../crates/utils" }
# Workspace dependencies
prisma-client-rust = { workspace = true }

View File

@@ -34,7 +34,6 @@ use sd_utils::{from_bytes_to_uuid, uuid_to_bytes};
use std::{borrow::Cow, fmt};
use serde::{Deserialize, Serialize};
use specta::Type;
use uuid::Uuid;
// File Path selectables!
@@ -75,6 +74,20 @@ file_path::select!(file_path_for_media_processor {
pub_id
}
});
file_path::select!(file_path_watcher_remove {
id
pub_id
location_id
materialized_path
is_dir
name
extension
object: select {
id
pub_id
}
});
file_path::select!(file_path_to_isolate {
location_id
materialized_path
@@ -244,7 +257,7 @@ job::select!(job_without_data {
location::select!(location_ids_and_path {
id
pub_id
instance_id
device: select { pub_id }
path
});
@@ -259,6 +272,7 @@ impl From<location_with_indexer_rules::Data> for location::Data {
id: data.id,
pub_id: data.pub_id,
path: data.path,
device_id: data.device_id,
instance_id: data.instance_id,
name: data.name,
total_capacity: data.total_capacity,
@@ -272,6 +286,7 @@ impl From<location_with_indexer_rules::Data> for location::Data {
scan_state: data.scan_state,
file_paths: None,
indexer_rules: None,
device: None,
instance: None,
}
}
@@ -283,6 +298,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data {
id: data.id,
pub_id: data.pub_id.clone(),
path: data.path.clone(),
device_id: data.device_id,
instance_id: data.instance_id,
name: data.name.clone(),
total_capacity: data.total_capacity,
@@ -296,6 +312,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data {
scan_state: data.scan_state,
file_paths: None,
indexer_rules: None,
device: None,
instance: None,
}
}
@@ -311,7 +328,7 @@ label::include!((take: i64) => label_with_objects {
}
});
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Type)]
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, specta::Type)]
#[serde(transparent)]
pub struct CasId<'cas_id>(Cow<'cas_id, str>);
@@ -321,7 +338,7 @@ impl Clone for CasId<'_> {
}
}
impl<'cas_id> CasId<'cas_id> {
impl CasId<'_> {
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_ref()
@@ -374,17 +391,32 @@ impl From<&CasId<'_>> for String {
}
}
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
#[serde(transparent)]
#[repr(transparent)]
#[specta(rename = "CoreDevicePubId")]
pub struct DevicePubId(PubId);
impl From<DevicePubId> for sd_cloud_schema::devices::PubId {
fn from(DevicePubId(pub_id): DevicePubId) -> Self {
Self(pub_id.into())
}
}
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
#[serde(transparent)]
#[repr(transparent)]
#[specta(rename = "CoreFilePathPubId")]
pub struct FilePathPubId(PubId);
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
#[serde(transparent)]
#[repr(transparent)]
#[specta(rename = "CoreObjectPubId")]
pub struct ObjectPubId(PubId);
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
#[specta(rename = "CorePubId")]
enum PubId {
Uuid(Uuid),
Vec(Vec<u8>),
@@ -392,7 +424,7 @@ enum PubId {
impl PubId {
fn new() -> Self {
Self::Uuid(Uuid::new_v4())
Self::Uuid(Uuid::now_v7())
}
fn to_db(&self) -> Vec<u8> {
@@ -451,6 +483,15 @@ impl From<PubId> for Uuid {
}
}
impl From<&PubId> for Uuid {
fn from(pub_id: &PubId) -> Self {
match pub_id {
PubId::Uuid(uuid) => *uuid,
PubId::Vec(bytes) => from_bytes_to_uuid(bytes),
}
}
}
impl fmt::Display for PubId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@@ -499,6 +540,12 @@ macro_rules! delegate_pub_id {
}
}
impl From<&$type_name> for ::uuid::Uuid {
fn from(pub_id: &$type_name) -> Self {
(&pub_id.0).into()
}
}
impl ::std::fmt::Display for $type_name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{}", self.0)
@@ -526,4 +573,4 @@ macro_rules! delegate_pub_id {
};
}
delegate_pub_id!(FilePathPubId, ObjectPubId);
delegate_pub_id!(FilePathPubId, ObjectPubId, DevicePubId);

View File

@@ -9,6 +9,8 @@ default = []
[dependencies]
# Spacedrive Sub-crates
sd-core-prisma-helpers = { path = "../prisma-helpers" }
sd-actors = { path = "../../../crates/actors" }
sd-prisma = { path = "../../../crates/prisma" }
sd-sync = { path = "../../../crates/sync" }
@@ -16,8 +18,11 @@ sd-utils = { path = "../../../crates/utils" }
# Workspace dependencies
async-channel = { workspace = true }
async-stream = { workspace = true }
chrono = { workspace = true }
futures = { workspace = true }
futures-concurrency = { workspace = true }
itertools = { workspace = true }
prisma-client-rust = { workspace = true, features = ["rspc"] }
rmp-serde = { workspace = true }
rmpv = { workspace = true }

View File

@@ -1,39 +0,0 @@
use async_channel as chan;
pub trait ActorTypes {
type Event: Send;
type Request: Send;
type Handler;
}
pub struct ActorIO<T: ActorTypes> {
pub event_rx: chan::Receiver<T::Event>,
pub req_tx: chan::Sender<T::Request>,
}
impl<T: ActorTypes> Clone for ActorIO<T> {
fn clone(&self) -> Self {
Self {
event_rx: self.event_rx.clone(),
req_tx: self.req_tx.clone(),
}
}
}
impl<T: ActorTypes> ActorIO<T> {
pub async fn send(&self, value: T::Request) -> Result<(), chan::SendError<T::Request>> {
self.req_tx.send(value).await
}
}
pub struct HandlerIO<T: ActorTypes> {
pub event_tx: chan::Sender<T::Event>,
pub req_rx: chan::Receiver<T::Request>,
}
pub fn create_actor_io<T: ActorTypes>() -> (ActorIO<T>, HandlerIO<T>) {
let (req_tx, req_rx) = chan::bounded(32);
let (event_tx, event_rx) = chan::bounded(32);
(ActorIO { event_rx, req_tx }, HandlerIO { event_tx, req_rx })
}

View File

@@ -1,7 +1,7 @@
use sd_prisma::{
prisma::{
crdt_operation, exif_data, file_path, instance, label, label_on_object, location, object,
tag, tag_on_object, PrismaClient, SortOrder,
crdt_operation, device, exif_data, file_path, label, label_on_object, location, object,
storage_statistics, tag, tag_on_object, PrismaClient, SortOrder,
},
prisma_sync,
};
@@ -10,49 +10,141 @@ use sd_utils::chain_optional_iter;
use std::future::Future;
use futures_concurrency::future::TryJoin;
use tokio::time::Instant;
use tracing::{debug, instrument};
use super::{crdt_op_unchecked_db, Error};
use super::{crdt_op_unchecked_db, Error, SyncManager};
/// Takes all the syncable data in the database and generates [`CRDTOperations`] for it.
/// This is a requirement before the library can sync.
pub async fn backfill_operations(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
) -> Result<(), Error> {
let lock = sync.timestamp_lock.lock().await;
pub async fn backfill_operations(sync: &SyncManager) -> Result<(), Error> {
let _lock_guard = sync.sync_lock.lock().await;
let res = db
._transaction()
let db = &sync.db;
let local_device = db
.device()
.find_unique(device::pub_id::equals(sync.device_pub_id.to_db()))
.exec()
.await?
.ok_or(Error::DeviceNotFound(sync.device_pub_id.clone()))?;
let local_device_id = local_device.id;
db._transaction()
.with_timeout(9_999_999_999)
.run(|db| async move {
debug!("backfill started");
let start = Instant::now();
db.crdt_operation()
.delete_many(vec![crdt_operation::instance_id::equals(instance_id)])
.delete_many(vec![crdt_operation::device_pub_id::equals(
sync.device_pub_id.to_db(),
)])
.exec()
.await?;
paginate_tags(&db, sync, instance_id).await?;
paginate_locations(&db, sync, instance_id).await?;
paginate_objects(&db, sync, instance_id).await?;
paginate_exif_datas(&db, sync, instance_id).await?;
paginate_file_paths(&db, sync, instance_id).await?;
paginate_tags_on_objects(&db, sync, instance_id).await?;
paginate_labels(&db, sync, instance_id).await?;
paginate_labels_on_objects(&db, sync, instance_id).await?;
backfill_device(&db, sync, local_device).await?;
(
backfill_storage_statistics(&db, sync, local_device_id),
paginate_tags(&db, sync),
paginate_locations(&db, sync, local_device_id),
paginate_objects(&db, sync, local_device_id),
paginate_labels(&db, sync),
)
.try_join()
.await?;
(
paginate_exif_datas(&db, sync, local_device_id),
paginate_file_paths(&db, sync, local_device_id),
paginate_tags_on_objects(&db, sync, local_device_id),
paginate_labels_on_objects(&db, sync, local_device_id),
)
.try_join()
.await?;
debug!(elapsed = ?start.elapsed(), "backfill ended");
Ok(())
})
.await;
.await
}
drop(lock);
#[instrument(skip(db, sync), err)]
async fn backfill_device(
db: &PrismaClient,
sync: &SyncManager,
local_device: device::Data,
) -> Result<(), Error> {
db.crdt_operation()
.create_many(vec![crdt_op_unchecked_db(&sync.shared_create(
prisma_sync::device::SyncId {
pub_id: local_device.pub_id,
},
chain_optional_iter(
[],
[
option_sync_entry!(local_device.name, device::name),
option_sync_entry!(local_device.os, device::os),
option_sync_entry!(local_device.hardware_model, device::hardware_model),
option_sync_entry!(local_device.timestamp, device::timestamp),
option_sync_entry!(local_device.date_created, device::date_created),
option_sync_entry!(local_device.date_deleted, device::date_deleted),
],
),
))?])
.exec()
.await?;
res
Ok(())
}
#[instrument(skip(db, sync), err)]
async fn backfill_storage_statistics(
db: &PrismaClient,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
let Some(stats) = db
.storage_statistics()
.find_first(vec![storage_statistics::device_id::equals(Some(device_id))])
.include(storage_statistics::include!({device: select { pub_id }}))
.exec()
.await?
else {
// Nothing to do
return Ok(());
};
db.crdt_operation()
.create_many(vec![crdt_op_unchecked_db(&sync.shared_create(
prisma_sync::storage_statistics::SyncId {
pub_id: stats.pub_id,
},
chain_optional_iter(
[
sync_entry!(stats.total_capacity, storage_statistics::total_capacity),
sync_entry!(
stats.available_capacity,
storage_statistics::available_capacity
),
],
[option_sync_entry!(
stats.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
storage_statistics::device
)],
),
))?])
.exec()
.await?;
Ok(())
}
async fn paginate<T, E1, E2, E3, GetterFut, OperationsFut>(
@@ -112,38 +204,32 @@ where
}
#[instrument(skip(db, sync), err)]
async fn paginate_tags(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
) -> Result<(), Error> {
use tag::{color, date_created, date_modified, id, name};
async fn paginate_tags(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> {
paginate(
|cursor| {
db.tag()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.find_many(vec![tag::id::gt(cursor)])
.order_by(tag::id::order(SortOrder::Asc))
.exec()
},
|tag| tag.id,
|tags| {
tags.into_iter()
.flat_map(|t| {
.map(|t| {
sync.shared_create(
prisma_sync::tag::SyncId { pub_id: t.pub_id },
chain_optional_iter(
[],
[
option_sync_entry!(t.name, name),
option_sync_entry!(t.color, color),
option_sync_entry!(t.date_created, date_created),
option_sync_entry!(t.date_modified, date_modified),
option_sync_entry!(t.name, tag::name),
option_sync_entry!(t.color, tag::color),
option_sync_entry!(t.date_created, tag::date_created),
option_sync_entry!(t.date_modified, tag::date_modified),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -154,25 +240,20 @@ async fn paginate_tags(
#[instrument(skip(db, sync), err)]
async fn paginate_locations(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use location::{
available_capacity, date_created, generate_preview_media, hidden, id, include, instance,
is_archived, name, path, size_in_bytes, sync_preview_media, total_capacity,
};
paginate(
|cursor| {
db.location()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.find_many(vec![
location::id::gt(cursor),
location::device_id::equals(Some(device_id)),
])
.order_by(location::id::order(SortOrder::Asc))
.take(1000)
.include(include!({
instance: select {
id
pub_id
}
.include(location::include!({
device: select { pub_id }
}))
.exec()
},
@@ -180,36 +261,44 @@ async fn paginate_locations(
|locations| {
locations
.into_iter()
.flat_map(|l| {
.map(|l| {
sync.shared_create(
prisma_sync::location::SyncId { pub_id: l.pub_id },
chain_optional_iter(
[],
[
option_sync_entry!(l.name, name),
option_sync_entry!(l.path, path),
option_sync_entry!(l.total_capacity, total_capacity),
option_sync_entry!(l.available_capacity, available_capacity),
option_sync_entry!(l.size_in_bytes, size_in_bytes),
option_sync_entry!(l.is_archived, is_archived),
option_sync_entry!(l.name, location::name),
option_sync_entry!(l.path, location::path),
option_sync_entry!(l.total_capacity, location::total_capacity),
option_sync_entry!(
l.available_capacity,
location::available_capacity
),
option_sync_entry!(l.size_in_bytes, location::size_in_bytes),
option_sync_entry!(l.is_archived, location::is_archived),
option_sync_entry!(
l.generate_preview_media,
generate_preview_media
location::generate_preview_media
),
option_sync_entry!(l.sync_preview_media, sync_preview_media),
option_sync_entry!(l.hidden, hidden),
option_sync_entry!(l.date_created, date_created),
option_sync_entry!(
l.instance.map(|i| {
prisma_sync::instance::SyncId { pub_id: i.pub_id }
l.sync_preview_media,
location::sync_preview_media
),
option_sync_entry!(l.hidden, location::hidden),
option_sync_entry!(l.date_created, location::date_created),
option_sync_entry!(
l.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
instance
location::device
),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -220,41 +309,53 @@ async fn paginate_locations(
#[instrument(skip(db, sync), err)]
async fn paginate_objects(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use object::{date_accessed, date_created, favorite, hidden, id, important, kind, note};
paginate(
|cursor| {
db.object()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.find_many(vec![
object::id::gt(cursor),
object::device_id::equals(Some(device_id)),
])
.order_by(object::id::order(SortOrder::Asc))
.take(1000)
.include(object::include!({
device: select { pub_id }
}))
.exec()
},
|object| object.id,
|objects| {
objects
.into_iter()
.flat_map(|o| {
.map(|o| {
sync.shared_create(
prisma_sync::object::SyncId { pub_id: o.pub_id },
chain_optional_iter(
[],
[
option_sync_entry!(o.kind, kind),
option_sync_entry!(o.hidden, hidden),
option_sync_entry!(o.favorite, favorite),
option_sync_entry!(o.important, important),
option_sync_entry!(o.note, note),
option_sync_entry!(o.date_created, date_created),
option_sync_entry!(o.date_accessed, date_accessed),
option_sync_entry!(o.kind, object::kind),
option_sync_entry!(o.hidden, object::hidden),
option_sync_entry!(o.favorite, object::favorite),
option_sync_entry!(o.important, object::important),
option_sync_entry!(o.note, object::note),
option_sync_entry!(o.date_created, object::date_created),
option_sync_entry!(o.date_accessed, object::date_accessed),
option_sync_entry!(
o.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
object::device
),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -265,22 +366,21 @@ async fn paginate_objects(
#[instrument(skip(db, sync), err)]
async fn paginate_exif_datas(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use exif_data::{
artist, camera_data, copyright, description, epoch_time, exif_version, id, include,
media_date, media_location, resolution,
};
paginate(
|cursor| {
db.exif_data()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.find_many(vec![
exif_data::id::gt(cursor),
exif_data::device_id::equals(Some(device_id)),
])
.order_by(exif_data::id::order(SortOrder::Asc))
.take(1000)
.include(include!({
.include(exif_data::include!({
object: select { pub_id }
device: select { pub_id }
}))
.exec()
},
@@ -288,7 +388,7 @@ async fn paginate_exif_datas(
|exif_datas| {
exif_datas
.into_iter()
.flat_map(|ed| {
.map(|ed| {
sync.shared_create(
prisma_sync::exif_data::SyncId {
object: prisma_sync::object::SyncId {
@@ -298,20 +398,28 @@ async fn paginate_exif_datas(
chain_optional_iter(
[],
[
option_sync_entry!(ed.resolution, resolution),
option_sync_entry!(ed.media_date, media_date),
option_sync_entry!(ed.media_location, media_location),
option_sync_entry!(ed.camera_data, camera_data),
option_sync_entry!(ed.artist, artist),
option_sync_entry!(ed.description, description),
option_sync_entry!(ed.copyright, copyright),
option_sync_entry!(ed.exif_version, exif_version),
option_sync_entry!(ed.epoch_time, epoch_time),
option_sync_entry!(ed.resolution, exif_data::resolution),
option_sync_entry!(ed.media_date, exif_data::media_date),
option_sync_entry!(ed.media_location, exif_data::media_location),
option_sync_entry!(ed.camera_data, exif_data::camera_data),
option_sync_entry!(ed.artist, exif_data::artist),
option_sync_entry!(ed.description, exif_data::description),
option_sync_entry!(ed.copyright, exif_data::copyright),
option_sync_entry!(ed.exif_version, exif_data::exif_version),
option_sync_entry!(ed.epoch_time, exif_data::epoch_time),
option_sync_entry!(
ed.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
exif_data::device
),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -322,22 +430,21 @@ async fn paginate_exif_datas(
#[instrument(skip(db, sync), err)]
async fn paginate_file_paths(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use file_path::{
cas_id, date_created, date_indexed, date_modified, extension, hidden, id, include, inode,
integrity_checksum, is_dir, location, materialized_path, name, object, size_in_bytes_bytes,
};
paginate(
|cursor| {
db.file_path()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.include(include!({
.find_many(vec![
file_path::id::gt(cursor),
file_path::device_id::equals(Some(device_id)),
])
.order_by(file_path::id::order(SortOrder::Asc))
.include(file_path::include!({
location: select { pub_id }
object: select { pub_id }
device: select { pub_id }
}))
.exec()
},
@@ -345,41 +452,58 @@ async fn paginate_file_paths(
|file_paths| {
file_paths
.into_iter()
.flat_map(|fp| {
.map(|fp| {
sync.shared_create(
prisma_sync::file_path::SyncId { pub_id: fp.pub_id },
chain_optional_iter(
[],
[
option_sync_entry!(fp.is_dir, is_dir),
option_sync_entry!(fp.cas_id, cas_id),
option_sync_entry!(fp.integrity_checksum, integrity_checksum),
option_sync_entry!(fp.is_dir, file_path::is_dir),
option_sync_entry!(fp.cas_id, file_path::cas_id),
option_sync_entry!(
fp.integrity_checksum,
file_path::integrity_checksum
),
option_sync_entry!(
fp.location.map(|l| {
prisma_sync::location::SyncId { pub_id: l.pub_id }
}),
location
file_path::location
),
option_sync_entry!(
fp.object.map(|o| {
prisma_sync::object::SyncId { pub_id: o.pub_id }
}),
object
file_path::object
),
option_sync_entry!(
fp.materialized_path,
file_path::materialized_path
),
option_sync_entry!(fp.name, file_path::name),
option_sync_entry!(fp.extension, file_path::extension),
option_sync_entry!(fp.hidden, file_path::hidden),
option_sync_entry!(
fp.size_in_bytes_bytes,
file_path::size_in_bytes_bytes
),
option_sync_entry!(fp.inode, file_path::inode),
option_sync_entry!(fp.date_created, file_path::date_created),
option_sync_entry!(fp.date_modified, file_path::date_modified),
option_sync_entry!(fp.date_indexed, file_path::date_indexed),
option_sync_entry!(
fp.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
file_path::device
),
option_sync_entry!(fp.materialized_path, materialized_path),
option_sync_entry!(fp.name, name),
option_sync_entry!(fp.extension, extension),
option_sync_entry!(fp.hidden, hidden),
option_sync_entry!(fp.size_in_bytes_bytes, size_in_bytes_bytes),
option_sync_entry!(fp.inode, inode),
option_sync_entry!(fp.date_created, date_created),
option_sync_entry!(fp.date_modified, date_modified),
option_sync_entry!(fp.date_indexed, date_indexed),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -390,20 +514,23 @@ async fn paginate_file_paths(
#[instrument(skip(db, sync), err)]
async fn paginate_tags_on_objects(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use tag_on_object::{date_created, include, object_id, tag_id};
paginate_relation(
|group_id, item_id| {
db.tag_on_object()
.find_many(vec![tag_id::gt(group_id), object_id::gt(item_id)])
.order_by(tag_id::order(SortOrder::Asc))
.order_by(object_id::order(SortOrder::Asc))
.include(include!({
.find_many(vec![
tag_on_object::tag_id::gt(group_id),
tag_on_object::object_id::gt(item_id),
tag_on_object::device_id::equals(Some(device_id)),
])
.order_by(tag_on_object::tag_id::order(SortOrder::Asc))
.order_by(tag_on_object::object_id::order(SortOrder::Asc))
.include(tag_on_object::include!({
tag: select { pub_id }
object: select { pub_id }
device: select { pub_id }
}))
.exec()
},
@@ -411,7 +538,7 @@ async fn paginate_tags_on_objects(
|tag_on_objects| {
tag_on_objects
.into_iter()
.flat_map(|t_o| {
.map(|t_o| {
sync.relation_create(
prisma_sync::tag_on_object::SyncId {
tag: prisma_sync::tag::SyncId {
@@ -423,11 +550,21 @@ async fn paginate_tags_on_objects(
},
chain_optional_iter(
[],
[option_sync_entry!(t_o.date_created, date_created)],
[
option_sync_entry!(t_o.date_created, tag_on_object::date_created),
option_sync_entry!(
t_o.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
tag_on_object::device
),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -436,37 +573,31 @@ async fn paginate_tags_on_objects(
}
#[instrument(skip(db, sync), err)]
async fn paginate_labels(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
) -> Result<(), Error> {
use label::{date_created, date_modified, id};
async fn paginate_labels(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> {
paginate(
|cursor| {
db.label()
.find_many(vec![id::gt(cursor)])
.order_by(id::order(SortOrder::Asc))
.find_many(vec![label::id::gt(cursor)])
.order_by(label::id::order(SortOrder::Asc))
.exec()
},
|label| label.id,
|labels| {
labels
.into_iter()
.flat_map(|l| {
.map(|l| {
sync.shared_create(
prisma_sync::label::SyncId { name: l.name },
chain_optional_iter(
[],
[
option_sync_entry!(l.date_created, date_created),
option_sync_entry!(l.date_modified, date_modified),
option_sync_entry!(l.date_created, label::date_created),
option_sync_entry!(l.date_modified, label::date_modified),
],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},
@@ -477,20 +608,23 @@ async fn paginate_labels(
#[instrument(skip(db, sync), err)]
async fn paginate_labels_on_objects(
db: &PrismaClient,
sync: &crate::Manager,
instance_id: instance::id::Type,
sync: &SyncManager,
device_id: device::id::Type,
) -> Result<(), Error> {
use label_on_object::{date_created, include, label_id, object_id};
paginate_relation(
|group_id, item_id| {
db.label_on_object()
.find_many(vec![label_id::gt(group_id), object_id::gt(item_id)])
.order_by(label_id::order(SortOrder::Asc))
.order_by(object_id::order(SortOrder::Asc))
.include(include!({
.find_many(vec![
label_on_object::label_id::gt(group_id),
label_on_object::object_id::gt(item_id),
label_on_object::device_id::equals(Some(device_id)),
])
.order_by(label_on_object::label_id::order(SortOrder::Asc))
.order_by(label_on_object::object_id::order(SortOrder::Asc))
.include(label_on_object::include!({
object: select { pub_id }
label: select { name }
device: select { pub_id }
}))
.exec()
},
@@ -498,7 +632,7 @@ async fn paginate_labels_on_objects(
|label_on_objects| {
label_on_objects
.into_iter()
.flat_map(|l_o| {
.map(|l_o| {
sync.relation_create(
prisma_sync::label_on_object::SyncId {
label: prisma_sync::label::SyncId {
@@ -508,10 +642,20 @@ async fn paginate_labels_on_objects(
pub_id: l_o.object.pub_id,
},
},
[sync_entry!(l_o.date_created, date_created)],
chain_optional_iter(
[sync_entry!(l_o.date_created, label_on_object::date_created)],
[option_sync_entry!(
l_o.device.map(|device| {
prisma_sync::device::SyncId {
pub_id: device.pub_id,
}
}),
label_on_object::device
)],
),
)
})
.map(|o| crdt_op_unchecked_db(&o, instance_id))
.map(|o| crdt_op_unchecked_db(&o))
.collect::<Result<Vec<_>, _>>()
.map(|creates| db.crdt_operation().create_many(creates).exec())
},

View File

@@ -1,82 +1,14 @@
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient};
use sd_core_prisma_helpers::DevicePubId;
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, PrismaClient};
use sd_sync::CRDTOperation;
use sd_utils::from_bytes_to_uuid;
use sd_utils::uuid_to_bytes;
use tracing::instrument;
use uhlc::NTP64;
use uuid::Uuid;
use super::Error;
crdt_operation::include!(crdt_with_instance {
instance: select { pub_id }
});
cloud_crdt_operation::include!(cloud_crdt_with_instance {
instance: select { pub_id }
});
impl crdt_with_instance::Data {
#[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations
pub const fn timestamp(&self) -> NTP64 {
NTP64(self.timestamp as u64)
}
pub fn instance(&self) -> Uuid {
from_bytes_to_uuid(&self.instance.pub_id)
}
pub fn into_operation(self) -> Result<CRDTOperation, Error> {
Ok(CRDTOperation {
instance: self.instance(),
timestamp: self.timestamp(),
record_id: rmp_serde::from_slice(&self.record_id)?,
model: {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
// SAFETY: we will not have more than 2^16 models and we had to store using signed
// integers due to SQLite limitations
{
self.model as u16
}
},
data: rmp_serde::from_slice(&self.data)?,
})
}
}
impl cloud_crdt_with_instance::Data {
#[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations
pub const fn timestamp(&self) -> NTP64 {
NTP64(self.timestamp as u64)
}
pub fn instance(&self) -> Uuid {
from_bytes_to_uuid(&self.instance.pub_id)
}
#[instrument(skip(self), err)]
pub fn into_operation(self) -> Result<(i32, CRDTOperation), Error> {
Ok((
self.id,
CRDTOperation {
instance: self.instance(),
timestamp: self.timestamp(),
record_id: rmp_serde::from_slice(&self.record_id)?,
model: {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
// SAFETY: we will not have more than 2^16 models and we had to store using signed
// integers due to SQLite limitations
{
self.model as u16
}
},
data: rmp_serde::from_slice(&self.data)?,
},
))
}
}
#[instrument(skip(op, db), err)]
pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Result<(), Error> {
crdt_operation::Create {
@@ -87,16 +19,85 @@ pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Resul
op.timestamp.0 as i64
}
},
instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()),
device_pub_id: uuid_to_bytes(&op.device_pub_id),
kind: op.kind().to_string(),
data: rmp_serde::to_vec(&op.data)?,
model: i32::from(op.model),
model: i32::from(op.model_id),
record_id: rmp_serde::to_vec(&op.record_id)?,
_params: vec![],
}
.to_query(db)
.select(crdt_operation::select!({ id })) // To don't fetch the whole object for nothing
.exec()
.await
.map_or_else(|e| Err(e.into()), |_| Ok(()))
.await?;
Ok(())
}
pub fn from_crdt_ops(
crdt_operation::Data {
timestamp,
model,
record_id,
data,
device_pub_id,
..
}: crdt_operation::Data,
) -> Result<CRDTOperation, Error> {
Ok(CRDTOperation {
device_pub_id: DevicePubId::from(device_pub_id).into(),
timestamp: {
#[allow(clippy::cast_sign_loss)]
{
// SAFETY: we had to store using i64 due to SQLite limitations
NTP64(timestamp as u64)
}
},
model_id: {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
// SAFETY: we will not have more than 2^16 models and we had to store using signed
// integers due to SQLite limitations
model as u16
}
},
record_id: rmp_serde::from_slice(&record_id)?,
data: rmp_serde::from_slice(&data)?,
})
}
pub fn from_cloud_crdt_ops(
cloud_crdt_operation::Data {
id,
timestamp,
model,
record_id,
data,
device_pub_id,
..
}: cloud_crdt_operation::Data,
) -> Result<(cloud_crdt_operation::id::Type, CRDTOperation), Error> {
Ok((
id,
CRDTOperation {
device_pub_id: DevicePubId::from(device_pub_id).into(),
timestamp: {
#[allow(clippy::cast_sign_loss)]
{
// SAFETY: we had to store using i64 due to SQLite limitations
NTP64(timestamp as u64)
}
},
model_id: {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
// SAFETY: we will not have more than 2^16 models and we had to store using signed
// integers due to SQLite limitations
model as u16
}
},
record_id: rmp_serde::from_slice(&record_id)?,
data: rmp_serde::from_slice(&data)?,
},
))
}

View File

@@ -1,641 +0,0 @@
use sd_prisma::{
prisma::{crdt_operation, PrismaClient, SortOrder},
prisma_sync::ModelSyncData,
};
use sd_sync::{
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, CompressedCRDTOperations,
OperationKind,
};
use std::{
collections::BTreeMap,
future::IntoFuture,
num::NonZeroU128,
ops::Deref,
pin::pin,
sync::{atomic::Ordering, Arc},
time::SystemTime,
};
use async_channel as chan;
use futures::{stream, FutureExt, StreamExt};
use futures_concurrency::{
future::{Race, TryJoin},
stream::Merge,
};
use prisma_client_rust::chrono::{DateTime, Utc};
use tokio::sync::oneshot;
use tracing::{debug, error, instrument, trace, warn};
use uhlc::{Timestamp, NTP64};
use uuid::Uuid;
use super::{
actor::{create_actor_io, ActorIO, ActorTypes, HandlerIO},
db_operation::write_crdt_op_to_db,
Error, SharedState,
};
#[derive(Debug)]
#[must_use]
/// Stuff that can be handled outside the actor
pub enum Request {
Messages {
timestamps: Vec<(Uuid, NTP64)>,
tx: oneshot::Sender<()>,
},
FinishedIngesting,
}
/// Stuff that the actor consumes
#[derive(Debug)]
pub enum Event {
Notification,
Messages(MessagesEvent),
}
#[derive(Debug, Default)]
pub enum State {
#[default]
WaitingForNotification,
RetrievingMessages,
Ingesting(MessagesEvent),
}
/// The single entrypoint for sync operation ingestion.
/// Requests sync operations in a given timestamp range,
/// and attempts to write them to the sync operations table along with
/// the actual cell that the operation points to.
///
/// If this actor stops running, no sync operations will
/// be applied to the database, independent of whether systems like p2p
/// or cloud are exchanging messages.
pub struct Actor {
state: Option<State>,
shared: Arc<SharedState>,
io: ActorIO<Self>,
}
impl Actor {
#[instrument(skip(self), fields(old_state = ?self.state))]
async fn tick(&mut self) {
let state = match self
.state
.take()
.expect("ingest actor in inconsistent state")
{
State::WaitingForNotification => self.waiting_for_notification_state_transition().await,
State::RetrievingMessages => self.retrieving_messages_state_transition().await,
State::Ingesting(event) => self.ingesting_state_transition(event).await,
};
trace!(?state, "Actor state transitioned;");
self.state = Some(state);
}
async fn waiting_for_notification_state_transition(&self) -> State {
self.shared.active.store(false, Ordering::Relaxed);
self.shared.active_notify.notify_waiters();
loop {
match self
.io
.event_rx
.recv()
.await
.expect("sync actor receiver unexpectedly closed")
{
Event::Notification => {
trace!("Received notification");
break;
}
Event::Messages(event) => {
trace!(
?event,
"Ignored event message as we're waiting for a `Event::Notification`"
);
}
}
}
self.shared.active.store(true, Ordering::Relaxed);
self.shared.active_notify.notify_waiters();
State::RetrievingMessages
}
async fn retrieving_messages_state_transition(&self) -> State {
enum StreamMessage {
NewEvent(Event),
AckedRequest(Result<(), oneshot::error::RecvError>),
}
let (tx, rx) = oneshot::channel::<()>();
let timestamps = self
.timestamps
.read()
.await
.iter()
.map(|(&uid, &timestamp)| (uid, timestamp))
.collect();
if self
.io
.send(Request::Messages { timestamps, tx })
.await
.is_err()
{
warn!("Failed to send messages request");
}
let mut msg_stream = pin!((
self.io.event_rx.clone().map(StreamMessage::NewEvent),
stream::once(rx.map(StreamMessage::AckedRequest)),
)
.merge());
loop {
if let Some(msg) = msg_stream.next().await {
match msg {
StreamMessage::NewEvent(event) => {
if let Event::Messages(messages) = event {
trace!(?messages, "Received messages;");
break State::Ingesting(messages);
}
}
StreamMessage::AckedRequest(res) => {
if res.is_err() {
debug!("messages request ignored");
break State::WaitingForNotification;
}
}
}
} else {
break State::WaitingForNotification;
}
}
}
async fn ingesting_state_transition(&mut self, event: MessagesEvent) -> State {
debug!(
messages_count = event.messages.len(),
first_message = ?DateTime::<Utc>::from(
event.messages
.first()
.map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time())
),
last_message = ?DateTime::<Utc>::from(
event.messages
.last()
.map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time())
),
"Ingesting operations;",
);
for (instance, data) in event.messages.0 {
for (model, data) in data {
for (record, ops) in data {
if let Err(e) = self
.process_crdt_operations(instance, model, record, ops)
.await
{
error!(?e, "Failed to ingest CRDT operations;");
}
}
}
}
if let Some(tx) = event.wait_tx {
if tx.send(()).is_err() {
warn!("Failed to send wait_tx signal");
}
}
if event.has_more {
State::RetrievingMessages
} else {
{
if self.io.send(Request::FinishedIngesting).await.is_err() {
error!("Failed to send finished ingesting request");
}
State::WaitingForNotification
}
}
}
pub async fn declare(shared: Arc<SharedState>) -> Handler {
let (io, HandlerIO { event_tx, req_rx }) = create_actor_io::<Self>();
shared
.actors
.declare(
"Sync Ingest",
{
let shared = Arc::clone(&shared);
move |stop| async move {
enum Race {
Ticked,
Stopped,
}
let mut this = Self {
state: Some(State::default()),
io,
shared,
};
while matches!(
(
this.tick().map(|()| Race::Ticked),
stop.into_future().map(|()| Race::Stopped),
)
.race()
.await,
Race::Ticked
) { /* Everything is Awesome! */ }
}
},
true,
)
.await;
Handler { event_tx, req_rx }
}
// where the magic happens
#[instrument(skip(self, ops), fields(operations_count = %ops.len()), err)]
async fn process_crdt_operations(
&mut self,
instance: Uuid,
model: u16,
record_id: rmpv::Value,
mut ops: Vec<CompressedCRDTOperation>,
) -> Result<(), Error> {
let db = &self.db;
ops.sort_by_key(|op| op.timestamp);
let new_timestamp = ops.last().expect("Empty ops array").timestamp;
// first, we update the HLC's timestamp with the incoming one.
// this involves a drift check + sets the last time of the clock
self.clock
.update_with_timestamp(&Timestamp::new(
new_timestamp,
uhlc::ID::from(NonZeroU128::new(instance.to_u128_le()).expect("Non zero id")),
))
.expect("timestamp has too much drift!");
// read the timestamp for the operation's instance, or insert one if it doesn't exist
let timestamp = self.timestamps.read().await.get(&instance).copied();
// Delete - ignores all other messages
if let Some(delete_op) = ops
.iter()
.rev()
.find(|op| matches!(op.data, CRDTOperationData::Delete))
{
trace!("Deleting operation");
handle_crdt_deletion(db, instance, model, record_id, delete_op).await?;
}
// Create + > 0 Update - overwrites the create's data with the updates
else if let Some(timestamp) = ops
.iter()
.rev()
.find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp))
{
trace!("Create + Updates operations");
// conflict resolution
let delete = db
.crdt_operation()
.find_first(vec![
crdt_operation::model::equals(i32::from(model)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
])
.order_by(crdt_operation::timestamp::order(SortOrder::Desc))
.exec()
.await?;
if delete.is_some() {
debug!("Found a previous delete operation with the same SyncId, will ignore these operations");
return Ok(());
}
handle_crdt_create_and_updates(db, instance, model, record_id, ops, timestamp).await?;
}
// > 0 Update - batches updates with a fake Create op
else {
trace!("Updates operation");
let mut data = BTreeMap::new();
for op in ops.into_iter().rev() {
let CRDTOperationData::Update { field, value } = op.data else {
unreachable!("Create + Delete should be filtered out!");
};
data.insert(field, (value, op.timestamp));
}
// conflict resolution
let (create, updates) = db
._batch((
db.crdt_operation()
.find_first(vec![
crdt_operation::model::equals(i32::from(model)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
crdt_operation::kind::equals(OperationKind::Create.to_string()),
])
.order_by(crdt_operation::timestamp::order(SortOrder::Desc)),
data.iter()
.map(|(k, (_, timestamp))| {
Ok(db
.crdt_operation()
.find_first(vec![
crdt_operation::timestamp::gt({
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we had to store using i64 due to SQLite limitations
{
timestamp.as_u64() as i64
}
}),
crdt_operation::model::equals(i32::from(model)),
crdt_operation::record_id::equals(rmp_serde::to_vec(
&record_id,
)?),
crdt_operation::kind::equals(
OperationKind::Update(k).to_string(),
),
])
.order_by(crdt_operation::timestamp::order(SortOrder::Desc)))
})
.collect::<Result<Vec<_>, Error>>()?,
))
.await?;
if create.is_none() {
warn!("Failed to find a previous create operation with the same SyncId");
return Ok(());
}
handle_crdt_updates(db, instance, model, record_id, data, updates).await?;
}
// update the stored timestamp for this instance - will be derived from the crdt operations table on restart
let new_ts = NTP64::max(timestamp.unwrap_or_default(), new_timestamp);
self.timestamps.write().await.insert(instance, new_ts);
Ok(())
}
}
async fn handle_crdt_updates(
db: &PrismaClient,
instance: Uuid,
model: u16,
record_id: rmpv::Value,
mut data: BTreeMap<String, (rmpv::Value, NTP64)>,
updates: Vec<Option<crdt_operation::Data>>,
) -> Result<(), Error> {
let keys = data.keys().cloned().collect::<Vec<_>>();
// does the same thing as processing ops one-by-one and returning early if a newer op was found
for (update, key) in updates.into_iter().zip(keys) {
if update.is_some() {
data.remove(&key);
}
}
db._transaction()
.with_timeout(30 * 1000)
.run(|db| async move {
// fake operation to batch them all at once
ModelSyncData::from_op(CRDTOperation {
instance,
model,
record_id: record_id.clone(),
timestamp: NTP64(0),
data: CRDTOperationData::Create(
data.iter()
.map(|(k, (data, _))| (k.clone(), data.clone()))
.collect(),
),
})
.ok_or(Error::InvalidModelId(model))?
.exec(&db)
.await?;
// need to only apply ops that haven't been filtered out
data.into_iter()
.map(|(field, (value, timestamp))| {
let record_id = record_id.clone();
let db = &db;
async move {
write_crdt_op_to_db(
&CRDTOperation {
instance,
model,
record_id,
timestamp,
data: CRDTOperationData::Update { field, value },
},
db,
)
.await
}
})
.collect::<Vec<_>>()
.try_join()
.await
.map(|_| ())
})
.await
}
async fn handle_crdt_create_and_updates(
db: &PrismaClient,
instance: Uuid,
model: u16,
record_id: rmpv::Value,
ops: Vec<CompressedCRDTOperation>,
timestamp: NTP64,
) -> Result<(), Error> {
let mut data = BTreeMap::new();
let mut applied_ops = vec![];
// search for all Updates until a Create is found
for op in ops.iter().rev() {
match &op.data {
CRDTOperationData::Delete => unreachable!("Delete can't exist here!"),
CRDTOperationData::Create(create_data) => {
for (k, v) in create_data {
data.entry(k).or_insert(v);
}
applied_ops.push(op);
break;
}
CRDTOperationData::Update { field, value } => {
applied_ops.push(op);
data.insert(field, value);
}
}
}
db._transaction()
.with_timeout(30 * 1000)
.run(|db| async move {
// fake a create with a bunch of data rather than individual insert
ModelSyncData::from_op(CRDTOperation {
instance,
model,
record_id: record_id.clone(),
timestamp,
data: CRDTOperationData::Create(
data.into_iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
})
.ok_or(Error::InvalidModelId(model))?
.exec(&db)
.await?;
applied_ops
.into_iter()
.map(|op| {
let record_id = record_id.clone();
let db = &db;
async move {
let operation = CRDTOperation {
instance,
model,
record_id,
timestamp: op.timestamp,
data: op.data.clone(),
};
write_crdt_op_to_db(&operation, db).await
}
})
.collect::<Vec<_>>()
.try_join()
.await
.map(|_| ())
})
.await
}
async fn handle_crdt_deletion(
db: &PrismaClient,
instance: Uuid,
model: u16,
record_id: rmpv::Value,
delete_op: &CompressedCRDTOperation,
) -> Result<(), Error> {
// deletes are the be all and end all, no need to check anything
let op = CRDTOperation {
instance,
model,
record_id,
timestamp: delete_op.timestamp,
data: CRDTOperationData::Delete,
};
db._transaction()
.with_timeout(30 * 1000)
.run(|db| async move {
ModelSyncData::from_op(op.clone())
.ok_or(Error::InvalidModelId(model))?
.exec(&db)
.await?;
write_crdt_op_to_db(&op, &db).await
})
.await
}
impl Deref for Actor {
type Target = SharedState;
fn deref(&self) -> &Self::Target {
&self.shared
}
}
pub struct Handler {
pub event_tx: chan::Sender<Event>,
pub req_rx: chan::Receiver<Request>,
}
#[derive(Debug)]
pub struct MessagesEvent {
pub instance_id: Uuid,
pub messages: CompressedCRDTOperations,
pub has_more: bool,
pub wait_tx: Option<oneshot::Sender<()>>,
}
impl ActorTypes for Actor {
type Event = Event;
type Request = Request;
type Handler = Handler;
}
#[cfg(test)]
mod test {
use std::{sync::atomic::AtomicBool, time::Duration};
use tokio::sync::Notify;
use uhlc::HLCBuilder;
use super::*;
async fn new_actor() -> (Handler, Arc<SharedState>) {
let instance = Uuid::new_v4();
let shared = Arc::new(SharedState {
db: sd_prisma::test_db().await,
instance,
clock: HLCBuilder::new()
.with_id(uhlc::ID::from(
NonZeroU128::new(instance.to_u128_le()).expect("Non zero id"),
))
.build(),
timestamps: Arc::default(),
emit_messages_flag: Arc::new(AtomicBool::new(true)),
active: AtomicBool::default(),
active_notify: Notify::default(),
actors: Arc::default(),
});
(Actor::declare(Arc::clone(&shared)).await, shared)
}
/// If messages tx is dropped, actor should reset and assume no further messages
/// will be sent
#[tokio::test]
async fn messages_request_drop() -> Result<(), ()> {
let (ingest, _) = new_actor().await;
for _ in 0..10 {
ingest.event_tx.send(Event::Notification).await.unwrap();
let Ok(Request::Messages { .. }) = ingest.req_rx.recv().await else {
panic!("bruh")
};
// without this the test hangs, idk
tokio::time::sleep(Duration::from_millis(0)).await;
}
Ok(())
}
}

View File

@@ -0,0 +1,517 @@
use sd_core_prisma_helpers::DevicePubId;
use sd_prisma::{
prisma::{crdt_operation, PrismaClient},
prisma_sync::ModelSyncData,
};
use sd_sync::{
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationKind, RecordId,
};
use std::{collections::BTreeMap, num::NonZeroU128, sync::Arc};
use futures_concurrency::future::TryJoin;
use tokio::sync::Mutex;
use tracing::{debug, instrument, trace, warn};
use uhlc::{Timestamp, HLC, NTP64};
use uuid::Uuid;
use super::{db_operation::write_crdt_op_to_db, Error, TimestampPerDevice};
crdt_operation::select!(crdt_operation_id { id });
// where the magic happens
#[instrument(skip(clock, ops), fields(operations_count = %ops.len()), err)]
pub async fn process_crdt_operations(
clock: &HLC,
timestamp_per_device: &TimestampPerDevice,
sync_lock: Arc<Mutex<()>>,
db: &PrismaClient,
device_pub_id: DevicePubId,
model_id: ModelId,
(record_id, mut ops): (RecordId, Vec<CompressedCRDTOperation>),
) -> Result<(), Error> {
ops.sort_by_key(|op| op.timestamp);
let new_timestamp = ops.last().expect("Empty ops array").timestamp;
update_clock(clock, new_timestamp, &device_pub_id);
// Delete - ignores all other messages
if let Some(delete_op) = ops
.iter()
.rev()
.find(|op| matches!(op.data, CRDTOperationData::Delete))
{
trace!("Deleting operation");
handle_crdt_deletion(
db,
&sync_lock,
&device_pub_id,
model_id,
record_id,
delete_op,
)
.await?;
}
// Create + > 0 Update - overwrites the create's data with the updates
else if let Some(timestamp) = ops
.iter()
.rev()
.find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp))
{
trace!("Create + Updates operations");
// conflict resolution
let delete_count = db
.crdt_operation()
.count(vec![
crdt_operation::model::equals(i32::from(model_id)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
])
.exec()
.await?;
if delete_count > 0 {
debug!("Found a previous delete operation with the same SyncId, will ignore these operations");
return Ok(());
}
handle_crdt_create_and_updates(
db,
&sync_lock,
&device_pub_id,
model_id,
record_id,
ops,
timestamp,
)
.await?;
}
// > 0 Update - batches updates with a fake Create op
else {
trace!("Updates operation");
let mut data = BTreeMap::new();
for op in ops.into_iter().rev() {
let CRDTOperationData::Update(fields_and_values) = op.data else {
unreachable!("Create + Delete should be filtered out!");
};
for (field, value) in fields_and_values {
data.insert(field, (value, op.timestamp));
}
}
let earlier_time = data.values().fold(
NTP64(u64::from(u32::MAX)),
|earlier_time, (_, timestamp)| {
if timestamp.0 < earlier_time.0 {
*timestamp
} else {
earlier_time
}
},
);
// conflict resolution
let (create, possible_newer_updates_count) = db
._batch((
db.crdt_operation().count(vec![
crdt_operation::model::equals(i32::from(model_id)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
crdt_operation::kind::equals(OperationKind::Create.to_string()),
]),
// Fetching all update operations newer than our current earlier timestamp
db.crdt_operation()
.find_many(vec![
crdt_operation::timestamp::gt({
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we had to store using i64 due to SQLite limitations
{
earlier_time.as_u64() as i64
}
}),
crdt_operation::model::equals(i32::from(model_id)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
crdt_operation::kind::starts_with("u".to_string()),
])
.select(crdt_operation::select!({ kind timestamp })),
))
.await?;
if create == 0 {
warn!("Failed to find a previous create operation with the same SyncId");
return Ok(());
}
for candidate in possible_newer_updates_count {
// The first element is "u" meaning that this is an update, so we skip it
for key in candidate
.kind
.split(':')
.filter(|field| !field.is_empty())
.skip(1)
{
// remove entries if we possess locally more recent updates for this field
if data.get(key).is_some_and(|(_, new_timestamp)| {
#[allow(clippy::cast_sign_loss)]
{
// we need to store as i64 due to SQLite limitations
*new_timestamp < NTP64(candidate.timestamp as u64)
}
}) {
data.remove(key);
}
}
if data.is_empty() {
break;
}
}
handle_crdt_updates(db, &sync_lock, &device_pub_id, model_id, record_id, data).await?;
}
update_timestamp_per_device(timestamp_per_device, device_pub_id, new_timestamp).await;
Ok(())
}
pub async fn bulk_ingest_create_only_ops(
clock: &HLC,
timestamp_per_device: &TimestampPerDevice,
db: &PrismaClient,
device_pub_id: DevicePubId,
model_id: ModelId,
ops: Vec<(RecordId, CompressedCRDTOperation)>,
sync_lock: Arc<Mutex<()>>,
) -> Result<(), Error> {
let latest_timestamp = ops.iter().fold(NTP64(0), |latest, (_, op)| {
if latest < op.timestamp {
op.timestamp
} else {
latest
}
});
update_clock(clock, latest_timestamp, &device_pub_id);
let ops = ops
.into_iter()
.map(|(record_id, op)| {
rmp_serde::to_vec(&record_id)
.map(|serialized_record_id| (record_id, serialized_record_id, op))
})
.collect::<Result<Vec<_>, _>>()?;
// conflict resolution
let delete_counts = db
._batch(
ops.iter()
.map(|(_, serialized_record_id, _)| {
db.crdt_operation().count(vec![
crdt_operation::model::equals(i32::from(model_id)),
crdt_operation::record_id::equals(serialized_record_id.clone()),
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
])
})
.collect::<Vec<_>>(),
)
.await?;
let lock_guard = sync_lock.lock().await;
db._transaction()
.with_timeout(30 * 10000)
.with_max_wait(30 * 10000)
.run(|db| {
let device_pub_id = device_pub_id.clone();
async move {
// complying with borrowck
let device_pub_id = &device_pub_id;
let (crdt_creates, model_sync_data) = ops
.into_iter()
.zip(delete_counts)
.filter_map(|(data, delete_count)| (delete_count == 0).then_some(data))
.map(
|(
record_id,
serialized_record_id,
CompressedCRDTOperation { timestamp, data },
)| {
let crdt_create = crdt_operation::CreateUnchecked {
timestamp: {
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we have to store using i64 due to SQLite limitations
{
timestamp.0 as i64
}
},
model: i32::from(model_id),
record_id: serialized_record_id,
kind: "c".to_string(),
data: rmp_serde::to_vec(&data)?,
device_pub_id: device_pub_id.to_db(),
_params: vec![],
};
// NOTE(@fogodev): I wish I could do a create many here instead of creating separately each
// entry, but it's not supported by PCR
let model_sync_data = ModelSyncData::from_op(CRDTOperation {
device_pub_id: Uuid::from(device_pub_id),
model_id,
record_id,
timestamp,
data,
})?
.exec(&db);
Ok::<_, Error>((crdt_create, model_sync_data))
},
)
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
model_sync_data.try_join().await?;
db.crdt_operation().create_many(crdt_creates).exec().await?;
Ok::<_, Error>(())
}
})
.await?;
drop(lock_guard);
update_timestamp_per_device(timestamp_per_device, device_pub_id, latest_timestamp).await;
Ok(())
}
#[instrument(skip_all, err)]
async fn handle_crdt_updates(
db: &PrismaClient,
sync_lock: &Mutex<()>,
device_pub_id: &DevicePubId,
model_id: ModelId,
record_id: rmpv::Value,
data: BTreeMap<String, (rmpv::Value, NTP64)>,
) -> Result<(), Error> {
let device_pub_id = sd_sync::DevicePubId::from(device_pub_id);
let _lock_guard = sync_lock.lock().await;
db._transaction()
.with_timeout(30 * 10000)
.with_max_wait(30 * 10000)
.run(|db| async move {
// fake operation to batch them all at once, inserting the latest data on appropriate table
ModelSyncData::from_op(CRDTOperation {
device_pub_id,
model_id,
record_id: record_id.clone(),
timestamp: NTP64(0),
data: CRDTOperationData::Create(
data.iter()
.map(|(k, (data, _))| (k.clone(), data.clone()))
.collect(),
),
})?
.exec(&db)
.await?;
let (fields_and_values, latest_timestamp) = data.into_iter().fold(
(BTreeMap::new(), NTP64::default()),
|(mut fields_and_values, mut latest_time_stamp), (field, (value, timestamp))| {
fields_and_values.insert(field, value);
if timestamp > latest_time_stamp {
latest_time_stamp = timestamp;
}
(fields_and_values, latest_time_stamp)
},
);
write_crdt_op_to_db(
&CRDTOperation {
device_pub_id,
model_id,
record_id,
timestamp: latest_timestamp,
data: CRDTOperationData::Update(fields_and_values),
},
&db,
)
.await
})
.await
}
#[instrument(skip_all, err)]
async fn handle_crdt_create_and_updates(
db: &PrismaClient,
sync_lock: &Mutex<()>,
device_pub_id: &DevicePubId,
model_id: ModelId,
record_id: rmpv::Value,
ops: Vec<CompressedCRDTOperation>,
timestamp: NTP64,
) -> Result<(), Error> {
let mut data = BTreeMap::new();
let device_pub_id = sd_sync::DevicePubId::from(device_pub_id);
let mut applied_ops = vec![];
// search for all Updates until a Create is found
for op in ops.into_iter().rev() {
match &op.data {
CRDTOperationData::Delete => unreachable!("Delete can't exist here!"),
CRDTOperationData::Create(create_data) => {
for (k, v) in create_data {
data.entry(k.clone()).or_insert_with(|| v.clone());
}
applied_ops.push(op);
break;
}
CRDTOperationData::Update(fields_and_values) => {
for (field, value) in fields_and_values {
data.insert(field.clone(), value.clone());
}
applied_ops.push(op);
}
}
}
let _lock_guard = sync_lock.lock().await;
db._transaction()
.with_timeout(30 * 10000)
.with_max_wait(30 * 10000)
.run(|db| async move {
// fake a create with a bunch of data rather than individual insert
ModelSyncData::from_op(CRDTOperation {
device_pub_id,
model_id,
record_id: record_id.clone(),
timestamp,
data: CRDTOperationData::Create(data),
})?
.exec(&db)
.await?;
applied_ops
.into_iter()
.map(|CompressedCRDTOperation { timestamp, data }| {
let record_id = record_id.clone();
let db = &db;
async move {
write_crdt_op_to_db(
&CRDTOperation {
device_pub_id,
timestamp,
model_id,
record_id,
data,
},
db,
)
.await
}
})
.collect::<Vec<_>>()
.try_join()
.await
.map(|_| ())
})
.await
}
#[instrument(skip_all, err)]
async fn handle_crdt_deletion(
db: &PrismaClient,
sync_lock: &Mutex<()>,
device_pub_id: &DevicePubId,
model: u16,
record_id: rmpv::Value,
delete_op: &CompressedCRDTOperation,
) -> Result<(), Error> {
// deletes are the be all and end all, except if we never created the object to begin with
// in this case we don't need to delete anything
if db
.crdt_operation()
.count(vec![
crdt_operation::model::equals(i32::from(model)),
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
])
.exec()
.await?
== 0
{
// This means that in the other device this entry was created and deleted, before this
// device here could even take notice of it. So we don't need to do anything here.
return Ok(());
}
let op = CRDTOperation {
device_pub_id: device_pub_id.into(),
model_id: model,
record_id,
timestamp: delete_op.timestamp,
data: CRDTOperationData::Delete,
};
let _lock_guard = sync_lock.lock().await;
db._transaction()
.with_timeout(30 * 10000)
.with_max_wait(30 * 10000)
.run(|db| async move {
ModelSyncData::from_op(op.clone())?.exec(&db).await?;
write_crdt_op_to_db(&op, &db).await
})
.await
}
fn update_clock(clock: &HLC, latest_timestamp: NTP64, device_pub_id: &DevicePubId) {
// first, we update the HLC's timestamp with the incoming one.
// this involves a drift check + sets the last time of the clock
clock
.update_with_timestamp(&Timestamp::new(
latest_timestamp,
uhlc::ID::from(
NonZeroU128::new(Uuid::from(device_pub_id).to_u128_le()).expect("Non zero id"),
),
))
.expect("timestamp has too much drift!");
}
async fn update_timestamp_per_device(
timestamp_per_device: &TimestampPerDevice,
device_pub_id: DevicePubId,
latest_timestamp: NTP64,
) {
// read the timestamp for the operation's device, or insert one if it doesn't exist
let current_last_timestamp = timestamp_per_device
.read()
.await
.get(&device_pub_id)
.copied();
// update the stored timestamp for this device - will be derived from the crdt operations table on restart
let new_ts = NTP64::max(current_last_timestamp.unwrap_or_default(), latest_timestamp);
timestamp_per_device
.write()
.await
.insert(device_pub_id, new_ts);
}

View File

@@ -27,45 +27,39 @@
#![forbid(deprecated_in_future)]
#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)]
use sd_prisma::prisma::{crdt_operation, instance, PrismaClient};
use sd_sync::CRDTOperation;
use std::{
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
use sd_prisma::{
prisma::{cloud_crdt_operation, crdt_operation},
prisma_sync,
};
use sd_utils::uuid_to_bytes;
use tokio::sync::{Notify, RwLock};
use uuid::Uuid;
use std::{collections::HashMap, sync::Arc};
use tokio::{sync::RwLock, task::JoinError};
mod actor;
pub mod backfill;
mod db_operation;
pub mod ingest;
mod ingest_utils;
mod manager;
pub use ingest::*;
pub use manager::*;
pub use db_operation::{from_cloud_crdt_ops, from_crdt_ops, write_crdt_op_to_db};
pub use manager::Manager as SyncManager;
pub use uhlc::NTP64;
#[derive(Clone, Debug)]
pub enum SyncMessage {
pub enum SyncEvent {
Ingested,
Created,
}
pub type Timestamps = Arc<RwLock<HashMap<Uuid, NTP64>>>;
pub use sd_core_prisma_helpers::DevicePubId;
pub use sd_sync::{
CRDTOperation, CompressedCRDTOperation, CompressedCRDTOperationsPerModel,
CompressedCRDTOperationsPerModelPerDevice, ModelId, OperationFactory, RecordId, RelationSyncId,
RelationSyncModel, SharedSyncModel, SyncId, SyncModel,
};
pub struct SharedState {
pub db: Arc<PrismaClient>,
pub emit_messages_flag: Arc<AtomicBool>,
pub instance: Uuid,
pub timestamps: Timestamps,
pub clock: uhlc::HLC,
pub active: AtomicBool,
pub active_notify: Notify,
pub actors: Arc<sd_actors::Actors>,
}
pub type TimestampPerDevice = Arc<RwLock<HashMap<DevicePubId, NTP64>>>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
@@ -75,8 +69,16 @@ pub enum Error {
Deserialization(#[from] rmp_serde::decode::Error),
#[error("database error: {0}")]
Database(#[from] prisma_client_rust::QueryError),
#[error("PrismaSync error: {0}")]
PrismaSync(#[from] prisma_sync::Error),
#[error("invalid model id: {0}")]
InvalidModelId(u16),
InvalidModelId(ModelId),
#[error("tried to write an empty operations list")]
EmptyOperations,
#[error("device not found: {0}")]
DeviceNotFound(DevicePubId),
#[error("processes crdt task panicked")]
ProcessCrdtPanic(JoinError),
}
impl From<Error> for rspc::Error {
@@ -105,19 +107,16 @@ pub fn crdt_op_db(op: &CRDTOperation) -> Result<crdt_operation::Create, Error> {
op.timestamp.as_u64() as i64
}
},
instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()),
device_pub_id: uuid_to_bytes(&op.device_pub_id),
kind: op.kind().to_string(),
data: rmp_serde::to_vec(&op.data)?,
model: i32::from(op.model),
model: i32::from(op.model_id),
record_id: rmp_serde::to_vec(&op.record_id)?,
_params: vec![],
})
}
pub fn crdt_op_unchecked_db(
op: &CRDTOperation,
instance_id: i32,
) -> Result<crdt_operation::CreateUnchecked, Error> {
pub fn crdt_op_unchecked_db(op: &CRDTOperation) -> Result<crdt_operation::CreateUnchecked, Error> {
Ok(crdt_operation::CreateUnchecked {
timestamp: {
#[allow(clippy::cast_possible_wrap)]
@@ -126,10 +125,28 @@ pub fn crdt_op_unchecked_db(
op.timestamp.as_u64() as i64
}
},
instance_id,
device_pub_id: uuid_to_bytes(&op.device_pub_id),
kind: op.kind().to_string(),
data: rmp_serde::to_vec(&op.data)?,
model: i32::from(op.model),
model: i32::from(op.model_id),
record_id: rmp_serde::to_vec(&op.record_id)?,
_params: vec![],
})
}
pub fn cloud_crdt_op_db(op: &CRDTOperation) -> Result<cloud_crdt_operation::Create, Error> {
Ok(cloud_crdt_operation::Create {
timestamp: {
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we had to store using i64 due to SQLite limitations
{
op.timestamp.as_u64() as i64
}
},
device_pub_id: uuid_to_bytes(&op.device_pub_id),
kind: op.data.as_kind().to_string(),
data: rmp_serde::to_vec(&op.data)?,
model: i32::from(op.model_id),
record_id: rmp_serde::to_vec(&op.record_id)?,
_params: vec![],
})

View File

@@ -1,35 +1,60 @@
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient, SortOrder};
use sd_sync::{CRDTOperation, OperationFactory};
use sd_utils::{from_bytes_to_uuid, uuid_to_bytes};
use tracing::warn;
use sd_core_prisma_helpers::DevicePubId;
use sd_prisma::{
prisma::{cloud_crdt_operation, crdt_operation, device, PrismaClient, SortOrder},
prisma_sync,
};
use sd_sync::{
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationFactory, RecordId,
};
use sd_utils::timestamp_to_datetime;
use std::{
cmp, fmt,
collections::{hash_map::Entry, BTreeMap, HashMap},
fmt, mem,
num::NonZeroU128,
ops::Deref,
sync::{
atomic::{self, AtomicBool},
Arc,
},
time::{Duration, SystemTime},
};
use prisma_client_rust::{and, operator::or};
use tokio::sync::{broadcast, Mutex, Notify, RwLock};
use async_stream::stream;
use futures::{stream::FuturesUnordered, Stream, TryStreamExt};
use futures_concurrency::future::TryJoin;
use itertools::Itertools;
use tokio::{
spawn,
sync::{broadcast, Mutex, Notify, RwLock},
time::Instant,
};
use tracing::{debug, instrument, warn};
use uhlc::{HLCBuilder, HLC};
use uuid::Uuid;
use super::{
crdt_op_db,
db_operation::{cloud_crdt_with_instance, crdt_with_instance},
ingest, Error, SharedState, SyncMessage, NTP64,
db_operation::{from_cloud_crdt_ops, from_crdt_ops},
ingest_utils::{bulk_ingest_create_only_ops, process_crdt_operations},
Error, SyncEvent, TimestampPerDevice, NTP64,
};
const INGESTION_BATCH_SIZE: i64 = 10_000;
/// Wrapper that spawns the ingest actor and provides utilities for reading and writing sync operations.
#[derive(Clone)]
pub struct Manager {
pub tx: broadcast::Sender<SyncMessage>,
pub ingest: ingest::Handler,
pub shared: Arc<SharedState>,
pub timestamp_lock: Mutex<()>,
pub tx: broadcast::Sender<SyncEvent>,
pub db: Arc<PrismaClient>,
pub emit_messages_flag: Arc<AtomicBool>,
pub device_pub_id: DevicePubId,
pub timestamp_per_device: TimestampPerDevice,
pub clock: Arc<HLC>,
pub active: Arc<AtomicBool>,
pub active_notify: Arc<Notify>,
pub(crate) sync_lock: Arc<Mutex<()>>,
pub(crate) available_parallelism: usize,
}
impl fmt::Debug for Manager {
@@ -38,29 +63,21 @@ impl fmt::Debug for Manager {
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
pub struct GetOpsArgs {
pub clocks: Vec<(Uuid, NTP64)>,
pub count: u32,
}
impl Manager {
/// Creates a new manager that can be used to read and write CRDT operations.
/// Sync messages are received on the returned [`broadcast::Receiver<SyncMessage>`].
pub async fn new(
db: Arc<PrismaClient>,
current_instance_uuid: Uuid,
current_device_pub_id: &DevicePubId,
emit_messages_flag: Arc<AtomicBool>,
actors: Arc<sd_actors::Actors>,
) -> Result<(Self, broadcast::Receiver<SyncMessage>), Error> {
let existing_instances = db.instance().find_many(vec![]).exec().await?;
) -> Result<(Self, broadcast::Receiver<SyncEvent>), Error> {
let existing_devices = db.device().find_many(vec![]).exec().await?;
Self::with_existing_instances(
Self::with_existing_devices(
db,
current_instance_uuid,
current_device_pub_id,
emit_messages_flag,
&existing_instances,
actors,
&existing_devices,
)
.await
}
@@ -69,33 +86,34 @@ impl Manager {
/// Sync messages are received on the returned [`broadcast::Receiver<SyncMessage>`].
///
/// # Panics
/// Panics if the `current_instance_id` UUID is zeroed.
pub async fn with_existing_instances(
/// Panics if the `current_device_pub_id` UUID is zeroed, which will never happen as we use `UUIDv7` for the
/// device pub id. As this version have a timestamp part, instead of being totally random. So the only
/// possible way to get zero from a `UUIDv7` is to go back in time to 1970
pub async fn with_existing_devices(
db: Arc<PrismaClient>,
current_instance_uuid: Uuid,
current_device_pub_id: &DevicePubId,
emit_messages_flag: Arc<AtomicBool>,
existing_instances: &[instance::Data],
actors: Arc<sd_actors::Actors>,
) -> Result<(Self, broadcast::Receiver<SyncMessage>), Error> {
let timestamps = db
existing_devices: &[device::Data],
) -> Result<(Self, broadcast::Receiver<SyncEvent>), Error> {
let latest_timestamp_per_device = db
._batch(
existing_instances
existing_devices
.iter()
.map(|i| {
.map(|device| {
db.crdt_operation()
.find_first(vec![crdt_operation::instance::is(vec![
instance::id::equals(i.id),
])])
.find_first(vec![crdt_operation::device_pub_id::equals(
device.pub_id.clone(),
)])
.order_by(crdt_operation::timestamp::order(SortOrder::Desc))
})
.collect::<Vec<_>>(),
)
.await?
.into_iter()
.zip(existing_instances)
.map(|(op, i)| {
.zip(existing_devices)
.map(|(op, device)| {
(
from_bytes_to_uuid(&i.pub_id),
DevicePubId::from(&device.pub_id),
#[allow(clippy::cast_sign_loss)]
// SAFETY: we had to store using i64 due to SQLite limitations
NTP64(op.map(|o| o.timestamp).unwrap_or_default() as u64),
@@ -105,54 +123,303 @@ impl Manager {
let (tx, rx) = broadcast::channel(64);
let clock = HLCBuilder::new()
.with_id(uhlc::ID::from(
NonZeroU128::new(current_instance_uuid.to_u128_le()).expect("Non zero id"),
))
.build();
let shared = Arc::new(SharedState {
db,
instance: current_instance_uuid,
clock,
timestamps: Arc::new(RwLock::new(timestamps)),
emit_messages_flag,
active: AtomicBool::default(),
active_notify: Notify::default(),
actors,
});
let ingest = ingest::Actor::declare(shared.clone()).await;
Ok((
Self {
tx,
ingest,
shared,
timestamp_lock: Mutex::default(),
db,
device_pub_id: current_device_pub_id.clone(),
clock: Arc::new(
HLCBuilder::new()
.with_id(uhlc::ID::from(
NonZeroU128::new(Uuid::from(current_device_pub_id).to_u128_le())
.expect("Non zero id"),
))
.build(),
),
timestamp_per_device: Arc::new(RwLock::new(latest_timestamp_per_device)),
emit_messages_flag,
active: Arc::default(),
active_notify: Arc::default(),
sync_lock: Arc::new(Mutex::default()),
available_parallelism: std::thread::available_parallelism()
.map_or(1, std::num::NonZero::get),
},
rx,
))
}
pub fn subscribe(&self) -> broadcast::Receiver<SyncMessage> {
async fn fetch_cloud_crdt_ops(
&self,
model_id: ModelId,
batch_size: i64,
) -> Result<(Vec<cloud_crdt_operation::id::Type>, Vec<CRDTOperation>), Error> {
self.db
.cloud_crdt_operation()
.find_many(vec![cloud_crdt_operation::model::equals(i32::from(
model_id,
))])
.take(batch_size)
.order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
.exec()
.await?
.into_iter()
.map(from_cloud_crdt_ops)
.collect::<Result<(Vec<_>, Vec<_>), _>>()
}
#[instrument(skip(self))]
async fn ingest_by_model(&self, model_id: ModelId) -> Result<usize, Error> {
let mut total_count = 0;
let mut buckets = (0..self.available_parallelism)
.map(|_| FuturesUnordered::new())
.collect::<Vec<_>>();
let mut total_fetch_time = Duration::ZERO;
let mut total_compression_time = Duration::ZERO;
let mut total_work_distribution_time = Duration::ZERO;
let mut total_process_time = Duration::ZERO;
loop {
let fetching_start = Instant::now();
let (ops_ids, ops) = self
.fetch_cloud_crdt_ops(model_id, INGESTION_BATCH_SIZE)
.await?;
if ops_ids.is_empty() {
break;
}
total_fetch_time += fetching_start.elapsed();
let messages_count = ops.len();
debug!(
messages_count,
first_message = ?ops
.first()
.map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)),
last_message = ?ops
.last()
.map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)),
"Messages by model to ingest",
);
let compression_start = Instant::now();
let mut compressed_map =
BTreeMap::<Uuid, HashMap<Vec<u8>, (RecordId, Vec<CompressedCRDTOperation>)>>::new();
for CRDTOperation {
device_pub_id,
timestamp,
model_id: _, // Ignoring model_id as we know it already
record_id,
data,
} in ops
{
let records = compressed_map.entry(device_pub_id).or_default();
// Can't use RecordId as a key because rmpv::Value doesn't implement Hash + Eq.
// So we use it's serialized bytes as a key.
let record_id_bytes =
rmp_serde::to_vec_named(&record_id).expect("already serialized to Value");
match records.entry(record_id_bytes) {
Entry::Occupied(mut entry) => {
entry
.get_mut()
.1
.push(CompressedCRDTOperation { timestamp, data });
}
Entry::Vacant(entry) => {
entry
.insert((record_id, vec![CompressedCRDTOperation { timestamp, data }]));
}
}
}
// Now that we separated all operations by their record_ids, we can do an optimization
// to process all records that only posses a single create operation, batching them together
let mut create_only_ops: BTreeMap<Uuid, Vec<(RecordId, CompressedCRDTOperation)>> =
BTreeMap::new();
for (device_pub_id, records) in &mut compressed_map {
for (record_id, ops) in records.values_mut() {
if ops.len() == 1 && matches!(ops[0].data, CRDTOperationData::Create(_)) {
create_only_ops
.entry(*device_pub_id)
.or_default()
.push((mem::replace(record_id, rmpv::Value::Nil), ops.remove(0)));
}
}
}
total_count += bulk_process_of_create_only_ops(
self.available_parallelism,
Arc::clone(&self.clock),
Arc::clone(&self.timestamp_per_device),
Arc::clone(&self.db),
Arc::clone(&self.sync_lock),
model_id,
create_only_ops,
)
.await?;
total_compression_time += compression_start.elapsed();
let work_distribution_start = Instant::now();
compressed_map
.into_iter()
.flat_map(|(device_pub_id, records)| {
records.into_values().filter_map(move |(record_id, ops)| {
if record_id.is_nil() {
return None;
}
// We can process each record in parallel as they are independent
let clock = Arc::clone(&self.clock);
let timestamp_per_device = Arc::clone(&self.timestamp_per_device);
let db = Arc::clone(&self.db);
let device_pub_id = device_pub_id.into();
let sync_lock = Arc::clone(&self.sync_lock);
Some(async move {
let count = ops.len();
process_crdt_operations(
&clock,
&timestamp_per_device,
sync_lock,
&db,
device_pub_id,
model_id,
(record_id, ops),
)
.await
.map(|()| count)
})
})
})
.enumerate()
.for_each(|(idx, fut)| buckets[idx % self.available_parallelism].push(fut));
total_work_distribution_time += work_distribution_start.elapsed();
let processing_start = Instant::now();
let handles = buckets
.iter_mut()
.enumerate()
.filter(|(_idx, bucket)| !bucket.is_empty())
.map(|(idx, bucket)| {
let mut bucket = mem::take(bucket);
spawn(async move {
let mut ops_count = 0;
let processing_start = Instant::now();
while let Some(count) = bucket.try_next().await? {
ops_count += count;
}
debug!(
"Ingested {ops_count} operations in {:?}",
processing_start.elapsed()
);
Ok::<_, Error>((ops_count, idx, bucket))
})
})
.collect::<Vec<_>>();
let results = handles.try_join().await.map_err(Error::ProcessCrdtPanic)?;
total_process_time += processing_start.elapsed();
for res in results {
let (count, idx, bucket) = res?;
buckets[idx] = bucket;
total_count += count;
}
self.db
.cloud_crdt_operation()
.delete_many(vec![cloud_crdt_operation::id::in_vec(ops_ids)])
.exec()
.await?;
}
debug!(
total_count,
?total_fetch_time,
?total_compression_time,
?total_work_distribution_time,
?total_process_time,
"Ingested all operations of this model"
);
Ok(total_count)
}
pub async fn ingest_ops(&self) -> Result<usize, Error> {
let mut total_count = 0;
// WARN: this order here exists because sync messages MUST be processed in this exact order
// due to relationship dependencies between these tables.
total_count += self.ingest_by_model(prisma_sync::device::MODEL_ID).await?;
total_count += [
self.ingest_by_model(prisma_sync::storage_statistics::MODEL_ID),
self.ingest_by_model(prisma_sync::tag::MODEL_ID),
self.ingest_by_model(prisma_sync::location::MODEL_ID),
self.ingest_by_model(prisma_sync::object::MODEL_ID),
self.ingest_by_model(prisma_sync::label::MODEL_ID),
]
.try_join()
.await?
.into_iter()
.sum::<usize>();
total_count += [
self.ingest_by_model(prisma_sync::exif_data::MODEL_ID),
self.ingest_by_model(prisma_sync::file_path::MODEL_ID),
self.ingest_by_model(prisma_sync::tag_on_object::MODEL_ID),
self.ingest_by_model(prisma_sync::label_on_object::MODEL_ID),
]
.try_join()
.await?
.into_iter()
.sum::<usize>();
if self.tx.send(SyncEvent::Ingested).is_err() {
warn!("failed to send ingested message on `ingest_ops`");
}
Ok(total_count)
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<SyncEvent> {
self.tx.subscribe()
}
pub async fn write_ops<'item, Q>(
&self,
tx: &PrismaClient,
(mut ops, queries): (Vec<CRDTOperation>, Q),
(ops, queries): (Vec<CRDTOperation>, Q),
) -> Result<Q::ReturnValue, Error>
where
Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send,
{
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
let lock = self.timestamp_lock.lock().await;
if ops.is_empty() {
return Err(Error::EmptyOperations);
}
for op in &mut ops {
op.timestamp = *self.get_clock().new_timestamp().get_time();
}
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
let lock_guard = self.sync_lock.lock().await;
let (res, _) = tx
._batch((
@@ -164,18 +431,17 @@ impl Manager {
.await?;
if let Some(last) = ops.last() {
self.shared
.timestamps
self.timestamp_per_device
.write()
.await
.insert(self.instance, last.timestamp);
.insert(self.device_pub_id.clone(), last.timestamp);
}
if self.tx.send(SyncMessage::Created).is_err() {
if self.tx.send(SyncEvent::Created).is_err() {
warn!("failed to send created message on `write_ops`");
}
drop(lock);
drop(lock_guard);
res
} else {
@@ -188,160 +454,289 @@ impl Manager {
pub async fn write_op<'item, Q>(
&self,
tx: &PrismaClient,
mut op: CRDTOperation,
op: CRDTOperation,
query: Q,
) -> Result<Q::ReturnValue, Error>
where
Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send,
{
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
let lock = self.timestamp_lock.lock().await;
op.timestamp = *self.get_clock().new_timestamp().get_time();
let lock_guard = self.sync_lock.lock().await;
let ret = tx._batch((crdt_op_db(&op)?.to_query(tx), query)).await?.1;
if self.tx.send(SyncMessage::Created).is_err() {
if self.tx.send(SyncEvent::Created).is_err() {
warn!("failed to send created message on `write_op`");
}
drop(lock);
drop(lock_guard);
ret
} else {
tx._batch(vec![query]).await?.remove(0)
};
self.shared
.timestamps
self.timestamp_per_device
.write()
.await
.insert(self.instance, op.timestamp);
.insert(self.device_pub_id.clone(), op.timestamp);
Ok(ret)
}
pub async fn get_instance_ops(
&self,
count: u32,
instance_uuid: Uuid,
timestamp: NTP64,
) -> Result<Vec<CRDTOperation>, Error> {
self.db
.crdt_operation()
.find_many(vec![
crdt_operation::instance::is(vec![instance::pub_id::equals(uuid_to_bytes(
&instance_uuid,
))]),
#[allow(clippy::cast_possible_wrap)]
crdt_operation::timestamp::gt(timestamp.as_u64() as i64),
])
.take(i64::from(count))
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
.include(crdt_with_instance::include())
.exec()
.await?
// pub async fn get_device_ops(
// &self,
// count: u32,
// device_pub_id: DevicePubId,
// timestamp: NTP64,
// ) -> Result<Vec<CRDTOperation>, Error> {
// self.db
// .crdt_operation()
// .find_many(vec![
// crdt_operation::device_pub_id::equals(device_pub_id.into()),
// #[allow(clippy::cast_possible_wrap)]
// crdt_operation::timestamp::gt(timestamp.as_u64() as i64),
// ])
// .take(i64::from(count))
// .order_by(crdt_operation::timestamp::order(SortOrder::Asc))
// .exec()
// .await?
// .into_iter()
// .map(from_crdt_ops)
// .collect()
// }
pub fn stream_device_ops<'a>(
&'a self,
device_pub_id: &'a DevicePubId,
chunk_size: u32,
initial_timestamp: NTP64,
) -> impl Stream<Item = Result<Vec<CRDTOperation>, Error>> + Send + 'a {
stream! {
let mut current_initial_timestamp = initial_timestamp;
loop {
match self.db.crdt_operation()
.find_many(vec![
crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
#[allow(clippy::cast_possible_wrap)]
crdt_operation::timestamp::gt(current_initial_timestamp.as_u64() as i64),
])
.take(i64::from(chunk_size))
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
.exec()
.await
{
Ok(ops) if ops.is_empty() => break,
Ok(ops) => match ops
.into_iter()
.map(from_crdt_ops)
.collect::<Result<Vec<_>, _>>()
{
Ok(ops) => {
debug!(
start_datetime = ?ops
.first()
.map(|op| timestamp_to_datetime(op.timestamp)),
end_datetime = ?ops
.last()
.map(|op| timestamp_to_datetime(op.timestamp)),
count = ops.len(),
"Streaming crdt ops",
);
if let Some(last_op) = ops.last() {
current_initial_timestamp = last_op.timestamp;
}
yield Ok(ops);
}
Err(e) => return yield Err(e),
}
Err(e) => return yield Err(e.into())
}
}
}
}
// pub async fn get_ops(
// &self,
// count: u32,
// timestamp_per_device: Vec<(DevicePubId, NTP64)>,
// ) -> Result<Vec<CRDTOperation>, Error> {
// let mut ops = self
// .db
// .crdt_operation()
// .find_many(vec![or(timestamp_per_device
// .iter()
// .map(|(device_pub_id, timestamp)| {
// and![
// crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
// crdt_operation::timestamp::gt({
// #[allow(clippy::cast_possible_wrap)]
// // SAFETY: we had to store using i64 due to SQLite limitations
// {
// timestamp.as_u64() as i64
// }
// })
// ]
// })
// .chain([crdt_operation::device_pub_id::not_in_vec(
// timestamp_per_device
// .iter()
// .map(|(device_pub_id, _)| device_pub_id.to_db())
// .collect(),
// )])
// .collect())])
// .take(i64::from(count))
// .order_by(crdt_operation::timestamp::order(SortOrder::Asc))
// .exec()
// .await?;
// ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) {
// cmp::Ordering::Equal => {
// from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id))
// }
// o => o,
// });
// ops.into_iter()
// .take(count as usize)
// .map(from_crdt_ops)
// .collect()
// }
// pub async fn get_cloud_ops(
// &self,
// count: u32,
// timestamp_per_device: Vec<(DevicePubId, NTP64)>,
// ) -> Result<Vec<(cloud_crdt_operation::id::Type, CRDTOperation)>, Error> {
// let mut ops = self
// .db
// .cloud_crdt_operation()
// .find_many(vec![or(timestamp_per_device
// .iter()
// .map(|(device_pub_id, timestamp)| {
// and![
// cloud_crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
// cloud_crdt_operation::timestamp::gt({
// #[allow(clippy::cast_possible_wrap)]
// // SAFETY: we had to store using i64 due to SQLite limitations
// {
// timestamp.as_u64() as i64
// }
// })
// ]
// })
// .chain([cloud_crdt_operation::device_pub_id::not_in_vec(
// timestamp_per_device
// .iter()
// .map(|(device_pub_id, _)| device_pub_id.to_db())
// .collect(),
// )])
// .collect())])
// .take(i64::from(count))
// .order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
// .exec()
// .await?;
// ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) {
// cmp::Ordering::Equal => {
// from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id))
// }
// o => o,
// });
// ops.into_iter()
// .take(count as usize)
// .map(from_cloud_crdt_ops)
// .collect()
// }
}
async fn bulk_process_of_create_only_ops(
available_parallelism: usize,
clock: Arc<HLC>,
timestamp_per_device: TimestampPerDevice,
db: Arc<PrismaClient>,
sync_lock: Arc<Mutex<()>>,
model_id: ModelId,
create_only_ops: BTreeMap<Uuid, Vec<(RecordId, CompressedCRDTOperation)>>,
) -> Result<usize, Error> {
let buckets = (0..available_parallelism)
.map(|_| FuturesUnordered::new())
.collect::<Vec<_>>();
let mut bucket_idx = 0;
for (device_pub_id, records) in create_only_ops {
records
.into_iter()
.map(crdt_with_instance::Data::into_operation)
.collect()
.chunks(100)
.into_iter()
.for_each(|chunk| {
let ops = chunk.collect::<Vec<_>>();
buckets[bucket_idx % available_parallelism].push({
let clock = Arc::clone(&clock);
let timestamp_per_device = Arc::clone(&timestamp_per_device);
let db = Arc::clone(&db);
let device_pub_id = device_pub_id.into();
let sync_lock = Arc::clone(&sync_lock);
async move {
let count = ops.len();
bulk_ingest_create_only_ops(
&clock,
&timestamp_per_device,
&db,
device_pub_id,
model_id,
ops,
sync_lock,
)
.await
.map(|()| count)
}
});
bucket_idx += 1;
});
}
pub async fn get_ops(&self, args: GetOpsArgs) -> Result<Vec<CRDTOperation>, Error> {
let mut ops = self
.db
.crdt_operation()
.find_many(vec![or(args
.clocks
.iter()
.map(|(instance_id, timestamp)| {
and![
crdt_operation::instance::is(vec![instance::pub_id::equals(
uuid_to_bytes(instance_id)
)]),
crdt_operation::timestamp::gt({
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we had to store using i64 due to SQLite limitations
{
timestamp.as_u64() as i64
}
})
]
})
.chain([crdt_operation::instance::is_not(vec![
instance::pub_id::in_vec(
args.clocks
.iter()
.map(|(instance_id, _)| uuid_to_bytes(instance_id))
.collect(),
),
])])
.collect())])
.take(i64::from(args.count))
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
.include(crdt_with_instance::include())
.exec()
.await?;
let handles = buckets
.into_iter()
.map(|mut bucket| {
spawn(async move {
let mut total_count = 0;
ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) {
cmp::Ordering::Equal => a.instance().cmp(&b.instance()),
o => o,
});
let process_creates_batch_start = Instant::now();
ops.into_iter()
.take(args.count as usize)
.map(crdt_with_instance::Data::into_operation)
.collect()
}
while let Some(count) = bucket.try_next().await? {
total_count += count;
}
pub async fn get_cloud_ops(
&self,
args: GetOpsArgs,
) -> Result<Vec<(i32, CRDTOperation)>, Error> {
let mut ops = self
.db
.cloud_crdt_operation()
.find_many(vec![or(args
.clocks
.iter()
.map(|(instance_id, timestamp)| {
and![
cloud_crdt_operation::instance::is(vec![instance::pub_id::equals(
uuid_to_bytes(instance_id)
)]),
cloud_crdt_operation::timestamp::gt({
#[allow(clippy::cast_possible_wrap)]
// SAFETY: we had to store using i64 due to SQLite limitations
{
timestamp.as_u64() as i64
}
})
]
})
.chain([cloud_crdt_operation::instance::is_not(vec![
instance::pub_id::in_vec(
args.clocks
.iter()
.map(|(instance_id, _)| uuid_to_bytes(instance_id))
.collect(),
),
])])
.collect())])
.take(i64::from(args.count))
.order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
.include(cloud_crdt_with_instance::include())
.exec()
.await?;
debug!(
"Processed {total_count} creates in {:?}",
process_creates_batch_start.elapsed()
);
ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) {
cmp::Ordering::Equal => a.instance().cmp(&b.instance()),
o => o,
});
Ok::<_, Error>(total_count)
})
})
.collect::<Vec<_>>();
ops.into_iter()
.take(args.count as usize)
.map(cloud_crdt_with_instance::Data::into_operation)
.collect()
}
Ok(handles
.try_join()
.await
.map_err(Error::ProcessCrdtPanic)?
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.sum())
}
impl OperationFactory for Manager {
@@ -349,15 +744,7 @@ impl OperationFactory for Manager {
&self.clock
}
fn get_instance(&self) -> Uuid {
self.instance
}
}
impl Deref for Manager {
type Target = SharedState;
fn deref(&self) -> &Self::Target {
&self.shared
fn get_device_pub_id(&self) -> sd_sync::DevicePubId {
sd_sync::DevicePubId::from(&self.device_pub_id)
}
}

View File

@@ -1,247 +0,0 @@
mod mock_instance;
use sd_core_sync::*;
use sd_prisma::{prisma::location, prisma_sync};
use sd_sync::*;
use sd_utils::{msgpack, uuid_to_bytes};
use mock_instance::Instance;
use tracing::info;
use tracing_test::traced_test;
use uuid::Uuid;
const MOCK_LOCATION_NAME: &str = "Location 0";
const MOCK_LOCATION_PATH: &str = "/User/Anon/Documents";
async fn write_test_location(instance: &Instance) -> location::Data {
let location_pub_id = Uuid::new_v4();
let location = instance
.sync
.write_ops(&instance.db, {
let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [
sync_db_entry!(MOCK_LOCATION_NAME, location::name),
sync_db_entry!(MOCK_LOCATION_PATH, location::path),
]
.into_iter()
.unzip();
(
instance.sync.shared_create(
prisma_sync::location::SyncId {
pub_id: uuid_to_bytes(&location_pub_id),
},
sync_ops,
),
instance
.db
.location()
.create(uuid_to_bytes(&location_pub_id), db_ops),
)
})
.await
.expect("failed to create mock location");
instance
.sync
.write_ops(&instance.db, {
let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [
sync_db_entry!(1024, location::total_capacity),
sync_db_entry!(512, location::available_capacity),
]
.into_iter()
.unzip();
(
sync_ops
.into_iter()
.map(|(k, v)| {
instance.sync.shared_update(
prisma_sync::location::SyncId {
pub_id: uuid_to_bytes(&location_pub_id),
},
k,
v,
)
})
.collect::<Vec<_>>(),
instance
.db
.location()
.update(location::id::equals(location.id), db_ops),
)
})
.await
.expect("failed to create mock location");
location
}
#[tokio::test]
#[traced_test]
async fn writes_operations_and_rows_together() -> Result<(), Box<dyn std::error::Error>> {
let instance = Instance::new(Uuid::new_v4()).await;
write_test_location(&instance).await;
let operations = instance
.db
.crdt_operation()
.find_many(vec![])
.exec()
.await?;
// 1 create, 2 update
assert_eq!(operations.len(), 3);
assert_eq!(operations[0].model, prisma_sync::location::MODEL_ID as i32);
let out = instance
.sync
.get_ops(GetOpsArgs {
clocks: vec![],
count: 100,
})
.await?;
assert_eq!(out.len(), 3);
let locations = instance.db.location().find_many(vec![]).exec().await?;
assert_eq!(locations.len(), 1);
let location = locations.first().unwrap();
assert_eq!(location.name.as_deref(), Some(MOCK_LOCATION_NAME));
assert_eq!(location.path.as_deref(), Some(MOCK_LOCATION_PATH));
Ok(())
}
#[tokio::test]
#[traced_test]
async fn operations_send_and_ingest() -> Result<(), Box<dyn std::error::Error>> {
let instance1 = Instance::new(Uuid::new_v4()).await;
let instance2 = Instance::new(Uuid::new_v4()).await;
let mut instance2_sync_rx = instance2.sync_rx.resubscribe();
info!("Created instances!");
Instance::pair(&instance1, &instance2).await;
info!("Paired instances!");
write_test_location(&instance1).await;
info!("Created mock location!");
assert!(matches!(
instance2_sync_rx.recv().await?,
SyncMessage::Ingested
));
let out = instance2
.sync
.get_ops(GetOpsArgs {
clocks: vec![],
count: 100,
})
.await?;
assert_locations_equality(
&instance1.db.location().find_many(vec![]).exec().await?[0],
&instance2.db.location().find_many(vec![]).exec().await?[0],
);
assert_eq!(out.len(), 3);
instance1.teardown().await;
instance2.teardown().await;
Ok(())
}
#[tokio::test]
async fn no_update_after_delete() -> Result<(), Box<dyn std::error::Error>> {
let instance1 = Instance::new(Uuid::new_v4()).await;
let instance2 = Instance::new(Uuid::new_v4()).await;
let mut instance2_sync_rx = instance2.sync_rx.resubscribe();
Instance::pair(&instance1, &instance2).await;
let location = write_test_location(&instance1).await;
assert!(matches!(
instance2_sync_rx.recv().await?,
SyncMessage::Ingested
));
instance2
.sync
.write_op(
&instance2.db,
instance2.sync.shared_delete(prisma_sync::location::SyncId {
pub_id: location.pub_id.clone(),
}),
instance2.db.location().delete_many(vec![]),
)
.await?;
assert!(matches!(
instance1.sync_rx.resubscribe().recv().await?,
SyncMessage::Ingested
));
instance1
.sync
.write_op(
&instance1.db,
instance1.sync.shared_update(
prisma_sync::location::SyncId {
pub_id: location.pub_id.clone(),
},
"name",
msgpack!("New Location"),
),
instance1.db.location().find_many(vec![]),
)
.await?;
// one spare update operation that actually gets ignored by instance 2
assert_eq!(instance1.db.crdt_operation().count(vec![]).exec().await?, 5);
assert_eq!(instance2.db.crdt_operation().count(vec![]).exec().await?, 4);
assert_eq!(instance1.db.location().count(vec![]).exec().await?, 0);
// the whole point of the test - the update (which is ingested as an upsert) should be ignored
assert_eq!(instance2.db.location().count(vec![]).exec().await?, 0);
instance1.teardown().await;
instance2.teardown().await;
Ok(())
}
fn assert_locations_equality(l1: &location::Data, l2: &location::Data) {
assert_eq!(l1.pub_id, l2.pub_id, "pub id");
assert_eq!(l1.name, l2.name, "name");
assert_eq!(l1.path, l2.path, "path");
assert_eq!(l1.total_capacity, l2.total_capacity, "total capacity");
assert_eq!(
l1.available_capacity, l2.available_capacity,
"available capacity"
);
assert_eq!(l1.size_in_bytes, l2.size_in_bytes, "size in bytes");
assert_eq!(l1.is_archived, l2.is_archived, "is archived");
assert_eq!(
l1.generate_preview_media, l2.generate_preview_media,
"generate preview media"
);
assert_eq!(
l1.sync_preview_media, l2.sync_preview_media,
"sync preview media"
);
assert_eq!(l1.hidden, l2.hidden, "hidden");
assert_eq!(l1.date_created, l2.date_created, "date created");
assert_eq!(l1.scan_state, l2.scan_state, "scan state");
assert_eq!(l1.instance_id, l2.instance_id, "instance id");
}

View File

@@ -1,161 +0,0 @@
use sd_core_sync::*;
use sd_prisma::prisma;
use sd_sync::CompressedCRDTOperations;
use sd_utils::uuid_to_bytes;
use std::sync::{atomic::AtomicBool, Arc};
use prisma_client_rust::chrono::Utc;
use tokio::{fs, spawn, sync::broadcast};
use tracing::{info, instrument, warn, Instrument};
use uuid::Uuid;
fn db_path(id: Uuid) -> String {
format!("/tmp/test-{id}.db")
}
#[derive(Clone)]
pub struct Instance {
pub id: Uuid,
pub db: Arc<prisma::PrismaClient>,
pub sync: Arc<sd_core_sync::Manager>,
pub sync_rx: Arc<broadcast::Receiver<SyncMessage>>,
}
impl Instance {
pub async fn new(id: Uuid) -> Arc<Self> {
let url = format!("file:{}", db_path(id));
let db = Arc::new(
prisma::PrismaClient::_builder()
.with_url(url.to_string())
.build()
.await
.unwrap(),
);
db._db_push().await.unwrap();
db.instance()
.create(
uuid_to_bytes(&id),
vec![],
vec![],
Utc::now().into(),
Utc::now().into(),
vec![],
)
.exec()
.await
.unwrap();
let (sync, sync_rx) = sd_core_sync::Manager::new(
Arc::clone(&db),
id,
Arc::new(AtomicBool::new(true)),
Default::default(),
)
.await
.expect("failed to create sync manager");
Arc::new(Self {
id,
db,
sync: Arc::new(sync),
sync_rx: Arc::new(sync_rx),
})
}
pub async fn teardown(&self) {
fs::remove_file(db_path(self.id)).await.unwrap();
}
pub async fn pair(instance1: &Arc<Self>, instance2: &Arc<Self>) {
#[instrument(skip(left, right))]
async fn half(left: &Arc<Instance>, right: &Arc<Instance>, context: &'static str) {
left.db
.instance()
.create(
uuid_to_bytes(&right.id),
vec![],
vec![],
Utc::now().into(),
Utc::now().into(),
vec![],
)
.exec()
.await
.unwrap();
spawn({
let mut sync_rx_left = left.sync_rx.resubscribe();
let right = Arc::clone(right);
async move {
while let Ok(msg) = sync_rx_left.recv().await {
info!(?msg, "sync_rx_left received message");
if matches!(msg, SyncMessage::Created) {
right
.sync
.ingest
.event_tx
.send(ingest::Event::Notification)
.await
.unwrap();
info!("sent notification to instance 2");
}
}
}
.in_current_span()
});
spawn({
let left = Arc::clone(left);
let right = Arc::clone(right);
async move {
while let Ok(msg) = right.sync.ingest.req_rx.recv().await {
info!(?msg, "right instance received request");
match msg {
ingest::Request::Messages { timestamps, tx } => {
let messages = left
.sync
.get_ops(GetOpsArgs {
clocks: timestamps,
count: 100,
})
.await
.unwrap();
let ingest = &right.sync.ingest;
ingest
.event_tx
.send(ingest::Event::Messages(ingest::MessagesEvent {
messages: CompressedCRDTOperations::new(messages),
has_more: false,
instance_id: left.id,
wait_tx: None,
}))
.await
.unwrap();
if tx.send(()).is_err() {
warn!("failed to send ack to instance 1");
}
}
ingest::Request::FinishedIngesting => {
right.sync.tx.send(SyncMessage::Ingested).unwrap();
}
}
}
}
.in_current_span()
});
}
half(instance1, instance2, "instance1 -> instance2").await;
half(instance2, instance1, "instance2 -> instance1").await;
}
}

View File

@@ -0,0 +1,201 @@
/*
Warnings:
- You are about to drop the `node` table. If the table is not empty, all the data it contains will be lost.
- You are about to drop the column `instance_id` on the `cloud_crdt_operation` table. All the data in the column will be lost.
- You are about to drop the column `instance_id` on the `crdt_operation` table. All the data in the column will be lost.
- You are about to drop the column `instance_pub_id` on the `storage_statistics` table. All the data in the column will be lost.
- Added the required column `device_pub_id` to the `cloud_crdt_operation` table without a default value. This is not possible if the table is not empty.
- Added the required column `device_pub_id` to the `crdt_operation` table without a default value. This is not possible if the table is not empty.
*/
-- DropIndex
DROP INDEX "node_pub_id_key";
-- DropTable
PRAGMA foreign_keys=off;
DROP TABLE "node";
PRAGMA foreign_keys=on;
-- CreateTable
CREATE TABLE "device" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"pub_id" BLOB NOT NULL,
"name" TEXT,
"os" INTEGER,
"hardware_model" INTEGER,
"timestamp" BIGINT,
"date_created" DATETIME,
"date_deleted" DATETIME
);
-- RedefineTables
PRAGMA defer_foreign_keys=ON;
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_cloud_crdt_operation" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"timestamp" BIGINT NOT NULL,
"model" INTEGER NOT NULL,
"record_id" BLOB NOT NULL,
"kind" TEXT NOT NULL,
"data" BLOB NOT NULL,
"device_pub_id" BLOB NOT NULL,
CONSTRAINT "cloud_crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_cloud_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "cloud_crdt_operation";
DROP TABLE "cloud_crdt_operation";
ALTER TABLE "new_cloud_crdt_operation" RENAME TO "cloud_crdt_operation";
CREATE INDEX "cloud_crdt_operation_timestamp_idx" ON "cloud_crdt_operation"("timestamp");
CREATE TABLE "new_crdt_operation" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"timestamp" BIGINT NOT NULL,
"model" INTEGER NOT NULL,
"record_id" BLOB NOT NULL,
"kind" TEXT NOT NULL,
"data" BLOB NOT NULL,
"device_pub_id" BLOB NOT NULL,
CONSTRAINT "crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "crdt_operation";
DROP TABLE "crdt_operation";
ALTER TABLE "new_crdt_operation" RENAME TO "crdt_operation";
CREATE INDEX "crdt_operation_timestamp_idx" ON "crdt_operation"("timestamp");
CREATE TABLE "new_exif_data" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"resolution" BLOB,
"media_date" BLOB,
"media_location" BLOB,
"camera_data" BLOB,
"artist" TEXT,
"description" TEXT,
"copyright" TEXT,
"exif_version" TEXT,
"epoch_time" BIGINT,
"object_id" INTEGER NOT NULL,
"device_pub_id" BLOB,
CONSTRAINT "exif_data_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE CASCADE ON UPDATE CASCADE,
CONSTRAINT "exif_data_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_exif_data" ("artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution") SELECT "artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution" FROM "exif_data";
DROP TABLE "exif_data";
ALTER TABLE "new_exif_data" RENAME TO "exif_data";
CREATE UNIQUE INDEX "exif_data_object_id_key" ON "exif_data"("object_id");
CREATE TABLE "new_file_path" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"pub_id" BLOB NOT NULL,
"is_dir" BOOLEAN,
"cas_id" TEXT,
"integrity_checksum" TEXT,
"location_id" INTEGER,
"materialized_path" TEXT,
"name" TEXT,
"extension" TEXT,
"hidden" BOOLEAN,
"size_in_bytes" TEXT,
"size_in_bytes_bytes" BLOB,
"inode" BLOB,
"object_id" INTEGER,
"key_id" INTEGER,
"date_created" DATETIME,
"date_modified" DATETIME,
"date_indexed" DATETIME,
"device_pub_id" BLOB,
CONSTRAINT "file_path_location_id_fkey" FOREIGN KEY ("location_id") REFERENCES "location" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
CONSTRAINT "file_path_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
CONSTRAINT "file_path_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_file_path" ("cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes") SELECT "cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes" FROM "file_path";
DROP TABLE "file_path";
ALTER TABLE "new_file_path" RENAME TO "file_path";
CREATE UNIQUE INDEX "file_path_pub_id_key" ON "file_path"("pub_id");
CREATE INDEX "file_path_location_id_idx" ON "file_path"("location_id");
CREATE INDEX "file_path_location_id_materialized_path_idx" ON "file_path"("location_id", "materialized_path");
CREATE UNIQUE INDEX "file_path_location_id_materialized_path_name_extension_key" ON "file_path"("location_id", "materialized_path", "name", "extension");
CREATE UNIQUE INDEX "file_path_location_id_inode_key" ON "file_path"("location_id", "inode");
CREATE TABLE "new_label_on_object" (
"date_created" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"object_id" INTEGER NOT NULL,
"label_id" INTEGER NOT NULL,
"device_pub_id" BLOB,
PRIMARY KEY ("label_id", "object_id"),
CONSTRAINT "label_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "label_on_object_label_id_fkey" FOREIGN KEY ("label_id") REFERENCES "label" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "label_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_label_on_object" ("date_created", "label_id", "object_id") SELECT "date_created", "label_id", "object_id" FROM "label_on_object";
DROP TABLE "label_on_object";
ALTER TABLE "new_label_on_object" RENAME TO "label_on_object";
CREATE TABLE "new_location" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"pub_id" BLOB NOT NULL,
"name" TEXT,
"path" TEXT,
"total_capacity" INTEGER,
"available_capacity" INTEGER,
"size_in_bytes" BLOB,
"is_archived" BOOLEAN,
"generate_preview_media" BOOLEAN,
"sync_preview_media" BOOLEAN,
"hidden" BOOLEAN,
"date_created" DATETIME,
"scan_state" INTEGER NOT NULL DEFAULT 0,
"device_pub_id" BLOB,
"instance_id" INTEGER,
CONSTRAINT "location_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE,
CONSTRAINT "location_instance_id_fkey" FOREIGN KEY ("instance_id") REFERENCES "instance" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_location" ("available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity") SELECT "available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity" FROM "location";
DROP TABLE "location";
ALTER TABLE "new_location" RENAME TO "location";
CREATE UNIQUE INDEX "location_pub_id_key" ON "location"("pub_id");
CREATE TABLE "new_object" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"pub_id" BLOB NOT NULL,
"kind" INTEGER,
"key_id" INTEGER,
"hidden" BOOLEAN,
"favorite" BOOLEAN,
"important" BOOLEAN,
"note" TEXT,
"date_created" DATETIME,
"date_accessed" DATETIME,
"device_pub_id" BLOB,
CONSTRAINT "object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_object" ("date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id") SELECT "date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id" FROM "object";
DROP TABLE "object";
ALTER TABLE "new_object" RENAME TO "object";
CREATE UNIQUE INDEX "object_pub_id_key" ON "object"("pub_id");
CREATE TABLE "new_storage_statistics" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"pub_id" BLOB NOT NULL,
"total_capacity" BIGINT NOT NULL DEFAULT 0,
"available_capacity" BIGINT NOT NULL DEFAULT 0,
"device_pub_id" BLOB,
CONSTRAINT "storage_statistics_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_storage_statistics" ("available_capacity", "id", "pub_id", "total_capacity") SELECT "available_capacity", "id", "pub_id", "total_capacity" FROM "storage_statistics";
DROP TABLE "storage_statistics";
ALTER TABLE "new_storage_statistics" RENAME TO "storage_statistics";
CREATE UNIQUE INDEX "storage_statistics_pub_id_key" ON "storage_statistics"("pub_id");
CREATE UNIQUE INDEX "storage_statistics_device_pub_id_key" ON "storage_statistics"("device_pub_id");
CREATE TABLE "new_tag_on_object" (
"object_id" INTEGER NOT NULL,
"tag_id" INTEGER NOT NULL,
"date_created" DATETIME,
"device_pub_id" BLOB,
PRIMARY KEY ("tag_id", "object_id"),
CONSTRAINT "tag_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "tag_on_object_tag_id_fkey" FOREIGN KEY ("tag_id") REFERENCES "tag" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "tag_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
);
INSERT INTO "new_tag_on_object" ("date_created", "object_id", "tag_id") SELECT "date_created", "object_id", "tag_id" FROM "tag_on_object";
DROP TABLE "tag_on_object";
ALTER TABLE "new_tag_on_object" RENAME TO "tag_on_object";
PRAGMA foreign_keys=ON;
PRAGMA defer_foreign_keys=OFF;
-- CreateIndex
CREATE UNIQUE INDEX "device_pub_id_key" ON "device"("pub_id");

View File

@@ -28,9 +28,10 @@ model CRDTOperation {
kind String
data Bytes
instance_id Int
instance Instance @relation(fields: [instance_id], references: [id])
// We just need the actual device_pub_id here, but we don't need as an actual relation
device_pub_id Bytes
@@index([timestamp])
@@map("crdt_operation")
}
@@ -46,24 +47,41 @@ model CloudCRDTOperation {
kind String
data Bytes
instance_id Int
instance Instance @relation(fields: [instance_id], references: [id])
// We just need the actual device_pub_id here, but we don't need as an actual relation
device_pub_id Bytes
@@index([timestamp])
@@map("cloud_crdt_operation")
}
/// @deprecated: This model has to exist solely for backwards compatibility.
/// @local
model Node {
id Int @id @default(autoincrement())
pub_id Bytes @unique
name String
// Enum: sd_core::node::Platform
platform Int
date_created DateTime
identity Bytes? // TODO: Change to required field in future
/// Devices are the owner machines connected to this library
/// @shared(id: pub_id, modelId: 12)
model Device {
id Int @id @default(autoincrement())
// uuid v7
pub_id Bytes @unique
name String? // Not actually NULLABLE, but we have to comply with current sync implementation BS
@@map("node")
// Enum: sd_cloud_schema::device::DeviceOS
os Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS
// Enum: sd_cloud_schema::device::HardwareModel
hardware_model Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS
// clock timestamp for sync
timestamp BigInt?
date_created DateTime? // Not actually NULLABLE, but we have to comply with current sync implementation BS
date_deleted DateTime?
StorageStatistics StorageStatistics?
Location Location[]
FilePath FilePath[]
Object Object[]
ExifData ExifData[]
TagOnObject TagOnObject[]
LabelOnObject LabelOnObject[]
@@map("device")
}
// represents a single `.db` file (SQLite DB) that is paired to the current library.
@@ -88,11 +106,7 @@ model Instance {
// clock timestamp for sync
timestamp BigInt?
locations Location[]
CRDTOperation CRDTOperation[]
CloudCRDTOperation CloudCRDTOperation[]
storage_statistics StorageStatistics?
Location Location[]
@@map("instance")
}
@@ -158,6 +172,9 @@ model Location {
scan_state Int @default(0) // Enum: sd_core::location::ScanState
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
// this should just be a local-only cache but it's too much effort to broadcast online locations rn (@brendan)
instance_id Int?
instance Instance? @relation(fields: [instance_id], references: [id], onDelete: SetNull)
@@ -208,6 +225,9 @@ model FilePath {
date_modified DateTime?
date_indexed DateTime?
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
// key Key? @relation(fields: [key_id], references: [id])
@@unique([location_id, materialized_path, name, extension])
@@ -255,6 +275,9 @@ model Object {
exif_data ExifData?
ffmpeg_data FfmpegData?
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
// key Key? @relation(fields: [key_id], references: [id])
@@map("object")
@@ -323,6 +346,9 @@ model ExifData {
object_id Int @unique
object Object @relation(fields: [object_id], references: [id], onDelete: Cascade)
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
@@map("exif_data")
}
@@ -509,6 +535,9 @@ model TagOnObject {
date_created DateTime?
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
@@id([tag_id, object_id])
@@map("tag_on_object")
}
@@ -537,6 +566,9 @@ model LabelOnObject {
label_id Int
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
device_id Int?
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
@@id([label_id, object_id])
@@map("label_on_object")
}
@@ -576,8 +608,8 @@ model StorageStatistics {
total_capacity BigInt @default(0)
available_capacity BigInt @default(0)
instance_pub_id Bytes? @unique
instance Instance? @relation(fields: [instance_pub_id], references: [pub_id], onDelete: Cascade)
device_id Int? @unique
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
@@map("storage_statistics")
}

View File

@@ -1,153 +0,0 @@
use std::time::Duration;
use reqwest::StatusCode;
use rspc::alpha::AlphaRouter;
use serde::{Deserialize, Serialize};
use specta::Type;
use super::{Ctx, R};
pub(crate) fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("loginSession", {
#[derive(Serialize, Type)]
#[specta(inline)]
enum Response {
Start {
user_code: String,
verification_url: String,
verification_url_complete: String,
},
Complete,
Error(String),
}
R.subscription(|node, _: ()| async move {
#[derive(Deserialize, Type)]
struct DeviceAuthorizationResponse {
device_code: String,
user_code: String,
verification_url: String,
verification_uri_complete: String,
}
async_stream::stream! {
let device_type = if cfg!(target_arch = "wasm32") {
"web".to_string()
} else if cfg!(target_os = "ios") || cfg!(target_os = "android") {
"mobile".to_string()
} else {
"desktop".to_string()
};
let auth_response = match match node
.http
.post(format!(
"{}/login/device/code",
&node.env.api_url.lock().await
))
.form(&[("client_id", &node.env.client_id), ("device", &device_type)])
.send()
.await
.map_err(|e| e.to_string())
{
Ok(r) => r.json::<DeviceAuthorizationResponse>().await.map_err(|e| e.to_string()),
Err(e) => {
yield Response::Error(e.to_string());
return
},
} {
Ok(v) => v,
Err(e) => {
yield Response::Error(e.to_string());
return
},
};
yield Response::Start {
user_code: auth_response.user_code.clone(),
verification_url: auth_response.verification_url.clone(),
verification_url_complete: auth_response.verification_uri_complete.clone(),
};
yield loop {
tokio::time::sleep(Duration::from_secs(5)).await;
let token_resp = match node.http
.post(format!("{}/login/oauth/access_token", &node.env.api_url.lock().await))
.form(&[
("grant_type", sd_cloud_api::auth::DEVICE_CODE_URN),
("device_code", &auth_response.device_code),
("client_id", &node.env.client_id)
])
.send()
.await {
Ok(v) => v,
Err(e) => break Response::Error(e.to_string())
};
match token_resp.status() {
StatusCode::OK => {
let token = match token_resp.json().await {
Ok(v) => v,
Err(e) => break Response::Error(e.to_string())
};
if let Err(e) = node.config
.write(|c| c.auth_token = Some(token))
.await {
break Response::Error(e.to_string());
};
break Response::Complete;
},
StatusCode::BAD_REQUEST => {
#[derive(Debug, Deserialize)]
struct OAuth400 {
error: String
}
let resp = match token_resp.json::<OAuth400>().await {
Ok(v) => v,
Err(e) => break Response::Error(e.to_string())
};
match resp.error.as_str() {
"authorization_pending" => continue,
e => {
break Response::Error(e.to_string())
}
}
},
s => {
break Response::Error(s.to_string());
}
}
}
}
})
})
.procedure(
"logout",
R.mutation(|node, _: ()| async move {
node.config
.write(|c| c.auth_token = None)
.await
.map(|_| ())
.map_err(|_| {
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to write config".to_string(),
)
})
}),
)
.procedure("me", {
R.query(|node, _: ()| async move {
let resp = sd_cloud_api::user::me(node.cloud_api_config().await).await?;
Ok(resp)
})
})
}

View File

@@ -381,6 +381,7 @@ async fn restore_backup(node: &Arc<Node>, path: impl AsRef<Path>) -> Result<Head
db_restored_path,
library_config_restored_path,
None,
None,
true,
node,
)

View File

@@ -1,218 +0,0 @@
use crate::{api::libraries::LibraryConfigWrapped, invalidate_query, library::LibraryName};
use reqwest::Response;
use rspc::alpha::AlphaRouter;
use serde::de::DeserializeOwned;
use uuid::Uuid;
use super::{utils::library, Ctx, R};
#[allow(unused)]
async fn parse_json_body<T: DeserializeOwned>(response: Response) -> Result<T, rspc::Error> {
response.json().await.map_err(|_| {
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"JSON conversion failed".to_string(),
)
})
}
pub(crate) fn mount() -> AlphaRouter<Ctx> {
R.router()
.merge("library.", library::mount())
.merge("locations.", locations::mount())
.procedure("getApiOrigin", {
R.query(|node, _: ()| async move { Ok(node.env.api_url.lock().await.to_string()) })
})
.procedure("setApiOrigin", {
R.mutation(|node, origin: String| async move {
let mut origin_env = node.env.api_url.lock().await;
origin_env.clone_from(&origin);
node.config
.write(|c| {
c.auth_token = None;
c.sd_api_origin = Some(origin);
})
.await
.ok();
Ok(())
})
})
}
mod library {
use std::str::FromStr;
use sd_p2p::RemoteIdentity;
use crate::util::MaybeUndefined;
use super::*;
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("get", {
R.with2(library())
.query(|(node, library), _: ()| async move {
Ok(
sd_cloud_api::library::get(node.cloud_api_config().await, library.id)
.await?,
)
})
})
.procedure("list", {
R.query(|node, _: ()| async move {
Ok(sd_cloud_api::library::list(node.cloud_api_config().await).await?)
})
})
.procedure("create", {
R.with2(library())
.mutation(|(node, library), _: ()| async move {
let node_config = node.config.get().await;
let cloud_library = sd_cloud_api::library::create(
node.cloud_api_config().await,
library.id,
&library.config().await.name,
library.instance_uuid,
library.identity.to_remote_identity(),
node_config.id,
node_config.identity.to_remote_identity(),
&node.p2p.peer_metadata(),
)
.await?;
node.libraries
.edit(
library.id,
None,
MaybeUndefined::Undefined,
MaybeUndefined::Value(cloud_library.id),
None,
)
.await?;
invalidate_query!(library, "cloud.library.get");
Ok(())
})
})
.procedure("join", {
R.mutation(|node, library_id: Uuid| async move {
let Some(cloud_library) =
sd_cloud_api::library::get(node.cloud_api_config().await, library_id)
.await?
else {
return Err(rspc::Error::new(
rspc::ErrorCode::NotFound,
"Library not found".to_string(),
));
};
let library = node
.libraries
.create_with_uuid(
library_id,
LibraryName::new(cloud_library.name).map_err(|e| {
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
e.to_string(),
)
})?,
None,
false,
None,
&node,
true,
)
.await?;
node.libraries
.edit(
library.id,
None,
MaybeUndefined::Undefined,
MaybeUndefined::Value(cloud_library.id),
None,
)
.await?;
let node_config = node.config.get().await;
let instances = sd_cloud_api::library::join(
node.cloud_api_config().await,
library_id,
library.instance_uuid,
library.identity.to_remote_identity(),
node_config.id,
node_config.identity.to_remote_identity(),
node.p2p.peer_metadata(),
)
.await?;
for instance in instances {
crate::cloud::sync::receive::upsert_instance(
library.id,
&library.db,
&library.sync,
&node.libraries,
&instance.uuid,
instance.identity,
&instance.node_id,
RemoteIdentity::from_str(&instance.node_remote_identity)
.expect("malformed remote identity in the DB"),
instance.metadata,
)
.await?;
}
invalidate_query!(library, "cloud.library.get");
invalidate_query!(library, "cloud.library.list");
Ok(LibraryConfigWrapped::from_library(&library).await)
})
})
.procedure("sync", {
R.with2(library())
.mutation(|(_, library), _: ()| async move {
library.do_cloud_sync();
Ok(())
})
})
}
}
mod locations {
use super::*;
use serde::{Deserialize, Serialize};
use specta::Type;
#[derive(Type, Serialize, Deserialize)]
pub struct CloudLocation {
id: String,
name: String,
}
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("list", {
R.query(|node, _: ()| async move {
sd_cloud_api::locations::list(node.cloud_api_config().await)
.await
.map_err(Into::into)
})
})
.procedure("create", {
R.mutation(|node, name: String| async move {
sd_cloud_api::locations::create(node.cloud_api_config().await, name)
.await
.map_err(Into::into)
})
})
.procedure("remove", {
R.mutation(|node, id: String| async move {
sd_cloud_api::locations::create(node.cloud_api_config().await, id)
.await
.map_err(Into::into)
})
})
}
}

View File

@@ -0,0 +1,397 @@
use crate::api::{Ctx, R};
use sd_core_cloud_services::QuinnConnection;
use sd_cloud_schema::{
auth::AccessToken,
devices::{self, DeviceOS, HardwareModel, PubId},
opaque_ke::{
ClientLogin, ClientLoginFinishParameters, ClientLoginFinishResult, ClientLoginStartResult,
ClientRegistration, ClientRegistrationFinishParameters, ClientRegistrationFinishResult,
ClientRegistrationStartResult,
},
Client, NodeId, Service, SpacedriveCipherSuite,
};
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng};
use blake3::Hash;
use futures::{FutureExt, SinkExt, StreamExt};
use futures_concurrency::future::TryJoin;
use rspc::alpha::AlphaRouter;
use serde::Deserialize;
use tracing::{debug, error};
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("get", {
R.query(|node, pub_id: devices::PubId| async move {
use devices::get::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
let Response(device) = super::handle_comm_error(
client
.devices()
.get(Request {
pub_id,
access_token,
})
.await,
"Failed to get device;",
)??;
debug!(?device, "Got device");
Ok(device)
})
})
.procedure("list", {
R.query(|node, _: ()| async move {
use devices::list::{Request, Response};
let ((client, access_token), pub_id) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
)
.try_join()
.await?;
let Response(mut devices) = super::handle_comm_error(
client.devices().list(Request { access_token }).await,
"Failed to list devices;",
)??;
// Filter out the local device by matching pub_id
devices.retain(|device| device.pub_id != pub_id);
debug!(?devices, "Listed devices");
Ok(devices)
})
})
.procedure("get_current_device", {
R.query(|node, _: ()| async move {
use devices::get::{Request, Response};
let ((client, access_token), pub_id) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
)
.try_join()
.await?;
let Response(device) = super::handle_comm_error(
client
.devices()
.get(Request {
pub_id,
access_token,
})
.await,
"Failed to get current device;",
)??;
Ok(device)
})
})
.procedure("delete", {
R.mutation(|node, pub_id: devices::PubId| async move {
use devices::delete::Request;
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.devices()
.delete(Request {
pub_id,
access_token,
})
.await,
"Failed to delete device;",
)??;
debug!("Deleted device");
Ok(())
})
})
.procedure("update", {
#[derive(Deserialize, specta::Type)]
struct CloudUpdateDeviceArgs {
pub_id: devices::PubId,
name: String,
}
R.mutation(
|node, CloudUpdateDeviceArgs { pub_id, name }: CloudUpdateDeviceArgs| async move {
use devices::update::Request;
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.devices()
.update(Request {
access_token,
pub_id,
name,
})
.await,
"Failed to update device;",
)??;
debug!("Updated device");
Ok(())
},
)
})
}
pub async fn hello(
client: &Client<QuinnConnection<Service>, Service>,
access_token: AccessToken,
device_pub_id: PubId,
hashed_pub_id: Hash,
rng: &mut CryptoRng,
) -> Result<SecretKey, rspc::Error> {
use devices::hello::{Request, RequestUpdate, Response, State};
let ClientLoginStartResult { message, state } =
ClientLogin::<SpacedriveCipherSuite>::start(rng, hashed_pub_id.as_bytes().as_slice())
.map_err(|e| {
error!(?e, "OPAQUE error initializing device hello request;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to initialize device login".into(),
)
})?;
let (mut hello_continuation, mut res_stream) = super::handle_comm_error(
client
.devices()
.hello(Request {
access_token,
pub_id: device_pub_id,
opaque_login_message: Box::new(message),
})
.await,
"Failed to send device hello request;",
)?;
let Some(res) = res_stream.next().await else {
let message = "Server did not send a device hello response;";
error!("{message}");
return Err(rspc::Error::new(
rspc::ErrorCode::InternalServerError,
message.to_string(),
));
};
let credential_response = match super::handle_comm_error(
res,
"Communication error on device hello response;",
)? {
Ok(Response(State::LoginResponse(credential_response))) => credential_response,
Ok(Response(State::End)) => {
unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;");
}
Err(e) => {
error!(?e, "Device hello response error;");
return Err(e.into());
}
};
let ClientLoginFinishResult {
message,
export_key,
..
} = state
.finish(
hashed_pub_id.as_bytes().as_slice(),
*credential_response,
ClientLoginFinishParameters::default(),
)
.map_err(|e| {
error!(?e, "Device hello finish error;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to finish device login".into(),
)
})?;
hello_continuation
.send(RequestUpdate {
opaque_login_finish: Box::new(message),
})
.await
.map_err(|e| {
error!(?e, "Failed to send device hello request continuation;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to finish device login procedure;".into(),
)
})?;
let Some(res) = res_stream.next().await else {
let message = "Server did not send a device hello END response;";
error!("{message}");
return Err(rspc::Error::new(
rspc::ErrorCode::InternalServerError,
message.to_string(),
));
};
match super::handle_comm_error(res, "Communication error on device hello response;")? {
Ok(Response(State::LoginResponse(_))) => {
unreachable!("Device hello final response MUST be End here, this is a serious bug and should crash;");
}
Ok(Response(State::End)) => {
// Protocol completed successfully
Ok(SecretKey::from(export_key))
}
Err(e) => {
error!(?e, "Device hello final response error;");
Err(e.into())
}
}
}
pub struct DeviceRegisterData {
pub pub_id: PubId,
pub name: String,
pub os: DeviceOS,
pub hardware_model: HardwareModel,
pub connection_id: NodeId,
}
pub async fn register(
client: &Client<QuinnConnection<Service>, Service>,
access_token: AccessToken,
DeviceRegisterData {
pub_id,
name,
os,
hardware_model,
connection_id,
}: DeviceRegisterData,
hashed_pub_id: Hash,
rng: &mut CryptoRng,
) -> Result<SecretKey, rspc::Error> {
use devices::register::{Request, RequestUpdate, Response, State};
let ClientRegistrationStartResult { message, state } =
ClientRegistration::<SpacedriveCipherSuite>::start(
rng,
hashed_pub_id.as_bytes().as_slice(),
)
.map_err(|e| {
error!(?e, "OPAQUE error initializing device register request;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to initialize device register".into(),
)
})?;
let (mut register_continuation, mut res_stream) = super::handle_comm_error(
client
.devices()
.register(Request {
access_token,
pub_id,
name,
os,
hardware_model,
connection_id,
opaque_register_message: Box::new(message),
})
.await,
"Failed to send device register request;",
)?;
let Some(res) = res_stream.next().await else {
let message = "Server did not send a device register response;";
error!("{message}");
return Err(rspc::Error::new(
rspc::ErrorCode::InternalServerError,
message.to_string(),
));
};
let registration_response = match super::handle_comm_error(
res,
"Communication error on device register response;",
)? {
Ok(Response(State::RegistrationResponse(res))) => res,
Ok(Response(State::End)) => {
unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;");
}
Err(e) => {
error!(?e, "Device hello response error;");
return Err(e.into());
}
};
let ClientRegistrationFinishResult {
message,
export_key,
..
} = state
.finish(
rng,
hashed_pub_id.as_bytes().as_slice(),
*registration_response,
ClientRegistrationFinishParameters::default(),
)
.map_err(|e| {
error!(?e, "Device register finish error;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to finish device register".into(),
)
})?;
register_continuation
.send(RequestUpdate {
opaque_registration_finish: Box::new(message),
})
.await
.map_err(|e| {
error!(?e, "Failed to send device register request continuation;");
rspc::Error::new(
rspc::ErrorCode::InternalServerError,
"Failed to finish device register procedure;".into(),
)
})?;
let Some(res) = res_stream.next().await else {
let message = "Server did not send a device register END response;";
error!("{message}");
return Err(rspc::Error::new(
rspc::ErrorCode::InternalServerError,
message.to_string(),
));
};
match super::handle_comm_error(res, "Communication error on device register response;")? {
Ok(Response(State::RegistrationResponse(_))) => {
unreachable!("Device register final response MUST be End here, this is a serious bug and should crash;");
}
Ok(Response(State::End)) => {
// Protocol completed successfully
Ok(SecretKey::from(export_key))
}
Err(e) => {
error!(?e, "Device register final response error;");
Err(e.into())
}
}
}

View File

@@ -0,0 +1,140 @@
use crate::api::{utils::library, Ctx, R};
use sd_cloud_schema::libraries;
use futures::FutureExt;
use futures_concurrency::future::TryJoin;
use rspc::alpha::AlphaRouter;
use serde::Deserialize;
use tracing::debug;
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("get", {
#[derive(Deserialize, specta::Type)]
struct CloudGetLibraryArgs {
pub_id: libraries::PubId,
with_device: bool,
}
R.query(
|node,
CloudGetLibraryArgs {
pub_id,
with_device,
}: CloudGetLibraryArgs| async move {
use libraries::get::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
let Response(library) = super::handle_comm_error(
client
.libraries()
.get(Request {
access_token,
pub_id,
with_device,
})
.await,
"Failed to get library;",
)??;
debug!(?library, "Got library");
Ok(library)
},
)
})
.procedure("list", {
R.query(|node, with_device: bool| async move {
use libraries::list::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
let Response(libraries) = super::handle_comm_error(
client
.libraries()
.list(Request {
access_token,
with_device,
})
.await,
"Failed to list libraries;",
)??;
debug!(?libraries, "Listed libraries");
Ok(libraries)
})
})
.procedure("create", {
R.with2(library())
.mutation(|(node, library), _: ()| async move {
let ((client, access_token), name, device_pub_id) = (
super::get_client_and_access_token(&node),
library.config().map(|config| Ok(config.name.to_string())),
node.config.get().map(|config| Ok(config.id.into())),
)
.try_join()
.await?;
super::handle_comm_error(
client
.libraries()
.create(libraries::create::Request {
name,
access_token,
pub_id: libraries::PubId(library.id),
device_pub_id,
})
.await,
"Failed to create library;",
)??;
Ok(())
})
})
.procedure("delete", {
R.with2(library())
.mutation(|(node, library), _: ()| async move {
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.libraries()
.delete(libraries::delete::Request {
access_token,
pub_id: libraries::PubId(library.id),
})
.await,
"Failed to delete library;",
)??;
debug!("Deleted library");
Ok(())
})
})
.procedure("update", {
R.with2(library())
.mutation(|(node, library), name: String| async move {
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.libraries()
.update(libraries::update::Request {
access_token,
pub_id: libraries::PubId(library.id),
name,
})
.await,
"Failed to update library;",
)??;
debug!("Updated library");
Ok(())
})
})
}

View File

@@ -0,0 +1,112 @@
use crate::api::{Ctx, R};
use sd_cloud_schema::{devices, libraries, locations};
use rspc::alpha::AlphaRouter;
use serde::Deserialize;
use tracing::debug;
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("list", {
#[derive(Deserialize, specta::Type)]
struct CloudListLocationsArgs {
pub library_pub_id: libraries::PubId,
pub with_library: bool,
pub with_device: bool,
}
R.query(
|node,
CloudListLocationsArgs {
library_pub_id,
with_library,
with_device,
}: CloudListLocationsArgs| async move {
use locations::list::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
let Response(locations) = super::handle_comm_error(
client
.locations()
.list(Request {
access_token,
library_pub_id,
with_library,
with_device,
})
.await,
"Failed to list locations;",
)??;
debug!(?locations, "Got locations");
Ok(locations)
},
)
})
.procedure("create", {
#[derive(Deserialize, specta::Type)]
struct CloudCreateLocationArgs {
pub pub_id: locations::PubId,
pub name: String,
pub library_pub_id: libraries::PubId,
pub device_pub_id: devices::PubId,
}
R.mutation(
|node,
CloudCreateLocationArgs {
pub_id,
name,
library_pub_id,
device_pub_id,
}: CloudCreateLocationArgs| async move {
use locations::create::Request;
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.locations()
.create(Request {
access_token,
pub_id,
name,
library_pub_id,
device_pub_id,
})
.await,
"Failed to list locations;",
)??;
debug!("Created cloud location");
Ok(())
},
)
})
.procedure("delete", {
R.mutation(|node, pub_id: locations::PubId| async move {
use locations::delete::Request;
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.locations()
.delete(Request {
access_token,
pub_id,
})
.await,
"Failed to list locations;",
)??;
debug!("Created cloud location");
Ok(())
})
})
}

316
core/src/api/cloud/mod.rs Normal file
View File

@@ -0,0 +1,316 @@
use crate::{
library::LibraryManagerError,
node::{config::NodeConfig, HardwareModel},
Node,
};
use sd_core_cloud_services::{CloudP2P, KeyManager, QuinnConnection, UserResponse};
use sd_cloud_schema::{
auth,
error::{ClientSideError, Error},
sync::groups,
users, Client, SecretKey as IrohSecretKey, Service,
};
use sd_crypto::{CryptoRng, SeedableRng};
use sd_utils::error::report_error;
use std::pin::pin;
use async_stream::stream;
use futures::{FutureExt, StreamExt};
use futures_concurrency::future::TryJoin;
use rspc::alpha::AlphaRouter;
use tracing::{debug, error, instrument};
use super::{Ctx, R};
mod devices;
mod libraries;
mod locations;
mod sync_groups;
async fn try_get_cloud_services_client(
node: &Node,
) -> Result<Client<QuinnConnection<Service>, Service>, sd_core_cloud_services::Error> {
node.cloud_services
.client()
.await
.map_err(report_error("Failed to get cloud services client"))
}
pub(crate) fn mount() -> AlphaRouter<Ctx> {
R.router()
.merge("libraries.", libraries::mount())
.merge("locations.", locations::mount())
.merge("devices.", devices::mount())
.merge("syncGroups.", sync_groups::mount())
.procedure("bootstrap", {
R.mutation(
|node, (access_token, refresh_token): (auth::AccessToken, auth::RefreshToken)| async move {
use sd_cloud_schema::devices;
// Only allow a single bootstrap request in flight at a time
let mut has_bootstrapped_lock = node
.cloud_services
.has_bootstrapped
.try_lock()
.map_err(|_| {
rspc::Error::new(
rspc::ErrorCode::Conflict,
String::from("Bootstrap in progress"),
)
})?;
if *has_bootstrapped_lock {
return Err(rspc::Error::new(
rspc::ErrorCode::Conflict,
String::from("Already bootstrapped"),
));
}
node.cloud_services
.token_refresher
.init(access_token, refresh_token)
.await?;
let client = try_get_cloud_services_client(&node).await?;
let data_directory = node.config.data_directory();
let mut rng =
CryptoRng::from_seed(node.master_rng.lock().await.generate_fixed());
// create user route is idempotent, so we can safely keep creating the same user over and over
handle_comm_error(
client
.users()
.create(users::create::Request {
access_token: node
.cloud_services
.token_refresher
.get_access_token()
.await?,
})
.await,
"Failed to create user;",
)??;
let (device_pub_id, name, os) = {
let NodeConfig { id, name, os, .. } = node.config.get().await;
(devices::PubId(id.into()), name, os)
};
let hashed_pub_id = blake3::hash(device_pub_id.0.as_bytes().as_slice());
let key_manager = match handle_comm_error(
client
.devices()
.get(devices::get::Request {
access_token: node
.cloud_services
.token_refresher
.get_access_token()
.await?,
pub_id: device_pub_id,
})
.await,
"Failed to get device on cloud bootstrap;",
)? {
Ok(_) => {
// Device registered, we execute a device hello flow
let master_key = self::devices::hello(
&client,
node.cloud_services
.token_refresher
.get_access_token()
.await?,
device_pub_id,
hashed_pub_id,
&mut rng,
)
.await?;
debug!("Device hello successful");
KeyManager::load(master_key, data_directory).await?
}
Err(Error::Client(ClientSideError::NotFound(_))) => {
// Device not registered, we execute a device register flow
let iroh_secret_key = IrohSecretKey::generate_with_rng(&mut rng);
let hardware_model = Into::into(
HardwareModel::try_get().unwrap_or(HardwareModel::Other),
);
let master_key = self::devices::register(
&client,
node.cloud_services
.token_refresher
.get_access_token()
.await?,
self::devices::DeviceRegisterData {
pub_id: device_pub_id,
name,
os,
hardware_model,
connection_id: iroh_secret_key.public(),
},
hashed_pub_id,
&mut rng,
)
.await?;
debug!("Device registered successfully");
KeyManager::new(master_key, iroh_secret_key, data_directory, &mut rng)
.await?
}
Err(e) => return Err(e.into()),
};
let iroh_secret_key = key_manager.iroh_secret_key().await;
node.cloud_services.set_key_manager(key_manager).await;
node.cloud_services
.set_cloud_p2p(
CloudP2P::new(
device_pub_id,
&node.cloud_services,
rng,
iroh_secret_key,
node.cloud_services.cloud_p2p_dns_origin_name.clone(),
node.cloud_services.cloud_p2p_dns_pkarr_url.clone(),
node.cloud_services.cloud_p2p_relay_url.clone(),
)
.await?,
)
.await;
let groups::list::Response(groups) = handle_comm_error(
client
.sync()
.groups()
.list(groups::list::Request {
access_token: node
.cloud_services
.token_refresher
.get_access_token()
.await?,
})
.await,
"Failed to list sync groups on bootstrap",
)??;
groups
.into_iter()
.map(
|groups::GroupBaseData {
pub_id,
library,
// TODO(@fogodev): We can use this latest key hash to check if we
// already have the latest key hash for this group locally
// issuing a ask for key hash request for other devices if we don't
latest_key_hash: _latest_key_hash,
..
}| {
let node = &node;
async move {
match initialize_cloud_sync(pub_id, library, node).await {
// If we don't have this library locally, we didn't joined this group yet
Ok(()) | Err(LibraryManagerError::LibraryNotFound) => {
Ok(())
}
Err(e) => Err(e),
}
}
},
)
.collect::<Vec<_>>()
.try_join()
.await?;
*has_bootstrapped_lock = true;
Ok(())
},
)
})
.procedure(
"listenCloudServicesNotifications",
R.subscription(|node, _: ()| async move {
stream! {
let mut notifications_stream =
pin!(node.cloud_services.stream_user_notifications());
while let Some(notification) = notifications_stream.next().await {
yield notification;
}
}
}),
)
.procedure(
"userResponse",
R.mutation(|node, response: UserResponse| async move {
node.cloud_services.send_user_response(response).await;
Ok(())
}),
)
.procedure(
"hasBootstrapped",
R.query(|node, _: ()| async move {
// If we can't lock immediately, it means that there is a bootstrap in progress
// so we didn't bootstrapped yet
Ok(node
.cloud_services
.has_bootstrapped
.try_lock()
.map(|lock| *lock)
.unwrap_or(false))
}),
)
}
fn handle_comm_error<T, E: std::error::Error + std::fmt::Debug + Send + Sync + 'static>(
res: Result<T, E>,
message: &'static str,
) -> Result<T, rspc::Error> {
res.map_err(|e| {
error!(?e, "Communication with cloud services error: {message}");
rspc::Error::with_cause(rspc::ErrorCode::InternalServerError, message.into(), e)
})
}
#[instrument(skip_all, fields(%group_pub_id, %library_pub_id), err)]
async fn initialize_cloud_sync(
group_pub_id: groups::PubId,
sd_cloud_schema::libraries::Library {
pub_id: sd_cloud_schema::libraries::PubId(library_pub_id),
..
}: sd_cloud_schema::libraries::Library,
node: &Node,
) -> Result<(), LibraryManagerError> {
let library = node
.libraries
.get_library(&library_pub_id)
.await
.ok_or(LibraryManagerError::LibraryNotFound)?;
library.init_cloud_sync(node, group_pub_id).await
}
async fn get_client_and_access_token(
node: &Node,
) -> Result<(Client<QuinnConnection<Service>, Service>, auth::AccessToken), rspc::Error> {
(
try_get_cloud_services_client(node),
node.cloud_services
.token_refresher
.get_access_token()
.map(|res| res.map_err(Into::into)),
)
.try_join()
.await
.map_err(Into::into)
}

View File

@@ -0,0 +1,404 @@
use crate::{
api::{utils::library, Ctx, R},
library::LibraryName,
Node,
};
use sd_core_cloud_services::JoinedLibraryCreateArgs;
use sd_cloud_schema::{
cloud_p2p, devices, libraries,
sync::{groups, KeyHash},
};
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng, SeedableRng};
use std::sync::Arc;
use futures::FutureExt;
use futures_concurrency::future::TryJoin;
use rspc::alpha::AlphaRouter;
use serde::{Deserialize, Serialize};
use tokio::{spawn, sync::oneshot};
use tracing::{debug, error};
pub fn mount() -> AlphaRouter<Ctx> {
R.router()
.procedure("create", {
R.with2(library())
.mutation(|(node, library), _: ()| async move {
use groups::create::{Request, Response};
let ((client, access_token), device_pub_id, mut rng, key_manager) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
node.master_rng
.lock()
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
node.cloud_services
.key_manager()
.map(|res| res.map_err(Into::into)),
)
.try_join()
.await?;
let new_key = SecretKey::generate(&mut rng);
let key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string());
let Response(group_pub_id) = super::handle_comm_error(
client
.sync()
.groups()
.create(Request {
access_token: access_token.clone(),
key_hash: key_hash.clone(),
library_pub_id: libraries::PubId(library.id),
device_pub_id,
})
.await,
"Failed to create sync group;",
)??;
if let Err(e) = key_manager
.add_key_with_hash(group_pub_id, new_key, key_hash.clone(), &mut rng)
.await
{
super::handle_comm_error(
client
.sync()
.groups()
.delete(groups::delete::Request {
access_token,
pub_id: group_pub_id,
})
.await,
"Failed to delete sync group after we failed to store secret key in key manager;",
)??;
return Err(e.into());
}
library.init_cloud_sync(&node, group_pub_id).await?;
debug!(%group_pub_id, ?key_hash, "Created sync group");
Ok(())
})
})
.procedure("delete", {
R.mutation(|node, pub_id: groups::PubId| async move {
use groups::delete::Request;
let (client, access_token) = super::get_client_and_access_token(&node).await?;
super::handle_comm_error(
client
.sync()
.groups()
.delete(Request {
access_token,
pub_id,
})
.await,
"Failed to delete sync group;",
)??;
debug!(%pub_id, "Deleted sync group");
Ok(())
})
})
.procedure("get", {
#[derive(Deserialize, specta::Type)]
struct CloudGetSyncGroupArgs {
pub pub_id: groups::PubId,
pub kind: groups::get::RequestKind,
}
// This is a compatibility layer because quic-rpc uses bincode for serialization
// and bincode doesn't support serde's tagged enums, and we need them for serializing
// to frontend
#[derive(Debug, Serialize, specta::Type)]
#[serde(tag = "kind", content = "data")]
pub enum CloudSyncGroupGetResponseKind {
WithDevices(groups::GroupWithDevices),
FullData(groups::Group),
}
impl From<groups::get::ResponseKind> for CloudSyncGroupGetResponseKind {
fn from(kind: groups::get::ResponseKind) -> Self {
match kind {
groups::get::ResponseKind::WithDevices(data) => {
CloudSyncGroupGetResponseKind::WithDevices(data)
}
groups::get::ResponseKind::FullData(data) => {
CloudSyncGroupGetResponseKind::FullData(data)
}
groups::get::ResponseKind::DevicesConnectionIds(_) => {
unreachable!(
"DevicesConnectionIds response is not expected, as we requested it"
);
}
}
}
}
R.query(
|node, CloudGetSyncGroupArgs { pub_id, kind }: CloudGetSyncGroupArgs| async move {
use groups::get::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
if matches!(kind, groups::get::RequestKind::DevicesConnectionIds) {
return Err(rspc::Error::new(
rspc::ErrorCode::PreconditionFailed,
"This request isn't allowed here".into(),
));
}
let Response(response_kind) = super::handle_comm_error(
client
.sync()
.groups()
.get(Request {
access_token,
pub_id,
kind,
})
.await,
"Failed to get sync group;",
)??;
debug!(?response_kind, "Got sync group");
Ok(CloudSyncGroupGetResponseKind::from(response_kind))
},
)
})
.procedure("leave", {
R.query(|node, pub_id: groups::PubId| async move {
let ((client, access_token), current_device_pub_id, mut rng, key_manager) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
node.master_rng
.lock()
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
node.cloud_services
.key_manager()
.map(|res| res.map_err(Into::into)),
)
.try_join()
.await?;
super::handle_comm_error(
client
.sync()
.groups()
.leave(groups::leave::Request {
access_token,
pub_id,
current_device_pub_id,
})
.await,
"Failed to leave sync group;",
)??;
key_manager.remove_group(pub_id, &mut rng).await?;
debug!(%pub_id, "Left sync group");
Ok(())
})
})
.procedure("list", {
R.query(|node, _: ()| async move {
use groups::list::{Request, Response};
let (client, access_token) = super::get_client_and_access_token(&node).await?;
let Response(groups) = super::handle_comm_error(
client.sync().groups().list(Request { access_token }).await,
"Failed to list groups;",
)??;
debug!(?groups, "Listed sync groups");
Ok(groups)
})
})
.procedure("remove_device", {
#[derive(Deserialize, specta::Type)]
struct CloudSyncGroupsRemoveDeviceArgs {
group_pub_id: groups::PubId,
to_remove_device_pub_id: devices::PubId,
}
R.query(
|node,
CloudSyncGroupsRemoveDeviceArgs {
group_pub_id,
to_remove_device_pub_id,
}: CloudSyncGroupsRemoveDeviceArgs| async move {
use groups::remove_device::Request;
let ((client, access_token), current_device_pub_id, mut rng, key_manager) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
node.master_rng
.lock()
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
node.cloud_services
.key_manager()
.map(|res| res.map_err(Into::into)),
)
.try_join()
.await?;
let new_key = SecretKey::generate(&mut rng);
let new_key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string());
key_manager
.add_key_with_hash(group_pub_id, new_key, new_key_hash.clone(), &mut rng)
.await?;
super::handle_comm_error(
client
.sync()
.groups()
.remove_device(Request {
access_token,
group_pub_id,
new_key_hash,
current_device_pub_id,
to_remove_device_pub_id,
})
.await,
"Failed to remove device from sync group;",
)??;
debug!(%to_remove_device_pub_id, %group_pub_id, "Removed device");
Ok(())
},
)
})
.procedure("request_join", {
#[derive(Deserialize, specta::Type)]
struct SyncGroupsRequestJoinArgs {
sync_group: groups::GroupWithDevices,
asking_device: devices::Device,
}
R.mutation(
|node,
SyncGroupsRequestJoinArgs {
sync_group,
asking_device,
}: SyncGroupsRequestJoinArgs| async move {
let ((client, access_token), current_device_pub_id, cloud_p2p) = (
super::get_client_and_access_token(&node),
node.config.get().map(|config| Ok(config.id.into())),
node.cloud_services
.cloud_p2p()
.map(|res| res.map_err(Into::into)),
)
.try_join()
.await?;
let group_pub_id = sync_group.pub_id;
debug!("My pub id: {:?}", current_device_pub_id);
debug!("Asking device pub id: {:?}", asking_device.pub_id);
if asking_device.pub_id != current_device_pub_id {
return Err(rspc::Error::new(
rspc::ErrorCode::BadRequest,
String::from("Asking device must be the current device"),
));
}
let groups::request_join::Response(existing_devices) =
super::handle_comm_error(
client
.sync()
.groups()
.request_join(groups::request_join::Request {
access_token,
group_pub_id,
current_device_pub_id,
})
.await,
"Failed to update library;",
)??;
let (tx, rx) = oneshot::channel();
cloud_p2p
.request_join_sync_group(
existing_devices,
cloud_p2p::authorize_new_device_in_sync_group::Request {
sync_group,
asking_device,
},
tx,
)
.await;
JoinedSyncGroupReceiver {
node,
group_pub_id,
rx,
}
.dispatch();
debug!(%group_pub_id, "Requested to join sync group");
Ok(())
},
)
})
}
struct JoinedSyncGroupReceiver {
node: Arc<Node>,
group_pub_id: groups::PubId,
rx: oneshot::Receiver<JoinedLibraryCreateArgs>,
}
impl JoinedSyncGroupReceiver {
fn dispatch(self) {
spawn(async move {
let Self {
node,
group_pub_id,
rx,
} = self;
if let Ok(JoinedLibraryCreateArgs {
pub_id: libraries::PubId(pub_id),
name,
description,
}) = rx.await
{
let Ok(name) =
LibraryName::new(name).map_err(|e| error!(?e, "Invalid library name"))
else {
return;
};
let Ok(library) = node
.libraries
.create_with_uuid(pub_id, name, description, true, None, &node)
.await
.map_err(|e| {
error!(?e, "Failed to create library from sync group join response")
})
else {
return;
};
if let Err(e) = library.init_cloud_sync(&node, group_pub_id).await {
error!(?e, "Failed to initialize cloud sync for library");
}
}
});
}
}

Some files were not shown because too many files have changed in this diff Show More