mirror of
https://github.com/spacedriveapp/spacedrive.git
synced 2026-05-01 11:53:36 -04:00
Merge pull request #2633 from spacedriveapp/eng-1828-migration-to-new-cloud-api-system
[ENG-1828] Migrate to new Cloud Services API
This commit is contained in:
BIN
Cargo.lock
generated
BIN
Cargo.lock
generated
Binary file not shown.
16
Cargo.toml
16
Cargo.toml
@@ -19,6 +19,9 @@ repository = "https://github.com/spacedriveapp/spacedrive"
|
||||
rust-version = "1.81"
|
||||
|
||||
[workspace.dependencies]
|
||||
# First party dependencies
|
||||
sd-cloud-schema = { git = "https://github.com/spacedriveapp/cloud-services-schema", rev = "bbc69c5cb2" }
|
||||
|
||||
# Third party dependencies used by one or more of our crates
|
||||
async-channel = "2.3"
|
||||
async-stream = "0.3.6"
|
||||
@@ -26,23 +29,25 @@ async-trait = "0.1.83"
|
||||
axum = "0.7.7"
|
||||
axum-extra = "0.9.4"
|
||||
base64 = "0.22.1"
|
||||
blake3 = "1.5"
|
||||
blake3 = "1.5.4"
|
||||
bytes = "1.7.1" # Update blocked by hyper
|
||||
chrono = "0.4.38"
|
||||
ed25519-dalek = "2.1"
|
||||
flume = "0.11.0"
|
||||
futures = "0.3.31"
|
||||
futures-concurrency = "7.6"
|
||||
globset = "0.4.15"
|
||||
http = "1.1"
|
||||
hyper = "1.5"
|
||||
image = "0.24.9" # Update blocked due to https://github.com/image-rs/image/issues/2230
|
||||
image = "0.25.4"
|
||||
itertools = "0.13.0"
|
||||
lending-stream = "1.0"
|
||||
libc = "0.2"
|
||||
libc = "0.2.159"
|
||||
mimalloc = "0.1.43"
|
||||
normpath = "1.3"
|
||||
pin-project-lite = "0.2.14"
|
||||
rand = "0.9.0-alpha.2"
|
||||
regex = "1"
|
||||
regex = "1.11"
|
||||
reqwest = { version = "0.12.8", default-features = false }
|
||||
rmp = "0.8.14"
|
||||
rmp-serde = "1.3"
|
||||
@@ -62,7 +67,8 @@ tracing-subscriber = "0.3.18"
|
||||
tracing-test = "0.2.5"
|
||||
uhlc = "0.8.0" # Must follow version used by specta
|
||||
uuid = "1.10" # Must follow version used by specta
|
||||
webp = "0.2.6" # Update blocked by image
|
||||
webp = "0.3.0"
|
||||
zeroize = "1.8"
|
||||
|
||||
[workspace.dependencies.rspc]
|
||||
git = "https://github.com/spacedriveapp/rspc.git"
|
||||
|
||||
@@ -20,26 +20,28 @@
|
||||
"@sd/ui": "workspace:*",
|
||||
"@t3-oss/env-core": "^0.7.1",
|
||||
"@tanstack/react-query": "^5.59",
|
||||
"@tauri-apps/api": "=2.0.2",
|
||||
"@tauri-apps/plugin-dialog": "2.0.0",
|
||||
"@tauri-apps/api": "=2.0.3",
|
||||
"@tauri-apps/plugin-dialog": "2.0.1",
|
||||
"@tauri-apps/plugin-http": "2.0.1",
|
||||
"@tauri-apps/plugin-os": "2.0.0",
|
||||
"@tauri-apps/plugin-shell": "2.0.0",
|
||||
"@tauri-apps/plugin-shell": "2.0.1",
|
||||
"consistent-hash": "^1.2.2",
|
||||
"immer": "^10.0.3",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-router-dom": "=6.20.1",
|
||||
"sonner": "^1.0.3"
|
||||
"sonner": "^1.0.3",
|
||||
"supertokens-web-js": "^0.13.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@sd/config": "workspace:*",
|
||||
"@sentry/vite-plugin": "^2.16.0",
|
||||
"@tauri-apps/cli": "2.0.1",
|
||||
"@tauri-apps/cli": "2.0.4",
|
||||
"@types/react": "^18.2.67",
|
||||
"@types/react-dom": "^18.2.22",
|
||||
"sass": "^1.72.0",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^5.2.0",
|
||||
"vite-tsconfig-paths": "^4.3.2"
|
||||
"vite": "^5.4.9",
|
||||
"vite-tsconfig-paths": "^5.0.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,12 +35,15 @@ uuid = { workspace = true, features = ["serde"] }
|
||||
|
||||
# Specific Desktop dependencies
|
||||
# WARNING: Do NOT enable default features, as that vendors dbus (see below)
|
||||
opener = { version = "0.7.1", features = ["reveal"], default-features = false }
|
||||
specta-typescript = "=0.0.7"
|
||||
tauri-plugin-dialog = "=2.0.2"
|
||||
tauri-plugin-os = "=2.0.1"
|
||||
tauri-plugin-shell = "=2.0.2"
|
||||
tauri-plugin-updater = "=2.0.2"
|
||||
opener = { version = "0.7.1", features = ["reveal"], default-features = false }
|
||||
specta-typescript = "=0.0.7"
|
||||
tauri-plugin-clipboard-manager = "=2.0.1"
|
||||
tauri-plugin-deep-link = "=2.0.1"
|
||||
tauri-plugin-dialog = "=2.0.3"
|
||||
tauri-plugin-http = "=2.0.3"
|
||||
tauri-plugin-os = "=2.0.1"
|
||||
tauri-plugin-shell = "=2.0.2"
|
||||
tauri-plugin-updater = "=2.0.2"
|
||||
|
||||
# memory allocator
|
||||
mimalloc = { workspace = true }
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
"dialog:allow-open",
|
||||
"dialog:allow-save",
|
||||
"dialog:allow-confirm",
|
||||
"deep-link:default",
|
||||
"os:allow-os-type",
|
||||
"core:window:allow-close",
|
||||
"core:window:allow-create",
|
||||
@@ -24,6 +25,32 @@
|
||||
"core:window:allow-minimize",
|
||||
"core:window:allow-toggle-maximize",
|
||||
"core:window:allow-start-dragging",
|
||||
"core:webview:allow-internal-toggle-devtools"
|
||||
"core:webview:allow-internal-toggle-devtools",
|
||||
{
|
||||
"identifier": "http:default",
|
||||
"allow": [
|
||||
{
|
||||
"url": "http://ipc.localhost"
|
||||
},
|
||||
{
|
||||
"url": "http://asset.localhost"
|
||||
},
|
||||
{
|
||||
"url": "http://localhost:8001"
|
||||
},
|
||||
{
|
||||
"url": "http://tauri.localhost"
|
||||
},
|
||||
{
|
||||
"url": "http://localhost:9420"
|
||||
},
|
||||
{
|
||||
"url": "https://auth.spacedrive.com"
|
||||
},
|
||||
{
|
||||
"url": "https://plausible.io"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ use sd_core::{Node, NodeError};
|
||||
use sd_fda::DiskAccess;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use specta_typescript::Typescript;
|
||||
use tauri::Emitter;
|
||||
use tauri::{async_runtime::block_on, webview::PlatformWebview, AppHandle, Manager, WindowEvent};
|
||||
use tauri::{Emitter, Listener};
|
||||
use tauri_plugins::{sd_error_plugin, sd_server_plugin};
|
||||
use tauri_specta::{collect_events, Builder};
|
||||
use tokio::task::block_in_place;
|
||||
use tokio::time::sleep;
|
||||
use tracing::error;
|
||||
use tracing::{debug, error};
|
||||
|
||||
mod file;
|
||||
mod menu;
|
||||
@@ -179,7 +179,11 @@ pub enum DragAndDropEvent {
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
const CLIENT_ID: &str = "2abb241e-40b8-4517-a3e3-5594375c8fbb";
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, specta::Type, tauri_specta::Event)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DeepLinkEvent {
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> tauri::Result<()> {
|
||||
@@ -221,9 +225,20 @@ async fn main() -> tauri::Result<()> {
|
||||
|
||||
tauri::Builder::default()
|
||||
.invoke_handler(builder.invoke_handler())
|
||||
.plugin(tauri_plugin_deep_link::init())
|
||||
.setup(move |app| {
|
||||
// We need a the app handle to determine the data directory now.
|
||||
// This means all the setup code has to be within `setup`, however it doesn't support async so we `block_on`.
|
||||
let handle = app.handle().clone();
|
||||
app.listen("deep-link://new-url", move |event| {
|
||||
let deep_link_event = DeepLinkEvent {
|
||||
data: event.payload().to_string(),
|
||||
};
|
||||
println!("Deep link event={:?}", deep_link_event);
|
||||
|
||||
handle.emit("deeplink", deep_link_event).unwrap();
|
||||
});
|
||||
|
||||
block_in_place(|| {
|
||||
block_on(async move {
|
||||
builder.mount_events(app);
|
||||
@@ -239,10 +254,7 @@ async fn main() -> tauri::Result<()> {
|
||||
|
||||
// The `_guard` must be assigned to variable for flushing remaining logs on main exit through Drop
|
||||
let (_guard, result) = match Node::init_logger(&data_dir) {
|
||||
Ok(guard) => (
|
||||
Some(guard),
|
||||
Node::new(data_dir, sd_core::Env::new(CLIENT_ID)).await,
|
||||
),
|
||||
Ok(guard) => (Some(guard), Node::new(data_dir).await),
|
||||
Err(err) => (None, Err(NodeError::Logger(err))),
|
||||
};
|
||||
|
||||
@@ -256,7 +268,7 @@ async fn main() -> tauri::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let should_clear_localstorage = node.libraries.get_all().await.is_empty();
|
||||
let should_clear_local_storage = node.libraries.get_all().await.is_empty();
|
||||
|
||||
handle.plugin(rspc::integrations::tauri::plugin(router, {
|
||||
let node = node.clone();
|
||||
@@ -266,8 +278,8 @@ async fn main() -> tauri::Result<()> {
|
||||
handle.manage(node.clone());
|
||||
|
||||
handle.windows().iter().for_each(|(_, window)| {
|
||||
if should_clear_localstorage {
|
||||
println!("cleaning localStorage");
|
||||
if should_clear_local_storage {
|
||||
debug!("cleaning localStorage");
|
||||
for webview in window.webviews() {
|
||||
webview.eval("localStorage.clear();").ok();
|
||||
}
|
||||
@@ -344,6 +356,7 @@ async fn main() -> tauri::Result<()> {
|
||||
.plugin(tauri_plugin_dialog::init())
|
||||
.plugin(tauri_plugin_os::init())
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_http::init())
|
||||
// TODO: Bring back Tauri Plugin Window State - it was buggy so we removed it.
|
||||
.plugin(tauri_plugin_updater::Builder::new().build())
|
||||
.plugin(updater::plugin())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.2/core/tauri-config-schema/schema.json",
|
||||
"$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.8/crates/tauri-cli/tauri.config.schema.json",
|
||||
"productName": "Spacedrive",
|
||||
"identifier": "com.spacedrive.desktop",
|
||||
"build": {
|
||||
@@ -36,7 +36,12 @@
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": "default-src webkit-pdfjs-viewer: asset: https://asset.localhost blob: data: filesystem: ws: wss: http: https: tauri: 'unsafe-eval' 'unsafe-inline' 'self' img-src: 'self'"
|
||||
"csp": {
|
||||
"default-src": "'self' webkit-pdfjs-viewer: asset: http://asset.localhost blob: data: filesystem: http: https: tauri:",
|
||||
"connect-src": "'self' ipc: http://ipc.localhost ws: wss: http: https: tauri:",
|
||||
"img-src": "'self' asset: http://asset.localhost blob: data: filesystem: http: https: tauri:",
|
||||
"style-src": "'self' 'unsafe-inline' http: https: tauri:"
|
||||
}
|
||||
}
|
||||
},
|
||||
"bundle": {
|
||||
@@ -100,6 +105,12 @@
|
||||
"endpoints": [
|
||||
"https://spacedrive.com/api/releases/tauri/{{version}}/{{target}}/{{arch}}"
|
||||
]
|
||||
},
|
||||
"deep-link": {
|
||||
"mobile": [],
|
||||
"desktop": {
|
||||
"schemes": ["spacedrive"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@ import { QueryClientProvider } from '@tanstack/react-query';
|
||||
import { listen } from '@tauri-apps/api/event';
|
||||
import { PropsWithChildren, startTransition, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { RspcProvider } from '@sd/client';
|
||||
import { RspcProvider, useBridgeMutation } from '@sd/client';
|
||||
import {
|
||||
createRoutes,
|
||||
DeeplinkEvent,
|
||||
ErrorPage,
|
||||
KeybindEvent,
|
||||
PlatformProvider,
|
||||
@@ -17,14 +18,11 @@ import { RouteTitleContext } from '@sd/interface/hooks/useRouteTitle';
|
||||
|
||||
import '@sd/ui/style/style.scss';
|
||||
|
||||
import { useLocale } from '@sd/interface/hooks';
|
||||
|
||||
import { commands } from './commands';
|
||||
import { platform } from './platform';
|
||||
import { queryClient } from './query';
|
||||
import { createMemoryRouterWithHistory } from './router';
|
||||
import { createUpdater } from './updater';
|
||||
|
||||
import SuperTokens from 'supertokens-web-js';
|
||||
import EmailPassword from 'supertokens-web-js/recipe/emailpassword';
|
||||
import Passwordless from 'supertokens-web-js/recipe/passwordless';
|
||||
import Session from 'supertokens-web-js/recipe/session';
|
||||
import ThirdParty from 'supertokens-web-js/recipe/thirdparty';
|
||||
// TODO: Bring this back once upstream is fixed up.
|
||||
// const client = hooks.createClient({
|
||||
// links: [
|
||||
@@ -34,6 +32,32 @@ import { createUpdater } from './updater';
|
||||
// tauriLink()
|
||||
// ]
|
||||
// });
|
||||
import getCookieHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/cookieHandler';
|
||||
import getWindowHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/windowHandler';
|
||||
import { useLocale } from '@sd/interface/hooks';
|
||||
import { AUTH_SERVER_URL, getTokens } from '@sd/interface/util';
|
||||
|
||||
import { commands } from './commands';
|
||||
import { platform } from './platform';
|
||||
import { queryClient } from './query';
|
||||
import { createMemoryRouterWithHistory } from './router';
|
||||
import { createUpdater } from './updater';
|
||||
|
||||
SuperTokens.init({
|
||||
appInfo: {
|
||||
apiDomain: AUTH_SERVER_URL,
|
||||
apiBasePath: '/api/auth',
|
||||
appName: 'Spacedrive Auth Service'
|
||||
},
|
||||
cookieHandler: getCookieHandler,
|
||||
windowHandler: getWindowHandler,
|
||||
recipeList: [
|
||||
Session.init({ tokenTransferMethod: 'header' }),
|
||||
EmailPassword.init(),
|
||||
ThirdParty.init(),
|
||||
Passwordless.init()
|
||||
]
|
||||
});
|
||||
|
||||
const startupError = (window as any).__SD_ERROR__ as string | undefined;
|
||||
|
||||
@@ -41,15 +65,31 @@ export default function App() {
|
||||
useEffect(() => {
|
||||
// This tells Tauri to show the current window because it's finished loading
|
||||
commands.appReady();
|
||||
// .then(() => {
|
||||
// if (import.meta.env.PROD) window.fetch = fetch;
|
||||
// });
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const keybindListener = listen('keybind', (input) => {
|
||||
document.dispatchEvent(new KeybindEvent(input.payload as string));
|
||||
});
|
||||
const deeplinkListener = listen('deeplink', async (data) => {
|
||||
const payload = (data.payload as any).data as string;
|
||||
if (!payload) return;
|
||||
const json = JSON.parse(payload)[0];
|
||||
if (!json) return;
|
||||
//json output: "spacedrive://-/URL"
|
||||
if (typeof json !== 'string') return;
|
||||
if (!json.startsWith('spacedrive://-')) return;
|
||||
const url = (json as string).split('://-/')[1];
|
||||
if (!url) return;
|
||||
document.dispatchEvent(new DeeplinkEvent(url));
|
||||
});
|
||||
|
||||
return () => {
|
||||
keybindListener.then((unlisten) => unlisten());
|
||||
deeplinkListener.then((unlisten) => unlisten());
|
||||
};
|
||||
}, []);
|
||||
|
||||
@@ -79,6 +119,15 @@ type RedirectPath = { pathname: string; search: string | undefined };
|
||||
function AppInner() {
|
||||
const [tabs, setTabs] = useState(() => [createTab()]);
|
||||
const [selectedTabIndex, setSelectedTabIndex] = useState(0);
|
||||
const tokens = getTokens();
|
||||
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
|
||||
|
||||
useEffect(() => {
|
||||
// If the access token and/or refresh token are missing, we need to skip the cloud bootstrap
|
||||
if (tokens.accessToken.length === 0 || tokens.refreshToken.length === 0) return;
|
||||
cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]);
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const selectedTab = tabs[selectedTabIndex]!;
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
"declarationDir": "dist",
|
||||
"paths": {
|
||||
"~/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"moduleResolution": "bundler"
|
||||
},
|
||||
"include": ["src"],
|
||||
"references": [
|
||||
|
||||
@@ -19,18 +19,6 @@ const Pre: FC<{ children: React.ReactNode }> = ({ children }) => {
|
||||
|
||||
return (
|
||||
<div ref={textInput} className="relative">
|
||||
{/* <button
|
||||
aria-label="Copy code"
|
||||
type="button"
|
||||
className="absolute right-2 top-2 z-10 rounded-md bg-app-box p-3 text-white/60 transition-colors duration-200 ease-in-out hover:bg-app-darkBox"
|
||||
onClick={onCopy}
|
||||
>
|
||||
{copied ? (
|
||||
<Check size={18} className="text-green-400" />
|
||||
) : (
|
||||
<Copy size={18} className="text-white opacity-70" />
|
||||
)}
|
||||
</button> */}
|
||||
<Button
|
||||
size="md"
|
||||
rounding="both"
|
||||
|
||||
2
apps/mobile/.gitignore
vendored
2
apps/mobile/.gitignore
vendored
@@ -312,6 +312,7 @@ buck-out/
|
||||
.fakebuckversion
|
||||
|
||||
### ReactNative.Xcode Stack ###
|
||||
ios/
|
||||
|
||||
## Xcode 8 and earlier
|
||||
|
||||
@@ -319,6 +320,7 @@ buck-out/
|
||||
.gradle
|
||||
**/build/
|
||||
!src/**/build/
|
||||
android/
|
||||
|
||||
# Ignore Gradle GUI config
|
||||
gradle-app.setting
|
||||
|
||||
@@ -15,17 +15,17 @@ err() {
|
||||
|
||||
if [ -z "${HOME:-}" ]; then
|
||||
case "$(uname)" in
|
||||
"Darwin")
|
||||
HOME="$(CDPATH='' cd -- "$(osascript -e 'set output to (POSIX path of (path to home folder))')" && pwd -P)"
|
||||
;;
|
||||
"Linux")
|
||||
HOME="$(CDPATH='' cd -- "$(getent passwd "$(id -un)" | cut -d: -f6)" && pwd -P)"
|
||||
;;
|
||||
*)
|
||||
err "Your OS ($(uname)) is not supported by this script." \
|
||||
'We would welcome a PR or some help adding your OS to this script.' \
|
||||
'https://github.com/spacedriveapp/spacedrive/issues'
|
||||
;;
|
||||
"Darwin")
|
||||
HOME="$(CDPATH='' cd -- "$(osascript -e 'set output to (POSIX path of (path to home folder))')" && pwd -P)"
|
||||
;;
|
||||
"Linux")
|
||||
HOME="$(CDPATH='' cd -- "$(getent passwd "$(id -un)" | cut -d: -f6)" && pwd -P)"
|
||||
;;
|
||||
*)
|
||||
err "Your OS ($(uname)) is not supported by this script." \
|
||||
'We would welcome a PR or some help adding your OS to this script.' \
|
||||
'https://github.com/spacedriveapp/spacedrive/issues'
|
||||
;;
|
||||
esac
|
||||
|
||||
export HOME
|
||||
@@ -47,18 +47,19 @@ export PATH="${CARGO_HOME:-"${HOME}/.cargo"}/bin:$PATH"
|
||||
if [ "${CI:-}" = "true" ]; then
|
||||
# TODO: This need to be adjusted for future mobile release CI
|
||||
case "$(uname -m)" in
|
||||
"arm64" | "aarch64")
|
||||
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
|
||||
;;
|
||||
"x86_64")
|
||||
ANDROID_BUILD_TARGET_LIST="x86_64"
|
||||
;;
|
||||
*)
|
||||
err 'Unsupported architecture for CI build.'
|
||||
;;
|
||||
"arm64" | "aarch64")
|
||||
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
|
||||
;;
|
||||
"x86_64")
|
||||
ANDROID_BUILD_TARGET_LIST="x86_64"
|
||||
;;
|
||||
*)
|
||||
err 'Unsupported architecture for CI build.'
|
||||
;;
|
||||
esac
|
||||
else
|
||||
ANDROID_BUILD_TARGET_LIST="arm64-v8a armeabi-v7a x86_64"
|
||||
# ANDROID_BUILD_TARGET_LIST="arm64-v8a armeabi-v7a x86_64"
|
||||
ANDROID_BUILD_TARGET_LIST="arm64-v8a"
|
||||
fi
|
||||
|
||||
# Configure build targets CLI arg for `cargo ndk`
|
||||
@@ -69,4 +70,5 @@ for _target in $ANDROID_BUILD_TARGET_LIST; do
|
||||
done
|
||||
|
||||
cd "${__dirname}/crate"
|
||||
cargo ndk --platform 34 "$@" -o "$OUTPUT_DIRECTORY" build --release
|
||||
cargo ndk --platform 34 "$@" -o "$OUTPUT_DIRECTORY" build
|
||||
# \ --release
|
||||
|
||||
@@ -37,9 +37,7 @@ pub extern "system" fn Java_com_spacedrive_core_SDCoreModule_registerCoreEventLi
|
||||
if let Err(err) = result {
|
||||
// TODO: Send rspc error or something here so we can show this in the UI.
|
||||
// TODO: Maybe reinitialise the core cause it could be in an invalid state?
|
||||
println!(
|
||||
"Error in Java_com_spacedrive_core_SDCoreModule_registerCoreEventListener: {err:?}"
|
||||
);
|
||||
error!("Error in Java_com_spacedrive_core_SDCoreModule_registerCoreEventListener: {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ license.workspace = true
|
||||
repository.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
|
||||
# Spacedrive Sub-crates
|
||||
[target.'cfg(target_os = "ios")'.dependencies]
|
||||
sd-core = { default-features = false, features = [
|
||||
|
||||
@@ -34,8 +34,6 @@ pub static SUBSCRIPTIONS: LazyLock<
|
||||
|
||||
pub static EVENT_SENDER: OnceLock<mpsc::Sender<Response>> = OnceLock::new();
|
||||
|
||||
pub const CLIENT_ID: &str = "d068776a-05b6-4aaa-9001-4d01734e1944";
|
||||
|
||||
pub struct MobileSender<'a> {
|
||||
resp: &'a mut Option<Response>,
|
||||
}
|
||||
@@ -76,7 +74,7 @@ pub fn handle_core_msg(
|
||||
None => {
|
||||
let _guard = Node::init_logger(&data_dir);
|
||||
|
||||
let new_node = match Node::new(data_dir, sd_core::Env::new(CLIENT_ID)).await {
|
||||
let new_node = match Node::new(data_dir).await {
|
||||
Ok(node) => node,
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to initialize node;");
|
||||
|
||||
@@ -37,7 +37,8 @@ Pod::Spec.new do |s|
|
||||
ffmpeg_frameworks = [
|
||||
"-framework AudioToolbox",
|
||||
"-framework VideoToolbox",
|
||||
"-framework AVFoundation"
|
||||
"-framework AVFoundation",
|
||||
"-framework SystemConfiguration",
|
||||
].join(' ')
|
||||
|
||||
s.xcconfig = {
|
||||
|
||||
@@ -48,10 +48,10 @@ mkdir -p "$TARGET_DIRECTORY"
|
||||
TARGET_DIRECTORY="$(CDPATH='' cd -- "$TARGET_DIRECTORY" && pwd -P)"
|
||||
|
||||
TARGET_CONFIG=debug
|
||||
if [ "${CONFIGURATION:-}" = "Release" ]; then
|
||||
set -- --release
|
||||
TARGET_CONFIG=release
|
||||
fi
|
||||
# if [ "${CONFIGURATION:-}" = "Release" ]; then
|
||||
# set -- --release
|
||||
# TARGET_CONFIG=release
|
||||
# fi
|
||||
|
||||
trap 'if [ -e "${CARGO_CONFIG}.bak" ]; then mv "${CARGO_CONFIG}.bak" "$CARGO_CONFIG"; fi' EXIT
|
||||
|
||||
@@ -59,21 +59,21 @@ trap 'if [ -e "${CARGO_CONFIG}.bak" ]; then mv "${CARGO_CONFIG}.bak" "$CARGO_CON
|
||||
RUST_PATH="${CARGO_HOME:-"${HOME}/.cargo"}/bin:$(brew --prefix)/bin:$(env -i /bin/bash --noprofile --norc -c 'echo $PATH')"
|
||||
if [ "${PLATFORM_NAME:-}" = "iphonesimulator" ]; then
|
||||
case "$(uname -m)" in
|
||||
"arm64" | "aarch64") # M series
|
||||
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios-sim\" }|" "$CARGO_CONFIG"
|
||||
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target aarch64-apple-ios-sim "$@"
|
||||
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/aarch64-apple-ios-sim/${TARGET_CONFIG}/libsd_mobile_ios.a"
|
||||
symlink_libs "${DEPS}/aarch64-apple-ios-sim/lib" "$TARGET_DIRECTORY"
|
||||
;;
|
||||
"x86_64") # Intel
|
||||
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/x86_64-apple-ios\" }|" "$CARGO_CONFIG"
|
||||
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target x86_64-apple-ios "$@"
|
||||
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/x86_64-apple-ios/${TARGET_CONFIG}/libsd_mobile_ios.a"
|
||||
symlink_libs "${DEPS}/x86_64-apple-ios/lib" "$TARGET_DIRECTORY"
|
||||
;;
|
||||
*)
|
||||
err 'Unsupported architecture.'
|
||||
;;
|
||||
"arm64" | "aarch64") # M series
|
||||
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios-sim\" }|" "$CARGO_CONFIG"
|
||||
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target aarch64-apple-ios-sim "$@"
|
||||
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/aarch64-apple-ios-sim/${TARGET_CONFIG}/libsd_mobile_ios.a"
|
||||
symlink_libs "${DEPS}/aarch64-apple-ios-sim/lib" "$TARGET_DIRECTORY"
|
||||
;;
|
||||
"x86_64") # Intel
|
||||
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/x86_64-apple-ios\" }|" "$CARGO_CONFIG"
|
||||
env CARGO_FEATURE_STATIC=1 PATH="$RUST_PATH" cargo build -p sd-mobile-ios --target x86_64-apple-ios "$@"
|
||||
lipo -create -output "$TARGET_DIRECTORY"/libsd_mobile_iossim.a "${TARGET_DIRECTORY}/x86_64-apple-ios/${TARGET_CONFIG}/libsd_mobile_ios.a"
|
||||
symlink_libs "${DEPS}/x86_64-apple-ios/lib" "$TARGET_DIRECTORY"
|
||||
;;
|
||||
*)
|
||||
err 'Unsupported architecture.'
|
||||
;;
|
||||
esac
|
||||
else
|
||||
sed -i.bak "s|FFMPEG_DIR = { force = true, value = \".*\" }|FFMPEG_DIR = { force = true, value = \"${DEPS}/aarch64-apple-ios\" }|" "$CARGO_CONFIG"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"start": "expo start --dev-client",
|
||||
"android": "expo run:android",
|
||||
"android": "java -version 2>&1 | grep -q 'version \"17' && expo run:android || echo 'Java version 17 is required to be the running version. Please switch or uninstall the current version (Version '$(java -version 2>&1 | awk -F '\"' '/version/ {print $2}')').' && exit 1",
|
||||
"ios": "expo run:ios",
|
||||
"prebuild": "expo prebuild",
|
||||
"xcode": "open ios/Spacedrive.xcworkspace",
|
||||
@@ -22,11 +22,10 @@
|
||||
"@gorhom/bottom-sheet": "^4.6.1",
|
||||
"@hookform/resolvers": "^3.1.0",
|
||||
"@spacedrive/rspc-client": "github:spacedriveapp/rspc#path:packages/client&6a77167495",
|
||||
"@spacedrive/rspc-react": "github:spacedriveapp/rspc#path:packages/react&6a77167495",
|
||||
"@react-native-async-storage/async-storage": "~1.23.1",
|
||||
"@react-native-masked-view/masked-view": "^0.3.1",
|
||||
"@react-navigation/bottom-tabs": "^6.5.19",
|
||||
"@react-navigation/drawer": "^6.6.14",
|
||||
"@react-navigation/drawer": "^6.6.15",
|
||||
"@react-navigation/native": "^6.1.16",
|
||||
"@react-navigation/native-stack": "^6.9.25",
|
||||
"@sd/assets": "workspace:*",
|
||||
@@ -37,12 +36,12 @@
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"dayjs": "^1.11.10",
|
||||
"event-target-polyfill": "^0.0.4",
|
||||
"expo": "~51.0.28",
|
||||
"expo-av": "^14.0.6",
|
||||
"expo": "~51.0.32",
|
||||
"expo-av": "^14.0.7",
|
||||
"expo-blur": "^13.0.2",
|
||||
"expo-build-properties": "~0.12.5",
|
||||
"expo-haptics": "~13.0.1",
|
||||
"expo-image": "^1.12.13",
|
||||
"expo-image": "^1.12.15",
|
||||
"expo-linking": "~6.3.1",
|
||||
"expo-media-library": "~16.0.4",
|
||||
"expo-splash-screen": "~0.27.5",
|
||||
@@ -71,6 +70,7 @@
|
||||
"react-native-wheel-color-picker": "^1.2.0",
|
||||
"rive-react-native": "^6.2.3",
|
||||
"solid-js": "^1.8.8",
|
||||
"supertokens-react-native": "^5.1.2",
|
||||
"twrnc": "^4.1.0",
|
||||
"use-count-up": "^3.0.1",
|
||||
"use-debounce": "^9.0.4",
|
||||
|
||||
@@ -17,6 +17,7 @@ import { Alert, LogBox, Permission, PermissionsAndroid, Platform } from 'react-n
|
||||
import { GestureHandlerRootView } from 'react-native-gesture-handler';
|
||||
import { MenuProvider } from 'react-native-popup-menu';
|
||||
import { SafeAreaProvider } from 'react-native-safe-area-context';
|
||||
import SuperTokens from 'supertokens-react-native';
|
||||
import { useSnapshot } from 'valtio';
|
||||
import {
|
||||
ClientContextProvider,
|
||||
@@ -24,7 +25,9 @@ import {
|
||||
LibraryContextProvider,
|
||||
P2PContextProvider,
|
||||
RspcProvider,
|
||||
useBridgeMutation,
|
||||
useBridgeQuery,
|
||||
useBridgeSubscription,
|
||||
useClientContext,
|
||||
useInvalidateQuery,
|
||||
usePlausibleEvent,
|
||||
@@ -33,12 +36,13 @@ import {
|
||||
} from '@sd/client';
|
||||
|
||||
import { GlobalModals } from './components/modal/GlobalModals';
|
||||
import { Toast, toastConfig } from './components/primitive/Toast';
|
||||
import { toast, Toast, toastConfig } from './components/primitive/Toast';
|
||||
import { useTheme } from './hooks/useTheme';
|
||||
import { changeTwTheme, tw } from './lib/tailwind';
|
||||
import RootNavigator from './navigation';
|
||||
import OnboardingNavigator from './navigation/OnboardingNavigator';
|
||||
import { P2P } from './screens/p2p/P2P';
|
||||
import { AUTH_SERVER_URL } from './utils';
|
||||
import { currentLibraryStore } from './utils/nav';
|
||||
|
||||
LogBox.ignoreLogs(['Sending `onAnimatedValueUpdate` with no listeners registered.']);
|
||||
@@ -129,6 +133,41 @@ function AppContainer() {
|
||||
useInvalidateQuery();
|
||||
|
||||
const { id } = useSnapshot(currentLibraryStore);
|
||||
const userResponse = useBridgeMutation('cloud.userResponse');
|
||||
|
||||
useBridgeSubscription(['cloud.listenCloudServicesNotifications'], {
|
||||
onData: (d) => {
|
||||
console.log('Received cloud service notification', d);
|
||||
switch (d.kind) {
|
||||
case 'ReceivedJoinSyncGroupRequest':
|
||||
// WARNING: This is a debug solution to accept the device into the sync group. THIS SHOULD NOT MAKE IT TO PRODUCTION
|
||||
userResponse.mutate({
|
||||
kind: 'AcceptDeviceInSyncGroup',
|
||||
data: {
|
||||
ticket: d.data.ticket,
|
||||
accepted: {
|
||||
id: d.data.sync_group.library.pub_id,
|
||||
name: d.data.sync_group.library.name,
|
||||
description: null
|
||||
}
|
||||
}
|
||||
});
|
||||
// TODO: Move the code above into the dialog below (@Rocky43007)
|
||||
// dialogManager.create((dp) => (
|
||||
// <RequestAddDialog
|
||||
// device_model={'MacBookPro'}
|
||||
// device_name={"Arnab's Macbook"}
|
||||
// library_name={"Arnab's Library"}
|
||||
// {...dp}
|
||||
// />
|
||||
// ));
|
||||
break;
|
||||
default:
|
||||
toast.info(`Cloud Service Notification: ${d.kind}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return (
|
||||
<SafeAreaProvider style={tw`flex-1 bg-black`}>
|
||||
@@ -156,6 +195,10 @@ export default function App() {
|
||||
useEffect(() => {
|
||||
global.Intl = require('intl');
|
||||
require('intl/locale-data/jsonp/en'); //TODO(@Rocky43007): Setup a way to import all the languages we support, once we add localization on mobile.
|
||||
SuperTokens.init({
|
||||
apiDomain: AUTH_SERVER_URL,
|
||||
apiBasePath: '/api/auth'
|
||||
});
|
||||
SplashScreen.hideAsync();
|
||||
if (Platform.OS === 'android') {
|
||||
(async () => {
|
||||
|
||||
@@ -8,12 +8,13 @@ import { tw, twStyle } from '~/lib/tailwind';
|
||||
type Props = {
|
||||
route?: RouteProp<any, any>; // supporting title from the options object of navigation
|
||||
navBack?: boolean; // whether to show the back icon
|
||||
navBackTo?: string; // route to go back to
|
||||
search?: boolean; // whether to show the search icon
|
||||
title?: string; // in some cases - we want to override the route title
|
||||
};
|
||||
|
||||
// Default header with search bar and button to open drawer
|
||||
export default function Header({ route, navBack, title, search = false }: Props) {
|
||||
export default function Header({ route, navBack, title, navBackTo, search = false }: Props) {
|
||||
const navigation = useNavigation<DrawerNavigationHelpers>();
|
||||
const headerHeight = useSafeAreaInsets().top;
|
||||
const isAndroid = Platform.OS === 'android';
|
||||
@@ -28,7 +29,13 @@ export default function Header({ route, navBack, title, search = false }: Props)
|
||||
<View style={tw`w-full flex-row items-center justify-between`}>
|
||||
<View style={tw`flex-row items-center gap-3`}>
|
||||
{navBack ? (
|
||||
<Pressable hitSlop={24} onPress={() => navigation.goBack()}>
|
||||
<Pressable
|
||||
hitSlop={24}
|
||||
onPress={() => {
|
||||
if (navBackTo) return navigation.navigate(navBackTo);
|
||||
navigation.goBack();
|
||||
}}
|
||||
>
|
||||
<ArrowLeft size={24} color={tw.color('ink')} />
|
||||
</Pressable>
|
||||
) : (
|
||||
|
||||
@@ -2,13 +2,7 @@ import { BottomSheetFlatList } from '@gorhom/bottom-sheet';
|
||||
import { NavigationProp, useNavigation } from '@react-navigation/native';
|
||||
import { forwardRef } from 'react';
|
||||
import { ActivityIndicator, Text, View } from 'react-native';
|
||||
import {
|
||||
CloudLibrary,
|
||||
useBridgeMutation,
|
||||
useBridgeQuery,
|
||||
useClientContext,
|
||||
useRspcContext
|
||||
} from '@sd/client';
|
||||
import { useBridgeMutation, useBridgeQuery, useClientContext, useRspcContext } from '@sd/client';
|
||||
import { Modal, ModalRef } from '~/components/layout/Modal';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import useForwardedRef from '~/hooks/useForwardedRef';
|
||||
@@ -25,9 +19,9 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
|
||||
|
||||
const { libraries } = useClientContext();
|
||||
|
||||
const cloudLibraries = useBridgeQuery(['cloud.library.list']);
|
||||
const cloudLibraries = useBridgeQuery(['cloud.libraries.list', true]);
|
||||
const cloudLibrariesData = cloudLibraries.data?.filter(
|
||||
(cloudLibrary) => !libraries.data?.find((l) => l.uuid === cloudLibrary.uuid)
|
||||
(cloudLibrary) => !libraries.data?.find((l) => l.uuid === cloudLibrary.pub_id)
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -63,11 +57,11 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
|
||||
description="No cloud libraries available to join"
|
||||
/>
|
||||
}
|
||||
keyExtractor={(item) => item.uuid}
|
||||
keyExtractor={(item) => item.pub_id}
|
||||
showsVerticalScrollIndicator={false}
|
||||
renderItem={({ item }) => (
|
||||
<CloudLibraryCard
|
||||
data={item}
|
||||
// data={item}
|
||||
navigation={navigation}
|
||||
modalRef={modalRef}
|
||||
/>
|
||||
@@ -81,38 +75,37 @@ const ImportModalLibrary = forwardRef<ModalRef, unknown>((_, ref) => {
|
||||
});
|
||||
|
||||
interface Props {
|
||||
data: CloudLibrary;
|
||||
// data: CloudLibrary;
|
||||
modalRef: React.RefObject<ModalRef>;
|
||||
navigation: NavigationProp<RootStackParamList>;
|
||||
}
|
||||
|
||||
const CloudLibraryCard = ({ data, modalRef, navigation }: Props) => {
|
||||
const CloudLibraryCard = ({ modalRef, navigation }: Props) => {
|
||||
const rspc = useRspcContext().queryClient;
|
||||
const joinLibrary = useBridgeMutation(['cloud.library.join']);
|
||||
// const joinLibrary = useBridgeMutation(['cloud.library.join']);
|
||||
return (
|
||||
<View
|
||||
key={data.uuid}
|
||||
style={tw`flex flex-row items-center justify-between gap-2 rounded-md border border-app-box bg-app p-2`}
|
||||
>
|
||||
<Text numberOfLines={1} style={tw`max-w-[80%] text-sm font-bold text-ink`}>
|
||||
{data.name}
|
||||
{'BOB'}
|
||||
</Text>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="accent"
|
||||
disabled={joinLibrary.isPending}
|
||||
// disabled={joinLibrary.isPending}
|
||||
onPress={async () => {
|
||||
const library = await joinLibrary.mutateAsync(data.uuid);
|
||||
// const library = await joinLibrary.mutateAsync(data.uuid);
|
||||
|
||||
rspc.setQueryData(['library.list'], (libraries: any) => {
|
||||
// The invalidation system beat us to it
|
||||
if ((libraries || []).find((l: any) => l.uuid === library.uuid))
|
||||
return libraries;
|
||||
// rspc.setQueryData(['library.list'], (libraries: any) => {
|
||||
// // The invalidation system beat us to it
|
||||
// if ((libraries || []).find((l: any) => l.uuid === library.uuid))
|
||||
// return libraries;
|
||||
|
||||
return [...(libraries || []), library];
|
||||
});
|
||||
// return [...(libraries || []), library];
|
||||
// });
|
||||
|
||||
currentLibraryStore.id = library.uuid;
|
||||
// currentLibraryStore.id = library.uuid;
|
||||
|
||||
navigation.navigate('Root', {
|
||||
screen: 'Home',
|
||||
@@ -128,9 +121,10 @@ const CloudLibraryCard = ({ data, modalRef, navigation }: Props) => {
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-sm font-medium text-white`}>
|
||||
{joinLibrary.isPending && joinLibrary.variables === data.uuid
|
||||
{/* {joinLibrary.isPending && joinLibrary.variables === data.uuid
|
||||
? 'Joining...'
|
||||
: 'Join'}
|
||||
: 'Join'} */}
|
||||
THIS FILE NEEDS TO BE UPDATED TO USE THE NEW LIBRARY SYSTEM IN THE FUTURE
|
||||
</Text>
|
||||
</Button>
|
||||
</View>
|
||||
|
||||
66
apps/mobile/src/components/modal/sync/JoinRequestModal.tsx
Normal file
66
apps/mobile/src/components/modal/sync/JoinRequestModal.tsx
Normal file
@@ -0,0 +1,66 @@
|
||||
import { ArrowRight } from 'phosphor-react-native';
|
||||
import React, { forwardRef } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import { HardwareModel } from '@sd/client';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import { Modal, ModalRef } from '~/components/layout/Modal';
|
||||
import { hardwareModelToIcon } from '~/components/overview/Devices';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import useForwardedRef from '~/hooks/useForwardedRef';
|
||||
import { twStyle } from '~/lib/tailwind';
|
||||
|
||||
interface Props {
|
||||
device_name: string;
|
||||
device_model: HardwareModel;
|
||||
library_name: string;
|
||||
}
|
||||
|
||||
const JoinRequestModal = forwardRef<ModalRef, Props>((props, ref) => {
|
||||
const modalRef = useForwardedRef(ref);
|
||||
return (
|
||||
<Modal ref={modalRef} snapPoints={['36']} title="Sync request">
|
||||
<View style={twStyle('px-6')}>
|
||||
<Text style={twStyle('mx-auto mt-2 text-center text-ink-dull')}>
|
||||
A device is requesting to join one of your libraries. Please review the device
|
||||
and the library it is requesting to join below.
|
||||
</Text>
|
||||
<View style={twStyle('my-7 flex-row items-center justify-center gap-10')}>
|
||||
<View style={twStyle('flex flex-col items-center justify-center gap-2')}>
|
||||
<Icon
|
||||
// once backend endpoint is populated need to check if this is working correctly i.e fetching correct icons for devices
|
||||
name={hardwareModelToIcon(props.device_model)}
|
||||
alt="Device icon"
|
||||
size={48}
|
||||
/>
|
||||
<Text style={twStyle('text-sm font-bold text-ink')}>
|
||||
{props.device_name}
|
||||
</Text>
|
||||
</View>
|
||||
<ArrowRight weight="bold" color="#ABACBA" size={18} />
|
||||
{/* library */}
|
||||
<View style={twStyle('flex flex-col items-center justify-center gap-2')}>
|
||||
<Icon
|
||||
// once backend endpoint is populated need to check if this is working correctly i.e fetching correct icons for devices
|
||||
name={'Book'}
|
||||
alt="Device icon"
|
||||
size={48}
|
||||
/>
|
||||
<Text style={twStyle('text-sm font-bold text-ink')}>
|
||||
{props.library_name}
|
||||
</Text>
|
||||
</View>
|
||||
</View>
|
||||
<View style={twStyle('mx-auto flex-row justify-center gap-5')}>
|
||||
<Button style={twStyle('flex-1')} variant="gray">
|
||||
<Text style={twStyle('font-bold text-ink-dull')}>Cancel</Text>
|
||||
</Button>
|
||||
<Button style={twStyle('flex-1')} variant="accent">
|
||||
<Text style={twStyle('font-bold text-ink')}>Accept</Text>
|
||||
</Button>
|
||||
</View>
|
||||
</View>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
export default JoinRequestModal;
|
||||
@@ -5,8 +5,9 @@ import React, { useEffect, useState } from 'react';
|
||||
import { Platform, Text, View } from 'react-native';
|
||||
import DeviceInfo from 'react-native-device-info';
|
||||
import { ScrollView } from 'react-native-gesture-handler';
|
||||
import { HardwareModel, NodeState, StatisticsResponse } from '@sd/client';
|
||||
import { HardwareModel, NodeState, StatisticsResponse, useBridgeQuery } from '@sd/client';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { getTokens } from '~/utils';
|
||||
|
||||
import Fade from '../layout/Fade';
|
||||
import { Button } from '../primitive/Button';
|
||||
@@ -44,6 +45,23 @@ const Devices = ({ node, stats }: Props) => {
|
||||
Omit<RNFS.FSInfoResultT, 'totalSpaceEx' | 'freeSpaceEx'>
|
||||
>({ freeSpace: 0, totalSpace: 0 });
|
||||
const [deviceName, setDeviceName] = useState<string>('');
|
||||
const [accessToken, setAccessToken] = useState<string>('');
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const at = await getTokens();
|
||||
setAccessToken(at.accessToken);
|
||||
})();
|
||||
}, []);
|
||||
|
||||
const devices = useBridgeQuery(['cloud.devices.list']);
|
||||
|
||||
// Refetch devices every 10 seconds
|
||||
useEffect(() => {
|
||||
const interval = setInterval(async () => {
|
||||
await devices.refetch();
|
||||
}, 10000);
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const getFSInfo = async () => {
|
||||
@@ -74,7 +92,7 @@ const Devices = ({ node, stats }: Props) => {
|
||||
}, [node]);
|
||||
|
||||
return (
|
||||
<OverviewSection title="Devices" count={node ? 1 : 0}>
|
||||
<OverviewSection title="Devices" count={node ? 1 + (devices.data?.length ?? 0) : 0}>
|
||||
<View>
|
||||
<Fade height={'100%'} width={30} color="black">
|
||||
<ScrollView
|
||||
@@ -93,6 +111,18 @@ const Devices = ({ node, stats }: Props) => {
|
||||
connectionType={null}
|
||||
/>
|
||||
)}
|
||||
{devices.data?.map((device) => (
|
||||
<StatCard
|
||||
key={device.pub_id}
|
||||
name={device.name}
|
||||
// TODO (Optional): Use Brand Type for Different Android Models/iOS Models using DeviceInfo.getBrand()
|
||||
icon={hardwareModelToIcon(device.hardware_model)}
|
||||
totalSpace={'0'}
|
||||
freeSpace={'0'}
|
||||
color="#0362FF"
|
||||
connectionType={'cloud'}
|
||||
/>
|
||||
))}
|
||||
<NewCard
|
||||
icons={['Laptop', 'Server', 'SilverBox', 'Tablet']}
|
||||
text="Spacedrive works best on all your devices."
|
||||
|
||||
@@ -4,6 +4,8 @@ import { CompositeScreenProps } from '@react-navigation/native';
|
||||
import { createNativeStackNavigator, NativeStackScreenProps } from '@react-navigation/native-stack';
|
||||
import Header from '~/components/header/Header';
|
||||
import SearchHeader from '~/components/header/SearchHeader';
|
||||
import AccountLogin from '~/screens/settings/client/AccountSettings/AccountLogin';
|
||||
import AccountProfile from '~/screens/settings/client/AccountSettings/AccountProfile';
|
||||
import AppearanceSettingsScreen from '~/screens/settings/client/AppearanceSettings';
|
||||
import ExtensionsSettingsScreen from '~/screens/settings/client/ExtensionsSettings';
|
||||
import GeneralSettingsScreen from '~/screens/settings/client/GeneralSettings';
|
||||
@@ -12,12 +14,10 @@ import PrivacySettingsScreen from '~/screens/settings/client/PrivacySettings';
|
||||
import AboutScreen from '~/screens/settings/info/About';
|
||||
import DebugScreen from '~/screens/settings/info/Debug';
|
||||
import SupportScreen from '~/screens/settings/info/Support';
|
||||
import CloudSettings from '~/screens/settings/library/CloudSettings/CloudSettings';
|
||||
import EditLocationSettingsScreen from '~/screens/settings/library/EditLocationSettings';
|
||||
import LibraryGeneralSettingsScreen from '~/screens/settings/library/LibraryGeneralSettings';
|
||||
import LocationSettingsScreen from '~/screens/settings/library/LocationSettings';
|
||||
import NodesSettingsScreen from '~/screens/settings/library/NodesSettings';
|
||||
import SyncSettingsScreen from '~/screens/settings/library/SyncSettings';
|
||||
import TagsSettingsScreen from '~/screens/settings/library/TagsSettings';
|
||||
import SettingsScreen from '~/screens/settings/Settings';
|
||||
|
||||
@@ -46,6 +46,16 @@ export default function SettingsStack() {
|
||||
component={GeneralSettingsScreen}
|
||||
options={{ header: () => <Header navBack title="General" /> }}
|
||||
/>
|
||||
<Stack.Screen
|
||||
name="AccountLogin"
|
||||
component={AccountLogin}
|
||||
options={{ header: () => <Header navBackTo="Settings" navBack title="Account" /> }}
|
||||
/>
|
||||
<Stack.Screen
|
||||
name="AccountProfile"
|
||||
component={AccountProfile}
|
||||
options={{ header: () => <Header navBackTo="Settings" navBack title="Account" /> }}
|
||||
/>
|
||||
<Stack.Screen
|
||||
name="LibrarySettings"
|
||||
component={LibrarySettingsScreen}
|
||||
@@ -94,16 +104,6 @@ export default function SettingsStack() {
|
||||
component={TagsSettingsScreen}
|
||||
options={{ header: () => <Header navBack title="Tags" /> }}
|
||||
/>
|
||||
<Stack.Screen
|
||||
name="SyncSettings"
|
||||
component={SyncSettingsScreen}
|
||||
options={{ header: () => <Header navBack title="Sync" /> }}
|
||||
/>
|
||||
<Stack.Screen
|
||||
name="CloudSettings"
|
||||
component={CloudSettings}
|
||||
options={{ header: () => <Header navBack title="Cloud" /> }}
|
||||
/>
|
||||
{/* <Stack.Screen
|
||||
name="KeysSettings"
|
||||
component={KeysSettingsScreen}
|
||||
@@ -134,6 +134,8 @@ export type SettingsStackParamList = {
|
||||
Settings: undefined;
|
||||
// Client
|
||||
GeneralSettings: undefined;
|
||||
AccountLogin: undefined;
|
||||
AccountProfile: undefined;
|
||||
LibrarySettings: undefined;
|
||||
AppearanceSettings: undefined;
|
||||
PrivacySettings: undefined;
|
||||
|
||||
@@ -12,9 +12,9 @@ import {
|
||||
PuzzlePiece,
|
||||
ShareNetwork,
|
||||
ShieldCheck,
|
||||
TagSimple
|
||||
TagSimple,
|
||||
UserCircle
|
||||
} from 'phosphor-react-native';
|
||||
import React from 'react';
|
||||
import { Platform, SectionList, Text, TouchableWithoutFeedback, View } from 'react-native';
|
||||
import { DebugState, useDebugState, useDebugStateEnabler, useLibraryQuery } from '@sd/client';
|
||||
import ScreenContainer from '~/components/layout/ScreenContainer';
|
||||
@@ -22,6 +22,7 @@ import { SettingsItem } from '~/components/settings/SettingsItem';
|
||||
import { useEnableDrawer } from '~/hooks/useEnableDrawer';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { SettingsStackParamList, SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
import { useUserStore } from '~/stores/userStore';
|
||||
|
||||
type SectionType = {
|
||||
title: string;
|
||||
@@ -34,7 +35,10 @@ type SectionType = {
|
||||
}[];
|
||||
};
|
||||
|
||||
const sections: (debugState: DebugState) => SectionType[] = (debugState) => [
|
||||
const sections: (
|
||||
debugState: DebugState,
|
||||
userInfo: ReturnType<typeof useUserStore>['userInfo']
|
||||
) => SectionType[] = (debugState, userInfo) => [
|
||||
{
|
||||
title: 'Client',
|
||||
data: [
|
||||
@@ -44,6 +48,21 @@ const sections: (debugState: DebugState) => SectionType[] = (debugState) => [
|
||||
title: 'General',
|
||||
rounded: 'top'
|
||||
},
|
||||
...(userInfo
|
||||
? ([
|
||||
{
|
||||
icon: UserCircle,
|
||||
navigateTo: 'AccountProfile',
|
||||
title: 'Account'
|
||||
}
|
||||
] as const)
|
||||
: ([
|
||||
{
|
||||
icon: UserCircle,
|
||||
navigateTo: 'AccountLogin',
|
||||
title: 'Account'
|
||||
}
|
||||
] as const)),
|
||||
{
|
||||
icon: Books,
|
||||
navigateTo: 'LibrarySettings',
|
||||
@@ -158,12 +177,16 @@ function renderSectionHeader({ section }: { section: { title: string } }) {
|
||||
export default function SettingsScreen({ navigation }: SettingsStackScreenProps<'Settings'>) {
|
||||
const debugState = useDebugState();
|
||||
const syncEnabled = useLibraryQuery(['sync.enabled']);
|
||||
const userInfo = useUserStore().userInfo;
|
||||
|
||||
// Enables the drawer from react-navigation
|
||||
useEnableDrawer();
|
||||
|
||||
return (
|
||||
<ScreenContainer tabHeight={false} style={tw`gap-0 px-5 py-0`}>
|
||||
<SectionList
|
||||
contentContainerStyle={tw`py-6`}
|
||||
sections={sections(debugState)}
|
||||
sections={sections(debugState, userInfo)}
|
||||
renderItem={({ item }) => (
|
||||
<SettingsItem
|
||||
syncEnabled={syncEnabled.data}
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
import { MotiView } from 'moti';
|
||||
import { AppleLogo, GithubLogo, GoogleLogo, IconProps } from 'phosphor-react-native';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import { LinearTransition } from 'react-native-reanimated';
|
||||
import Card from '~/components/layout/Card';
|
||||
import ScreenContainer from '~/components/layout/ScreenContainer';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { getUserStore, useUserStore } from '~/stores/userStore';
|
||||
import { AUTH_SERVER_URL } from '~/utils';
|
||||
|
||||
import Login from './Login';
|
||||
import Register from './Register';
|
||||
|
||||
const AccountTabs = ['Login', 'Register'] as const;
|
||||
|
||||
type SocialLogin = {
|
||||
name: 'Github' | 'Google' | 'Apple';
|
||||
icon: React.FC<IconProps>;
|
||||
};
|
||||
|
||||
const SocialLogins: SocialLogin[] = [
|
||||
{ name: 'Github', icon: GithubLogo },
|
||||
{ name: 'Google', icon: GoogleLogo },
|
||||
{ name: 'Apple', icon: AppleLogo }
|
||||
];
|
||||
|
||||
const AccountLogin = () => {
|
||||
const [activeTab, setActiveTab] = useState<'Login' | 'Register'>('Login');
|
||||
const userInfo = useUserStore().userInfo;
|
||||
|
||||
useEffect(() => {
|
||||
if (userInfo) return; //no need to check if user info is already present
|
||||
async function _() {
|
||||
const user_data = await fetch(`${AUTH_SERVER_URL}/api/user`, {
|
||||
method: 'GET'
|
||||
});
|
||||
const data = await user_data.json();
|
||||
if (data.message !== 'unauthorised') {
|
||||
getUserStore().userInfo = data;
|
||||
}
|
||||
}
|
||||
_();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
// FIXME: Currently opens in App.
|
||||
// const socialLoginHandlers = (name: SocialLogin['name']) => {
|
||||
// return {
|
||||
// Github: async () => {
|
||||
// try {
|
||||
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
|
||||
// thirdPartyId: 'github',
|
||||
|
||||
// // This is where Github should redirect the user back after login or error.
|
||||
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/github'
|
||||
// });
|
||||
|
||||
// // we redirect the user to Github for auth.
|
||||
// window.location.assign(authUrl);
|
||||
// } catch (err: any) {
|
||||
// if (err.isSuperTokensGeneralError === true) {
|
||||
// // this may be a custom error message sent from the API by you.
|
||||
// toast.error(err.message);
|
||||
// } else {
|
||||
// toast.error('Oops! Something went wrong.');
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// Google: async () => {
|
||||
// try {
|
||||
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
|
||||
// thirdPartyId: 'google',
|
||||
|
||||
// // This is where Google should redirect the user back after login or error.
|
||||
// // This URL goes on the Google's dashboard as well.
|
||||
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/google'
|
||||
// });
|
||||
|
||||
// /*
|
||||
// Example value of authUrl: https://accounts.google.com/o/oauth2/v2/auth/oauthchooseaccount?scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&access_type=offline&include_granted_scopes=true&response_type=code&client_id=1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com&state=5a489996a28cafc83ddff&redirect_uri=https%3A%2F%2Fsupertokens.io%2Fdev%2Foauth%2Fredirect-to-app&flowName=GeneralOAuthFlow
|
||||
// */
|
||||
|
||||
// // we redirect the user to google for auth.
|
||||
// window.location.assign(authUrl);
|
||||
// } catch (err: any) {
|
||||
// if (err.isSuperTokensGeneralError === true) {
|
||||
// // this may be a custom error message sent from the API by you.
|
||||
// toast.error(err.message);
|
||||
// } else {
|
||||
// toast.error('Oops! Something went wrong.');
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// Apple: async () => {
|
||||
// try {
|
||||
// const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({
|
||||
// thirdPartyId: 'apple',
|
||||
|
||||
// // This is where Apple should redirect the user back after login or error.
|
||||
// frontendRedirectURI: 'http://localhost:9420/api/auth/callback/apple'
|
||||
// });
|
||||
|
||||
// // we redirect the user to Apple for auth.
|
||||
// window.location.assign(authUrl);
|
||||
// } catch (err: any) {
|
||||
// if (err.isSuperTokensGeneralError === true) {
|
||||
// // this may be a custom error message sent from the API by you.
|
||||
// toast.error(err.message);
|
||||
// } else {
|
||||
// toast.error('Oops! Something went wrong.');
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }[name]();
|
||||
// };
|
||||
|
||||
return (
|
||||
<ScreenContainer scrollview={false} style={tw`gap-2 px-6`}>
|
||||
<View style={tw`flex flex-col justify-between gap-5 lg:flex-row`}>
|
||||
<Card style={tw`relative flex w-full flex-col items-center justify-center`}>
|
||||
<View style={tw`flex w-full flex-row gap-x-1.5`}>
|
||||
{AccountTabs.map((text) => (
|
||||
<Button
|
||||
key={text}
|
||||
onPress={() => {
|
||||
setActiveTab(text);
|
||||
}}
|
||||
style={twStyle(
|
||||
'relative flex-1 border-b border-app-line/50 p-2 text-center',
|
||||
text === 'Login' ? 'rounded-tl-md' : 'rounded-tr-md'
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
style={twStyle(
|
||||
'relative z-10 text-sm',
|
||||
text === activeTab ? 'font-bold text-ink' : 'text-ink-faint'
|
||||
)}
|
||||
>
|
||||
{text}
|
||||
</Text>
|
||||
{text === activeTab && (
|
||||
<MotiView
|
||||
animate={{
|
||||
borderRadius: text === 'Login' ? 0.3 : 0
|
||||
}}
|
||||
layout={LinearTransition.duration(200)}
|
||||
style={tw`absolute inset-x-0 top-0 z-0 bg-app-line/60`}
|
||||
/>
|
||||
)}
|
||||
</Button>
|
||||
))}
|
||||
</View>
|
||||
<View style={tw`mt-3 flex w-full flex-col justify-center gap-1.5`}>
|
||||
{activeTab === 'Login' ? <Login /> : <Register />}
|
||||
{/* Disabled for now */}
|
||||
{/* <View style={tw`flex items-center w-full gap-3 my-2`}>
|
||||
<Divider />
|
||||
<Text style={tw`text-xs text-ink-faint`}>OR</Text>
|
||||
<Divider />
|
||||
</View>
|
||||
<View style={tw`flex justify-center gap-3`}>
|
||||
{SocialLogins.map((social) => (
|
||||
<Button
|
||||
variant="outline"
|
||||
onPress={async () => await socialLoginHandlers(social.name)}
|
||||
key={social.name}
|
||||
style={tw`p-3 border rounded-full border-app-line bg-app-input`}
|
||||
>
|
||||
<social.icon style={tw`text-white`} weight="bold" />
|
||||
</Button>
|
||||
))}
|
||||
</View> */}
|
||||
</View>
|
||||
</Card>
|
||||
</View>
|
||||
</ScreenContainer>
|
||||
);
|
||||
};
|
||||
export default AccountLogin;
|
||||
@@ -0,0 +1,179 @@
|
||||
import { useNavigation } from '@react-navigation/native';
|
||||
import { Envelope } from 'phosphor-react-native';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import {
|
||||
SyncStatus,
|
||||
useBridgeMutation,
|
||||
useBridgeQuery,
|
||||
useLibraryMutation,
|
||||
useLibrarySubscription
|
||||
} from '@sd/client';
|
||||
import Card from '~/components/layout/Card';
|
||||
import ScreenContainer from '~/components/layout/ScreenContainer';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
import { getUserStore, useUserStore } from '~/stores/userStore';
|
||||
import { AUTH_SERVER_URL, getTokens } from '~/utils';
|
||||
|
||||
const AccountProfile = () => {
|
||||
const userInfo = useUserStore().userInfo;
|
||||
|
||||
const emailName = userInfo ? userInfo.email.split('@')[0] : '';
|
||||
const capitalizedEmailName = (emailName?.charAt(0).toUpperCase() ?? '') + emailName?.slice(1);
|
||||
const navigator = useNavigation<SettingsStackScreenProps<'AccountLogin'>['navigation']>();
|
||||
|
||||
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
|
||||
const devices = useBridgeQuery(['cloud.devices.list']);
|
||||
const addLibraryToCloud = useLibraryMutation('cloud.libraries.create');
|
||||
const listLibraries = useBridgeQuery(['cloud.libraries.list', true]);
|
||||
const createSyncGroup = useLibraryMutation('cloud.syncGroups.create');
|
||||
const listSyncGroups = useBridgeQuery(['cloud.syncGroups.list']);
|
||||
const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join');
|
||||
const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']);
|
||||
const [{ accessToken, refreshToken }, setTokens] = useState<{
|
||||
accessToken: string;
|
||||
refreshToken: string;
|
||||
}>({
|
||||
accessToken: '',
|
||||
refreshToken: ''
|
||||
});
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const { accessToken, refreshToken } = await getTokens();
|
||||
setTokens({ accessToken, refreshToken });
|
||||
})();
|
||||
}, []);
|
||||
const [syncStatus, setSyncStatus] = useState<SyncStatus | null>(null);
|
||||
useLibrarySubscription(['sync.active'], {
|
||||
onData: (data) => {
|
||||
console.log('sync activity', data);
|
||||
setSyncStatus(data);
|
||||
}
|
||||
});
|
||||
|
||||
async function signOut() {
|
||||
await fetch(`${AUTH_SERVER_URL}/api/auth/signout`, {
|
||||
method: 'POST'
|
||||
});
|
||||
navigator.navigate('AccountLogin');
|
||||
getUserStore().userInfo = undefined;
|
||||
}
|
||||
|
||||
return (
|
||||
<ScreenContainer scrollview={false} style={tw`gap-2 px-6`}>
|
||||
<View style={tw`flex flex-col justify-between gap-5 lg:flex-row`}>
|
||||
<Card
|
||||
style={tw`relative flex w-full flex-col items-center justify-center lg:max-w-[320px]`}
|
||||
>
|
||||
<View style={tw`w-full`}>
|
||||
<Text style={tw`mx-auto mt-3 text-lg text-white`}>
|
||||
Welcome{' '}
|
||||
<Text style={tw`font-bold text-white`}>{capitalizedEmailName}</Text>
|
||||
</Text>
|
||||
<Card
|
||||
style={tw`mt-4 flex-row items-center gap-2 overflow-hidden border-app-inputborder bg-app-input`}
|
||||
>
|
||||
<Envelope weight="fill" size={20} color="white" />
|
||||
<Text numberOfLines={1} style={tw`max-w-[90%] text-white`}>
|
||||
{userInfo ? userInfo.email : ''}
|
||||
</Text>
|
||||
</Card>
|
||||
|
||||
<Button variant="danger" style={tw`mt-3`} onPress={signOut}>
|
||||
<Text style={tw`font-bold text-white`}>Sign out</Text>
|
||||
</Button>
|
||||
</View>
|
||||
</Card>
|
||||
{/* Sync activity */}
|
||||
<View style={tw`mt-5 flex flex-col`}>
|
||||
<Text style={tw`mb-2 text-md font-semibold`}>Sync Activity</Text>
|
||||
<View style={tw`flex flex-row gap-2`}>
|
||||
{Object.keys(syncStatus ?? {}).map((status, index) => (
|
||||
<Card key={index} style="flex w-full items-center p-4">
|
||||
<View
|
||||
style={twStyle(
|
||||
'mr-2 size-[15px] rounded-full bg-app-box',
|
||||
syncStatus?.[status as keyof SyncStatus]
|
||||
? 'bg-accent'
|
||||
: 'bg-app-input'
|
||||
)}
|
||||
/>
|
||||
<Text style={tw`text-sm font-semibold`}>{status}</Text>
|
||||
</Card>
|
||||
))}
|
||||
</View>
|
||||
</View>
|
||||
|
||||
{/* Automatically list libraries */}
|
||||
<View style={tw`mt-5 flex flex-col gap-3`}>
|
||||
<Text style={tw`text-md font-semibold text-white`}>Cloud Libraries</Text>
|
||||
{listLibraries.data?.map((library) => (
|
||||
<Card key={library.pub_id} style={tw`p-41 w-full`}>
|
||||
<Text style={tw`text-sm font-semibold text-white`}>{library.name}</Text>
|
||||
</Card>
|
||||
)) || <Text style={tw`text-white`}>No libraries found.</Text>}
|
||||
</View>
|
||||
|
||||
{/* Debug buttons */}
|
||||
<Card style={tw`flex gap-2 text-white`}>
|
||||
<Button
|
||||
variant="gray"
|
||||
onPress={async () => {
|
||||
cloudBootstrap.mutate([accessToken.trim(), refreshToken.trim()]);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-white`}>Start Cloud Bootstrap</Text>
|
||||
</Button>
|
||||
<Button
|
||||
variant="gray"
|
||||
onPress={async () => {
|
||||
addLibraryToCloud.mutate(null);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-white`}>Add Library to Cloud</Text>
|
||||
</Button>
|
||||
<Button
|
||||
variant="gray"
|
||||
onPress={async () => {
|
||||
createSyncGroup.mutate(null);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-white`}>Create Sync Group</Text>
|
||||
</Button>
|
||||
</Card>
|
||||
|
||||
<View style={tw`mt-5 flex flex-col gap-3 text-white`}>
|
||||
<Text style={tw`text-md font-semibold`}>Library Sync Groups</Text>
|
||||
{listSyncGroups.data?.map((group) => (
|
||||
<Card key={group.pub_id} style="w-full p-4">
|
||||
<Text style={tw`text-sm font-semibold text-white`}>
|
||||
{group.library.name}
|
||||
</Text>
|
||||
<Button
|
||||
style={tw`mt-2`}
|
||||
onPress={async () => {
|
||||
if (!currentDevice.data) await currentDevice.refetch();
|
||||
if (currentDevice.data && devices.data) {
|
||||
requestJoinSyncGroup.mutate({
|
||||
asking_device: currentDevice.data,
|
||||
sync_group: {
|
||||
devices: devices.data,
|
||||
...group
|
||||
}
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-white`}>Join Sync Group</Text>
|
||||
</Button>
|
||||
</Card>
|
||||
)) || <Text style={tw`text-white`}>No sync groups found.</Text>}
|
||||
</View>
|
||||
</View>
|
||||
</ScreenContainer>
|
||||
);
|
||||
};
|
||||
|
||||
export default AccountProfile;
|
||||
@@ -0,0 +1,195 @@
|
||||
import AsyncStorage from '@react-native-async-storage/async-storage';
|
||||
import { useNavigation } from '@react-navigation/native';
|
||||
import { RSPCError } from '@spacedrive/rspc-client';
|
||||
import { UseMutationResult } from '@tanstack/react-query';
|
||||
import { useState } from 'react';
|
||||
import { Controller } from 'react-hook-form';
|
||||
import { Text, View } from 'react-native';
|
||||
import { z } from 'zod';
|
||||
import { useBridgeMutation, useZodForm } from '@sd/client';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { Input } from '~/components/primitive/Input';
|
||||
import { toast } from '~/components/primitive/Toast';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
import { getUserStore } from '~/stores/userStore';
|
||||
import { AUTH_SERVER_URL } from '~/utils';
|
||||
|
||||
import ShowPassword from './ShowPassword';
|
||||
|
||||
const LoginSchema = z.object({
|
||||
email: z.string().email({
|
||||
message: 'Email is required'
|
||||
}),
|
||||
password: z.string().min(6, {
|
||||
message: 'Password must be at least 6 characters'
|
||||
})
|
||||
});
|
||||
|
||||
const Login = () => {
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
const form = useZodForm({
|
||||
schema: LoginSchema,
|
||||
defaultValues: {
|
||||
email: '',
|
||||
password: ''
|
||||
}
|
||||
});
|
||||
const updateUserStore = getUserStore();
|
||||
const navigator = useNavigation<SettingsStackScreenProps<'AccountProfile'>['navigation']>();
|
||||
const cloudBootstrap = useBridgeMutation('cloud.bootstrap');
|
||||
|
||||
return (
|
||||
<View>
|
||||
<View style={tw`flex flex-col gap-1.5`}>
|
||||
<Controller
|
||||
control={form.control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<View style={tw`relative flex items-start`}>
|
||||
<Input
|
||||
{...field}
|
||||
placeholder="Email"
|
||||
style={twStyle(
|
||||
`w-full`,
|
||||
form.formState.errors.email && 'border-red-500'
|
||||
)}
|
||||
onChangeText={field.onChange}
|
||||
/>
|
||||
{form.formState.errors.email && (
|
||||
<Text style={tw`my-1 text-xs text-red-500`}>
|
||||
{form.formState.errors.email.message}
|
||||
</Text>
|
||||
)}
|
||||
</View>
|
||||
)}
|
||||
/>
|
||||
<Controller
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<View style={tw`relative flex items-start`}>
|
||||
<Input
|
||||
{...field}
|
||||
placeholder="Password"
|
||||
style={twStyle(
|
||||
`w-full`,
|
||||
form.formState.errors.password && 'border-red-500'
|
||||
)}
|
||||
onChangeText={field.onChange}
|
||||
secureTextEntry={!showPassword}
|
||||
/>
|
||||
{form.formState.errors.password && (
|
||||
<Text style={tw`my-1 text-xs text-red-500`}>
|
||||
{form.formState.errors.password.message}
|
||||
</Text>
|
||||
)}
|
||||
<ShowPassword
|
||||
showPassword={showPassword}
|
||||
setShowPassword={setShowPassword}
|
||||
/>
|
||||
</View>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
style={tw`mx-auto mt-2 w-full`}
|
||||
variant="accent"
|
||||
onPress={form.handleSubmit(async (data) => {
|
||||
await signInClicked(
|
||||
data.email,
|
||||
data.password,
|
||||
navigator,
|
||||
cloudBootstrap,
|
||||
updateUserStore
|
||||
);
|
||||
})}
|
||||
disabled={form.formState.isSubmitting}
|
||||
>
|
||||
<Text style={tw`font-bold text-white`}>Submit</Text>
|
||||
</Button>
|
||||
</View>
|
||||
</View>
|
||||
);
|
||||
};
|
||||
|
||||
async function signInClicked(
|
||||
email: string,
|
||||
password: string,
|
||||
navigator: SettingsStackScreenProps<'AccountProfile'>['navigation'],
|
||||
cloudBootstrap: UseMutationResult<null, RSPCError, [string, string], unknown>, // Cloud bootstrap mutation
|
||||
updateUserStore: ReturnType<typeof getUserStore>
|
||||
) {
|
||||
try {
|
||||
const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signin`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json; charset=utf-8'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
formFields: [
|
||||
{
|
||||
id: 'email',
|
||||
value: email
|
||||
},
|
||||
{
|
||||
id: 'password',
|
||||
value: password
|
||||
}
|
||||
]
|
||||
})
|
||||
});
|
||||
|
||||
const response: {
|
||||
status: string;
|
||||
reason?: string;
|
||||
user?: {
|
||||
id: string;
|
||||
email: string;
|
||||
timeJoined: number;
|
||||
tenantIds: string[];
|
||||
};
|
||||
} = await req.json();
|
||||
|
||||
if (response.status === 'FIELD_ERROR') {
|
||||
// response.reason?.forEach((formField) => {
|
||||
// if (formField.id === 'email') {
|
||||
// // Email validation failed (for example incorrect email syntax).
|
||||
// toast.error(formField.error);
|
||||
// }
|
||||
// });
|
||||
console.error('Field error: ', response.reason);
|
||||
} else if (response.status === 'WRONG_CREDENTIALS_ERROR') {
|
||||
toast.error('Email & password combination is incorrect.');
|
||||
} else if (response.status === 'SIGN_IN_NOT_ALLOWED') {
|
||||
// the reason string is a user friendly message
|
||||
// about what went wrong. It can also contain a support code which users
|
||||
// can tell you so you know why their sign in was not allowed.
|
||||
toast.error(response.reason!);
|
||||
} else {
|
||||
// sign in successful. The session tokens are automatically handled by
|
||||
// the frontend SDK.
|
||||
cloudBootstrap.mutate([
|
||||
req.headers.get('st-access-token')!,
|
||||
req.headers.get('st-refresh-token')!
|
||||
]);
|
||||
toast.success('Sign in successful');
|
||||
// Update the user store with the user info
|
||||
updateUserStore.userInfo = response.user;
|
||||
// Save the access token to AsyncStorage, because SuperTokens doesn't store it correctly. Thanks to the React Native SDK.
|
||||
await AsyncStorage.setItem('access_token', req.headers.get('st-access-token')!);
|
||||
await AsyncStorage.setItem('refresh_token', req.headers.get('st-refresh-token')!);
|
||||
// Refresh the page to show the user is logged in
|
||||
navigator.navigate('AccountProfile');
|
||||
}
|
||||
} catch (err: any) {
|
||||
if (err.isSuperTokensGeneralError === true) {
|
||||
// this may be a custom error message sent from the API by you.
|
||||
toast.error(err.message);
|
||||
} else {
|
||||
console.error(err);
|
||||
toast.error('Oops! Something went wrong.');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default Login;
|
||||
@@ -0,0 +1,191 @@
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { useNavigation } from '@react-navigation/native';
|
||||
import { useState } from 'react';
|
||||
import { Controller, useForm } from 'react-hook-form';
|
||||
import { Text, View } from 'react-native';
|
||||
import { z } from 'zod';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { Input } from '~/components/primitive/Input';
|
||||
import { toast } from '~/components/primitive/Toast';
|
||||
import { tw, twStyle } from '~/lib/tailwind';
|
||||
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
import { AUTH_SERVER_URL } from '~/utils';
|
||||
|
||||
import ShowPassword from './ShowPassword';
|
||||
|
||||
const RegisterSchema = z
|
||||
.object({
|
||||
email: z.string().email({
|
||||
message: 'Email is required'
|
||||
}),
|
||||
password: z.string().min(6, {
|
||||
message: 'Password must be at least 6 characters'
|
||||
}),
|
||||
confirmPassword: z.string().min(6, {
|
||||
message: 'Password must be at least 6 characters'
|
||||
})
|
||||
})
|
||||
.refine((data) => data.password === data.confirmPassword, {
|
||||
message: 'Passwords do not match',
|
||||
path: ['confirmPassword']
|
||||
});
|
||||
type RegisterType = z.infer<typeof RegisterSchema>;
|
||||
|
||||
const Register = () => {
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
// useZodForm seems to be out-dated or needs
|
||||
//fixing as it does not support the schema using zod.refine
|
||||
const form = useForm<RegisterType>({
|
||||
resolver: zodResolver(RegisterSchema),
|
||||
defaultValues: {
|
||||
email: '',
|
||||
password: '',
|
||||
confirmPassword: ''
|
||||
}
|
||||
});
|
||||
|
||||
const navigator = useNavigation<SettingsStackScreenProps<'AccountProfile'>['navigation']>();
|
||||
return (
|
||||
<View style={tw`flex flex-col gap-1.5`}>
|
||||
<Controller
|
||||
control={form.control}
|
||||
name="email"
|
||||
render={({ field }) => (
|
||||
<Input
|
||||
{...field}
|
||||
style={twStyle(`w-full`, form.formState.errors.email && 'border-red-500')}
|
||||
placeholder="Email"
|
||||
onChangeText={field.onChange}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
{form.formState.errors.email && (
|
||||
<Text style={tw`text-xs text-red-500`}>{form.formState.errors.email.message}</Text>
|
||||
)}
|
||||
<Controller
|
||||
control={form.control}
|
||||
name="password"
|
||||
render={({ field }) => (
|
||||
<View style={tw`relative flex items-center justify-center`}>
|
||||
<Input
|
||||
{...field}
|
||||
placeholder="Password"
|
||||
style={twStyle(
|
||||
`w-full`,
|
||||
form.formState.errors.password && 'border-red-500'
|
||||
)}
|
||||
onChangeText={field.onChange}
|
||||
secureTextEntry={!showPassword}
|
||||
/>
|
||||
</View>
|
||||
)}
|
||||
/>
|
||||
{form.formState.errors.password && (
|
||||
<Text style={tw`text-xs text-red-500`}>
|
||||
{form.formState.errors.password.message}
|
||||
</Text>
|
||||
)}
|
||||
<Controller
|
||||
control={form.control}
|
||||
name="confirmPassword"
|
||||
render={({ field }) => (
|
||||
<View style={tw`relative flex items-start`}>
|
||||
<Input
|
||||
{...field}
|
||||
placeholder="Confirm Password"
|
||||
style={twStyle(
|
||||
`w-full`,
|
||||
form.formState.errors.confirmPassword && 'border-red-500'
|
||||
)}
|
||||
onChangeText={field.onChange}
|
||||
secureTextEntry={!showPassword}
|
||||
/>
|
||||
{form.formState.errors.confirmPassword && (
|
||||
<Text style={tw`my-1 text-xs text-red-500`}>
|
||||
{form.formState.errors.confirmPassword.message}
|
||||
</Text>
|
||||
)}
|
||||
<ShowPassword
|
||||
showPassword={showPassword}
|
||||
setShowPassword={setShowPassword}
|
||||
plural={true}
|
||||
/>
|
||||
</View>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
style={tw`mx-auto mt-2 w-full`}
|
||||
variant="accent"
|
||||
onPress={form.handleSubmit(
|
||||
async (data) => await signUpClicked(data.email, data.password, navigator)
|
||||
)}
|
||||
disabled={form.formState.isSubmitting}
|
||||
>
|
||||
<Text style={tw`font-bold text-white`}>Submit</Text>
|
||||
</Button>
|
||||
</View>
|
||||
);
|
||||
};
|
||||
|
||||
async function signUpClicked(
|
||||
email: string,
|
||||
password: string,
|
||||
navigator: SettingsStackScreenProps<'AccountProfile'>['navigation']
|
||||
) {
|
||||
try {
|
||||
const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signup`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json; charset=utf-8'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
formFields: [
|
||||
{
|
||||
id: 'email',
|
||||
value: email
|
||||
},
|
||||
{
|
||||
id: 'password',
|
||||
value: password
|
||||
}
|
||||
]
|
||||
})
|
||||
});
|
||||
|
||||
const response: {
|
||||
status: string;
|
||||
reason?: string;
|
||||
user?: {
|
||||
id: string;
|
||||
email: string;
|
||||
timeJoined: number;
|
||||
tenantIds: string[];
|
||||
};
|
||||
} = await req.json();
|
||||
|
||||
if (response.status === 'FIELD_ERROR') {
|
||||
// one of the input formFields failed validaiton
|
||||
console.error('Field error: ', response.reason);
|
||||
} else if (response.status === 'SIGN_UP_NOT_ALLOWED') {
|
||||
// the reason string is a user friendly message
|
||||
// about what went wrong. It can also contain a support code which users
|
||||
// can tell you so you know why their sign up was not allowed.
|
||||
toast.error(response.reason!);
|
||||
} else {
|
||||
// sign up successful. The session tokens are automatically handled by
|
||||
// the frontend SDK.
|
||||
toast.success('Sign up successful');
|
||||
navigator.navigate('AccountProfile');
|
||||
}
|
||||
} catch (err: any) {
|
||||
if (err.isSuperTokensGeneralError === true) {
|
||||
// this may be a custom error message sent from the API by you.
|
||||
toast.error(err.message);
|
||||
} else {
|
||||
console.error(err);
|
||||
toast.error('Oops! Something went wrong.');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default Register;
|
||||
@@ -0,0 +1,29 @@
|
||||
import { Eye, EyeClosed } from 'phosphor-react-native';
|
||||
import { Text } from 'react-native';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
|
||||
interface Props {
|
||||
showPassword: boolean;
|
||||
setShowPassword: (value: boolean) => void;
|
||||
plural?: boolean;
|
||||
}
|
||||
|
||||
const ShowPassword = ({ showPassword, setShowPassword, plural }: Props) => {
|
||||
return (
|
||||
<Button
|
||||
variant="gray"
|
||||
style={tw`mt-1.5 flex w-full flex-row items-center justify-center gap-2`}
|
||||
onPressIn={() => setShowPassword(!showPassword)}
|
||||
>
|
||||
{!showPassword ? (
|
||||
<EyeClosed size={14} color="white" />
|
||||
) : (
|
||||
<Eye size={14} color="white" />
|
||||
)}
|
||||
<Text style={tw`font-bold text-ink`}>Show Password{plural ? 's' : ''}</Text>
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
export default ShowPassword;
|
||||
@@ -1,26 +1,52 @@
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import React from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import {
|
||||
auth,
|
||||
toggleFeatureFlag,
|
||||
useBridgeMutation,
|
||||
useBridgeQuery,
|
||||
useDebugState,
|
||||
useFeatureFlags
|
||||
useFeatureFlags,
|
||||
useLibraryMutation
|
||||
} from '@sd/client';
|
||||
import Card from '~/components/layout/Card';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
import { getTokens } from '~/utils';
|
||||
|
||||
const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
|
||||
const debugState = useDebugState();
|
||||
const featureFlags = useFeatureFlags();
|
||||
const origin = useBridgeQuery(['cloud.getApiOrigin']);
|
||||
const setOrigin = useBridgeMutation(['cloud.setApiOrigin']);
|
||||
const [tokens, setTokens] = React.useState({ accessToken: '', refreshToken: '' });
|
||||
const accessToken = tokens.accessToken;
|
||||
const refreshToken = tokens.refreshToken;
|
||||
// const origin = useBridgeQuery(['cloud.getApiOrigin']);
|
||||
// const setOrigin = useBridgeMutation(['cloud.setApiOrigin']);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
React.useEffect(() => {
|
||||
async function _() {
|
||||
const _a = await getTokens();
|
||||
setTokens({ accessToken: _a.accessToken, refreshToken: _a.refreshToken });
|
||||
}
|
||||
_();
|
||||
}, []);
|
||||
|
||||
const cloudBootstrap = useBridgeMutation(['cloud.bootstrap']);
|
||||
const addLibraryToCloud = useLibraryMutation('cloud.libraries.create');
|
||||
const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join');
|
||||
const getGroup = useBridgeQuery([
|
||||
'cloud.syncGroups.get',
|
||||
{
|
||||
pub_id: '01924497-a1be-76e3-b62f-9582ea15463a',
|
||||
// pub_id: '01924a25-966b-7c00-a582-9eed3aadd2cd',
|
||||
kind: 'WithDevices'
|
||||
}
|
||||
]);
|
||||
// console.log(getGroup.data);
|
||||
const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']);
|
||||
// console.log('Current Device: ', currentDevice.data);
|
||||
const createSyncGroup = useLibraryMutation('cloud.syncGroups.create');
|
||||
|
||||
// const queryClient = useQueryClient();
|
||||
|
||||
return (
|
||||
<View style={tw`flex-1 p-4`}>
|
||||
@@ -31,7 +57,7 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
|
||||
</Button>
|
||||
<Text style={tw`text-ink`}>{JSON.stringify(featureFlags)}</Text>
|
||||
<Text style={tw`text-ink`}>{JSON.stringify(debugState)}</Text>
|
||||
<Button
|
||||
{/* <Button
|
||||
onPress={() => {
|
||||
navigation.popToTop();
|
||||
navigation.replace('Settings');
|
||||
@@ -39,13 +65,13 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Disable Debug Mode</Text>
|
||||
</Button>
|
||||
<Button
|
||||
</Button> */}
|
||||
{/* <Button
|
||||
onPress={() => {
|
||||
const url =
|
||||
origin.data === 'https://app.spacedrive.com'
|
||||
origin.data === 'https://api.spacedrive.com'
|
||||
? 'http://localhost:3000'
|
||||
: 'https://app.spacedrive.com';
|
||||
: 'https://api.spacedrive.com';
|
||||
setOrigin.mutateAsync(url).then(async () => {
|
||||
await auth.logout();
|
||||
await queryClient.invalidateQueries();
|
||||
@@ -53,7 +79,7 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Toggle API Route ({origin.data})</Text>
|
||||
</Button>
|
||||
</Button> */}
|
||||
<Button
|
||||
onPress={() => {
|
||||
navigation.popToTop();
|
||||
@@ -64,12 +90,53 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => {
|
||||
>
|
||||
<Text style={tw`text-ink`}>Go to Backfill Waiting Page</Text>
|
||||
</Button>
|
||||
<Button
|
||||
{/* <Button
|
||||
onPress={async () => {
|
||||
await auth.logout();
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Logout</Text>
|
||||
</Button> */}
|
||||
<Button
|
||||
onPress={async () => {
|
||||
const tokens = await getTokens();
|
||||
cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Cloud Bootstrap</Text>
|
||||
</Button>
|
||||
<Button
|
||||
onPress={async () => {
|
||||
addLibraryToCloud.mutate(null);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Add Library to Cloud</Text>
|
||||
</Button>
|
||||
<Button
|
||||
onPress={async () => {
|
||||
createSyncGroup.mutate(null);
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Create Sync Group</Text>
|
||||
</Button>
|
||||
<Button
|
||||
onPress={async () => {
|
||||
if (
|
||||
currentDevice.data &&
|
||||
getGroup.data &&
|
||||
getGroup.data.kind === 'WithDevices'
|
||||
) {
|
||||
currentDevice.refetch();
|
||||
console.log('Current Device: ', currentDevice.data);
|
||||
console.log('Get Group: ', getGroup.data.data);
|
||||
requestJoinSyncGroup.mutate({
|
||||
sync_group: getGroup.data.data,
|
||||
asking_device: currentDevice.data
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Text style={tw`text-ink`}>Request Join Sync Group</Text>
|
||||
</Button>
|
||||
</Card>
|
||||
</View>
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
import { useMemo } from 'react';
|
||||
import { ActivityIndicator, FlatList, Text, View } from 'react-native';
|
||||
import { useLibraryContext, useLibraryMutation, useLibraryQuery } from '@sd/client';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import Card from '~/components/layout/Card';
|
||||
import Empty from '~/components/layout/Empty';
|
||||
import ScreenContainer from '~/components/layout/ScreenContainer';
|
||||
import VirtualizedListWrapper from '~/components/layout/VirtualizedListWrapper';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { Divider } from '~/components/primitive/Divider';
|
||||
import { styled, tw, twStyle } from '~/lib/tailwind';
|
||||
import { useAuthStateSnapshot } from '~/stores/auth';
|
||||
|
||||
import Instance from './Instance';
|
||||
import Library from './Library';
|
||||
import Login from './Login';
|
||||
import ThisInstance from './ThisInstance';
|
||||
|
||||
export const InfoBox = styled(View, 'rounded-md border gap-1 border-app bg-transparent p-2');
|
||||
|
||||
const CloudSettings = () => {
|
||||
return (
|
||||
<ScreenContainer scrollview={false} style={tw`gap-0 px-6 py-0`}>
|
||||
<AuthSensitiveChild />
|
||||
</ScreenContainer>
|
||||
);
|
||||
};
|
||||
|
||||
const AuthSensitiveChild = () => {
|
||||
const authState = useAuthStateSnapshot();
|
||||
if (authState.status === 'loggedIn') return <Authenticated />;
|
||||
if (authState.status === 'notLoggedIn' || authState.status === 'loggingIn') return <Login />;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const Authenticated = () => {
|
||||
const { library } = useLibraryContext();
|
||||
const cloudLibrary = useLibraryQuery(['cloud.library.get'], { retry: false });
|
||||
const createLibrary = useLibraryMutation(['cloud.library.create']);
|
||||
|
||||
const cloudInstances = useMemo(
|
||||
() =>
|
||||
cloudLibrary.data?.instances.filter(
|
||||
(instance) => instance.uuid !== library.instance_id
|
||||
),
|
||||
[cloudLibrary.data, library.instance_id]
|
||||
);
|
||||
|
||||
if (cloudLibrary.isLoading) {
|
||||
return (
|
||||
<View style={tw`flex-1 items-center justify-center`}>
|
||||
<ActivityIndicator size="small" />
|
||||
</View>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<ScreenContainer
|
||||
scrollview={Boolean(cloudLibrary.data)}
|
||||
style={tw`gap-0`}
|
||||
tabHeight={false}
|
||||
>
|
||||
{cloudLibrary.data ? (
|
||||
<View style={tw`flex-col items-start gap-5`}>
|
||||
<Library cloudLibrary={cloudLibrary.data} />
|
||||
<ThisInstance cloudLibrary={cloudLibrary.data} />
|
||||
<Card style={tw`w-full`}>
|
||||
<View style={tw`flex-row items-center gap-2`}>
|
||||
<View
|
||||
style={tw`self-start rounded border border-app-lightborder bg-app-highlight px-1.5 py-[2px]`}
|
||||
>
|
||||
<Text style={tw`text-xs font-semibold text-ink`}>
|
||||
{cloudInstances?.length}
|
||||
</Text>
|
||||
</View>
|
||||
<Text style={tw`font-semibold text-ink`}>Instances</Text>
|
||||
</View>
|
||||
<Divider style={tw`mb-4 mt-2`} />
|
||||
<VirtualizedListWrapper
|
||||
scrollEnabled={false}
|
||||
contentContainerStyle={tw`flex-1`}
|
||||
horizontal
|
||||
>
|
||||
<FlatList
|
||||
data={cloudInstances}
|
||||
scrollEnabled={false}
|
||||
ListEmptyComponent={
|
||||
<Empty textStyle={tw`my-0`} description="No instances found" />
|
||||
}
|
||||
contentContainerStyle={twStyle(
|
||||
cloudInstances?.length === 0 && 'flex-row'
|
||||
)}
|
||||
showsHorizontalScrollIndicator={false}
|
||||
ItemSeparatorComponent={() => <View style={tw`h-2`} />}
|
||||
renderItem={({ item }) => <Instance data={item} />}
|
||||
keyExtractor={(item) => item.id}
|
||||
numColumns={1}
|
||||
/>
|
||||
</VirtualizedListWrapper>
|
||||
</Card>
|
||||
</View>
|
||||
) : (
|
||||
<View style={tw`flex-1 justify-center`}>
|
||||
<Card style={tw`relative p-6`}>
|
||||
<Icon style={tw`mx-auto mb-2`} name="CloudSync" size={64} />
|
||||
<Text style={tw`mx-auto text-center text-sm text-ink`}>
|
||||
Uploading your library to the cloud will allow you to access your
|
||||
library from other devices using your account & importing.
|
||||
</Text>
|
||||
<Button
|
||||
variant={'accent'}
|
||||
style={tw`mx-auto mt-4 max-w-[82%]`}
|
||||
disabled={createLibrary.isPending}
|
||||
onPress={async () => await createLibrary.mutateAsync(null)}
|
||||
>
|
||||
{createLibrary.isPending ? (
|
||||
<Text style={tw`text-ink`}>Connecting library...</Text>
|
||||
) : (
|
||||
<Text style={tw`font-medium text-ink`}>Connect library</Text>
|
||||
)}
|
||||
</Button>
|
||||
</Card>
|
||||
</View>
|
||||
)}
|
||||
</ScreenContainer>
|
||||
);
|
||||
};
|
||||
|
||||
export default CloudSettings;
|
||||
@@ -1,64 +0,0 @@
|
||||
import { Text, View } from 'react-native';
|
||||
import { CloudInstance, HardwareModel } from '@sd/client';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import { hardwareModelToIcon } from '~/components/overview/Devices';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
|
||||
import { InfoBox } from './CloudSettings';
|
||||
|
||||
interface Props {
|
||||
data: CloudInstance;
|
||||
}
|
||||
|
||||
const Instance = ({ data }: Props) => {
|
||||
return (
|
||||
<InfoBox style={tw`w-full gap-2`}>
|
||||
<View>
|
||||
<View style={tw`mx-auto my-2`}>
|
||||
<Icon
|
||||
name={
|
||||
hardwareModelToIcon(data.metadata.device_model as HardwareModel) as any
|
||||
}
|
||||
size={60}
|
||||
/>
|
||||
</View>
|
||||
<Text
|
||||
numberOfLines={1}
|
||||
style={tw`mb-3 px-1 text-center text-sm font-medium text-ink`}
|
||||
>
|
||||
{data.metadata.name}
|
||||
</Text>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>Id:</Text>
|
||||
<Text numberOfLines={1} style={tw`max-w-[250px] text-ink-dull`}>
|
||||
{data.id}
|
||||
</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
<View>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>UUID:</Text>
|
||||
<Text numberOfLines={1} style={tw`max-w-[85%] text-ink-dull`}>
|
||||
{data.uuid}
|
||||
</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
<View>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>Public key:</Text>
|
||||
<Text numberOfLines={1} style={tw`max-w-3/4 text-ink-dull`}>
|
||||
{data.identity}
|
||||
</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
</InfoBox>
|
||||
);
|
||||
};
|
||||
|
||||
export default Instance;
|
||||
@@ -1,66 +0,0 @@
|
||||
import { CheckCircle, XCircle } from 'phosphor-react-native';
|
||||
import { useMemo } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import { CloudLibrary, useLibraryContext, useLibraryMutation } from '@sd/client';
|
||||
import Card from '~/components/layout/Card';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { Divider } from '~/components/primitive/Divider';
|
||||
import { SettingsTitle } from '~/components/settings/SettingsContainer';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
import { logout, useAuthStateSnapshot } from '~/stores/auth';
|
||||
|
||||
import { InfoBox } from './CloudSettings';
|
||||
|
||||
interface LibraryProps {
|
||||
cloudLibrary?: CloudLibrary;
|
||||
}
|
||||
|
||||
const Library = ({ cloudLibrary }: LibraryProps) => {
|
||||
const authState = useAuthStateSnapshot();
|
||||
const { library } = useLibraryContext();
|
||||
const syncLibrary = useLibraryMutation(['cloud.library.sync']);
|
||||
const thisInstance = useMemo(
|
||||
() => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id),
|
||||
[cloudLibrary, library.instance_id]
|
||||
);
|
||||
|
||||
return (
|
||||
<Card style={tw`w-full`}>
|
||||
<View style={tw`flex-row items-center justify-between`}>
|
||||
<Text style={tw`font-medium text-ink`}>Library</Text>
|
||||
{authState.status === 'loggedIn' && (
|
||||
<Button variant="gray" size="sm" onPress={logout}>
|
||||
<Text style={tw`text-xs font-semibold text-ink`}>Logout</Text>
|
||||
</Button>
|
||||
)}
|
||||
</View>
|
||||
<Divider style={tw`mb-4 mt-2`} />
|
||||
<SettingsTitle style={tw`mb-2`}>Name</SettingsTitle>
|
||||
<InfoBox>
|
||||
<Text style={tw`text-ink`}>{cloudLibrary?.name}</Text>
|
||||
</InfoBox>
|
||||
<Button
|
||||
disabled={syncLibrary.isPending || thisInstance !== undefined}
|
||||
variant="gray"
|
||||
onPress={() => syncLibrary.mutate(null)}
|
||||
style={tw`mt-2 flex-row gap-1 py-2`}
|
||||
>
|
||||
{thisInstance ? (
|
||||
<CheckCircle size={16} weight="fill" color={tw.color('green-400')} />
|
||||
) : (
|
||||
<XCircle
|
||||
style={tw`rounded-full`}
|
||||
size={16}
|
||||
weight="fill"
|
||||
color={tw.color('red-500')}
|
||||
/>
|
||||
)}
|
||||
<Text style={tw`text-sm font-semibold text-ink`}>
|
||||
{thisInstance !== undefined ? 'Library synced' : 'Library not synced'}
|
||||
</Text>
|
||||
</Button>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default Library;
|
||||
@@ -1,45 +0,0 @@
|
||||
import { Text, View } from 'react-native';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import Card from '~/components/layout/Card';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
import { cancel, login, useAuthStateSnapshot } from '~/stores/auth';
|
||||
|
||||
const Login = () => {
|
||||
const authState = useAuthStateSnapshot();
|
||||
const buttonText = {
|
||||
notLoggedIn: 'Login',
|
||||
loggingIn: 'Cancel'
|
||||
};
|
||||
return (
|
||||
<View style={tw`flex-1 flex-col items-center justify-center gap-2`}>
|
||||
<Card style={tw`w-full items-center justify-center gap-2 p-6`}>
|
||||
<View style={tw`flex-col items-center gap-2`}>
|
||||
<Icon name="CloudSync" size={64} />
|
||||
<Text style={tw`text-center text-sm text-ink`}>
|
||||
Cloud Sync will upload your library to the cloud so you can access your
|
||||
library from other devices by importing it from the cloud.
|
||||
</Text>
|
||||
</View>
|
||||
{(authState.status === 'notLoggedIn' || authState.status === 'loggingIn') && (
|
||||
<Button
|
||||
variant="accent"
|
||||
style={tw`mx-auto mt-4 max-w-[50%]`}
|
||||
onPress={async (e) => {
|
||||
e.preventDefault();
|
||||
if (authState.status === 'loggingIn') {
|
||||
await cancel();
|
||||
} else {
|
||||
await login();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Text style={tw`font-medium text-ink`}>{buttonText[authState.status]}</Text>
|
||||
</Button>
|
||||
)}
|
||||
</Card>
|
||||
</View>
|
||||
);
|
||||
};
|
||||
|
||||
export default Login;
|
||||
@@ -1,76 +0,0 @@
|
||||
import { useMemo } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import { CloudLibrary, HardwareModel, useLibraryContext } from '@sd/client';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import Card from '~/components/layout/Card';
|
||||
import { hardwareModelToIcon } from '~/components/overview/Devices';
|
||||
import { Divider } from '~/components/primitive/Divider';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
|
||||
import { InfoBox } from './CloudSettings';
|
||||
|
||||
interface ThisInstanceProps {
|
||||
cloudLibrary?: CloudLibrary;
|
||||
}
|
||||
|
||||
const ThisInstance = ({ cloudLibrary }: ThisInstanceProps) => {
|
||||
const { library } = useLibraryContext();
|
||||
const thisInstance = useMemo(
|
||||
() => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id),
|
||||
[cloudLibrary, library.instance_id]
|
||||
);
|
||||
|
||||
if (!thisInstance) return null;
|
||||
|
||||
return (
|
||||
<Card style={tw`w-full gap-2`}>
|
||||
<View>
|
||||
<Text style={tw`mb-1 font-semibold text-ink`}>This Instance</Text>
|
||||
<Divider />
|
||||
</View>
|
||||
<View style={tw`mx-auto my-2 items-center`}>
|
||||
<Icon
|
||||
name={
|
||||
hardwareModelToIcon(
|
||||
thisInstance.metadata.device_model as HardwareModel
|
||||
) as any
|
||||
}
|
||||
size={60}
|
||||
/>
|
||||
<Text numberOfLines={1} style={tw`px-1 font-semibold text-ink`}>
|
||||
{thisInstance.metadata.name}
|
||||
</Text>
|
||||
</View>
|
||||
<View>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>Id:</Text>
|
||||
<Text style={tw`max-w-[250px] text-ink-dull`}>{thisInstance.id}</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
<View>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>UUID:</Text>
|
||||
<Text numberOfLines={1} style={tw`max-w-[85%] text-ink-dull`}>
|
||||
{thisInstance.uuid}
|
||||
</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
<View>
|
||||
<InfoBox>
|
||||
<View style={tw`flex-row items-center gap-1`}>
|
||||
<Text style={tw`text-sm font-medium text-ink`}>Publc Key:</Text>
|
||||
<Text numberOfLines={1} style={tw`max-w-3/4 text-ink-dull`}>
|
||||
{thisInstance.identity}
|
||||
</Text>
|
||||
</View>
|
||||
</InfoBox>
|
||||
</View>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default ThisInstance;
|
||||
@@ -1,158 +0,0 @@
|
||||
import { useIsFocused } from '@react-navigation/native';
|
||||
import { inferSubscriptionResult } from '@spacedrive/rspc-client';
|
||||
import { MotiView } from 'moti';
|
||||
import { Circle } from 'phosphor-react-native';
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { Text, View } from 'react-native';
|
||||
import {
|
||||
Procedures,
|
||||
useLibraryMutation,
|
||||
useLibraryQuery,
|
||||
useLibrarySubscription
|
||||
} from '@sd/client';
|
||||
import { Icon } from '~/components/icons/Icon';
|
||||
import Card from '~/components/layout/Card';
|
||||
import { ModalRef } from '~/components/layout/Modal';
|
||||
import ScreenContainer from '~/components/layout/ScreenContainer';
|
||||
import CloudModal from '~/components/modal/cloud/CloudModal';
|
||||
import { Button } from '~/components/primitive/Button';
|
||||
import { tw } from '~/lib/tailwind';
|
||||
import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack';
|
||||
|
||||
const SyncSettingsScreen = ({ navigation }: SettingsStackScreenProps<'SyncSettings'>) => {
|
||||
const syncEnabled = useLibraryQuery(['sync.enabled']);
|
||||
const [data, setData] = useState<inferSubscriptionResult<Procedures, 'library.actors'>>({});
|
||||
const modalRef = useRef<ModalRef>(null);
|
||||
|
||||
const [startBackfill, setStart] = useState(false);
|
||||
const pageFocused = useIsFocused();
|
||||
const [showCloudModal, setShowCloudModal] = useState(false);
|
||||
|
||||
useLibrarySubscription(['library.actors'], { onData: setData });
|
||||
|
||||
useEffect(() => {
|
||||
if (startBackfill === true) {
|
||||
navigation.navigate('BackfillWaitingStack', {
|
||||
screen: 'BackfillWaiting'
|
||||
});
|
||||
setTimeout(() => setShowCloudModal(true), 1000);
|
||||
}
|
||||
}, [startBackfill, navigation]);
|
||||
|
||||
useEffect(() => {
|
||||
if (pageFocused && showCloudModal) modalRef.current?.present();
|
||||
return () => {
|
||||
if (showCloudModal) setShowCloudModal(false);
|
||||
};
|
||||
}, [pageFocused, showCloudModal]);
|
||||
|
||||
return (
|
||||
<ScreenContainer scrollview={false} style={tw`gap-0 px-6`}>
|
||||
{syncEnabled.data === false ? (
|
||||
<View style={tw`flex-1 justify-center`}>
|
||||
<Card style={tw`relative flex-col items-center gap-5 p-6`}>
|
||||
<View style={tw`flex-col items-center gap-2`}>
|
||||
<Icon name="Sync" size={72} style={tw`mb-2`} />
|
||||
<Text style={tw`text-center leading-5 text-ink`}>
|
||||
With Sync, you can share your library with other devices using P2P
|
||||
technology.
|
||||
</Text>
|
||||
<Text style={tw`text-center leading-5 text-ink`}>
|
||||
Additionally, allowing you to enable Cloud services to upload your
|
||||
library to the cloud, making it accessible on any of your devices.
|
||||
</Text>
|
||||
</View>
|
||||
<Button
|
||||
variant={'accent'}
|
||||
style={tw`mx-auto max-w-[82%]`}
|
||||
onPress={() => setStart(true)}
|
||||
>
|
||||
<Text style={tw`font-medium text-white`}>Start</Text>
|
||||
</Button>
|
||||
</Card>
|
||||
</View>
|
||||
) : (
|
||||
<View style={tw`flex-row flex-wrap gap-2`}>
|
||||
{Object.keys(data).map((key) => {
|
||||
return (
|
||||
<Card style={tw`w-[48%]`} key={key}>
|
||||
<OnlineIndicator online={data[key] ?? false} />
|
||||
<Text
|
||||
key={key}
|
||||
style={tw`mb-3 mt-1 flex-col items-center justify-center text-left text-xs text-white`}
|
||||
>
|
||||
{key}
|
||||
</Text>
|
||||
{data[key] ? <StopButton name={key} /> : <StartButton name={key} />}
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</View>
|
||||
)}
|
||||
<CloudModal ref={modalRef} />
|
||||
</ScreenContainer>
|
||||
);
|
||||
};
|
||||
|
||||
export default SyncSettingsScreen;
|
||||
|
||||
function OnlineIndicator({ online }: { online: boolean }) {
|
||||
const size = 6;
|
||||
return (
|
||||
<View
|
||||
style={tw`mb-1 h-6 w-6 items-center justify-center rounded-full border border-app-inputborder bg-app-input p-2`}
|
||||
>
|
||||
{online ? (
|
||||
<View style={tw`relative items-center justify-center`}>
|
||||
<MotiView
|
||||
from={{ scale: 0, opacity: 1 }}
|
||||
animate={{ scale: 3, opacity: 0 }}
|
||||
transition={{
|
||||
type: 'timing',
|
||||
duration: 1500,
|
||||
loop: true,
|
||||
repeatReverse: false,
|
||||
delay: 1000
|
||||
}}
|
||||
style={tw`absolute z-10 h-2 w-2 items-center justify-center rounded-full bg-green-500`}
|
||||
/>
|
||||
<View style={tw`h-2 w-2 rounded-full bg-green-500`} />
|
||||
</View>
|
||||
) : (
|
||||
<Circle size={size} color={tw.color('red-400')} weight="fill" />
|
||||
)}
|
||||
</View>
|
||||
);
|
||||
}
|
||||
|
||||
function StartButton({ name }: { name: string }) {
|
||||
const startActor = useLibraryMutation(['library.startActor']);
|
||||
return (
|
||||
<Button
|
||||
variant="accent"
|
||||
size="sm"
|
||||
disabled={startActor.isPending}
|
||||
onPress={() => startActor.mutate(name)}
|
||||
>
|
||||
<Text style={tw`text-xs font-medium text-ink`}>
|
||||
{startActor.isPending ? 'Starting' : 'Start'}
|
||||
</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
function StopButton({ name }: { name: string }) {
|
||||
const stopActor = useLibraryMutation(['library.stopActor']);
|
||||
return (
|
||||
<Button
|
||||
variant="accent"
|
||||
size="sm"
|
||||
disabled={stopActor.isPending}
|
||||
onPress={() => stopActor.mutate(name)}
|
||||
>
|
||||
<Text style={tw`text-xs font-medium text-ink`}>
|
||||
{stopActor.isPending ? 'Stopping' : 'Stop'}
|
||||
</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -18,16 +18,16 @@ export function useAuthStateSnapshot() {
|
||||
return useSolidStore(store).state;
|
||||
}
|
||||
|
||||
nonLibraryClient
|
||||
.query(['auth.me'])
|
||||
.then(() => (store.state = { status: 'loggedIn' }))
|
||||
.catch((e) => {
|
||||
if (e instanceof RSPCError && e.code === 401) {
|
||||
// TODO: handle error?
|
||||
console.error('error', e);
|
||||
}
|
||||
store.state = { status: 'notLoggedIn' };
|
||||
});
|
||||
// nonLibraryClient
|
||||
// .query(['auth.me'])
|
||||
// .then(() => (store.state = { status: 'loggedIn' }))
|
||||
// .catch((e) => {
|
||||
// if (e instanceof RSPCError && e.code === 401) {
|
||||
// // TODO: handle error?
|
||||
// console.error('error', e);
|
||||
// }
|
||||
// store.state = { status: 'notLoggedIn' };
|
||||
// });
|
||||
|
||||
type CallbackStatus = 'success' | { error: string } | 'cancel';
|
||||
const loginCallbacks = new Set<(status: CallbackStatus) => void>();
|
||||
@@ -41,29 +41,29 @@ export function login() {
|
||||
|
||||
store.state = { status: 'loggingIn' };
|
||||
|
||||
let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], {
|
||||
onData(data) {
|
||||
if (data === 'Complete') {
|
||||
loginCallbacks.forEach((cb) => cb('success'));
|
||||
} else if ('Error' in data) {
|
||||
console.error('[auth] error: ', data.Error);
|
||||
onError(data.Error);
|
||||
} else {
|
||||
console.log('[auth] verification url: ', data.Start.verification_url_complete);
|
||||
Promise.resolve()
|
||||
.then(() => Linking.openURL(data.Start.verification_url_complete))
|
||||
.then(
|
||||
(res) => {
|
||||
authCleanup = res;
|
||||
},
|
||||
(e) => onError(e.message)
|
||||
);
|
||||
}
|
||||
},
|
||||
onError(e) {
|
||||
onError(e.message);
|
||||
}
|
||||
});
|
||||
// let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], {
|
||||
// onData(data) {
|
||||
// if (data === 'Complete') {
|
||||
// loginCallbacks.forEach((cb) => cb('success'));
|
||||
// } else if ('Error' in data) {
|
||||
// console.error('[auth] error: ', data.Error);
|
||||
// onError(data.Error);
|
||||
// } else {
|
||||
// console.log('[auth] verification url: ', data.Start.verification_url_complete);
|
||||
// Promise.resolve()
|
||||
// .then(() => Linking.openURL(data.Start.verification_url_complete))
|
||||
// .then(
|
||||
// (res) => {
|
||||
// authCleanup = res;
|
||||
// },
|
||||
// (e) => onError(e.message)
|
||||
// );
|
||||
// }
|
||||
// },
|
||||
// onError(e) {
|
||||
// onError(e.message);
|
||||
// }
|
||||
// });
|
||||
|
||||
return new Promise<void>((res, rej) => {
|
||||
const cb = async (status: CallbackStatus) => {
|
||||
@@ -71,7 +71,7 @@ export function login() {
|
||||
|
||||
if (status === 'success') {
|
||||
store.state = { status: 'loggedIn' };
|
||||
nonLibraryClient.query(['auth.me']);
|
||||
// nonLibraryClient.query(['auth.me']);
|
||||
res();
|
||||
} else {
|
||||
store.state = { status: 'notLoggedIn' };
|
||||
@@ -88,8 +88,8 @@ export function set_logged_in() {
|
||||
|
||||
export function logout() {
|
||||
store.state = { status: 'loggingOut' };
|
||||
nonLibraryClient.mutation(['auth.logout']);
|
||||
nonLibraryClient.query(['auth.me']);
|
||||
// nonLibraryClient.mutation(['auth.logout']);
|
||||
// nonLibraryClient.query(['auth.me']);
|
||||
store.state = { status: 'notLoggedIn' };
|
||||
}
|
||||
|
||||
|
||||
22
apps/mobile/src/stores/userStore.ts
Normal file
22
apps/mobile/src/stores/userStore.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { proxy, useSnapshot } from 'valtio';
|
||||
|
||||
export type User = {
|
||||
id: string;
|
||||
email: string;
|
||||
timeJoined: number;
|
||||
tenantIds: string[];
|
||||
};
|
||||
|
||||
const state = {
|
||||
userInfo: undefined as User | undefined
|
||||
};
|
||||
|
||||
const store = proxy({
|
||||
...state
|
||||
});
|
||||
|
||||
// for reading
|
||||
export const useUserStore = () => useSnapshot(store);
|
||||
|
||||
// for writing
|
||||
export const getUserStore = () => store;
|
||||
13
apps/mobile/src/utils/index.ts
Normal file
13
apps/mobile/src/utils/index.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import AsyncStorage from '@react-native-async-storage/async-storage';
|
||||
|
||||
export async function getTokens() {
|
||||
const fetchedToken = await AsyncStorage.getItem('access_token');
|
||||
const fetchedRefreshToken = await AsyncStorage.getItem('refresh_token');
|
||||
return {
|
||||
accessToken: fetchedToken ?? '',
|
||||
refreshToken: fetchedRefreshToken ?? ''
|
||||
};
|
||||
}
|
||||
|
||||
// export const AUTH_SERVER_URL = __DEV__ ? 'http://localhost:9420' : 'https://auth.spacedrive.com';
|
||||
export const AUTH_SERVER_URL = 'https://auth.spacedrive.com';
|
||||
@@ -144,19 +144,7 @@ async fn main() {
|
||||
|
||||
let state = AppState { auth };
|
||||
|
||||
let (node, router) = match Node::new(
|
||||
data_dir,
|
||||
sd_core::Env {
|
||||
api_url: tokio::sync::Mutex::new(
|
||||
std::env::var("SD_API_URL")
|
||||
.unwrap_or_else(|_| "https://app.spacedrive.com".to_string()),
|
||||
),
|
||||
client_id: std::env::var("SD_CLIENT_ID")
|
||||
.unwrap_or_else(|_| "04701823-a498-406e-aef9-22081c1dae34".to_string()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
let (node, router) = match Node::new(data_dir).await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
panic!("{}", e.to_string())
|
||||
|
||||
@@ -30,6 +30,6 @@
|
||||
"storybook": "^8.0.1",
|
||||
"tailwindcss": "^3.4.10",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^5.2.0"
|
||||
"vite": "^5.4.9"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"start-server-and-test": "^2.0.3",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^5.2.0",
|
||||
"vite-tsconfig-paths": "^4.3.2"
|
||||
"vite": "^5.4.9",
|
||||
"vite-tsconfig-paths": "^5.0.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ heif = ["sd-images/heif"]
|
||||
|
||||
[dependencies]
|
||||
# Inner Core Sub-crates
|
||||
sd-core-cloud-services = { path = "./crates/cloud-services" }
|
||||
sd-core-file-path-helper = { path = "./crates/file-path-helper" }
|
||||
sd-core-heavy-lifting = { path = "./crates/heavy-lifting" }
|
||||
sd-core-indexer-rules = { path = "./crates/indexer-rules" }
|
||||
@@ -29,7 +30,8 @@ sd-core-sync = { path = "./crates/sync" }
|
||||
# Spacedrive Sub-crates
|
||||
sd-actors = { path = "../crates/actors" }
|
||||
sd-ai = { path = "../crates/ai", optional = true }
|
||||
sd-cloud-api = { path = "../crates/cloud-api" }
|
||||
sd-crypto = { path = "../crates/crypto" }
|
||||
sd-ffmpeg = { path = "../crates/ffmpeg", optional = true }
|
||||
sd-file-ext = { path = "../crates/file-ext" }
|
||||
sd-images = { path = "../crates/images", features = ["rspc", "serde", "specta"] }
|
||||
sd-media-metadata = { path = "../crates/media-metadata" }
|
||||
@@ -49,6 +51,7 @@ async-trait = { workspace = true }
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
base64 = { workspace = true }
|
||||
blake3 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
futures = { workspace = true }
|
||||
futures-concurrency = { workspace = true }
|
||||
@@ -64,6 +67,7 @@ reqwest = { workspace = true, features = ["json", "native-tls-vendor
|
||||
rmp-serde = { workspace = true }
|
||||
rmpv = { workspace = true }
|
||||
rspc = { workspace = true, features = ["alpha", "axum", "chrono", "unstable", "uuid"] }
|
||||
sd-cloud-schema = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive", "rc"] }
|
||||
serde_json = { workspace = true }
|
||||
specta = { workspace = true }
|
||||
@@ -75,12 +79,11 @@ tokio-stream = { workspace = true, features = ["fs"] }
|
||||
tokio-util = { workspace = true, features = ["io"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
uuid = { workspace = true, features = ["serde", "v4"] }
|
||||
uuid = { workspace = true, features = ["serde", "v4", "v7"] }
|
||||
|
||||
# Specific Core dependencies
|
||||
async-recursion = "1.1"
|
||||
base91 = "0.1.0"
|
||||
bytes = "1.6"
|
||||
ctor = "0.2.8"
|
||||
directories = "5.0"
|
||||
flate2 = "1.0"
|
||||
@@ -98,6 +101,7 @@ sysinfo = "0.29.11" # Update blocked
|
||||
tar = "0.4.41"
|
||||
tower-service = "0.3.2"
|
||||
tracing-appender = "0.2.3"
|
||||
whoami = "1.5.2"
|
||||
|
||||
[dependencies.tokio]
|
||||
features = ["io-util", "macros", "process", "rt-multi-thread", "sync", "time"]
|
||||
|
||||
55
core/crates/cloud-services/Cargo.toml
Normal file
55
core/crates/cloud-services/Cargo.toml
Normal file
@@ -0,0 +1,55 @@
|
||||
[package]
|
||||
name = "sd-core-cloud-services"
|
||||
version = "0.1.0"
|
||||
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Core Spacedrive Sub-crates
|
||||
sd-core-sync = { path = "../sync" }
|
||||
|
||||
# Spacedrive Sub-crates
|
||||
sd-actors = { path = "../../../crates/actors" }
|
||||
sd-cloud-schema = { workspace = true }
|
||||
sd-crypto = { path = "../../../crates/crypto" }
|
||||
sd-prisma = { path = "../../../crates/prisma" }
|
||||
sd-utils = { path = "../../../crates/utils" }
|
||||
|
||||
# Workspace dependencies
|
||||
async-stream = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
blake3 = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
flume = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-concurrency = { workspace = true }
|
||||
rmp-serde = { workspace = true }
|
||||
rspc = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
specta = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "time"] }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde"] }
|
||||
zeroize = { workspace = true }
|
||||
|
||||
# External dependencies
|
||||
anyhow = "1.0.86"
|
||||
dashmap = "6.1.0"
|
||||
iroh-net = { version = "0.27", features = ["discovery-local-network", "iroh-relay"] }
|
||||
paste = "=1.0.15"
|
||||
quic-rpc = { version = "0.12.1", features = ["quinn-transport"] }
|
||||
quinn = { package = "iroh-quinn", version = "0.11" }
|
||||
# Using whatever version of reqwest that reqwest-middleware uses, just putting here to enable some features
|
||||
reqwest = { version = "0.12", features = ["json", "native-tls-vendored", "stream"] }
|
||||
reqwest-middleware = { version = "0.3", features = ["json"] }
|
||||
reqwest-retry = "0.6"
|
||||
rustls = { version = "=0.23.15", default-features = false, features = ["brotli", "ring", "std"] }
|
||||
rustls-platform-verifier = "0.3.3"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["rt", "sync", "time"] }
|
||||
358
core/crates/cloud-services/src/client.rs
Normal file
358
core/crates/cloud-services/src/client.rs
Normal file
@@ -0,0 +1,358 @@
|
||||
use crate::p2p::{NotifyUser, UserResponse};
|
||||
|
||||
use sd_cloud_schema::{Client, Service, ServicesALPN};
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use futures::Stream;
|
||||
use iroh_net::relay::RelayUrl;
|
||||
use quic_rpc::{transport::quinn::QuinnConnection, RpcClient};
|
||||
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint};
|
||||
use reqwest::{IntoUrl, Url};
|
||||
use reqwest_middleware::{reqwest, ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tracing::warn;
|
||||
|
||||
use super::{
|
||||
error::Error, key_manager::KeyManager, p2p::CloudP2P, token_refresher::TokenRefresher,
|
||||
};
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
enum ClientState {
|
||||
#[default]
|
||||
NotConnected,
|
||||
Connected(Client<QuinnConnection<Service>, Service>),
|
||||
}
|
||||
|
||||
/// Cloud services are a optional feature that allows you to interact with the cloud services
|
||||
/// of Spacedrive.
|
||||
/// They're optional in two different ways:
|
||||
/// - The cloud services depends on a user being logged in with our server.
|
||||
/// - The user being connected to the internet to begin with.
|
||||
///
|
||||
/// As we don't want to force the user to be connected to the internet, we have to make sure
|
||||
/// that core can always operate without the cloud services.
|
||||
#[derive(Debug)]
|
||||
pub struct CloudServices {
|
||||
client_state: Arc<RwLock<ClientState>>,
|
||||
get_cloud_api_address: Url,
|
||||
http_client: ClientWithMiddleware,
|
||||
domain_name: String,
|
||||
pub cloud_p2p_dns_origin_name: String,
|
||||
pub cloud_p2p_relay_url: RelayUrl,
|
||||
pub cloud_p2p_dns_pkarr_url: Url,
|
||||
pub token_refresher: TokenRefresher,
|
||||
key_manager: Arc<RwLock<Option<Arc<KeyManager>>>>,
|
||||
cloud_p2p: Arc<RwLock<Option<Arc<CloudP2P>>>>,
|
||||
pub(crate) notify_user_tx: flume::Sender<NotifyUser>,
|
||||
notify_user_rx: flume::Receiver<NotifyUser>,
|
||||
user_response_tx: flume::Sender<UserResponse>,
|
||||
pub(crate) user_response_rx: flume::Receiver<UserResponse>,
|
||||
pub has_bootstrapped: Arc<Mutex<bool>>,
|
||||
}
|
||||
|
||||
impl CloudServices {
|
||||
/// Creates a new cloud services client that can be used to interact with the cloud services.
|
||||
/// The client will try to connect to the cloud services on a best effort basis, as the user
|
||||
/// might not be connected to the internet.
|
||||
/// If the client fails to connect, it will try again the next time it's used.
|
||||
pub async fn new(
|
||||
get_cloud_api_address: impl IntoUrl + Send,
|
||||
cloud_p2p_relay_url: impl IntoUrl + Send,
|
||||
cloud_p2p_dns_pkarr_url: impl IntoUrl + Send,
|
||||
cloud_p2p_dns_origin_name: String,
|
||||
domain_name: String,
|
||||
) -> Result<Self, Error> {
|
||||
let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3));
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
http_client_builder = http_client_builder.https_only(true);
|
||||
}
|
||||
|
||||
let cloud_p2p_relay_url = cloud_p2p_relay_url
|
||||
.into_url()
|
||||
.map_err(Error::InvalidUrl)?
|
||||
.into();
|
||||
|
||||
let cloud_p2p_dns_pkarr_url = cloud_p2p_dns_pkarr_url
|
||||
.into_url()
|
||||
.map_err(Error::InvalidUrl)?;
|
||||
|
||||
let http_client =
|
||||
ClientBuilder::new(http_client_builder.build().map_err(Error::HttpClientInit)?)
|
||||
.with(RetryTransientMiddleware::new_with_policy(
|
||||
ExponentialBackoff::builder().build_with_max_retries(3),
|
||||
))
|
||||
.build();
|
||||
let get_cloud_api_address = get_cloud_api_address
|
||||
.into_url()
|
||||
.map_err(Error::InvalidUrl)?;
|
||||
|
||||
let client_state = match Self::init_client(
|
||||
&http_client,
|
||||
get_cloud_api_address.clone(),
|
||||
domain_name.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(client) => Arc::new(RwLock::new(ClientState::Connected(client))),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
?e,
|
||||
"Failed to initialize cloud services client; \
|
||||
This is a best effort and we will continue in Not Connected mode"
|
||||
);
|
||||
Arc::new(RwLock::new(ClientState::NotConnected))
|
||||
}
|
||||
};
|
||||
|
||||
let (notify_user_tx, notify_user_rx) = flume::bounded(16);
|
||||
let (user_response_tx, user_response_rx) = flume::bounded(16);
|
||||
|
||||
Ok(Self {
|
||||
client_state,
|
||||
token_refresher: TokenRefresher::new(
|
||||
http_client.clone(),
|
||||
get_cloud_api_address.clone(),
|
||||
),
|
||||
get_cloud_api_address,
|
||||
http_client,
|
||||
cloud_p2p_dns_origin_name,
|
||||
cloud_p2p_relay_url,
|
||||
cloud_p2p_dns_pkarr_url,
|
||||
domain_name,
|
||||
key_manager: Arc::default(),
|
||||
cloud_p2p: Arc::default(),
|
||||
notify_user_tx,
|
||||
notify_user_rx,
|
||||
user_response_tx,
|
||||
user_response_rx,
|
||||
has_bootstrapped: Arc::default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn stream_user_notifications(&self) -> impl Stream<Item = NotifyUser> + '_ {
|
||||
self.notify_user_rx.stream()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn http_client(&self) -> &ClientWithMiddleware {
|
||||
&self.http_client
|
||||
}
|
||||
|
||||
/// Send back a user response to the Cloud P2P actor
|
||||
///
|
||||
/// # Panics
|
||||
/// Will panic if the channel is closed, which should never happen
|
||||
pub async fn send_user_response(&self, response: UserResponse) {
|
||||
self.user_response_tx
|
||||
.send_async(response)
|
||||
.await
|
||||
.expect("user response channel must never close");
|
||||
}
|
||||
|
||||
async fn init_client(
|
||||
http_client: &ClientWithMiddleware,
|
||||
get_cloud_api_address: Url,
|
||||
domain_name: String,
|
||||
) -> Result<Client<QuinnConnection<Service>, Service>, Error> {
|
||||
let cloud_api_address = http_client
|
||||
.get(get_cloud_api_address)
|
||||
.send()
|
||||
.await
|
||||
.map_err(Error::FailedToRequestApiAddress)?
|
||||
.error_for_status()
|
||||
.map_err(Error::AuthServerError)?
|
||||
.text()
|
||||
.await
|
||||
.map_err(Error::FailedToExtractApiAddress)?
|
||||
.parse::<SocketAddr>()?;
|
||||
|
||||
let mut crypto_config = {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
#[derive(Debug)]
|
||||
struct SkipServerVerification;
|
||||
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
|
||||
{
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
|
||||
{
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA1,
|
||||
rustls::SignatureScheme::ECDSA_SHA1_Legacy,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::ED448,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
|
||||
.with_no_client_auth()
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(
|
||||
rustls_platform_verifier::Verifier::new(),
|
||||
))
|
||||
.with_no_client_auth()
|
||||
}
|
||||
};
|
||||
|
||||
crypto_config
|
||||
.alpn_protocols
|
||||
.extend([ServicesALPN::LATEST.to_vec()]);
|
||||
|
||||
let client_config = ClientConfig::new(Arc::new(
|
||||
QuicClientConfig::try_from(crypto_config)
|
||||
.expect("misconfigured TLS client config, this is a bug and should crash"),
|
||||
));
|
||||
|
||||
let mut endpoint = Endpoint::client("[::]:0".parse().expect("hardcoded address"))
|
||||
.map_err(Error::FailedToCreateEndpoint)?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
// TODO(@fogodev): It's possible that we can't keep the connection alive all the time,
|
||||
// and need to use single shot connections. I will only be sure when we have
|
||||
// actually battle-tested the cloud services in core.
|
||||
Ok(Client::new(RpcClient::new(QuinnConnection::new(
|
||||
endpoint,
|
||||
cloud_api_address,
|
||||
domain_name,
|
||||
))))
|
||||
}
|
||||
|
||||
/// Returns a client to the cloud services.
|
||||
///
|
||||
/// If the client is not connected, it will try to connect to the cloud services.
|
||||
/// Available routes documented in
|
||||
/// [`sd_cloud_schema::Service`](https://github.com/spacedriveapp/cloud-services-schema).
|
||||
pub async fn client(&self) -> Result<Client<QuinnConnection<Service>, Service>, Error> {
|
||||
if let ClientState::Connected(client) = { self.client_state.read().await.clone() } {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
// If we're not connected, we need to try to connect.
|
||||
let client = Self::init_client(
|
||||
&self.http_client,
|
||||
self.get_cloud_api_address.clone(),
|
||||
self.domain_name.clone(),
|
||||
)
|
||||
.await?;
|
||||
*self.client_state.write().await = ClientState::Connected(client.clone());
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn set_key_manager(&self, key_manager: KeyManager) {
|
||||
self.key_manager
|
||||
.write()
|
||||
.await
|
||||
.replace(Arc::new(key_manager));
|
||||
}
|
||||
|
||||
pub async fn key_manager(&self) -> Result<Arc<KeyManager>, Error> {
|
||||
self.key_manager
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.map_or(Err(Error::KeyManagerNotInitialized), |key_manager| {
|
||||
Ok(Arc::clone(key_manager))
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn set_cloud_p2p(&self, cloud_p2p: CloudP2P) {
|
||||
self.cloud_p2p.write().await.replace(Arc::new(cloud_p2p));
|
||||
}
|
||||
|
||||
pub async fn cloud_p2p(&self) -> Result<Arc<CloudP2P>, Error> {
|
||||
self.cloud_p2p
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.map_or(Err(Error::CloudP2PNotInitialized), |cloud_p2p| {
|
||||
Ok(Arc::clone(cloud_p2p))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sd_cloud_schema::{auth, devices};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[ignore]
|
||||
#[tokio::test]
|
||||
async fn test_client() {
|
||||
let response = CloudServices::new(
|
||||
"http://localhost:9420/cloud-api-address",
|
||||
"http://relay.localhost:9999/",
|
||||
"http://pkarr.localhost:9999/",
|
||||
"dns.localhost:9999".to_string(),
|
||||
"localhost".to_string(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.client()
|
||||
.await
|
||||
.unwrap()
|
||||
.devices()
|
||||
.list(devices::list::Request {
|
||||
access_token: auth::AccessToken("invalid".to_string()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
response,
|
||||
Err(sd_cloud_schema::Error::Client(
|
||||
sd_cloud_schema::error::ClientSideError::Unauthorized
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
170
core/crates/cloud-services/src/error.rs
Normal file
170
core/crates/cloud-services/src/error.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use sd_cloud_schema::{cloud_p2p, sync::groups, Service};
|
||||
use sd_utils::error::FileIOError;
|
||||
|
||||
use std::{io, net::AddrParseError};
|
||||
|
||||
use quic_rpc::{
|
||||
pattern::{bidi_streaming, rpc, server_streaming},
|
||||
transport::quinn::QuinnConnection,
|
||||
};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
// Setup errors
|
||||
#[error("Couldn't parse Cloud Services API address URL: {0}")]
|
||||
InvalidUrl(reqwest::Error),
|
||||
#[error("Failed to parse Cloud Services API address URL")]
|
||||
FailedToParseRelayUrl,
|
||||
#[error("Failed to initialize http client: {0}")]
|
||||
HttpClientInit(reqwest::Error),
|
||||
#[error("Failed to request Cloud Services API address from Auth Server route: {0}")]
|
||||
FailedToRequestApiAddress(reqwest_middleware::Error),
|
||||
#[error("Auth Server's Cloud Services API address route returned an error: {0}")]
|
||||
AuthServerError(reqwest::Error),
|
||||
#[error(
|
||||
"Failed to extract response body from Auth Server's Cloud Services API address route: {0}"
|
||||
)]
|
||||
FailedToExtractApiAddress(reqwest::Error),
|
||||
#[error("Failed to parse auth server's Cloud Services API address: {0}")]
|
||||
FailedToParseApiAddress(#[from] AddrParseError),
|
||||
#[error("Failed to create endpoint: {0}")]
|
||||
FailedToCreateEndpoint(io::Error),
|
||||
|
||||
// Token refresher errors
|
||||
#[error("Invalid token format, missing claims")]
|
||||
MissingClaims,
|
||||
#[error("Failed to decode access token data: {0}")]
|
||||
DecodeAccessTokenData(#[from] base64::DecodeError),
|
||||
#[error("Failed to deserialize access token json data: {0}")]
|
||||
DeserializeAccessTokenData(#[from] serde_json::Error),
|
||||
#[error("Token expired")]
|
||||
TokenExpired,
|
||||
#[error("Failed to request refresh token: {0}")]
|
||||
RefreshTokenRequest(reqwest_middleware::Error),
|
||||
#[error("Missing tokens on refresh response")]
|
||||
MissingTokensOnRefreshResponse,
|
||||
#[error("Failed to parse token header value to string: {0}")]
|
||||
FailedToParseTokenHeaderValueToString(#[from] reqwest::header::ToStrError),
|
||||
|
||||
// Key Manager errors
|
||||
#[error("Failed to handle File on KeyManager: {0}")]
|
||||
FileIO(#[from] FileIOError),
|
||||
#[error("Failed to handle key store serialization: {0}")]
|
||||
KeyStoreSerialization(rmp_serde::encode::Error),
|
||||
#[error("Failed to handle key store deserialization: {0}")]
|
||||
KeyStoreDeserialization(rmp_serde::decode::Error),
|
||||
#[error("Key store encryption related error: {{context: \"{context}\", source: {source}}}")]
|
||||
KeyStoreCrypto {
|
||||
#[source]
|
||||
source: sd_crypto::Error,
|
||||
context: &'static str,
|
||||
},
|
||||
#[error("Key manager not initialized")]
|
||||
KeyManagerNotInitialized,
|
||||
|
||||
// Cloud P2P errors
|
||||
#[error("Failed to create Cloud P2P endpoint: {0}")]
|
||||
CreateCloudP2PEndpoint(anyhow::Error),
|
||||
#[error("Failed to connect to Cloud P2P node: {0}")]
|
||||
ConnectToCloudP2PNode(anyhow::Error),
|
||||
#[error("Communication error with Cloud P2P node: {0}")]
|
||||
CloudP2PRpcCommunication(#[from] rpc::Error<QuinnConnection<cloud_p2p::Service>>),
|
||||
#[error("Cloud P2P not initialized")]
|
||||
CloudP2PNotInitialized,
|
||||
#[error("Failed to initialize LocalSwarmDiscovery: {0}")]
|
||||
LocalSwarmDiscoveryInit(anyhow::Error),
|
||||
#[error("Failed to initialize DhtDiscovery: {0}")]
|
||||
DhtDiscoveryInit(anyhow::Error),
|
||||
|
||||
// Communication errors
|
||||
#[error("Failed to communicate with RPC backend: {0}")]
|
||||
RpcCommunication(#[from] rpc::Error<QuinnConnection<Service>>),
|
||||
#[error("Failed to communicate with Server Streaming RPC backend: {0}")]
|
||||
ServerStreamCommunication(#[from] server_streaming::Error<QuinnConnection<Service>>),
|
||||
#[error("Failed to receive next response from Server Streaming RPC backend: {0}")]
|
||||
ServerStreamRecv(#[from] server_streaming::ItemError<QuinnConnection<Service>>),
|
||||
#[error("Failed to communicate with Bidi Streaming RPC backend: {0}")]
|
||||
BidiStreamCommunication(#[from] bidi_streaming::Error<QuinnConnection<Service>>),
|
||||
#[error("Failed to receive next response from Bidi Streaming RPC backend: {0}")]
|
||||
BidiStreamRecv(#[from] bidi_streaming::ItemError<QuinnConnection<Service>>),
|
||||
#[error("Error from backend: {0}")]
|
||||
Backend(#[from] sd_cloud_schema::Error),
|
||||
#[error("Failed to get access token from refresher: {0}")]
|
||||
GetToken(#[from] GetTokenError),
|
||||
#[error("Unexpected empty response from backend, context: {0}")]
|
||||
EmptyResponse(&'static str),
|
||||
#[error("Unexpected response from backend, context: {0}")]
|
||||
UnexpectedResponse(&'static str),
|
||||
|
||||
// Sync error
|
||||
#[error("Sync error: {0}")]
|
||||
Sync(#[from] sd_core_sync::Error),
|
||||
#[error("Tried to sync messages with a group without having needed key")]
|
||||
MissingSyncGroupKey(groups::PubId),
|
||||
#[error("Failed to encrypt sync messages: {0}")]
|
||||
Encrypt(sd_crypto::Error),
|
||||
#[error("Failed to decrypt sync messages: {0}")]
|
||||
Decrypt(sd_crypto::Error),
|
||||
#[error("Failed to upload sync messages: {0}")]
|
||||
UploadSyncMessages(reqwest_middleware::Error),
|
||||
#[error("Failed to download sync messages: {0}")]
|
||||
DownloadSyncMessages(reqwest_middleware::Error),
|
||||
#[error("Received an error response from uploading sync messages: {0}")]
|
||||
ErrorResponseUploadSyncMessages(reqwest::Error),
|
||||
#[error("Received an error response from downloading sync messages: {0}")]
|
||||
ErrorResponseDownloadSyncMessages(reqwest::Error),
|
||||
#[error(
|
||||
"Received an error response from downloading sync messages while reading its bytes: {0}"
|
||||
)]
|
||||
ErrorResponseDownloadReadBytesSyncMessages(reqwest::Error),
|
||||
#[error("Critical error while uploading sync messages")]
|
||||
CriticalErrorWhileUploadingSyncMessages,
|
||||
#[error("Failed to send End update to push sync messages")]
|
||||
EndUpdatePushSyncMessages(io::Error),
|
||||
#[error("Unexpected end of stream while encrypting sync messages")]
|
||||
UnexpectedEndOfStream,
|
||||
#[error("Failed to create directory to store timestamp keeper files")]
|
||||
FailedToCreateTimestampKeepersDirectory(io::Error),
|
||||
#[error("Failed to read last timestamp keeper for pulling sync messages: {0}")]
|
||||
FailedToReadLastTimestampKeeper(io::Error),
|
||||
#[error("Failed to handle last timestamp keeper serialization: {0}")]
|
||||
LastTimestampKeeperSerialization(rmp_serde::encode::Error),
|
||||
#[error("Failed to handle last timestamp keeper deserialization: {0}")]
|
||||
LastTimestampKeeperDeserialization(rmp_serde::decode::Error),
|
||||
#[error("Failed to write last timestamp keeper for pulling sync messages: {0}")]
|
||||
FailedToWriteLastTimestampKeeper(io::Error),
|
||||
#[error("Sync messages download and decrypt task panicked")]
|
||||
SyncMessagesDownloadAndDecryptTaskPanicked,
|
||||
#[error("Serialization failure to push sync messages: {0}")]
|
||||
SerializationFailureToPushSyncMessages(rmp_serde::encode::Error),
|
||||
#[error("Deserialization failure to pull sync messages: {0}")]
|
||||
DeserializationFailureToPullSyncMessages(rmp_serde::decode::Error),
|
||||
#[error("Read nonce stream decryption: {0}")]
|
||||
ReadNonceStreamDecryption(io::Error),
|
||||
#[error("Incomplete download bytes sync messages")]
|
||||
IncompleteDownloadBytesSyncMessages,
|
||||
|
||||
// Temporary errors
|
||||
#[error("Device missing secret key for decrypting sync messages")]
|
||||
MissingKeyHash,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum GetTokenError {
|
||||
#[error("Token refresher not initialized")]
|
||||
RefresherNotInitialized,
|
||||
#[error("Token refresher failed to refresh and need to be initialized again")]
|
||||
FailedToRefresh,
|
||||
}
|
||||
|
||||
impl From<Error> for rspc::Error {
|
||||
fn from(e: Error) -> Self {
|
||||
Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GetTokenError> for rspc::Error {
|
||||
fn from(e: GetTokenError) -> Self {
|
||||
Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e)
|
||||
}
|
||||
}
|
||||
331
core/crates/cloud-services/src/key_manager/key_store.rs
Normal file
331
core/crates/cloud-services/src/key_manager/key_store.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
use crate::Error;
|
||||
|
||||
use sd_cloud_schema::{
|
||||
sync::{groups, KeyHash},
|
||||
NodeId, SecretKey as IrohSecretKey,
|
||||
};
|
||||
use sd_crypto::{
|
||||
cloud::{decrypt, encrypt, secret_key::SecretKey},
|
||||
primitives::{EncryptedBlock, OneShotNonce, StreamNonce},
|
||||
CryptoRng,
|
||||
};
|
||||
use sd_utils::error::FileIOError;
|
||||
use tracing::debug;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, VecDeque},
|
||||
fs::Metadata,
|
||||
path::PathBuf,
|
||||
pin::pin,
|
||||
};
|
||||
|
||||
use futures::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufWriter},
|
||||
};
|
||||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||
|
||||
type KeyStack = VecDeque<(KeyHash, SecretKey)>;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct KeyStore {
|
||||
iroh_secret_key: IrohSecretKey,
|
||||
keys: BTreeMap<groups::PubId, KeyStack>,
|
||||
}
|
||||
|
||||
impl KeyStore {
|
||||
pub const fn new(iroh_secret_key: IrohSecretKey) -> Self {
|
||||
Self {
|
||||
iroh_secret_key,
|
||||
keys: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_key(&mut self, group_pub_id: groups::PubId, key: SecretKey) {
|
||||
self.keys.entry(group_pub_id).or_default().push_front((
|
||||
KeyHash(blake3::hash(key.as_ref()).to_hex().to_string()),
|
||||
key,
|
||||
));
|
||||
}
|
||||
|
||||
pub fn add_key_with_hash(
|
||||
&mut self,
|
||||
group_pub_id: groups::PubId,
|
||||
key: SecretKey,
|
||||
key_hash: KeyHash,
|
||||
) {
|
||||
debug!(
|
||||
key_hash = key_hash.0,
|
||||
?group_pub_id,
|
||||
"Added single cloud sync key to key manager"
|
||||
);
|
||||
|
||||
self.keys
|
||||
.entry(group_pub_id)
|
||||
.or_default()
|
||||
.push_front((key_hash, key));
|
||||
}
|
||||
|
||||
pub fn add_many_keys(
|
||||
&mut self,
|
||||
group_pub_id: groups::PubId,
|
||||
keys: impl IntoIterator<Item = SecretKey, IntoIter = impl DoubleEndedIterator<Item = SecretKey>>,
|
||||
) {
|
||||
let group_entry = self.keys.entry(group_pub_id).or_default();
|
||||
|
||||
// We reverse the secret keys as a implementation detail to
|
||||
// keep the keys in the same order as they were added as a stack
|
||||
for key in keys.into_iter().rev() {
|
||||
let key_hash = blake3::hash(key.as_ref()).to_hex().to_string();
|
||||
|
||||
debug!(
|
||||
key_hash,
|
||||
?group_pub_id,
|
||||
"Added cloud sync key to key manager"
|
||||
);
|
||||
|
||||
group_entry.push_front((KeyHash(key_hash), key));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_group(&mut self, group_pub_id: groups::PubId) {
|
||||
self.keys.remove(&group_pub_id);
|
||||
}
|
||||
|
||||
pub fn iroh_secret_key(&self) -> IrohSecretKey {
|
||||
self.iroh_secret_key.clone()
|
||||
}
|
||||
|
||||
pub fn node_id(&self) -> NodeId {
|
||||
self.iroh_secret_key.public()
|
||||
}
|
||||
|
||||
pub fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option<SecretKey> {
|
||||
self.keys.get(&group_pub_id).and_then(|group| {
|
||||
group
|
||||
.iter()
|
||||
.find_map(|(key_hash, key)| (key_hash == hash).then(|| key.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_latest_key(&self, group_pub_id: groups::PubId) -> Option<(KeyHash, SecretKey)> {
|
||||
self.keys
|
||||
.get(&group_pub_id)
|
||||
.and_then(|group| group.front().cloned())
|
||||
}
|
||||
|
||||
pub fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec<SecretKey> {
|
||||
self.keys
|
||||
.get(&group_pub_id)
|
||||
.map(|group| group.iter().map(|(_key_hash, key)| key.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn encrypt(
|
||||
&self,
|
||||
key: &SecretKey,
|
||||
rng: &mut CryptoRng,
|
||||
keys_file_path: &PathBuf,
|
||||
) -> Result<(), Error> {
|
||||
let plain_text_bytes =
|
||||
rmp_serde::to_vec_named(self).map_err(Error::KeyStoreSerialization)?;
|
||||
let mut file = BufWriter::with_capacity(
|
||||
EncryptedBlock::CIPHER_TEXT_SIZE,
|
||||
fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&keys_file_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to open space keys file to encrypt",
|
||||
))
|
||||
})?,
|
||||
);
|
||||
|
||||
if plain_text_bytes.len() < EncryptedBlock::PLAIN_TEXT_SIZE {
|
||||
use encrypt::OneShotEncryption;
|
||||
|
||||
let EncryptedBlock { nonce, cipher_text } = key
|
||||
.encrypt(&plain_text_bytes, rng)
|
||||
.map_err(|e| Error::KeyStoreCrypto {
|
||||
source: e,
|
||||
context: "Failed to oneshot encrypt key store",
|
||||
})?;
|
||||
|
||||
file.write_all(nonce.as_slice()).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to write space keys file oneshot nonce",
|
||||
))
|
||||
})?;
|
||||
|
||||
file.write_all(cipher_text.as_slice()).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to write space keys file oneshot cipher text",
|
||||
))
|
||||
})?;
|
||||
} else {
|
||||
use encrypt::StreamEncryption;
|
||||
|
||||
let (nonce, stream) = key.encrypt(plain_text_bytes.as_slice(), rng);
|
||||
|
||||
file.write_all(nonce.as_slice()).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to write space keys file stream nonce",
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut stream = pin!(stream);
|
||||
while let Some(res) = stream.next().await {
|
||||
file.write_all(&res.map_err(|e| Error::KeyStoreCrypto {
|
||||
source: e,
|
||||
context: "Failed to stream encrypt key store",
|
||||
})?)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to write space keys file stream cipher text",
|
||||
))
|
||||
})?;
|
||||
}
|
||||
};
|
||||
|
||||
file.flush().await.map_err(|e| {
|
||||
FileIOError::from((&keys_file_path, e, "Failed to flush space keys file")).into()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn decrypt(
|
||||
key: &SecretKey,
|
||||
metadata: Metadata,
|
||||
keys_file_path: &PathBuf,
|
||||
) -> Result<Self, Error> {
|
||||
let mut file = fs::File::open(&keys_file_path).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
keys_file_path,
|
||||
e,
|
||||
"Failed to open space keys file to decrypt",
|
||||
))
|
||||
})?;
|
||||
|
||||
let usize_file_len =
|
||||
usize::try_from(metadata.len()).expect("Failed to convert metadata length to usize");
|
||||
|
||||
let key_store_bytes =
|
||||
if usize_file_len <= EncryptedBlock::CIPHER_TEXT_SIZE + size_of::<OneShotNonce>() {
|
||||
use decrypt::OneShotDecryption;
|
||||
|
||||
let mut nonce = OneShotNonce::default();
|
||||
|
||||
file.read_exact(&mut nonce).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
keys_file_path,
|
||||
e,
|
||||
"Failed to read space keys file oneshot nonce",
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut cipher_text = vec![0u8; usize_file_len - size_of::<OneShotNonce>()];
|
||||
|
||||
file.read_exact(&mut cipher_text).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
keys_file_path,
|
||||
e,
|
||||
"Failed to read space keys file oneshot cipher text",
|
||||
))
|
||||
})?;
|
||||
|
||||
key.decrypt_owned(&EncryptedBlock { nonce, cipher_text })
|
||||
.map_err(|e| Error::KeyStoreCrypto {
|
||||
source: e,
|
||||
context: "Failed to oneshot decrypt space keys file",
|
||||
})?
|
||||
} else {
|
||||
use decrypt::StreamDecryption;
|
||||
|
||||
let mut nonce = StreamNonce::default();
|
||||
|
||||
let mut key_store_bytes = Vec::with_capacity(
|
||||
(usize_file_len - size_of::<StreamNonce>()) / EncryptedBlock::CIPHER_TEXT_SIZE
|
||||
* EncryptedBlock::PLAIN_TEXT_SIZE,
|
||||
);
|
||||
|
||||
file.read_exact(&mut nonce).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
keys_file_path,
|
||||
e,
|
||||
"Failed to read space keys file stream nonce",
|
||||
))
|
||||
})?;
|
||||
|
||||
key.decrypt(&nonce, &mut file, &mut key_store_bytes)
|
||||
.await
|
||||
.map_err(|e| Error::KeyStoreCrypto {
|
||||
source: e,
|
||||
context: "Failed to stream decrypt space keys file",
|
||||
})?;
|
||||
|
||||
key_store_bytes
|
||||
};
|
||||
|
||||
let this = rmp_serde::from_slice::<Self>(&key_store_bytes)
|
||||
.map_err(Error::KeyStoreDeserialization)?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use std::fmt::Write;
|
||||
let mut key_hashes_log = String::new();
|
||||
|
||||
this.keys.iter().for_each(|(group_pub_id, key_stack)| {
|
||||
writeln!(
|
||||
key_hashes_log,
|
||||
"Group: {group_pub_id:?} => KeyHashes: {:?}",
|
||||
key_stack
|
||||
.iter()
|
||||
.map(|(KeyHash(key_hash), _)| key_hash)
|
||||
.collect::<Vec<_>>()
|
||||
)
|
||||
.expect("Failed to write to key hashes log");
|
||||
});
|
||||
|
||||
tracing::info!("Loaded key hashes: {key_hashes_log}");
|
||||
}
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
}
|
||||
|
||||
/// Zeroize our secret keys and scrambles up iroh's secret key that doesn't implement zeroize
|
||||
impl Zeroize for KeyStore {
|
||||
fn zeroize(&mut self) {
|
||||
self.iroh_secret_key = IrohSecretKey::generate();
|
||||
self.keys.values_mut().for_each(|group| {
|
||||
group
|
||||
.iter_mut()
|
||||
.map(|(_key_hash, key)| key)
|
||||
.for_each(Zeroize::zeroize);
|
||||
});
|
||||
self.keys = BTreeMap::new();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for KeyStore {
|
||||
fn drop(&mut self) {
|
||||
self.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl ZeroizeOnDrop for KeyStore {}
|
||||
183
core/crates/cloud-services/src/key_manager/mod.rs
Normal file
183
core/crates/cloud-services/src/key_manager/mod.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use crate::Error;
|
||||
|
||||
use sd_cloud_schema::{
|
||||
sync::{groups, KeyHash},
|
||||
NodeId, SecretKey as IrohSecretKey,
|
||||
};
|
||||
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng};
|
||||
use sd_utils::error::FileIOError;
|
||||
|
||||
use std::{
|
||||
fmt,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use tokio::{fs, sync::RwLock};
|
||||
|
||||
mod key_store;
|
||||
|
||||
use key_store::KeyStore;
|
||||
|
||||
const KEY_FILE_NAME: &str = "space.keys";
|
||||
|
||||
pub struct KeyManager {
|
||||
master_key: SecretKey,
|
||||
keys_file_path: PathBuf,
|
||||
store: RwLock<KeyStore>,
|
||||
}
|
||||
|
||||
impl KeyManager {
|
||||
pub async fn new(
|
||||
master_key: SecretKey,
|
||||
iroh_secret_key: IrohSecretKey,
|
||||
data_directory: impl AsRef<Path> + Send,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<Self, Error> {
|
||||
async fn inner(
|
||||
master_key: SecretKey,
|
||||
iroh_secret_key: IrohSecretKey,
|
||||
keys_file_path: PathBuf,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<KeyManager, Error> {
|
||||
let store = KeyStore::new(iroh_secret_key);
|
||||
store.encrypt(&master_key, rng, &keys_file_path).await?;
|
||||
|
||||
Ok(KeyManager {
|
||||
master_key,
|
||||
keys_file_path,
|
||||
store: RwLock::new(store),
|
||||
})
|
||||
}
|
||||
|
||||
inner(
|
||||
master_key,
|
||||
iroh_secret_key,
|
||||
data_directory.as_ref().join(KEY_FILE_NAME),
|
||||
rng,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load(
|
||||
master_key: SecretKey,
|
||||
data_directory: impl AsRef<Path> + Send,
|
||||
) -> Result<Self, Error> {
|
||||
async fn inner(
|
||||
master_key: SecretKey,
|
||||
keys_file_path: PathBuf,
|
||||
) -> Result<KeyManager, Error> {
|
||||
Ok(KeyManager {
|
||||
store: RwLock::new(
|
||||
KeyStore::decrypt(
|
||||
&master_key,
|
||||
fs::metadata(&keys_file_path).await.map_err(|e| {
|
||||
FileIOError::from((
|
||||
&keys_file_path,
|
||||
e,
|
||||
"Failed to read space keys file",
|
||||
))
|
||||
})?,
|
||||
&keys_file_path,
|
||||
)
|
||||
.await?,
|
||||
),
|
||||
master_key,
|
||||
keys_file_path,
|
||||
})
|
||||
}
|
||||
|
||||
inner(master_key, data_directory.as_ref().join(KEY_FILE_NAME)).await
|
||||
}
|
||||
|
||||
pub async fn iroh_secret_key(&self) -> IrohSecretKey {
|
||||
self.store.read().await.iroh_secret_key()
|
||||
}
|
||||
|
||||
pub async fn node_id(&self) -> NodeId {
|
||||
self.store.read().await.node_id()
|
||||
}
|
||||
|
||||
pub async fn add_key(
|
||||
&self,
|
||||
group_pub_id: groups::PubId,
|
||||
key: SecretKey,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<(), Error> {
|
||||
let mut store = self.store.write().await;
|
||||
store.add_key(group_pub_id, key);
|
||||
// Keeping the write lock here, this way we ensure that we can't corrupt the file
|
||||
store
|
||||
.encrypt(&self.master_key, rng, &self.keys_file_path)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn add_key_with_hash(
|
||||
&self,
|
||||
group_pub_id: groups::PubId,
|
||||
key: SecretKey,
|
||||
key_hash: KeyHash,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<(), Error> {
|
||||
let mut store = self.store.write().await;
|
||||
store.add_key_with_hash(group_pub_id, key, key_hash);
|
||||
// Keeping the write lock here, this way we ensure that we can't corrupt the file
|
||||
store
|
||||
.encrypt(&self.master_key, rng, &self.keys_file_path)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn remove_group(
|
||||
&self,
|
||||
group_pub_id: groups::PubId,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<(), Error> {
|
||||
let mut store = self.store.write().await;
|
||||
store.remove_group(group_pub_id);
|
||||
// Keeping the write lock here, this way we ensure that we can't corrupt the file
|
||||
store
|
||||
.encrypt(&self.master_key, rng, &self.keys_file_path)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn add_many_keys(
|
||||
&self,
|
||||
group_pub_id: groups::PubId,
|
||||
keys: impl IntoIterator<
|
||||
Item = SecretKey,
|
||||
IntoIter = impl DoubleEndedIterator<Item = SecretKey> + Send,
|
||||
> + Send,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<(), Error> {
|
||||
let mut store = self.store.write().await;
|
||||
store.add_many_keys(group_pub_id, keys);
|
||||
// Keeping the write lock here, this way we ensure that we can't corrupt the file
|
||||
store
|
||||
.encrypt(&self.master_key, rng, &self.keys_file_path)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_latest_key(
|
||||
&self,
|
||||
group_pub_id: groups::PubId,
|
||||
) -> Option<(KeyHash, SecretKey)> {
|
||||
self.store.read().await.get_latest_key(group_pub_id)
|
||||
}
|
||||
|
||||
pub async fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option<SecretKey> {
|
||||
self.store.read().await.get_key(group_pub_id, hash)
|
||||
}
|
||||
|
||||
pub async fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec<SecretKey> {
|
||||
self.store.read().await.get_group_keys(group_pub_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for KeyManager {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("KeyManager")
|
||||
.field("master_key", &"[REDACTED]")
|
||||
.field("keys_file_path", &self.keys_file_path)
|
||||
.field("store", &"[REDACTED]")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
55
core/crates/cloud-services/src/lib.rs
Normal file
55
core/crates/cloud-services/src/lib.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
#![recursion_limit = "256"]
|
||||
#![warn(
|
||||
clippy::all,
|
||||
clippy::pedantic,
|
||||
clippy::correctness,
|
||||
clippy::perf,
|
||||
clippy::style,
|
||||
clippy::suspicious,
|
||||
clippy::complexity,
|
||||
clippy::nursery,
|
||||
clippy::unwrap_used,
|
||||
unused_qualifications,
|
||||
rust_2018_idioms,
|
||||
trivial_casts,
|
||||
trivial_numeric_casts,
|
||||
unused_allocation,
|
||||
clippy::unnecessary_cast,
|
||||
clippy::cast_lossless,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::cast_possible_wrap,
|
||||
clippy::cast_precision_loss,
|
||||
clippy::cast_sign_loss,
|
||||
clippy::dbg_macro,
|
||||
clippy::deprecated_cfg_attr,
|
||||
clippy::separated_literal_suffix,
|
||||
deprecated
|
||||
)]
|
||||
#![forbid(deprecated_in_future)]
|
||||
#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)]
|
||||
|
||||
mod error;
|
||||
|
||||
mod client;
|
||||
mod key_manager;
|
||||
mod p2p;
|
||||
mod sync;
|
||||
mod token_refresher;
|
||||
|
||||
pub use client::CloudServices;
|
||||
pub use error::{Error, GetTokenError};
|
||||
pub use key_manager::KeyManager;
|
||||
pub use p2p::{
|
||||
CloudP2P, JoinSyncGroupResponse, JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse,
|
||||
};
|
||||
pub use sync::{
|
||||
declare_actors as declare_cloud_sync, SyncActors as CloudSyncActors,
|
||||
SyncActorsState as CloudSyncActorsState,
|
||||
};
|
||||
|
||||
// Re-exports
|
||||
pub use quic_rpc::transport::quinn::QuinnConnection;
|
||||
|
||||
// Export URL for the auth server
|
||||
pub const AUTH_SERVER_URL: &str = "https://auth.spacedrive.com";
|
||||
// pub const AUTH_SERVER_URL: &str = "http://localhost:9420";
|
||||
242
core/crates/cloud-services/src/p2p/mod.rs
Normal file
242
core/crates/cloud-services/src/p2p/mod.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
use crate::{sync::ReceiveAndIngestNotifiers, CloudServices, Error};
|
||||
|
||||
use sd_cloud_schema::{
|
||||
cloud_p2p::{authorize_new_device_in_sync_group, CloudP2PALPN, CloudP2PError},
|
||||
devices::{self, Device},
|
||||
libraries,
|
||||
sync::groups::{self, GroupWithDevices},
|
||||
SecretKey as IrohSecretKey,
|
||||
};
|
||||
use sd_crypto::{CryptoRng, SeedableRng};
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use iroh_net::{
|
||||
discovery::{
|
||||
dns::DnsDiscovery, local_swarm_discovery::LocalSwarmDiscovery, pkarr::dht::DhtDiscovery,
|
||||
ConcurrentDiscovery, Discovery,
|
||||
},
|
||||
relay::{RelayMap, RelayMode, RelayUrl},
|
||||
Endpoint, NodeId,
|
||||
};
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{spawn, sync::oneshot, time::sleep};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
mod new_sync_messages_notifier;
|
||||
mod runner;
|
||||
|
||||
use runner::Runner;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JoinedLibraryCreateArgs {
|
||||
pub pub_id: libraries::PubId,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)]
|
||||
#[serde(transparent)]
|
||||
#[repr(transparent)]
|
||||
#[specta(rename = "CloudP2PTicket")]
|
||||
pub struct Ticket(u64);
|
||||
|
||||
#[derive(Debug, Serialize, specta::Type)]
|
||||
#[serde(tag = "kind", content = "data")]
|
||||
#[specta(rename = "CloudP2PNotifyUser")]
|
||||
pub enum NotifyUser {
|
||||
ReceivedJoinSyncGroupRequest {
|
||||
ticket: Ticket,
|
||||
asking_device: Device,
|
||||
sync_group: GroupWithDevices,
|
||||
},
|
||||
ReceivedJoinSyncGroupResponse {
|
||||
response: JoinSyncGroupResponse,
|
||||
sync_group: GroupWithDevices,
|
||||
},
|
||||
SendingJoinSyncGroupResponseError {
|
||||
error: JoinSyncGroupError,
|
||||
sync_group: GroupWithDevices,
|
||||
},
|
||||
TimedOutJoinRequest {
|
||||
device: Device,
|
||||
succeeded: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)]
|
||||
pub enum JoinSyncGroupError {
|
||||
Communication,
|
||||
InternalServer,
|
||||
Auth,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, specta::Type)]
|
||||
pub enum JoinSyncGroupResponse {
|
||||
Accepted { authorizor_device: Device },
|
||||
Failed(CloudP2PError),
|
||||
CriticalError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, specta::Type)]
|
||||
pub struct BasicLibraryCreationArgs {
|
||||
pub id: libraries::PubId,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, specta::Type)]
|
||||
#[serde(tag = "kind", content = "data")]
|
||||
#[specta(rename = "CloudP2PUserResponse")]
|
||||
pub enum UserResponse {
|
||||
AcceptDeviceInSyncGroup {
|
||||
ticket: Ticket,
|
||||
accepted: Option<BasicLibraryCreationArgs>,
|
||||
},
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CloudP2P {
|
||||
msgs_tx: flume::Sender<runner::Message>,
|
||||
}
|
||||
|
||||
impl CloudP2P {
|
||||
pub async fn new(
|
||||
current_device_pub_id: devices::PubId,
|
||||
cloud_services: &CloudServices,
|
||||
mut rng: CryptoRng,
|
||||
iroh_secret_key: IrohSecretKey,
|
||||
dns_origin_domain: String,
|
||||
dns_pkarr_url: Url,
|
||||
relay_url: RelayUrl,
|
||||
) -> Result<Self, Error> {
|
||||
let dht_discovery = DhtDiscovery::builder()
|
||||
.secret_key(iroh_secret_key.clone())
|
||||
.pkarr_relay(dns_pkarr_url)
|
||||
.build()
|
||||
.map_err(Error::DhtDiscoveryInit)?;
|
||||
|
||||
let endpoint = Endpoint::builder()
|
||||
.alpns(vec![CloudP2PALPN::LATEST.to_vec()])
|
||||
.discovery(Box::new(ConcurrentDiscovery::from_services(vec![
|
||||
Box::new(DnsDiscovery::new(dns_origin_domain)),
|
||||
Box::new(
|
||||
LocalSwarmDiscovery::new(iroh_secret_key.public())
|
||||
.map_err(Error::LocalSwarmDiscoveryInit)?,
|
||||
),
|
||||
Box::new(dht_discovery.clone()),
|
||||
])))
|
||||
.secret_key(iroh_secret_key)
|
||||
.relay_mode(RelayMode::Custom(RelayMap::from_url(relay_url)))
|
||||
.bind()
|
||||
.await
|
||||
.map_err(Error::CreateCloudP2PEndpoint)?;
|
||||
|
||||
spawn({
|
||||
let endpoint = endpoint.clone();
|
||||
async move {
|
||||
loop {
|
||||
let Ok(node_addr) = endpoint.node_addr().await.map_err(|e| {
|
||||
warn!(?e, "Failed to get direct addresses to force publish on DHT");
|
||||
}) else {
|
||||
sleep(Duration::from_secs(5)).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
debug!("Force publishing peer on DHT");
|
||||
return dht_discovery.publish(&node_addr.info);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (msgs_tx, msgs_rx) = flume::bounded(16);
|
||||
|
||||
spawn({
|
||||
let runner = Runner::new(
|
||||
current_device_pub_id,
|
||||
cloud_services,
|
||||
msgs_tx.clone(),
|
||||
endpoint,
|
||||
)
|
||||
.await?;
|
||||
let user_response_rx = cloud_services.user_response_rx.clone();
|
||||
|
||||
async move {
|
||||
// All cloned runners share a single state with internal mutability
|
||||
while let Err(e) = spawn(runner.clone().run(
|
||||
msgs_rx.clone(),
|
||||
user_response_rx.clone(),
|
||||
CryptoRng::from_seed(rng.generate_fixed()),
|
||||
))
|
||||
.await
|
||||
{
|
||||
if e.is_panic() {
|
||||
error!("Cloud P2P runner panicked");
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self { msgs_tx })
|
||||
}
|
||||
|
||||
/// Requests the device with the given connection ID asking for permission to the current device
|
||||
/// to join the sync group
|
||||
///
|
||||
/// # Panics
|
||||
/// Will panic if the actor channel is closed, which should never happen
|
||||
pub async fn request_join_sync_group(
|
||||
&self,
|
||||
devices_in_group: Vec<(devices::PubId, NodeId)>,
|
||||
req: authorize_new_device_in_sync_group::Request,
|
||||
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
|
||||
) {
|
||||
self.msgs_tx
|
||||
.send_async(runner::Message::Request(runner::Request::JoinSyncGroup {
|
||||
req,
|
||||
devices_in_group,
|
||||
tx,
|
||||
}))
|
||||
.await
|
||||
.expect("Channel closed");
|
||||
}
|
||||
|
||||
/// Register a notifier for the desired sync group, which will notify the receiver actor when
|
||||
/// new sync messages arrive through cloud p2p notification requests.
|
||||
///
|
||||
/// # Panics
|
||||
/// Will panic if the actor channel is closed, which should never happen
|
||||
pub async fn register_sync_messages_receiver_notifier(
|
||||
&self,
|
||||
sync_group_pub_id: groups::PubId,
|
||||
notifier: Arc<ReceiveAndIngestNotifiers>,
|
||||
) {
|
||||
self.msgs_tx
|
||||
.send_async(runner::Message::RegisterSyncMessageNotifier((
|
||||
sync_group_pub_id,
|
||||
notifier,
|
||||
)))
|
||||
.await
|
||||
.expect("Channel closed");
|
||||
}
|
||||
|
||||
/// Emit a notification that new sync messages were sent to cloud, so other devices should pull
|
||||
/// them as soon as possible.
|
||||
///
|
||||
/// # Panics
|
||||
/// Will panic if the actor channel is closed, which should never happen
|
||||
pub async fn notify_new_sync_messages(&self, group_pub_id: groups::PubId) {
|
||||
self.msgs_tx
|
||||
.send_async(runner::Message::NotifyPeersSyncMessages(group_pub_id))
|
||||
.await
|
||||
.expect("Channel closed");
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CloudP2P {
|
||||
fn drop(&mut self) {
|
||||
self.msgs_tx.send(runner::Message::Stop).ok();
|
||||
}
|
||||
}
|
||||
156
core/crates/cloud-services/src/p2p/new_sync_messages_notifier.rs
Normal file
156
core/crates/cloud-services/src/p2p/new_sync_messages_notifier.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use crate::{token_refresher::TokenRefresher, Error};
|
||||
|
||||
use sd_cloud_schema::{
|
||||
cloud_p2p::{Client, CloudP2PALPN, Service},
|
||||
devices,
|
||||
sync::groups,
|
||||
};
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_concurrency::future::Join;
|
||||
use iroh_net::{Endpoint, NodeId};
|
||||
use quic_rpc::{transport::quinn::QuinnConnection, RpcClient};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
|
||||
use super::runner::Message;
|
||||
|
||||
const CACHED_MAX_DURATION: Duration = Duration::from_secs(60 * 5);
|
||||
|
||||
pub async fn dispatch_notifier(
|
||||
group_pub_id: groups::PubId,
|
||||
device_pub_id: devices::PubId,
|
||||
devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>,
|
||||
msgs_tx: flume::Sender<Message>,
|
||||
cloud_services: sd_cloud_schema::Client<
|
||||
QuinnConnection<sd_cloud_schema::Service>,
|
||||
sd_cloud_schema::Service,
|
||||
>,
|
||||
token_refresher: TokenRefresher,
|
||||
endpoint: Endpoint,
|
||||
) {
|
||||
match notify_peers(
|
||||
group_pub_id,
|
||||
device_pub_id,
|
||||
devices,
|
||||
cloud_services,
|
||||
token_refresher,
|
||||
endpoint,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((true, devices)) => {
|
||||
if msgs_tx
|
||||
.send_async(Message::UpdateCachedDevices((group_pub_id, devices)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
warn!("Failed to send update cached devices message to update cached devices");
|
||||
}
|
||||
}
|
||||
|
||||
Ok((false, _)) => {}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to notify peers");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(cloud_services, token_refresher, endpoint))]
|
||||
async fn notify_peers(
|
||||
group_pub_id: groups::PubId,
|
||||
device_pub_id: devices::PubId,
|
||||
devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>,
|
||||
cloud_services: sd_cloud_schema::Client<
|
||||
QuinnConnection<sd_cloud_schema::Service>,
|
||||
sd_cloud_schema::Service,
|
||||
>,
|
||||
token_refresher: TokenRefresher,
|
||||
endpoint: Endpoint,
|
||||
) -> Result<(bool, Vec<(devices::PubId, NodeId)>), Error> {
|
||||
let (devices, update_cache) = match devices {
|
||||
Some((when, devices)) if when.elapsed() < CACHED_MAX_DURATION => (devices, false),
|
||||
_ => {
|
||||
debug!("Fetching devices connection ids for group");
|
||||
let groups::get::Response(groups::get::ResponseKind::DevicesConnectionIds(devices)) =
|
||||
cloud_services
|
||||
.sync()
|
||||
.groups()
|
||||
.get(groups::get::Request {
|
||||
access_token: token_refresher.get_access_token().await?,
|
||||
pub_id: group_pub_id,
|
||||
kind: groups::get::RequestKind::DevicesConnectionIds,
|
||||
})
|
||||
.await??
|
||||
else {
|
||||
unreachable!("Only DevicesConnectionIds response is expected, as we requested it");
|
||||
};
|
||||
|
||||
(devices, true)
|
||||
}
|
||||
};
|
||||
|
||||
send_notifications(group_pub_id, device_pub_id, &devices, &endpoint).await;
|
||||
|
||||
Ok((update_cache, devices))
|
||||
}
|
||||
|
||||
async fn send_notifications(
|
||||
group_pub_id: groups::PubId,
|
||||
device_pub_id: devices::PubId,
|
||||
devices: &[(devices::PubId, NodeId)],
|
||||
endpoint: &Endpoint,
|
||||
) {
|
||||
devices
|
||||
.iter()
|
||||
.filter(|(peer_device_pub_id, _)| *peer_device_pub_id != device_pub_id)
|
||||
.map(|(peer_device_pub_id, connection_id)| async move {
|
||||
if let Err(e) =
|
||||
connect_and_send_notification(group_pub_id, device_pub_id, connection_id, endpoint)
|
||||
.await
|
||||
{
|
||||
// Using just a debug log here because we don't want to spam the logs with
|
||||
// every single notification failure, as this is more a nice to have feature than a
|
||||
// critical one
|
||||
debug!(?e, %peer_device_pub_id, "Failed to send new sync messages notification to peer");
|
||||
} else {
|
||||
debug!(%peer_device_pub_id, "Sent new sync messages notification to peer");
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join()
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn connect_and_send_notification(
|
||||
group_pub_id: groups::PubId,
|
||||
device_pub_id: devices::PubId,
|
||||
connection_id: &NodeId,
|
||||
endpoint: &Endpoint,
|
||||
) -> Result<(), Error> {
|
||||
let client = Client::new(RpcClient::new(QuinnConnection::<Service>::from_connection(
|
||||
endpoint
|
||||
.connect(*connection_id, CloudP2PALPN::LATEST)
|
||||
.await
|
||||
.map_err(Error::ConnectToCloudP2PNode)?,
|
||||
)));
|
||||
|
||||
if let Err(e) = client
|
||||
.notify_new_sync_messages(
|
||||
sd_cloud_schema::cloud_p2p::notify_new_sync_messages::Request {
|
||||
sync_group_pub_id: group_pub_id,
|
||||
device_pub_id,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
{
|
||||
warn!(
|
||||
?e,
|
||||
"This route shouldn't return an error, it's just a notification",
|
||||
);
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
651
core/crates/cloud-services/src/p2p/runner.rs
Normal file
651
core/crates/cloud-services/src/p2p/runner.rs
Normal file
@@ -0,0 +1,651 @@
|
||||
use crate::{
|
||||
p2p::JoinSyncGroupError, sync::ReceiveAndIngestNotifiers, token_refresher::TokenRefresher,
|
||||
CloudServices, Error, KeyManager,
|
||||
};
|
||||
|
||||
use sd_cloud_schema::{
|
||||
cloud_p2p::{
|
||||
self, authorize_new_device_in_sync_group, notify_new_sync_messages, Client, CloudP2PALPN,
|
||||
CloudP2PError, Service,
|
||||
},
|
||||
devices::{self, Device},
|
||||
sync::groups,
|
||||
};
|
||||
use sd_crypto::{CryptoRng, SeedableRng};
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
pin::pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use flume::SendError;
|
||||
use futures::StreamExt;
|
||||
use futures_concurrency::stream::Merge;
|
||||
use iroh_net::{Endpoint, NodeId};
|
||||
use quic_rpc::{
|
||||
server::{Accepting, RpcChannel, RpcServerError},
|
||||
transport::quinn::{QuinnConnection, QuinnServerEndpoint},
|
||||
RpcClient, RpcServer,
|
||||
};
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::{oneshot, Mutex},
|
||||
task::JoinHandle,
|
||||
time::{interval, Instant, MissedTickBehavior},
|
||||
};
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::{
|
||||
new_sync_messages_notifier::dispatch_notifier, BasicLibraryCreationArgs, JoinSyncGroupResponse,
|
||||
JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse,
|
||||
};
|
||||
|
||||
const TEN_SECONDS: Duration = Duration::from_secs(10);
|
||||
const FIVE_MINUTES: Duration = Duration::from_secs(60 * 5);
|
||||
|
||||
#[allow(clippy::large_enum_variant)] // Ignoring because the enum Stop variant will only happen a single time ever
|
||||
pub enum Message {
|
||||
Request(Request),
|
||||
RegisterSyncMessageNotifier((groups::PubId, Arc<ReceiveAndIngestNotifiers>)),
|
||||
NotifyPeersSyncMessages(groups::PubId),
|
||||
UpdateCachedDevices((groups::PubId, Vec<(devices::PubId, NodeId)>)),
|
||||
Stop,
|
||||
}
|
||||
|
||||
pub enum Request {
|
||||
JoinSyncGroup {
|
||||
req: authorize_new_device_in_sync_group::Request,
|
||||
devices_in_group: Vec<(devices::PubId, NodeId)>,
|
||||
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
|
||||
},
|
||||
}
|
||||
|
||||
/// We use internal mutability here, but don't worry because there will always be a single
|
||||
/// [`Runner`] running at a time, so the lock is never contended
|
||||
pub struct Runner {
|
||||
current_device_pub_id: devices::PubId,
|
||||
token_refresher: TokenRefresher,
|
||||
cloud_services: sd_cloud_schema::Client<
|
||||
QuinnConnection<sd_cloud_schema::Service>,
|
||||
sd_cloud_schema::Service,
|
||||
>,
|
||||
msgs_tx: flume::Sender<Message>,
|
||||
endpoint: Endpoint,
|
||||
key_manager: Arc<KeyManager>,
|
||||
ticketer: Arc<AtomicU64>,
|
||||
notify_user_tx: flume::Sender<NotifyUser>,
|
||||
sync_messages_receiver_notifiers_map:
|
||||
Arc<DashMap<groups::PubId, Arc<ReceiveAndIngestNotifiers>>>,
|
||||
pending_sync_group_join_requests: Arc<Mutex<HashMap<Ticket, PendingSyncGroupJoin>>>,
|
||||
cached_devices_per_group: HashMap<groups::PubId, (Instant, Vec<(devices::PubId, NodeId)>)>,
|
||||
timeout_checker_buffer: Vec<(Ticket, PendingSyncGroupJoin)>,
|
||||
}
|
||||
|
||||
impl Clone for Runner {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
current_device_pub_id: self.current_device_pub_id,
|
||||
token_refresher: self.token_refresher.clone(),
|
||||
cloud_services: self.cloud_services.clone(),
|
||||
msgs_tx: self.msgs_tx.clone(),
|
||||
endpoint: self.endpoint.clone(),
|
||||
key_manager: Arc::clone(&self.key_manager),
|
||||
ticketer: Arc::clone(&self.ticketer),
|
||||
notify_user_tx: self.notify_user_tx.clone(),
|
||||
sync_messages_receiver_notifiers_map: Arc::clone(
|
||||
&self.sync_messages_receiver_notifiers_map,
|
||||
),
|
||||
pending_sync_group_join_requests: Arc::clone(&self.pending_sync_group_join_requests),
|
||||
// Just cache the devices and their node_ids per group
|
||||
cached_devices_per_group: HashMap::new(),
|
||||
// This one is a temporary buffer only used for timeout checker
|
||||
timeout_checker_buffer: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingSyncGroupJoin {
|
||||
channel: RpcChannel<Service, QuinnServerEndpoint<Service>>,
|
||||
request: authorize_new_device_in_sync_group::Request,
|
||||
this_device: Device,
|
||||
since: Instant,
|
||||
}
|
||||
|
||||
impl Runner {
|
||||
pub async fn new(
|
||||
current_device_pub_id: devices::PubId,
|
||||
cloud_services: &CloudServices,
|
||||
msgs_tx: flume::Sender<Message>,
|
||||
endpoint: Endpoint,
|
||||
) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
current_device_pub_id,
|
||||
token_refresher: cloud_services.token_refresher.clone(),
|
||||
cloud_services: cloud_services.client().await?,
|
||||
msgs_tx,
|
||||
endpoint,
|
||||
key_manager: cloud_services.key_manager().await?,
|
||||
ticketer: Arc::default(),
|
||||
notify_user_tx: cloud_services.notify_user_tx.clone(),
|
||||
sync_messages_receiver_notifiers_map: Arc::default(),
|
||||
pending_sync_group_join_requests: Arc::default(),
|
||||
cached_devices_per_group: HashMap::new(),
|
||||
timeout_checker_buffer: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
mut self,
|
||||
msgs_rx: flume::Receiver<Message>,
|
||||
user_response_rx: flume::Receiver<UserResponse>,
|
||||
mut rng: CryptoRng,
|
||||
) {
|
||||
// Ignoring because this is only used internally and I think that boxing will be more expensive than wasting
|
||||
// some extra bytes for smaller variants
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum StreamMessage {
|
||||
AcceptResult(
|
||||
Result<
|
||||
Accepting<Service, QuinnServerEndpoint<Service>>,
|
||||
RpcServerError<QuinnServerEndpoint<Service>>,
|
||||
>,
|
||||
),
|
||||
Message(Message),
|
||||
UserResponse(UserResponse),
|
||||
Tick,
|
||||
}
|
||||
|
||||
let mut ticker = interval(TEN_SECONDS);
|
||||
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
// FIXME(@fogodev): Update this function to use iroh-net transport instead of quinn
|
||||
// when it's implemented
|
||||
let (server, server_handle) = setup_server_endpoint(self.endpoint.clone());
|
||||
|
||||
let mut msg_stream = pin!((
|
||||
async_stream::stream! {
|
||||
loop {
|
||||
yield StreamMessage::AcceptResult(server.accept().await);
|
||||
}
|
||||
},
|
||||
msgs_rx.stream().map(StreamMessage::Message),
|
||||
user_response_rx.stream().map(StreamMessage::UserResponse),
|
||||
IntervalStream::new(ticker).map(|_| StreamMessage::Tick),
|
||||
)
|
||||
.merge());
|
||||
|
||||
while let Some(msg) = msg_stream.next().await {
|
||||
match msg {
|
||||
StreamMessage::AcceptResult(Ok(accepting)) => {
|
||||
let Ok((request, channel)) = accepting.read_first().await.map_err(|e| {
|
||||
error!(?e, "Failed to read first request from a new connection;");
|
||||
}) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
self.handle_request(request, channel).await;
|
||||
}
|
||||
|
||||
StreamMessage::AcceptResult(Err(e)) => {
|
||||
// TODO(@fogodev): Maybe report this error to the user on a toast?
|
||||
error!(?e, "Error accepting connection;");
|
||||
}
|
||||
|
||||
StreamMessage::Message(Message::Request(Request::JoinSyncGroup {
|
||||
req,
|
||||
devices_in_group,
|
||||
tx,
|
||||
})) => self.dispatch_join_requests(req, devices_in_group, &mut rng, tx),
|
||||
|
||||
StreamMessage::Message(Message::RegisterSyncMessageNotifier((
|
||||
group_pub_id,
|
||||
notifier,
|
||||
))) => {
|
||||
self.sync_messages_receiver_notifiers_map
|
||||
.insert(group_pub_id, notifier);
|
||||
}
|
||||
|
||||
StreamMessage::Message(Message::NotifyPeersSyncMessages(group_pub_id)) => {
|
||||
spawn(dispatch_notifier(
|
||||
group_pub_id,
|
||||
self.current_device_pub_id,
|
||||
self.cached_devices_per_group.get(&group_pub_id).cloned(),
|
||||
self.msgs_tx.clone(),
|
||||
self.cloud_services.clone(),
|
||||
self.token_refresher.clone(),
|
||||
self.endpoint.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
StreamMessage::Message(Message::UpdateCachedDevices((
|
||||
group_pub_id,
|
||||
devices_connections_ids,
|
||||
))) => {
|
||||
self.cached_devices_per_group
|
||||
.insert(group_pub_id, (Instant::now(), devices_connections_ids));
|
||||
}
|
||||
|
||||
StreamMessage::UserResponse(UserResponse::AcceptDeviceInSyncGroup {
|
||||
ticket,
|
||||
accepted,
|
||||
}) => {
|
||||
self.handle_join_response(ticket, accepted).await;
|
||||
}
|
||||
|
||||
StreamMessage::Tick => self.tick().await,
|
||||
|
||||
StreamMessage::Message(Message::Stop) => {
|
||||
server_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_join_requests(
|
||||
&self,
|
||||
req: authorize_new_device_in_sync_group::Request,
|
||||
devices_in_group: Vec<(devices::PubId, NodeId)>,
|
||||
rng: &mut CryptoRng,
|
||||
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
|
||||
) {
|
||||
async fn inner(
|
||||
key_manager: Arc<KeyManager>,
|
||||
endpoint: Endpoint,
|
||||
mut rng: CryptoRng,
|
||||
req: authorize_new_device_in_sync_group::Request,
|
||||
devices_in_group: Vec<(devices::PubId, NodeId)>,
|
||||
tx: oneshot::Sender<JoinedLibraryCreateArgs>,
|
||||
) -> Result<JoinSyncGroupResponse, Error> {
|
||||
let group_pub_id = req.sync_group.pub_id;
|
||||
loop {
|
||||
let client =
|
||||
match connect_to_first_available_client(&endpoint, &devices_in_group).await {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
return Ok(JoinSyncGroupResponse::Failed(e));
|
||||
}
|
||||
};
|
||||
|
||||
match client
|
||||
.authorize_new_device_in_sync_group(req.clone())
|
||||
.await?
|
||||
{
|
||||
Ok(authorize_new_device_in_sync_group::Response {
|
||||
authorizor_device,
|
||||
keys,
|
||||
library_pub_id,
|
||||
library_name,
|
||||
library_description,
|
||||
}) => {
|
||||
debug!(
|
||||
device_pub_id = %authorizor_device.pub_id,
|
||||
%group_pub_id,
|
||||
keys_count = keys.len(),
|
||||
%library_pub_id,
|
||||
library_name,
|
||||
"Received join sync group response"
|
||||
);
|
||||
|
||||
key_manager
|
||||
.add_many_keys(
|
||||
group_pub_id,
|
||||
keys.into_iter().map(|key| {
|
||||
key.as_slice()
|
||||
.try_into()
|
||||
.expect("critical error, backend has invalid secret keys")
|
||||
}),
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if tx
|
||||
.send(JoinedLibraryCreateArgs {
|
||||
pub_id: library_pub_id,
|
||||
name: library_name,
|
||||
description: library_description,
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
error!("Failed to handle library creation locally from received library data");
|
||||
return Ok(JoinSyncGroupResponse::CriticalError);
|
||||
}
|
||||
|
||||
return Ok(JoinSyncGroupResponse::Accepted { authorizor_device });
|
||||
}
|
||||
|
||||
// In case of timeout, we will try again
|
||||
Err(CloudP2PError::TimedOut) => continue,
|
||||
|
||||
Err(e) => return Ok(JoinSyncGroupResponse::Failed(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
spawn({
|
||||
let endpoint = self.endpoint.clone();
|
||||
let notify_user_tx = self.notify_user_tx.clone();
|
||||
let key_manager = Arc::clone(&self.key_manager);
|
||||
let rng = CryptoRng::from_seed(rng.generate_fixed());
|
||||
async move {
|
||||
let sync_group = req.sync_group.clone();
|
||||
|
||||
if let Err(SendError(response)) = notify_user_tx
|
||||
.send_async(NotifyUser::ReceivedJoinSyncGroupResponse {
|
||||
response: inner(key_manager, endpoint, rng, req, devices_in_group, tx)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
?e,
|
||||
"Failed to issue authorize new device in sync group request;"
|
||||
);
|
||||
JoinSyncGroupResponse::CriticalError
|
||||
}),
|
||||
sync_group,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!(?response, "Failed to send response to user;");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
&self,
|
||||
request: cloud_p2p::Request,
|
||||
channel: RpcChannel<Service, QuinnServerEndpoint<Service>>,
|
||||
) {
|
||||
match request {
|
||||
cloud_p2p::Request::AuthorizeNewDeviceInSyncGroup(
|
||||
authorize_new_device_in_sync_group::Request {
|
||||
sync_group,
|
||||
asking_device,
|
||||
},
|
||||
) => {
|
||||
let ticket = Ticket(self.ticketer.fetch_add(1, Ordering::Relaxed));
|
||||
let this_device = sync_group
|
||||
.devices
|
||||
.iter()
|
||||
.find(|device| device.pub_id == self.current_device_pub_id)
|
||||
.expect(
|
||||
"current device must be in the sync group, otherwise we wouldn't be here",
|
||||
)
|
||||
.clone();
|
||||
|
||||
self.notify_user_tx
|
||||
.send_async(NotifyUser::ReceivedJoinSyncGroupRequest {
|
||||
ticket,
|
||||
asking_device: asking_device.clone(),
|
||||
sync_group: sync_group.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("notify_user_tx must never closes!");
|
||||
|
||||
self.pending_sync_group_join_requests.lock().await.insert(
|
||||
ticket,
|
||||
PendingSyncGroupJoin {
|
||||
channel,
|
||||
request: authorize_new_device_in_sync_group::Request {
|
||||
sync_group,
|
||||
asking_device,
|
||||
},
|
||||
this_device,
|
||||
since: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
cloud_p2p::Request::NotifyNewSyncMessages(req) => {
|
||||
if let Err(e) = channel
|
||||
.rpc(
|
||||
req,
|
||||
(),
|
||||
|(),
|
||||
notify_new_sync_messages::Request {
|
||||
sync_group_pub_id,
|
||||
device_pub_id,
|
||||
}| async move {
|
||||
debug!(%sync_group_pub_id, %device_pub_id, "Received new sync messages notification");
|
||||
if let Some(notifier) = self
|
||||
.sync_messages_receiver_notifiers_map
|
||||
.get(&sync_group_pub_id)
|
||||
{
|
||||
notifier.notify_receiver();
|
||||
} else {
|
||||
warn!("Received new sync messages notification for unknown sync group");
|
||||
}
|
||||
|
||||
Ok(notify_new_sync_messages::Response)
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(
|
||||
?e,
|
||||
"Failed to reply to new sync messages notification request"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_join_response(
|
||||
&self,
|
||||
ticket: Ticket,
|
||||
accepted: Option<BasicLibraryCreationArgs>,
|
||||
) {
|
||||
let Some(PendingSyncGroupJoin {
|
||||
channel,
|
||||
request,
|
||||
this_device,
|
||||
..
|
||||
}) = self
|
||||
.pending_sync_group_join_requests
|
||||
.lock()
|
||||
.await
|
||||
.remove(&ticket)
|
||||
else {
|
||||
warn!("Received join response for unknown ticket; We probably timed out this request already");
|
||||
return;
|
||||
};
|
||||
|
||||
let sync_group = request.sync_group.clone();
|
||||
let asking_device_pub_id = request.asking_device.pub_id;
|
||||
|
||||
let was_accepted = accepted.is_some();
|
||||
|
||||
let response = if let Some(BasicLibraryCreationArgs {
|
||||
id: library_pub_id,
|
||||
name: library_name,
|
||||
description: library_description,
|
||||
}) = accepted
|
||||
{
|
||||
Ok(authorize_new_device_in_sync_group::Response {
|
||||
authorizor_device: this_device,
|
||||
keys: self
|
||||
.key_manager
|
||||
.get_group_keys(request.sync_group.pub_id)
|
||||
.await
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
library_pub_id,
|
||||
library_name,
|
||||
library_description,
|
||||
})
|
||||
} else {
|
||||
Err(CloudP2PError::Rejected)
|
||||
};
|
||||
|
||||
if let Err(e) = channel
|
||||
.rpc(request, (), |(), _req| async move { response })
|
||||
.await
|
||||
{
|
||||
error!(?e, "Failed to send response to user;");
|
||||
self.notify_join_error(sync_group, JoinSyncGroupError::Communication)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if was_accepted {
|
||||
let Ok(access_token) = self
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await
|
||||
.map_err(|e| error!(?e, "Failed to get access token;"))
|
||||
else {
|
||||
self.notify_join_error(sync_group, JoinSyncGroupError::Auth)
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
|
||||
match self
|
||||
.cloud_services
|
||||
.sync()
|
||||
.groups()
|
||||
.reply_join_request(groups::reply_join_request::Request {
|
||||
access_token,
|
||||
group_pub_id: sync_group.pub_id,
|
||||
authorized_device_pub_id: asking_device_pub_id,
|
||||
authorizor_device_pub_id: self.current_device_pub_id,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(groups::reply_join_request::Response)) => {
|
||||
// Everything is Awesome!
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!(?e, "Failed to reply to join request");
|
||||
self.notify_join_error(sync_group, JoinSyncGroupError::InternalServer)
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to send reply to join request");
|
||||
self.notify_join_error(sync_group, JoinSyncGroupError::Communication)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_join_error(
|
||||
&self,
|
||||
sync_group: groups::GroupWithDevices,
|
||||
error: JoinSyncGroupError,
|
||||
) {
|
||||
self.notify_user_tx
|
||||
.send_async(NotifyUser::SendingJoinSyncGroupResponseError { error, sync_group })
|
||||
.await
|
||||
.expect("notify_user_tx must never closes!");
|
||||
}
|
||||
|
||||
async fn tick(&mut self) {
|
||||
self.timeout_checker_buffer.clear();
|
||||
|
||||
let mut pending_sync_group_join_requests =
|
||||
self.pending_sync_group_join_requests.lock().await;
|
||||
|
||||
for (ticket, pending_sync_group_join) in pending_sync_group_join_requests.drain() {
|
||||
if pending_sync_group_join.since.elapsed() > FIVE_MINUTES {
|
||||
let PendingSyncGroupJoin {
|
||||
channel, request, ..
|
||||
} = pending_sync_group_join;
|
||||
|
||||
let asking_device = request.asking_device.clone();
|
||||
|
||||
let notify_message = match channel
|
||||
.rpc(request, (), |(), _req| async move {
|
||||
Err(CloudP2PError::TimedOut)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(()) => NotifyUser::TimedOutJoinRequest {
|
||||
device: asking_device,
|
||||
succeeded: true,
|
||||
},
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to send timed out response to user;");
|
||||
NotifyUser::TimedOutJoinRequest {
|
||||
device: asking_device,
|
||||
succeeded: false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.notify_user_tx
|
||||
.send_async(notify_message)
|
||||
.await
|
||||
.expect("notify_user_tx must never closes!");
|
||||
} else {
|
||||
self.timeout_checker_buffer
|
||||
.push((ticket, pending_sync_group_join));
|
||||
}
|
||||
}
|
||||
|
||||
pending_sync_group_join_requests.extend(self.timeout_checker_buffer.drain(..));
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_to_first_available_client(
|
||||
endpoint: &Endpoint,
|
||||
devices_in_group: &[(devices::PubId, NodeId)],
|
||||
) -> Result<Client<QuinnConnection<Service>, Service>, CloudP2PError> {
|
||||
for (device_pub_id, device_connection_id) in devices_in_group {
|
||||
if let Ok(connection) = endpoint
|
||||
.connect(*device_connection_id, CloudP2PALPN::LATEST)
|
||||
.await
|
||||
.map_err(
|
||||
|e| error!(?e, %device_pub_id, "Failed to connect to authorizor device candidate"),
|
||||
) {
|
||||
debug!(%device_pub_id, "Connected to authorizor device candidate");
|
||||
return Ok(Client::new(RpcClient::new(
|
||||
QuinnConnection::<Service>::from_connection(connection),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Err(CloudP2PError::UnableToConnect)
|
||||
}
|
||||
|
||||
fn setup_server_endpoint(
|
||||
endpoint: Endpoint,
|
||||
) -> (
|
||||
RpcServer<Service, QuinnServerEndpoint<Service>>,
|
||||
JoinHandle<()>,
|
||||
) {
|
||||
let local_addr = {
|
||||
let (ipv4_addr, maybe_ipv6_addr) = endpoint.bound_sockets();
|
||||
// Trying to give preference to IPv6 addresses because it's 2024
|
||||
maybe_ipv6_addr.unwrap_or(ipv4_addr)
|
||||
};
|
||||
|
||||
let (connections_tx, connections_rx) = flume::bounded(16);
|
||||
|
||||
(
|
||||
RpcServer::new(QuinnServerEndpoint::<Service>::handle_connections(
|
||||
connections_rx,
|
||||
local_addr,
|
||||
)),
|
||||
spawn(async move {
|
||||
while let Some(connecting) = endpoint.accept().await {
|
||||
if let Ok(connection) = connecting.await.map_err(|e| {
|
||||
warn!(?e, "Cloud P2P failed to accept connection");
|
||||
}) {
|
||||
if connections_tx.send_async(connection).await.is_err() {
|
||||
warn!("Connection receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
121
core/crates/cloud-services/src/sync/ingest.rs
Normal file
121
core/crates/cloud-services/src/sync/ingest.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use crate::Error;
|
||||
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_actors::{Actor, Stopper};
|
||||
|
||||
use std::{
|
||||
future::IntoFuture,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures_concurrency::future::Race;
|
||||
use tokio::{
|
||||
sync::Notify,
|
||||
time::{sleep, Instant},
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE};
|
||||
|
||||
/// Responsible for taking sync operations received from the cloud,
|
||||
/// and applying them to the local database via the sync system's ingest actor.
|
||||
|
||||
pub struct Ingester {
|
||||
sync: SyncManager,
|
||||
notifiers: Arc<ReceiveAndIngestNotifiers>,
|
||||
active: Arc<AtomicBool>,
|
||||
active_notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl Actor<SyncActors> for Ingester {
|
||||
const IDENTIFIER: SyncActors = SyncActors::Ingester;
|
||||
|
||||
async fn run(&mut self, stop: Stopper) {
|
||||
enum Race {
|
||||
Notified,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
loop {
|
||||
self.active.store(true, Ordering::Relaxed);
|
||||
self.active_notify.notify_waiters();
|
||||
|
||||
if let Err(e) = self.run_loop_iteration().await {
|
||||
error!(?e, "Error during cloud sync ingester actor iteration");
|
||||
sleep(ONE_MINUTE).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
self.active.store(false, Ordering::Relaxed);
|
||||
self.active_notify.notify_waiters();
|
||||
|
||||
if matches!(
|
||||
(
|
||||
self.notifiers
|
||||
.wait_notification_to_ingest()
|
||||
.map(|()| Race::Notified),
|
||||
stop.into_future().map(|()| Race::Stopped),
|
||||
)
|
||||
.race()
|
||||
.await,
|
||||
Race::Stopped
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Ingester {
|
||||
pub const fn new(
|
||||
sync: SyncManager,
|
||||
notifiers: Arc<ReceiveAndIngestNotifiers>,
|
||||
active: Arc<AtomicBool>,
|
||||
active_notify: Arc<Notify>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sync,
|
||||
notifiers,
|
||||
active,
|
||||
active_notify,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_loop_iteration(&self) -> Result<(), Error> {
|
||||
let start = Instant::now();
|
||||
|
||||
let operations_to_ingest_count = self
|
||||
.sync
|
||||
.db
|
||||
.cloud_crdt_operation()
|
||||
.count(vec![])
|
||||
.exec()
|
||||
.await
|
||||
.map_err(sd_core_sync::Error::from)?;
|
||||
|
||||
if operations_to_ingest_count == 0 {
|
||||
debug!("Nothing to ingest, early finishing ingester loop");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!(
|
||||
operations_to_ingest_count,
|
||||
"Starting sync messages cloud ingestion loop"
|
||||
);
|
||||
|
||||
let ingested_count = self.sync.ingest_ops().await?;
|
||||
|
||||
debug!(
|
||||
ingested_count,
|
||||
"Finished sync messages cloud ingestion loop in {:?}",
|
||||
start.elapsed()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
136
core/crates/cloud-services/src/sync/mod.rs
Normal file
136
core/crates/cloud-services/src/sync/mod.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use crate::{CloudServices, Error};
|
||||
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_actors::{ActorsCollection, IntoActor};
|
||||
use sd_cloud_schema::sync::groups;
|
||||
use sd_crypto::CryptoRng;
|
||||
|
||||
use std::{
|
||||
fmt,
|
||||
path::Path,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
mod ingest;
|
||||
mod receive;
|
||||
mod send;
|
||||
|
||||
use ingest::Ingester;
|
||||
use receive::Receiver;
|
||||
use send::Sender;
|
||||
|
||||
const ONE_MINUTE: Duration = Duration::from_secs(60);
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SyncActorsState {
|
||||
pub send_active: Arc<AtomicBool>,
|
||||
pub receive_active: Arc<AtomicBool>,
|
||||
pub ingest_active: Arc<AtomicBool>,
|
||||
pub state_change_notifier: Arc<Notify>,
|
||||
receiver_and_ingester_notifiers: Arc<ReceiveAndIngestNotifiers>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, specta::Type)]
|
||||
#[specta(rename = "CloudSyncActors")]
|
||||
pub enum SyncActors {
|
||||
Ingester,
|
||||
Sender,
|
||||
Receiver,
|
||||
}
|
||||
|
||||
impl fmt::Display for SyncActors {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Ingester => write!(f, "Cloud Sync Ingester"),
|
||||
Self::Sender => write!(f, "Cloud Sync Sender"),
|
||||
Self::Receiver => write!(f, "Cloud Sync Receiver"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ReceiveAndIngestNotifiers {
|
||||
ingester: Notify,
|
||||
receiver: Notify,
|
||||
}
|
||||
|
||||
impl ReceiveAndIngestNotifiers {
|
||||
pub fn notify_receiver(&self) {
|
||||
self.receiver.notify_one();
|
||||
}
|
||||
|
||||
async fn wait_notification_to_receive(&self) {
|
||||
self.receiver.notified().await;
|
||||
}
|
||||
|
||||
fn notify_ingester(&self) {
|
||||
self.ingester.notify_one();
|
||||
}
|
||||
|
||||
async fn wait_notification_to_ingest(&self) {
|
||||
self.ingester.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn declare_actors(
|
||||
data_dir: Box<Path>,
|
||||
cloud_services: Arc<CloudServices>,
|
||||
actors: &ActorsCollection<SyncActors>,
|
||||
actors_state: &SyncActorsState,
|
||||
sync_group_pub_id: groups::PubId,
|
||||
sync: SyncManager,
|
||||
rng: CryptoRng,
|
||||
) -> Result<Arc<ReceiveAndIngestNotifiers>, Error> {
|
||||
let (sender, receiver) = (
|
||||
Sender::new(
|
||||
sync_group_pub_id,
|
||||
sync.clone(),
|
||||
Arc::clone(&cloud_services),
|
||||
Arc::clone(&actors_state.send_active),
|
||||
Arc::clone(&actors_state.state_change_notifier),
|
||||
rng,
|
||||
),
|
||||
Receiver::new(
|
||||
data_dir,
|
||||
sync_group_pub_id,
|
||||
cloud_services.clone(),
|
||||
sync.clone(),
|
||||
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
|
||||
Arc::clone(&actors_state.receive_active),
|
||||
Arc::clone(&actors_state.state_change_notifier),
|
||||
),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let ingester = Ingester::new(
|
||||
sync,
|
||||
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
|
||||
Arc::clone(&actors_state.ingest_active),
|
||||
Arc::clone(&actors_state.state_change_notifier),
|
||||
);
|
||||
|
||||
actors
|
||||
.declare_many_boxed([
|
||||
sender.into_actor(),
|
||||
receiver.into_actor(),
|
||||
ingester.into_actor(),
|
||||
])
|
||||
.await;
|
||||
|
||||
cloud_services
|
||||
.cloud_p2p()
|
||||
.await?
|
||||
.register_sync_messages_receiver_notifier(
|
||||
sync_group_pub_id,
|
||||
Arc::clone(&actors_state.receiver_and_ingester_notifiers),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Arc::clone(&actors_state.receiver_and_ingester_notifiers))
|
||||
}
|
||||
356
core/crates/cloud-services/src/sync/receive.rs
Normal file
356
core/crates/cloud-services/src/sync/receive.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
use crate::{CloudServices, Error, KeyManager};
|
||||
|
||||
use sd_cloud_schema::{
|
||||
devices,
|
||||
sync::{
|
||||
groups,
|
||||
messages::{pull, MessagesCollection},
|
||||
},
|
||||
Client, Service,
|
||||
};
|
||||
use sd_core_sync::{
|
||||
cloud_crdt_op_db, CRDTOperation, CompressedCRDTOperationsPerModel, SyncManager,
|
||||
};
|
||||
|
||||
use sd_actors::{Actor, Stopper};
|
||||
use sd_crypto::{
|
||||
cloud::{OneShotDecryption, SecretKey, StreamDecryption},
|
||||
primitives::{EncryptedBlock, StreamNonce},
|
||||
};
|
||||
use sd_prisma::prisma::PrismaClient;
|
||||
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
future::IntoFuture,
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures_concurrency::future::{Race, TryJoin};
|
||||
use quic_rpc::transport::quinn::QuinnConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{fs, io, sync::Notify, time::sleep};
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE};
|
||||
|
||||
const CLOUD_SYNC_DATA_KEEPER_DIRECTORY: &str = "cloud_sync_data_keeper";
|
||||
|
||||
/// Responsible for downloading sync operations from the cloud to be processed by the ingester
|
||||
|
||||
pub struct Receiver {
|
||||
keeper: LastTimestampKeeper,
|
||||
sync_group_pub_id: groups::PubId,
|
||||
device_pub_id: devices::PubId,
|
||||
cloud_services: Arc<CloudServices>,
|
||||
cloud_client: Client<QuinnConnection<Service>>,
|
||||
key_manager: Arc<KeyManager>,
|
||||
sync: SyncManager,
|
||||
notifiers: Arc<ReceiveAndIngestNotifiers>,
|
||||
active: Arc<AtomicBool>,
|
||||
active_notifier: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl Actor<SyncActors> for Receiver {
|
||||
const IDENTIFIER: SyncActors = SyncActors::Receiver;
|
||||
|
||||
async fn run(&mut self, stop: Stopper) {
|
||||
enum Race {
|
||||
Continue,
|
||||
Stop,
|
||||
}
|
||||
|
||||
loop {
|
||||
self.active.store(true, Ordering::Relaxed);
|
||||
self.active_notifier.notify_waiters();
|
||||
|
||||
let res = self.run_loop_iteration().await;
|
||||
|
||||
self.active.store(false, Ordering::Relaxed);
|
||||
|
||||
if let Err(e) = res {
|
||||
error!(?e, "Error during cloud sync receiver actor iteration");
|
||||
sleep(ONE_MINUTE).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
self.active_notifier.notify_waiters();
|
||||
|
||||
if matches!(
|
||||
(
|
||||
sleep(ONE_MINUTE).map(|()| Race::Continue),
|
||||
self.notifiers
|
||||
.wait_notification_to_receive()
|
||||
.map(|()| Race::Continue),
|
||||
stop.into_future().map(|()| Race::Stop),
|
||||
)
|
||||
.race()
|
||||
.await,
|
||||
Race::Stop
|
||||
) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Receiver {
|
||||
pub async fn new(
|
||||
data_dir: impl AsRef<Path> + Send,
|
||||
sync_group_pub_id: groups::PubId,
|
||||
cloud_services: Arc<CloudServices>,
|
||||
sync: SyncManager,
|
||||
notifiers: Arc<ReceiveAndIngestNotifiers>,
|
||||
active: Arc<AtomicBool>,
|
||||
active_notify: Arc<Notify>,
|
||||
) -> Result<Self, Error> {
|
||||
let (keeper, cloud_client, key_manager) = (
|
||||
LastTimestampKeeper::load(data_dir.as_ref(), sync_group_pub_id),
|
||||
cloud_services.client(),
|
||||
cloud_services.key_manager(),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
keeper,
|
||||
sync_group_pub_id,
|
||||
device_pub_id: devices::PubId(Uuid::from(&sync.device_pub_id)),
|
||||
cloud_services,
|
||||
cloud_client,
|
||||
key_manager,
|
||||
sync,
|
||||
notifiers,
|
||||
active,
|
||||
active_notifier: active_notify,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_loop_iteration(&mut self) -> Result<(), Error> {
|
||||
let mut responses_stream = self
|
||||
.cloud_client
|
||||
.sync()
|
||||
.messages()
|
||||
.pull(pull::Request {
|
||||
access_token: self
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
group_pub_id: self.sync_group_pub_id,
|
||||
current_device_pub_id: self.device_pub_id,
|
||||
start_time_per_device: self
|
||||
.keeper
|
||||
.timestamps
|
||||
.iter()
|
||||
.map(|(device_pub_id, timestamp)| (*device_pub_id, *timestamp))
|
||||
.collect(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
while let Some(new_messages_res) = responses_stream.next().await {
|
||||
let pull::Response(new_messages) = new_messages_res??;
|
||||
if new_messages.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.handle_new_messages(new_messages).await?;
|
||||
}
|
||||
|
||||
debug!("Finished sync messages receiver actor iteration");
|
||||
|
||||
self.keeper.save().await
|
||||
}
|
||||
|
||||
async fn handle_new_messages(
|
||||
&mut self,
|
||||
new_messages: Vec<MessagesCollection>,
|
||||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
new_messages_collections_count = new_messages.len(),
|
||||
start_time = ?new_messages.first().map(|c| c.start_time),
|
||||
end_time = ?new_messages.first().map(|c| c.end_time),
|
||||
"Handling new sync messages collections",
|
||||
);
|
||||
|
||||
for message in new_messages.into_iter().filter(|message| {
|
||||
if message.original_device_pub_id == self.device_pub_id {
|
||||
warn!("Received sync message from the current device, need to check backend, this is a bug!");
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}) {
|
||||
debug!(
|
||||
new_messages_count = message.operations_count,
|
||||
start_time = ?message.start_time,
|
||||
end_time = ?message.end_time,
|
||||
"Handling new sync messages",
|
||||
);
|
||||
|
||||
let (device_pub_id, timestamp) = handle_single_message(
|
||||
self.sync_group_pub_id,
|
||||
message,
|
||||
&self.key_manager,
|
||||
&self.sync,
|
||||
)
|
||||
.await?;
|
||||
|
||||
match self.keeper.timestamps.entry(device_pub_id) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
if entry.get() < ×tamp {
|
||||
*entry.get_mut() = timestamp;
|
||||
}
|
||||
}
|
||||
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
// To ingest after each sync message collection is received, we MUST download and
|
||||
// store the messages SEQUENTIALLY, otherwise we might ingest messages out of order
|
||||
// due to parallel downloads
|
||||
self.notifiers.notify_ingester();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(%sync_group_pub_id, %original_device_pub_id, operations_count, ?key_hash, %end_time),
|
||||
)]
|
||||
async fn handle_single_message(
|
||||
sync_group_pub_id: groups::PubId,
|
||||
MessagesCollection {
|
||||
original_device_pub_id,
|
||||
end_time,
|
||||
operations_count,
|
||||
key_hash,
|
||||
encrypted_messages,
|
||||
..
|
||||
}: MessagesCollection,
|
||||
key_manager: &KeyManager,
|
||||
sync: &SyncManager,
|
||||
) -> Result<(devices::PubId, DateTime<Utc>), Error> {
|
||||
// FIXME(@fogodev): If we don't have the key hash, we need to fetch it from another device in the group if possible
|
||||
let Some(secret_key) = key_manager.get_key(sync_group_pub_id, &key_hash).await else {
|
||||
return Err(Error::MissingKeyHash);
|
||||
};
|
||||
|
||||
debug!(
|
||||
size = encrypted_messages.len(),
|
||||
"Received encrypted sync messages collection"
|
||||
);
|
||||
|
||||
let crdt_ops = decrypt_messages(encrypted_messages, secret_key, original_device_pub_id).await?;
|
||||
|
||||
assert_eq!(
|
||||
crdt_ops.len(),
|
||||
operations_count as usize,
|
||||
"Sync messages count mismatch"
|
||||
);
|
||||
|
||||
write_cloud_ops_to_db(crdt_ops, &sync.db).await?;
|
||||
|
||||
Ok((original_device_pub_id, end_time))
|
||||
}
|
||||
|
||||
#[instrument(skip(encrypted_messages, secret_key), fields(messages_size = %encrypted_messages.len()), err)]
|
||||
async fn decrypt_messages(
|
||||
encrypted_messages: Vec<u8>,
|
||||
secret_key: SecretKey,
|
||||
devices::PubId(device_pub_id): devices::PubId,
|
||||
) -> Result<Vec<CRDTOperation>, Error> {
|
||||
let plain_text = if encrypted_messages.len() <= EncryptedBlock::CIPHER_TEXT_SIZE {
|
||||
OneShotDecryption::decrypt(&secret_key, encrypted_messages.as_slice().into())
|
||||
.map_err(Error::Decrypt)?
|
||||
} else {
|
||||
let (nonce, cipher_text) = encrypted_messages.split_at(size_of::<StreamNonce>());
|
||||
|
||||
let mut plain_text = Vec::with_capacity(cipher_text.len());
|
||||
|
||||
StreamDecryption::decrypt(
|
||||
&secret_key,
|
||||
nonce.try_into().expect("we split the correct amount"),
|
||||
cipher_text,
|
||||
&mut plain_text,
|
||||
)
|
||||
.await
|
||||
.map_err(Error::Decrypt)?;
|
||||
|
||||
plain_text
|
||||
};
|
||||
|
||||
rmp_serde::from_slice::<CompressedCRDTOperationsPerModel>(&plain_text)
|
||||
.map(|compressed_ops| compressed_ops.into_ops(device_pub_id))
|
||||
.map_err(Error::DeserializationFailureToPullSyncMessages)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
pub async fn write_cloud_ops_to_db(
|
||||
ops: Vec<CRDTOperation>,
|
||||
db: &PrismaClient,
|
||||
) -> Result<(), sd_core_sync::Error> {
|
||||
db._batch(
|
||||
ops.into_iter()
|
||||
.map(|op| cloud_crdt_op_db(&op).map(|op| op.to_query(db)))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct LastTimestampKeeper {
|
||||
timestamps: HashMap<devices::PubId, DateTime<Utc>>,
|
||||
file_path: Box<Path>,
|
||||
}
|
||||
|
||||
impl LastTimestampKeeper {
|
||||
async fn load(data_dir: &Path, sync_group_pub_id: groups::PubId) -> Result<Self, Error> {
|
||||
let cloud_sync_data_directory = data_dir.join(CLOUD_SYNC_DATA_KEEPER_DIRECTORY);
|
||||
|
||||
fs::create_dir_all(&cloud_sync_data_directory)
|
||||
.await
|
||||
.map_err(Error::FailedToCreateTimestampKeepersDirectory)?;
|
||||
|
||||
let file_path = cloud_sync_data_directory
|
||||
.join(format!("{sync_group_pub_id}.bin"))
|
||||
.into_boxed_path();
|
||||
|
||||
match fs::read(&file_path).await {
|
||||
Ok(bytes) => Ok(Self {
|
||||
timestamps: rmp_serde::from_slice(&bytes)
|
||||
.map_err(Error::LastTimestampKeeperDeserialization)?,
|
||||
file_path,
|
||||
}),
|
||||
|
||||
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(Self {
|
||||
timestamps: HashMap::new(),
|
||||
file_path,
|
||||
}),
|
||||
|
||||
Err(e) => Err(Error::FailedToReadLastTimestampKeeper(e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn save(&self) -> Result<(), Error> {
|
||||
fs::write(
|
||||
&self.file_path,
|
||||
&rmp_serde::to_vec_named(&self.timestamps)
|
||||
.map_err(Error::LastTimestampKeeperSerialization)?,
|
||||
)
|
||||
.await
|
||||
.map_err(Error::FailedToWriteLastTimestampKeeper)
|
||||
}
|
||||
}
|
||||
337
core/crates/cloud-services/src/sync/send.rs
Normal file
337
core/crates/cloud-services/src/sync/send.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use crate::{CloudServices, Error, KeyManager};
|
||||
|
||||
use sd_core_sync::{CompressedCRDTOperationsPerModelPerDevice, SyncEvent, SyncManager, NTP64};
|
||||
|
||||
use sd_actors::{Actor, Stopper};
|
||||
use sd_cloud_schema::{
|
||||
devices,
|
||||
error::{ClientSideError, NotFoundError},
|
||||
sync::{groups, messages},
|
||||
Client, Service,
|
||||
};
|
||||
use sd_crypto::{
|
||||
cloud::{OneShotEncryption, SecretKey, StreamEncryption},
|
||||
primitives::EncryptedBlock,
|
||||
CryptoRng, SeedableRng,
|
||||
};
|
||||
use sd_utils::{datetime_to_timestamp, timestamp_to_datetime};
|
||||
|
||||
use std::{
|
||||
future::IntoFuture,
|
||||
pin::pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::{FutureExt, StreamExt, TryStreamExt};
|
||||
use futures_concurrency::future::{Race, TryJoin};
|
||||
use quic_rpc::transport::quinn::QuinnConnection;
|
||||
use tokio::{
|
||||
sync::{broadcast, Notify},
|
||||
time::sleep,
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{SyncActors, ONE_MINUTE};
|
||||
|
||||
const TEN_SECONDS: Duration = Duration::from_secs(10);
|
||||
|
||||
const MESSAGES_COLLECTION_SIZE: u32 = 10_000;
|
||||
|
||||
enum RaceNotifiedOrStopped {
|
||||
Notified,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
enum LoopStatus {
|
||||
SentMessages,
|
||||
Idle,
|
||||
}
|
||||
|
||||
type LatestTimestamp = NTP64;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sender {
|
||||
sync_group_pub_id: groups::PubId,
|
||||
sync: SyncManager,
|
||||
cloud_services: Arc<CloudServices>,
|
||||
cloud_client: Client<QuinnConnection<Service>>,
|
||||
key_manager: Arc<KeyManager>,
|
||||
is_active: Arc<AtomicBool>,
|
||||
state_notify: Arc<Notify>,
|
||||
rng: CryptoRng,
|
||||
maybe_latest_timestamp: Option<LatestTimestamp>,
|
||||
}
|
||||
|
||||
impl Actor<SyncActors> for Sender {
|
||||
const IDENTIFIER: SyncActors = SyncActors::Sender;
|
||||
|
||||
async fn run(&mut self, stop: Stopper) {
|
||||
loop {
|
||||
self.is_active.store(true, Ordering::Relaxed);
|
||||
self.state_notify.notify_waiters();
|
||||
|
||||
let res = self.run_loop_iteration().await;
|
||||
|
||||
self.is_active.store(false, Ordering::Relaxed);
|
||||
|
||||
match res {
|
||||
Ok(LoopStatus::SentMessages) => {
|
||||
if let Ok(cloud_p2p) = self.cloud_services.cloud_p2p().await.map_err(|e| {
|
||||
error!(?e, "Failed to get cloud p2p client on sender actor");
|
||||
}) {
|
||||
cloud_p2p
|
||||
.notify_new_sync_messages(self.sync_group_pub_id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(LoopStatus::Idle) => {}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Error during cloud sync sender actor iteration");
|
||||
sleep(ONE_MINUTE).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
self.state_notify.notify_waiters();
|
||||
|
||||
if matches!(
|
||||
(
|
||||
// recreate subscription each time so that existing messages are dropped
|
||||
wait_notification(self.sync.subscribe()),
|
||||
stop.into_future().map(|()| RaceNotifiedOrStopped::Stopped),
|
||||
)
|
||||
.race()
|
||||
.await,
|
||||
RaceNotifiedOrStopped::Stopped
|
||||
) {
|
||||
break;
|
||||
}
|
||||
|
||||
sleep(TEN_SECONDS).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sender {
|
||||
pub async fn new(
|
||||
sync_group_pub_id: groups::PubId,
|
||||
sync: SyncManager,
|
||||
cloud_services: Arc<CloudServices>,
|
||||
is_active: Arc<AtomicBool>,
|
||||
state_notify: Arc<Notify>,
|
||||
rng: CryptoRng,
|
||||
) -> Result<Self, Error> {
|
||||
let (cloud_client, key_manager) = (cloud_services.client(), cloud_services.key_manager())
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
sync_group_pub_id,
|
||||
sync,
|
||||
cloud_services,
|
||||
cloud_client,
|
||||
key_manager,
|
||||
is_active,
|
||||
state_notify,
|
||||
rng,
|
||||
maybe_latest_timestamp: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_loop_iteration(&mut self) -> Result<LoopStatus, Error> {
|
||||
debug!("Starting cloud sender actor loop iteration");
|
||||
|
||||
let current_device_pub_id = devices::PubId(Uuid::from(&self.sync.device_pub_id));
|
||||
|
||||
let (key_hash, secret_key) = self
|
||||
.key_manager
|
||||
.get_latest_key(self.sync_group_pub_id)
|
||||
.await
|
||||
.ok_or(Error::MissingSyncGroupKey(self.sync_group_pub_id))?;
|
||||
|
||||
let current_latest_timestamp = self.get_latest_timestamp(current_device_pub_id).await?;
|
||||
|
||||
let mut crdt_ops_stream = pin!(self.sync.stream_device_ops(
|
||||
&self.sync.device_pub_id,
|
||||
MESSAGES_COLLECTION_SIZE,
|
||||
current_latest_timestamp
|
||||
));
|
||||
|
||||
let mut status = LoopStatus::Idle;
|
||||
|
||||
let mut new_latest_timestamp = current_latest_timestamp;
|
||||
|
||||
debug!(
|
||||
chunk_size = MESSAGES_COLLECTION_SIZE,
|
||||
"Trying to fetch chunk of sync messages from the database"
|
||||
);
|
||||
while let Some(ops_res) = crdt_ops_stream.next().await {
|
||||
let ops = ops_res?;
|
||||
|
||||
let (Some(first), Some(last)) = (ops.first(), ops.last()) else {
|
||||
break;
|
||||
};
|
||||
|
||||
debug!("Got first and last sync messages");
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let operations_count = ops.len() as u32;
|
||||
|
||||
debug!(operations_count, "Got chunk of sync messages");
|
||||
|
||||
new_latest_timestamp = last.timestamp;
|
||||
|
||||
let start_time = timestamp_to_datetime(first.timestamp);
|
||||
let end_time = timestamp_to_datetime(last.timestamp);
|
||||
|
||||
// Ignoring this device_pub_id here as we already know it
|
||||
let (_device_pub_id, compressed_ops) =
|
||||
CompressedCRDTOperationsPerModelPerDevice::new_single_device(ops);
|
||||
|
||||
let messages_bytes = rmp_serde::to_vec_named(&compressed_ops)
|
||||
.map_err(Error::SerializationFailureToPushSyncMessages)?;
|
||||
|
||||
let encrypted_messages =
|
||||
encrypt_messages(&secret_key, &mut self.rng, messages_bytes).await?;
|
||||
|
||||
let encrypted_messages_size = encrypted_messages.len();
|
||||
|
||||
debug!(
|
||||
operations_count,
|
||||
encrypted_messages_size, "Sending sync messages to cloud",
|
||||
);
|
||||
|
||||
self.cloud_client
|
||||
.sync()
|
||||
.messages()
|
||||
.push(messages::push::Request {
|
||||
access_token: self
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
group_pub_id: self.sync_group_pub_id,
|
||||
device_pub_id: current_device_pub_id,
|
||||
key_hash: key_hash.clone(),
|
||||
operations_count,
|
||||
time_range: (start_time, end_time),
|
||||
encrypted_messages,
|
||||
})
|
||||
.await??;
|
||||
|
||||
debug!(
|
||||
operations_count,
|
||||
encrypted_messages_size, "Sent sync messages to cloud",
|
||||
);
|
||||
|
||||
status = LoopStatus::SentMessages;
|
||||
}
|
||||
|
||||
self.maybe_latest_timestamp = Some(new_latest_timestamp);
|
||||
|
||||
debug!("Finished cloud sender actor loop iteration");
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
async fn get_latest_timestamp(
|
||||
&self,
|
||||
current_device_pub_id: devices::PubId,
|
||||
) -> Result<LatestTimestamp, Error> {
|
||||
if let Some(latest_timestamp) = &self.maybe_latest_timestamp {
|
||||
Ok(*latest_timestamp)
|
||||
} else {
|
||||
let latest_time = match self
|
||||
.cloud_client
|
||||
.sync()
|
||||
.messages()
|
||||
.get_latest_time(messages::get_latest_time::Request {
|
||||
access_token: self
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
group_pub_id: self.sync_group_pub_id,
|
||||
kind: messages::get_latest_time::Kind::ForCurrentDevice(current_device_pub_id),
|
||||
})
|
||||
.await?
|
||||
{
|
||||
Ok(messages::get_latest_time::Response {
|
||||
latest_time,
|
||||
latest_device_pub_id,
|
||||
}) => {
|
||||
assert_eq!(latest_device_pub_id, current_device_pub_id);
|
||||
latest_time
|
||||
}
|
||||
|
||||
Err(sd_cloud_schema::Error::Client(ClientSideError::NotFound(
|
||||
NotFoundError::LatestSyncMessageTime,
|
||||
))) => DateTime::<Utc>::from(UNIX_EPOCH),
|
||||
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
Ok(datetime_to_timestamp(latest_time))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn encrypt_messages(
|
||||
secret_key: &SecretKey,
|
||||
rng: &mut CryptoRng,
|
||||
messages_bytes: Vec<u8>,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
if messages_bytes.len() <= EncryptedBlock::PLAIN_TEXT_SIZE {
|
||||
let mut nonce_and_cipher_text = Vec::with_capacity(OneShotEncryption::cipher_text_size(
|
||||
secret_key,
|
||||
messages_bytes.len(),
|
||||
));
|
||||
|
||||
let EncryptedBlock { nonce, cipher_text } =
|
||||
OneShotEncryption::encrypt(secret_key, messages_bytes.as_slice(), rng)
|
||||
.map_err(Error::Encrypt)?;
|
||||
|
||||
nonce_and_cipher_text.extend_from_slice(nonce.as_slice());
|
||||
nonce_and_cipher_text.extend(&cipher_text);
|
||||
|
||||
Ok(nonce_and_cipher_text)
|
||||
} else {
|
||||
let mut rng = CryptoRng::from_seed(rng.generate_fixed());
|
||||
let mut nonce_and_cipher_text = Vec::with_capacity(StreamEncryption::cipher_text_size(
|
||||
secret_key,
|
||||
messages_bytes.len(),
|
||||
));
|
||||
|
||||
let (nonce, cipher_stream) =
|
||||
StreamEncryption::encrypt(secret_key, messages_bytes.as_slice(), &mut rng);
|
||||
|
||||
nonce_and_cipher_text.extend_from_slice(nonce.as_slice());
|
||||
|
||||
let mut cipher_stream = pin!(cipher_stream);
|
||||
|
||||
while let Some(ciphered_chunk) = cipher_stream.try_next().await.map_err(Error::Encrypt)? {
|
||||
nonce_and_cipher_text.extend(ciphered_chunk);
|
||||
}
|
||||
|
||||
Ok(nonce_and_cipher_text)
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_notification(mut rx: broadcast::Receiver<SyncEvent>) -> RaceNotifiedOrStopped {
|
||||
// wait until Created message comes in
|
||||
loop {
|
||||
if matches!(rx.recv().await, Ok(SyncEvent::Created)) {
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
RaceNotifiedOrStopped::Notified
|
||||
}
|
||||
468
core/crates/cloud-services/src/token_refresher.rs
Normal file
468
core/crates/cloud-services/src/token_refresher.rs
Normal file
@@ -0,0 +1,468 @@
|
||||
use sd_cloud_schema::auth::{AccessToken, RefreshToken};
|
||||
|
||||
use std::{pin::pin, time::Duration};
|
||||
|
||||
use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::StreamExt;
|
||||
use futures_concurrency::stream::Merge;
|
||||
use reqwest::Url;
|
||||
use reqwest_middleware::{reqwest::header, ClientWithMiddleware};
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::oneshot,
|
||||
time::{interval, sleep, MissedTickBehavior},
|
||||
};
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::{Error, GetTokenError};
|
||||
|
||||
const ONE_MINUTE: Duration = Duration::from_secs(60);
|
||||
const TEN_SECONDS: Duration = Duration::from_secs(10);
|
||||
|
||||
enum Message {
|
||||
Init(
|
||||
(
|
||||
AccessToken,
|
||||
RefreshToken,
|
||||
oneshot::Sender<Result<(), Error>>,
|
||||
),
|
||||
),
|
||||
CheckInitialization(oneshot::Sender<Result<(), GetTokenError>>),
|
||||
RequestToken(oneshot::Sender<Result<AccessToken, GetTokenError>>),
|
||||
RefreshTime,
|
||||
Tick,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenRefresher {
|
||||
tx: flume::Sender<Message>,
|
||||
}
|
||||
|
||||
impl TokenRefresher {
|
||||
pub(crate) fn new(http_client: ClientWithMiddleware, auth_server_url: Url) -> Self {
|
||||
let (tx, rx) = flume::bounded(8);
|
||||
|
||||
spawn(async move {
|
||||
let refresh_url = auth_server_url
|
||||
.join("/api/auth/session/refresh")
|
||||
.expect("hardcoded refresh url path");
|
||||
|
||||
while let Err(e) = spawn(Runner::run(
|
||||
http_client.clone(),
|
||||
refresh_url.clone(),
|
||||
rx.clone(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
if e.is_panic() {
|
||||
if let Some(msg) = e.into_panic().downcast_ref::<&str>() {
|
||||
error!(?msg, "Panic in request handler!");
|
||||
} else {
|
||||
error!("Some unknown panic in request handler!");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self { tx }
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
&self,
|
||||
access_token: AccessToken,
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<(), Error> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send_async(Message::Init((access_token, refresh_token, tx)))
|
||||
.await
|
||||
.expect("Token refresher channel closed");
|
||||
|
||||
rx.await.expect("Token refresher channel closed")
|
||||
}
|
||||
|
||||
pub async fn check_initialization(&self) -> Result<(), GetTokenError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send_async(Message::CheckInitialization(tx))
|
||||
.await
|
||||
.expect("Token refresher channel closed");
|
||||
|
||||
rx.await.expect("Token refresher channel closed")
|
||||
}
|
||||
|
||||
pub async fn get_access_token(&self) -> Result<AccessToken, GetTokenError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send_async(Message::RequestToken(tx))
|
||||
.await
|
||||
.expect("Token refresher channel closed");
|
||||
|
||||
rx.await.expect("Token refresher channel closed")
|
||||
}
|
||||
}
|
||||
|
||||
struct Runner {
|
||||
initialized: bool,
|
||||
http_client: ClientWithMiddleware,
|
||||
refresh_url: Url,
|
||||
current_token: Option<AccessToken>,
|
||||
current_refresh_token: Option<RefreshToken>,
|
||||
token_decoding_buffer: Vec<u8>,
|
||||
refresh_tx: flume::Sender<Message>,
|
||||
}
|
||||
|
||||
impl Runner {
|
||||
async fn run(
|
||||
http_client: ClientWithMiddleware,
|
||||
refresh_url: Url,
|
||||
msgs_rx: flume::Receiver<Message>,
|
||||
) {
|
||||
let (refresh_tx, refresh_rx) = flume::bounded(1);
|
||||
|
||||
let mut ticker = interval(TEN_SECONDS);
|
||||
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
let mut msg_stream = pin!((
|
||||
msgs_rx.into_stream(),
|
||||
refresh_rx.into_stream(),
|
||||
IntervalStream::new(ticker).map(|_| Message::Tick)
|
||||
)
|
||||
.merge());
|
||||
|
||||
let mut runner = Self {
|
||||
initialized: false,
|
||||
http_client,
|
||||
refresh_url,
|
||||
current_token: None,
|
||||
current_refresh_token: None,
|
||||
token_decoding_buffer: Vec::new(),
|
||||
refresh_tx,
|
||||
};
|
||||
|
||||
while let Some(msg) = msg_stream.next().await {
|
||||
match msg {
|
||||
Message::Init((access_token, refresh_token, ack)) => {
|
||||
if ack
|
||||
.send(runner.init(access_token, refresh_token).await)
|
||||
.is_err()
|
||||
{
|
||||
error!("Failed to send init token refresher response, receiver dropped;");
|
||||
}
|
||||
}
|
||||
|
||||
Message::CheckInitialization(ack) => runner.check_initialization(ack),
|
||||
|
||||
Message::RequestToken(ack) => runner.reply_token(ack),
|
||||
|
||||
Message::RefreshTime => {
|
||||
if let Err(e) = runner.refresh().await {
|
||||
error!(?e, "Failed to refresh token: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
Message::Tick => runner.tick().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn init(
|
||||
&mut self,
|
||||
access_token: AccessToken,
|
||||
refresh_token: RefreshToken,
|
||||
) -> Result<(), Error> {
|
||||
let access_token_duration =
|
||||
Self::extract_access_token_duration(&mut self.token_decoding_buffer, &access_token)?;
|
||||
|
||||
self.initialized = true;
|
||||
self.current_token = Some(access_token);
|
||||
self.current_refresh_token = Some(refresh_token);
|
||||
|
||||
// If the token has an expiration smaller than a minute, we need to refresh it immediately.
|
||||
if access_token_duration < ONE_MINUTE {
|
||||
self.refresh_tx
|
||||
.send_async(Message::RefreshTime)
|
||||
.await
|
||||
.expect("refresh channel never closes");
|
||||
} else {
|
||||
// This task will be mostly parked waiting a sleep
|
||||
spawn(Self::schedule_refresh(
|
||||
self.refresh_tx.clone(),
|
||||
access_token_duration - ONE_MINUTE,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reply_token(&self, ack: oneshot::Sender<Result<AccessToken, GetTokenError>>) {
|
||||
if ack
|
||||
.send(self.current_token.clone().ok_or({
|
||||
if self.initialized {
|
||||
GetTokenError::FailedToRefresh
|
||||
} else {
|
||||
GetTokenError::RefresherNotInitialized
|
||||
}
|
||||
}))
|
||||
.is_err()
|
||||
{
|
||||
warn!("Failed to send access token response, receiver dropped;");
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh(&mut self) -> Result<(), Error> {
|
||||
let RefreshToken(refresh_token) = self
|
||||
.current_refresh_token
|
||||
.clone()
|
||||
.expect("refresh token is set otherwise we wouldn't be here");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(self.refresh_url.clone())
|
||||
.header("rid", "session")
|
||||
.header(header::AUTHORIZATION, format!("Bearer {refresh_token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(Error::RefreshTokenRequest)?
|
||||
.error_for_status()
|
||||
.map_err(Error::AuthServerError)?;
|
||||
|
||||
if let (Some(access_token), Some(refresh_token)) = (
|
||||
response.headers().get("st-access-token"),
|
||||
response.headers().get("st-refresh-token"),
|
||||
) {
|
||||
// Only set values if we can parse both of them to strings
|
||||
let (access_token, refresh_token) = (
|
||||
Self::token_header_value_to_string(access_token)?,
|
||||
Self::token_header_value_to_string(refresh_token)?,
|
||||
);
|
||||
|
||||
self.current_token = Some(AccessToken(access_token));
|
||||
self.current_refresh_token = Some(RefreshToken(refresh_token));
|
||||
} else {
|
||||
return Err(Error::MissingTokensOnRefreshResponse);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extract_access_token_duration(
|
||||
token_decoding_buffer: &mut Vec<u8>,
|
||||
AccessToken(token): &AccessToken,
|
||||
) -> Result<Duration, Error> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Token {
|
||||
#[serde(with = "chrono::serde::ts_seconds")]
|
||||
exp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
token_decoding_buffer.clear();
|
||||
|
||||
// The format of a JWT token is simple:
|
||||
// "<base64-encoded header>.<base64-encoded claims>.<signature>"
|
||||
BASE64_URL_SAFE_NO_PAD.decode_vec(
|
||||
token.split('.').nth(1).ok_or(Error::MissingClaims)?,
|
||||
token_decoding_buffer,
|
||||
)?;
|
||||
|
||||
serde_json::from_slice::<Token>(token_decoding_buffer)?
|
||||
.exp
|
||||
.signed_duration_since(Utc::now())
|
||||
.to_std()
|
||||
.map_err(|_| Error::TokenExpired)
|
||||
}
|
||||
|
||||
async fn schedule_refresh(refresh_tx: flume::Sender<Message>, wait_time: Duration) {
|
||||
sleep(wait_time).await;
|
||||
refresh_tx
|
||||
.send_async(Message::RefreshTime)
|
||||
.await
|
||||
.expect("Refresh channel closed");
|
||||
}
|
||||
|
||||
fn token_header_value_to_string(token: &header::HeaderValue) -> Result<String, Error> {
|
||||
token.to_str().map(str::to_string).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn check_initialization(&self, ack: oneshot::Sender<Result<(), GetTokenError>>) {
|
||||
if ack
|
||||
.send(if self.initialized {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(GetTokenError::RefresherNotInitialized)
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
warn!("Failed to send access token response, receiver dropped;");
|
||||
}
|
||||
}
|
||||
|
||||
/// This method is a safeguard to make sure we try to keep refreshing tokens even if they
|
||||
/// already expired, as the refresh token has a bigger expiration than the access token.
|
||||
async fn tick(&mut self) {
|
||||
if let Some(access_token) = &self.current_token {
|
||||
if matches!(
|
||||
Self::extract_access_token_duration(&mut self.token_decoding_buffer, access_token),
|
||||
Err(Error::TokenExpired)
|
||||
) {
|
||||
if let Err(e) = self.refresh().await {
|
||||
error!(?e, "Failed to refresh expired token on tick method;");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This test is here for documentation purposes only, they are not meant to be run.
|
||||
/// They're just examples of how to sign-up/sign-in and refresh tokens
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use reqwest::header;
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::AUTH_SERVER_URL;
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn get_tokens() -> (AccessToken, RefreshToken) {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let req_body = json!({
|
||||
"formFields": [
|
||||
{
|
||||
"id": "email",
|
||||
"value": "johndoe@gmail.com"
|
||||
},
|
||||
{
|
||||
"id": "password",
|
||||
"value": "testPass123"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(format!("{AUTH_SERVER_URL}/api/auth/public/signup"))
|
||||
.header("rid", "emailpassword")
|
||||
.header("st-auth-mode", "header")
|
||||
.json(&req_body)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
if let (Some(access_token), Some(refresh_token)) = (
|
||||
response.headers().get("st-access-token"),
|
||||
response.headers().get("st-refresh-token"),
|
||||
) {
|
||||
(
|
||||
AccessToken(access_token.to_str().unwrap().to_string()),
|
||||
RefreshToken(refresh_token.to_str().unwrap().to_string()),
|
||||
)
|
||||
} else {
|
||||
let response = client
|
||||
.post(format!("{AUTH_SERVER_URL}/api/auth/public/signin"))
|
||||
.header("rid", "emailpassword")
|
||||
.header("st-auth-mode", "header")
|
||||
.json(&req_body)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
(
|
||||
AccessToken(
|
||||
response
|
||||
.headers()
|
||||
.get("st-access-token")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
),
|
||||
RefreshToken(
|
||||
response
|
||||
.headers()
|
||||
.get("st-refresh-token")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[ignore = "Documentation only"]
|
||||
#[tokio::test]
|
||||
async fn test_refresh_token() {
|
||||
let (AccessToken(access_token), RefreshToken(refresh_token)) = get_tokens().await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(format!("{AUTH_SERVER_URL}/api/auth/session/refresh"))
|
||||
.header("rid", "session")
|
||||
.header(header::AUTHORIZATION, format!("Bearer {refresh_token}"))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
assert_ne!(
|
||||
response
|
||||
.headers()
|
||||
.get("st-access-token")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
access_token.as_str()
|
||||
);
|
||||
|
||||
assert_ne!(
|
||||
response
|
||||
.headers()
|
||||
.get("st-refresh-token")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
refresh_token.as_str()
|
||||
);
|
||||
}
|
||||
|
||||
#[ignore = "Needs an actual SuperTokens auth server running"]
|
||||
#[tokio::test]
|
||||
async fn test_refresher_runner() {
|
||||
let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3));
|
||||
|
||||
let http_client = ClientBuilder::new(http_client_builder.build().unwrap())
|
||||
.with(RetryTransientMiddleware::new_with_policy(
|
||||
ExponentialBackoff::builder().build_with_max_retries(3),
|
||||
))
|
||||
.build();
|
||||
|
||||
let (refresh_tx, _refresh_rx) = flume::bounded(1);
|
||||
|
||||
let mut runner = Runner {
|
||||
initialized: false,
|
||||
http_client,
|
||||
refresh_url: Url::parse(&format!("{AUTH_SERVER_URL}/api/auth/session/refresh"))
|
||||
.unwrap(),
|
||||
current_token: None,
|
||||
current_refresh_token: None,
|
||||
token_decoding_buffer: Vec::new(),
|
||||
refresh_tx,
|
||||
};
|
||||
|
||||
let (access_token, refresh_token) = get_tokens().await;
|
||||
|
||||
runner.init(access_token, refresh_token).await.unwrap();
|
||||
|
||||
runner.refresh().await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use sd_core_prisma_helpers::{
|
||||
file_path_for_file_identifier, file_path_for_media_processor, file_path_for_object_validator,
|
||||
file_path_to_full_path, file_path_to_handle_custom_uri, file_path_to_handle_p2p_serve_file,
|
||||
file_path_to_isolate, file_path_to_isolate_with_id, file_path_to_isolate_with_pub_id,
|
||||
file_path_walker, file_path_with_object,
|
||||
file_path_walker, file_path_watcher_remove, file_path_with_object,
|
||||
};
|
||||
|
||||
use sd_prisma::prisma::{file_path, location};
|
||||
@@ -506,7 +506,8 @@ impl_from_db!(
|
||||
file_path_to_isolate_with_pub_id,
|
||||
file_path_walker,
|
||||
file_path_to_isolate_with_id,
|
||||
file_path_with_object
|
||||
file_path_with_object,
|
||||
file_path_watcher_remove
|
||||
);
|
||||
|
||||
impl_from_db_without_location_id!(
|
||||
|
||||
@@ -14,7 +14,11 @@ use crate::{
|
||||
use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId};
|
||||
|
||||
use sd_prisma::prisma::{file_path, location, SortOrder};
|
||||
use sd_prisma::{
|
||||
prisma::{device, file_path, location, SortOrder},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{sync_db_not_null_entry, OperationFactory};
|
||||
use sd_task_system::{
|
||||
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
|
||||
TaskOutput, TaskStatus,
|
||||
@@ -128,14 +132,14 @@ impl Job for FileIdentifier {
|
||||
match task_kind {
|
||||
TaskKind::Identifier => tasks::Identifier::deserialize(
|
||||
&task_bytes,
|
||||
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
|
||||
(Arc::clone(ctx.db()), ctx.sync().clone()),
|
||||
)
|
||||
.await
|
||||
.map(IntoTask::into_task),
|
||||
|
||||
TaskKind::ObjectProcessor => tasks::ObjectProcessor::deserialize(
|
||||
&task_bytes,
|
||||
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
|
||||
(Arc::clone(ctx.db()), ctx.sync().clone()),
|
||||
)
|
||||
.await
|
||||
.map(IntoTask::into_task),
|
||||
@@ -173,8 +177,21 @@ impl Job for FileIdentifier {
|
||||
) -> Result<ReturnStatus, Error> {
|
||||
let mut pending_running_tasks = FuturesUnordered::new();
|
||||
|
||||
let device_pub_id = &ctx.sync().device_pub_id;
|
||||
let device_id = ctx
|
||||
.db()
|
||||
.device()
|
||||
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(file_identifier::Error::from)?
|
||||
.ok_or(file_identifier::Error::DeviceNotFound(
|
||||
device_pub_id.clone(),
|
||||
))?
|
||||
.id;
|
||||
|
||||
match self
|
||||
.init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher)
|
||||
.init_or_resume(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
|
||||
.await
|
||||
{
|
||||
Ok(()) => { /* Everything is awesome! */ }
|
||||
@@ -201,7 +218,7 @@ impl Job for FileIdentifier {
|
||||
match task {
|
||||
Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => {
|
||||
match self
|
||||
.process_task_output(task_id, out, &ctx, &dispatcher)
|
||||
.process_task_output(task_id, out, &ctx, device_id, &dispatcher)
|
||||
.await
|
||||
{
|
||||
Ok(tasks) => pending_running_tasks.extend(tasks),
|
||||
@@ -254,15 +271,25 @@ impl Job for FileIdentifier {
|
||||
..
|
||||
} = self;
|
||||
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(
|
||||
location::id::equals(location.id),
|
||||
vec![location::scan_state::set(
|
||||
LocationScanState::FilesIdentified as i32,
|
||||
)],
|
||||
let (sync_param, db_param) = sync_db_not_null_entry!(
|
||||
LocationScanState::FilesIdentified as i32,
|
||||
location::scan_state
|
||||
);
|
||||
|
||||
ctx.sync()
|
||||
.write_op(
|
||||
ctx.db(),
|
||||
ctx.sync().shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location.pub_id.clone(),
|
||||
},
|
||||
[sync_param],
|
||||
),
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(location::id::equals(location.id), vec![db_param])
|
||||
.select(location::select!({ id })),
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.map_err(file_identifier::Error::from)?;
|
||||
|
||||
@@ -302,6 +329,7 @@ impl FileIdentifier {
|
||||
&mut self,
|
||||
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
|
||||
// if we don't have any pending task, then this is a fresh job
|
||||
@@ -335,6 +363,7 @@ impl FileIdentifier {
|
||||
.as_ref()
|
||||
.unwrap_or(&location_root_iso_file_path),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
pending_running_tasks,
|
||||
)
|
||||
@@ -345,8 +374,9 @@ impl FileIdentifier {
|
||||
self.last_orphan_file_path_id = None;
|
||||
|
||||
self.dispatch_deep_identifier_tasks(
|
||||
&maybe_sub_iso_file_path,
|
||||
maybe_sub_iso_file_path.as_ref(),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
pending_running_tasks,
|
||||
)
|
||||
@@ -378,6 +408,7 @@ impl FileIdentifier {
|
||||
.as_ref()
|
||||
.unwrap_or(&location_root_iso_file_path),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
pending_running_tasks,
|
||||
)
|
||||
@@ -388,8 +419,9 @@ impl FileIdentifier {
|
||||
self.last_orphan_file_path_id = None;
|
||||
|
||||
self.dispatch_deep_identifier_tasks(
|
||||
&maybe_sub_iso_file_path,
|
||||
maybe_sub_iso_file_path.as_ref(),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
pending_running_tasks,
|
||||
)
|
||||
@@ -401,8 +433,9 @@ impl FileIdentifier {
|
||||
|
||||
Phase::SearchingOrphans => {
|
||||
self.dispatch_deep_identifier_tasks(
|
||||
&maybe_sub_iso_file_path,
|
||||
maybe_sub_iso_file_path.as_ref(),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
pending_running_tasks,
|
||||
)
|
||||
@@ -447,6 +480,7 @@ impl FileIdentifier {
|
||||
task_id: TaskId,
|
||||
any_task_output: Box<dyn AnyTaskOutput>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Result<Vec<TaskHandle<Error>>, DispatcherError> {
|
||||
if any_task_output.is::<identifier::Output>() {
|
||||
@@ -457,6 +491,7 @@ impl FileIdentifier {
|
||||
.downcast::<identifier::Output>()
|
||||
.expect("just checked"),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
)
|
||||
.await;
|
||||
@@ -501,6 +536,7 @@ impl FileIdentifier {
|
||||
errors,
|
||||
}: identifier::Output,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Result<Vec<TaskHandle<Error>>, DispatcherError> {
|
||||
self.metadata.mean_extract_metadata_time += extract_metadata_time;
|
||||
@@ -548,6 +584,7 @@ impl FileIdentifier {
|
||||
let (tasks_count, res) = match dispatch_object_processor_tasks(
|
||||
self.file_paths_accumulator.drain(),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
false,
|
||||
)
|
||||
@@ -636,6 +673,7 @@ impl FileIdentifier {
|
||||
&mut self,
|
||||
sub_iso_file_path: &IsolatedFilePathData<'static>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
pending_running_tasks: &FuturesUnordered<TaskHandle<Error>>,
|
||||
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
|
||||
@@ -702,7 +740,8 @@ impl FileIdentifier {
|
||||
orphan_paths,
|
||||
true,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
))
|
||||
.await?,
|
||||
);
|
||||
@@ -713,8 +752,9 @@ impl FileIdentifier {
|
||||
|
||||
async fn dispatch_deep_identifier_tasks<OuterCtx: OuterContext>(
|
||||
&mut self,
|
||||
maybe_sub_iso_file_path: &Option<IsolatedFilePathData<'static>>,
|
||||
maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'static>>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
pending_running_tasks: &FuturesUnordered<TaskHandle<Error>>,
|
||||
) -> Result<(), JobErrorOrDispatcherError<file_identifier::Error>> {
|
||||
@@ -785,7 +825,8 @@ impl FileIdentifier {
|
||||
orphan_paths,
|
||||
false,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
))
|
||||
.await?,
|
||||
);
|
||||
|
||||
@@ -2,9 +2,10 @@ use crate::{utils::sub_path, OuterContext};
|
||||
|
||||
use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData};
|
||||
use sd_core_prisma_helpers::CasId;
|
||||
use sd_core_sync::DevicePubId;
|
||||
|
||||
use sd_file_ext::{extensions::Extension, kind::ObjectKind};
|
||||
use sd_prisma::prisma::{file_path, location};
|
||||
use sd_prisma::prisma::{device, file_path, location};
|
||||
use sd_task_system::{TaskDispatcher, TaskHandle};
|
||||
use sd_utils::{db::MissingFieldError, error::FileIOError};
|
||||
|
||||
@@ -41,6 +42,8 @@ const CHUNK_SIZE: usize = 100;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("device not found: <device_pub_id='{0}'")]
|
||||
DeviceNotFound(DevicePubId),
|
||||
#[error("missing field on database: {0}")]
|
||||
MissingField(#[from] MissingFieldError),
|
||||
#[error("failed to deserialized stored tasks for job resume: {0}")]
|
||||
@@ -173,7 +176,7 @@ fn orphan_path_filters_shallow(
|
||||
fn orphan_path_filters_deep(
|
||||
location_id: location::id::Type,
|
||||
file_path_id: Option<file_path::id::Type>,
|
||||
maybe_sub_iso_file_path: &Option<IsolatedFilePathData<'_>>,
|
||||
maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'_>>,
|
||||
) -> Vec<file_path::WhereParam> {
|
||||
sd_utils::chain_optional_iter(
|
||||
[
|
||||
@@ -197,6 +200,7 @@ fn orphan_path_filters_deep(
|
||||
async fn dispatch_object_processor_tasks<Iter, Dispatcher>(
|
||||
file_paths_by_cas_id: Iter,
|
||||
ctx: &impl OuterContext,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &Dispatcher,
|
||||
with_priority: bool,
|
||||
) -> Result<Vec<TaskHandle<crate::Error>>, Dispatcher::DispatchError>
|
||||
@@ -217,7 +221,8 @@ where
|
||||
.dispatch(tasks::ObjectProcessor::new(
|
||||
HashMap::from([(cas_id, objects_to_create_or_link)]),
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
with_priority,
|
||||
))
|
||||
.await?,
|
||||
@@ -239,7 +244,8 @@ where
|
||||
.dispatch(tasks::ObjectProcessor::new(
|
||||
mem::take(&mut current_batch),
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
with_priority,
|
||||
))
|
||||
.await?,
|
||||
@@ -256,7 +262,8 @@ where
|
||||
.dispatch(tasks::ObjectProcessor::new(
|
||||
current_batch,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
with_priority,
|
||||
))
|
||||
.await?,
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_prisma_helpers::file_path_for_file_identifier;
|
||||
|
||||
use sd_prisma::prisma::{file_path, location, SortOrder};
|
||||
use sd_prisma::prisma::{device, file_path, location, SortOrder};
|
||||
use sd_task_system::{
|
||||
BaseTaskDispatcher, CancelTaskOnDrop, TaskDispatcher, TaskHandle, TaskOutput, TaskStatus,
|
||||
};
|
||||
@@ -66,6 +66,19 @@ pub async fn shallow(
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
let device_pub_id = &ctx.sync().device_pub_id;
|
||||
let device_id = ctx
|
||||
.db()
|
||||
.device()
|
||||
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(file_identifier::Error::from)?
|
||||
.ok_or(file_identifier::Error::DeviceNotFound(
|
||||
device_pub_id.clone(),
|
||||
))?
|
||||
.id;
|
||||
|
||||
let mut orphans_count = 0;
|
||||
let mut last_orphan_file_path_id = None;
|
||||
|
||||
@@ -103,7 +116,8 @@ pub async fn shallow(
|
||||
orphan_paths,
|
||||
true,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
))
|
||||
.await
|
||||
else {
|
||||
@@ -119,13 +133,14 @@ pub async fn shallow(
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
process_tasks(identifier_tasks, dispatcher, ctx).await
|
||||
process_tasks(identifier_tasks, dispatcher, ctx, device_id).await
|
||||
}
|
||||
|
||||
async fn process_tasks(
|
||||
identifier_tasks: Vec<TaskHandle<Error>>,
|
||||
dispatcher: &BaseTaskDispatcher<Error>,
|
||||
ctx: &impl OuterContext,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<Vec<NonCriticalError>, Error> {
|
||||
let total_identifier_tasks = identifier_tasks.len();
|
||||
|
||||
@@ -169,6 +184,7 @@ async fn process_tasks(
|
||||
let Ok(tasks) = dispatch_object_processor_tasks(
|
||||
file_paths_accumulator.drain(),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
true,
|
||||
)
|
||||
|
||||
@@ -5,18 +5,18 @@ use crate::{
|
||||
|
||||
use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId, FilePathPubId};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_file_ext::kind::ObjectKind;
|
||||
use sd_prisma::{
|
||||
prisma::{file_path, location, PrismaClient},
|
||||
prisma::{device, file_path, location, PrismaClient},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::OperationFactory;
|
||||
use sd_sync::{sync_db_entry, OperationFactory};
|
||||
use sd_task_system::{
|
||||
ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, SerializableTask, Task, TaskId,
|
||||
};
|
||||
use sd_utils::{error::FileIOError, msgpack};
|
||||
use sd_utils::error::FileIOError;
|
||||
|
||||
use std::{
|
||||
collections::HashMap, convert::identity, future::IntoFuture, mem, path::PathBuf, pin::pin,
|
||||
@@ -64,6 +64,7 @@ pub struct Identifier {
|
||||
file_paths_by_id: HashMap<FilePathPubId, file_path_for_file_identifier::Data>,
|
||||
|
||||
// Inner state
|
||||
device_id: device::id::Type,
|
||||
identified_files: HashMap<FilePathPubId, IdentifiedFile>,
|
||||
file_paths_without_cas_id: Vec<FilePathToCreateOrLinkObject>,
|
||||
|
||||
@@ -72,7 +73,7 @@ pub struct Identifier {
|
||||
|
||||
// Dependencies
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
}
|
||||
|
||||
/// Output from the `[Identifier]` task
|
||||
@@ -135,6 +136,7 @@ impl Task<Error> for Identifier {
|
||||
let Self {
|
||||
location,
|
||||
location_path,
|
||||
device_id,
|
||||
file_paths_by_id,
|
||||
file_paths_without_cas_id,
|
||||
identified_files,
|
||||
@@ -255,6 +257,7 @@ impl Task<Error> for Identifier {
|
||||
file_paths_without_cas_id.drain(..),
|
||||
&self.db,
|
||||
&self.sync,
|
||||
*device_id,
|
||||
),
|
||||
)
|
||||
.try_join()
|
||||
@@ -301,6 +304,7 @@ impl Task<Error> for Identifier {
|
||||
file_paths_without_cas_id.drain(..),
|
||||
&self.db,
|
||||
&self.sync,
|
||||
*device_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -324,7 +328,8 @@ impl Identifier {
|
||||
file_paths: Vec<file_path_for_file_identifier::Data>,
|
||||
with_priority: bool,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Self {
|
||||
let mut output = Output::default();
|
||||
|
||||
@@ -377,6 +382,7 @@ impl Identifier {
|
||||
id: TaskId::new_v4(),
|
||||
location,
|
||||
location_path,
|
||||
device_id,
|
||||
identified_files: HashMap::with_capacity(file_paths_count - directories_count),
|
||||
file_paths_without_cas_id,
|
||||
file_paths_by_id,
|
||||
@@ -394,33 +400,31 @@ async fn assign_cas_id_to_file_paths(
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
) -> Result<(), file_identifier::Error> {
|
||||
// Assign cas_id to each file path
|
||||
sync.write_ops(
|
||||
db,
|
||||
identified_files
|
||||
.iter()
|
||||
.map(|(pub_id, IdentifiedFile { cas_id, .. })| {
|
||||
(
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: pub_id.to_db(),
|
||||
},
|
||||
file_path::cas_id::NAME,
|
||||
msgpack!(cas_id),
|
||||
),
|
||||
db.file_path()
|
||||
.update(
|
||||
file_path::pub_id::equals(pub_id.to_db()),
|
||||
vec![file_path::cas_id::set(cas_id.into())],
|
||||
)
|
||||
// We don't need any data here, just the id avoids receiving the entire object
|
||||
// as we can't pass an empty select macro call
|
||||
.select(file_path::select!({ id })),
|
||||
)
|
||||
})
|
||||
.unzip::<_, _, _, Vec<_>>(),
|
||||
)
|
||||
.await?;
|
||||
let (ops, queries) = identified_files
|
||||
.iter()
|
||||
.map(|(pub_id, IdentifiedFile { cas_id, .. })| {
|
||||
let (sync_param, db_param) = sync_db_entry!(cas_id, file_path::cas_id);
|
||||
|
||||
(
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: pub_id.to_db(),
|
||||
},
|
||||
[sync_param],
|
||||
),
|
||||
db.file_path()
|
||||
.update(file_path::pub_id::equals(pub_id.to_db()), vec![db_param])
|
||||
// We don't need any data here, just the id avoids receiving the entire object
|
||||
// as we can't pass an empty select macro call
|
||||
.select(file_path::select!({ id })),
|
||||
)
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
if !ops.is_empty() && !queries.is_empty() {
|
||||
// Assign cas_id to each file path
|
||||
sync.write_ops(db, (ops, queries)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -500,6 +504,7 @@ struct SaveState {
|
||||
id: TaskId,
|
||||
location: Arc<location::Data>,
|
||||
location_path: Arc<PathBuf>,
|
||||
device_id: device::id::Type,
|
||||
file_paths_by_id: HashMap<FilePathPubId, file_path_for_file_identifier::Data>,
|
||||
identified_files: HashMap<FilePathPubId, IdentifiedFile>,
|
||||
file_paths_without_cas_id: Vec<FilePathToCreateOrLinkObject>,
|
||||
@@ -512,13 +517,14 @@ impl SerializableTask<Error> for Identifier {
|
||||
|
||||
type DeserializeError = rmp_serde::decode::Error;
|
||||
|
||||
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
|
||||
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
|
||||
|
||||
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
|
||||
let Self {
|
||||
id,
|
||||
location,
|
||||
location_path,
|
||||
device_id,
|
||||
file_paths_by_id,
|
||||
identified_files,
|
||||
file_paths_without_cas_id,
|
||||
@@ -530,6 +536,7 @@ impl SerializableTask<Error> for Identifier {
|
||||
id,
|
||||
location,
|
||||
location_path,
|
||||
device_id,
|
||||
file_paths_by_id,
|
||||
identified_files,
|
||||
file_paths_without_cas_id,
|
||||
@@ -547,6 +554,7 @@ impl SerializableTask<Error> for Identifier {
|
||||
id,
|
||||
location,
|
||||
location_path,
|
||||
device_id,
|
||||
file_paths_by_id,
|
||||
identified_files,
|
||||
file_paths_without_cas_id,
|
||||
@@ -558,6 +566,7 @@ impl SerializableTask<Error> for Identifier {
|
||||
location,
|
||||
location_path,
|
||||
file_paths_by_id,
|
||||
device_id,
|
||||
identified_files,
|
||||
file_paths_without_cas_id,
|
||||
output,
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use crate::file_identifier;
|
||||
|
||||
use sd_core_prisma_helpers::{file_path_id, FilePathPubId, ObjectPubId};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_file_ext::kind::ObjectKind;
|
||||
use sd_prisma::{
|
||||
prisma::{file_path, object, PrismaClient},
|
||||
prisma::{device, file_path, object, PrismaClient},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{CRDTOperation, OperationFactory};
|
||||
use sd_utils::msgpack;
|
||||
use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, CRDTOperation, OperationFactory};
|
||||
use sd_utils::chain_optional_iter;
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
@@ -47,10 +47,12 @@ fn connect_file_path_to_object<'db>(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: file_path_pub_id.to_db(),
|
||||
},
|
||||
file_path::object::NAME,
|
||||
msgpack!(prisma_sync::object::SyncId {
|
||||
pub_id: object_pub_id.to_db(),
|
||||
}),
|
||||
[sync_entry!(
|
||||
prisma_sync::object::SyncId {
|
||||
pub_id: object_pub_id.to_db(),
|
||||
},
|
||||
file_path::object
|
||||
)],
|
||||
),
|
||||
db.file_path()
|
||||
.update(
|
||||
@@ -69,6 +71,7 @@ async fn create_objects_and_update_file_paths(
|
||||
files_and_kinds: impl IntoIterator<Item = FilePathToCreateOrLinkObject> + Send,
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<HashMap<file_path::id::Type, ObjectPubId>, file_identifier::Error> {
|
||||
trace!("Preparing objects");
|
||||
let (object_create_args, file_path_args) = files_and_kinds
|
||||
@@ -84,16 +87,23 @@ async fn create_objects_and_update_file_paths(
|
||||
|
||||
let kind = kind as i32;
|
||||
|
||||
let (sync_params, db_params) = [
|
||||
(
|
||||
(object::date_created::NAME, msgpack!(created_at)),
|
||||
object::date_created::set(created_at),
|
||||
),
|
||||
(
|
||||
(object::kind::NAME, msgpack!(kind)),
|
||||
object::kind::set(Some(kind)),
|
||||
),
|
||||
]
|
||||
let device_pub_id = sync.device_pub_id.to_db();
|
||||
|
||||
let (sync_params, db_params) = chain_optional_iter(
|
||||
[
|
||||
(
|
||||
sync_entry!(
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device_pub_id,
|
||||
},
|
||||
object::device
|
||||
),
|
||||
object::device_id::set(Some(device_id)),
|
||||
),
|
||||
sync_db_entry!(kind, object::kind),
|
||||
],
|
||||
[option_sync_db_entry!(created_at, object::date_created)],
|
||||
)
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
@@ -121,51 +131,57 @@ async fn create_objects_and_update_file_paths(
|
||||
.unzip::<_, _, HashMap<_, _>, Vec<_>>(
|
||||
);
|
||||
|
||||
trace!(
|
||||
new_objects_count = object_create_args.len(),
|
||||
"Creating new Objects!;",
|
||||
);
|
||||
let new_objects_count = object_create_args.len();
|
||||
if new_objects_count > 0 {
|
||||
trace!(new_objects_count, "Creating new Objects!;",);
|
||||
|
||||
// create new object records with assembled values
|
||||
let created_objects_count = sync
|
||||
.write_ops(db, {
|
||||
let (sync, db_params) = object_create_args
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
(
|
||||
sync.into_iter().flatten().collect(),
|
||||
db.object().create_many(db_params),
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
trace!(%created_objects_count, "Created new Objects;");
|
||||
|
||||
if created_objects_count > 0 {
|
||||
trace!("Updating file paths with created objects");
|
||||
|
||||
let updated_file_path_ids = sync
|
||||
.write_ops(
|
||||
db,
|
||||
file_path_update_args
|
||||
// create new object records with assembled values
|
||||
let created_objects_count = sync
|
||||
.write_ops(db, {
|
||||
let (sync, db_params) = object_create_args
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>(),
|
||||
)
|
||||
.await
|
||||
.map(|file_paths| {
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|file_path_id::Data { id }| id)
|
||||
.collect::<HashSet<_>>()
|
||||
})?;
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
object_pub_id_by_file_path_id
|
||||
.retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id));
|
||||
(sync, db.object().create_many(db_params))
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(object_pub_id_by_file_path_id)
|
||||
trace!(%created_objects_count, "Created new Objects;");
|
||||
|
||||
if created_objects_count > 0 {
|
||||
let file_paths_to_update_count = file_path_update_args.len();
|
||||
if file_paths_to_update_count > 0 {
|
||||
trace!(
|
||||
file_paths_to_update_count,
|
||||
"Updating file paths with created objects"
|
||||
);
|
||||
|
||||
let updated_file_path_ids = sync
|
||||
.write_ops(
|
||||
db,
|
||||
file_path_update_args
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>(),
|
||||
)
|
||||
.await
|
||||
.map(|file_paths| {
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|file_path_id::Data { id }| id)
|
||||
.collect::<HashSet<_>>()
|
||||
})?;
|
||||
|
||||
object_pub_id_by_file_path_id
|
||||
.retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id));
|
||||
}
|
||||
|
||||
Ok(object_pub_id_by_file_path_id)
|
||||
} else {
|
||||
trace!("No objects created, skipping file path updates");
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
} else {
|
||||
trace!("No objects created, skipping file path updates");
|
||||
trace!("No objects to create, skipping file path updates");
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use crate::{file_identifier, Error};
|
||||
|
||||
use sd_core_prisma_helpers::{file_path_id, object_for_file_identifier, CasId, ObjectPubId};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::prisma::{file_path, object, PrismaClient};
|
||||
use sd_prisma::prisma::{device, file_path, object, PrismaClient};
|
||||
use sd_task_system::{
|
||||
check_interruption, ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId,
|
||||
};
|
||||
@@ -29,13 +29,14 @@ pub struct ObjectProcessor {
|
||||
|
||||
// Inner state
|
||||
stage: Stage,
|
||||
device_id: device::id::Type,
|
||||
|
||||
// Out collector
|
||||
output: Output,
|
||||
|
||||
// Dependencies
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -93,6 +94,7 @@ impl Task<Error> for ObjectProcessor {
|
||||
let Self {
|
||||
db,
|
||||
sync,
|
||||
device_id,
|
||||
file_paths_by_cas_id,
|
||||
stage,
|
||||
output:
|
||||
@@ -167,8 +169,13 @@ impl Task<Error> for ObjectProcessor {
|
||||
);
|
||||
let start = Instant::now();
|
||||
let (more_file_paths_with_new_object, more_linked_objects_count) =
|
||||
assign_objects_to_duplicated_orphans(file_paths_by_cas_id, db, sync)
|
||||
.await?;
|
||||
assign_objects_to_duplicated_orphans(
|
||||
file_paths_by_cas_id,
|
||||
db,
|
||||
sync,
|
||||
*device_id,
|
||||
)
|
||||
.await?;
|
||||
*create_object_time = start.elapsed();
|
||||
file_path_ids_with_new_object.extend(more_file_paths_with_new_object);
|
||||
*linked_objects_count += more_linked_objects_count;
|
||||
@@ -194,7 +201,8 @@ impl ObjectProcessor {
|
||||
pub fn new(
|
||||
file_paths_by_cas_id: HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
device_id: device::id::Type,
|
||||
with_priority: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -202,6 +210,7 @@ impl ObjectProcessor {
|
||||
db,
|
||||
sync,
|
||||
file_paths_by_cas_id,
|
||||
device_id,
|
||||
stage: Stage::Starting,
|
||||
output: Output::default(),
|
||||
with_priority,
|
||||
@@ -270,45 +279,44 @@ async fn assign_existing_objects_to_file_paths(
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
) -> Result<Vec<file_path::id::Type>, file_identifier::Error> {
|
||||
sync.write_ops(
|
||||
db,
|
||||
objects_by_cas_id
|
||||
.iter()
|
||||
.flat_map(|(cas_id, object_pub_id)| {
|
||||
file_paths_by_cas_id
|
||||
.remove(cas_id)
|
||||
.map(|file_paths| {
|
||||
file_paths.into_iter().map(
|
||||
|FilePathToCreateOrLinkObject {
|
||||
file_path_pub_id, ..
|
||||
}| {
|
||||
connect_file_path_to_object(
|
||||
&file_path_pub_id,
|
||||
object_pub_id,
|
||||
db,
|
||||
sync,
|
||||
)
|
||||
},
|
||||
)
|
||||
})
|
||||
.expect("must be here")
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>(),
|
||||
)
|
||||
.await
|
||||
.map(|file_paths| {
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|file_path_id::Data { id }| id)
|
||||
.collect()
|
||||
})
|
||||
.map_err(Into::into)
|
||||
let (ops, queries) = objects_by_cas_id
|
||||
.iter()
|
||||
.flat_map(|(cas_id, object_pub_id)| {
|
||||
file_paths_by_cas_id
|
||||
.remove(cas_id)
|
||||
.map(|file_paths| {
|
||||
file_paths.into_iter().map(
|
||||
|FilePathToCreateOrLinkObject {
|
||||
file_path_pub_id, ..
|
||||
}| {
|
||||
connect_file_path_to_object(&file_path_pub_id, object_pub_id, db, sync)
|
||||
},
|
||||
)
|
||||
})
|
||||
.expect("must be here")
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
if ops.is_empty() && queries.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
sync.write_ops(db, (ops, queries))
|
||||
.await
|
||||
.map(|file_paths| {
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|file_path_id::Data { id }| id)
|
||||
.collect()
|
||||
})
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn assign_objects_to_duplicated_orphans(
|
||||
file_paths_by_cas_id: &mut HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(Vec<file_path::id::Type>, u64), file_identifier::Error> {
|
||||
// at least 1 file path per cas_id
|
||||
let mut selected_file_paths = Vec::with_capacity(file_paths_by_cas_id.len());
|
||||
@@ -327,7 +335,7 @@ async fn assign_objects_to_duplicated_orphans(
|
||||
});
|
||||
|
||||
let (mut file_paths_with_new_object, objects_by_cas_id) =
|
||||
create_objects_and_update_file_paths(selected_file_paths, db, sync)
|
||||
create_objects_and_update_file_paths(selected_file_paths, db, sync, device_id)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(file_path_id, object_pub_id)| {
|
||||
@@ -365,6 +373,7 @@ async fn assign_objects_to_duplicated_orphans(
|
||||
pub struct SaveState {
|
||||
id: TaskId,
|
||||
file_paths_by_cas_id: HashMap<CasId<'static>, Vec<FilePathToCreateOrLinkObject>>,
|
||||
device_id: device::id::Type,
|
||||
stage: Stage,
|
||||
output: Output,
|
||||
with_priority: bool,
|
||||
@@ -375,12 +384,13 @@ impl SerializableTask<Error> for ObjectProcessor {
|
||||
|
||||
type DeserializeError = rmp_serde::decode::Error;
|
||||
|
||||
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
|
||||
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
|
||||
|
||||
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
|
||||
let Self {
|
||||
id,
|
||||
file_paths_by_cas_id,
|
||||
device_id,
|
||||
stage,
|
||||
output,
|
||||
with_priority,
|
||||
@@ -390,6 +400,7 @@ impl SerializableTask<Error> for ObjectProcessor {
|
||||
rmp_serde::to_vec_named(&SaveState {
|
||||
id,
|
||||
file_paths_by_cas_id,
|
||||
device_id,
|
||||
stage,
|
||||
output,
|
||||
with_priority,
|
||||
@@ -404,6 +415,7 @@ impl SerializableTask<Error> for ObjectProcessor {
|
||||
|SaveState {
|
||||
id,
|
||||
file_paths_by_cas_id,
|
||||
device_id,
|
||||
stage,
|
||||
output,
|
||||
with_priority,
|
||||
@@ -412,6 +424,7 @@ impl SerializableTask<Error> for ObjectProcessor {
|
||||
with_priority,
|
||||
file_paths_by_cas_id,
|
||||
stage,
|
||||
device_id,
|
||||
output,
|
||||
db,
|
||||
sync,
|
||||
|
||||
@@ -16,7 +16,11 @@ use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_indexer_rules::{IndexerRule, IndexerRuler};
|
||||
use sd_core_prisma_helpers::location_with_indexer_rules;
|
||||
|
||||
use sd_prisma::prisma::location;
|
||||
use sd_prisma::{
|
||||
prisma::{device, location},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{sync_db_not_null_entry, OperationFactory};
|
||||
use sd_task_system::{
|
||||
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
|
||||
TaskOutput, TaskStatus,
|
||||
@@ -116,13 +120,13 @@ impl Job for Indexer {
|
||||
|
||||
TaskKind::Save => tasks::Saver::deserialize(
|
||||
&task_bytes,
|
||||
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
|
||||
(Arc::clone(ctx.db()), ctx.sync().clone()),
|
||||
)
|
||||
.await
|
||||
.map(IntoTask::into_task),
|
||||
TaskKind::Update => tasks::Updater::deserialize(
|
||||
&task_bytes,
|
||||
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
|
||||
(Arc::clone(ctx.db()), ctx.sync().clone()),
|
||||
)
|
||||
.await
|
||||
.map(IntoTask::into_task),
|
||||
@@ -161,6 +165,17 @@ impl Job for Indexer {
|
||||
) -> Result<ReturnStatus, Error> {
|
||||
let mut pending_running_tasks = FuturesUnordered::new();
|
||||
|
||||
let device_pub_id = &ctx.sync().device_pub_id;
|
||||
let device_id = ctx
|
||||
.db()
|
||||
.device()
|
||||
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(indexer::Error::from)?
|
||||
.ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))?
|
||||
.id;
|
||||
|
||||
match self
|
||||
.init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher)
|
||||
.await
|
||||
@@ -191,21 +206,26 @@ impl Job for Indexer {
|
||||
}
|
||||
|
||||
if let Some(res) = self
|
||||
.process_handles(&mut pending_running_tasks, &ctx, &dispatcher)
|
||||
.process_handles(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
|
||||
.await
|
||||
{
|
||||
return res;
|
||||
}
|
||||
|
||||
if let Some(res) = self
|
||||
.dispatch_last_save_and_update_tasks(&mut pending_running_tasks, &ctx, &dispatcher)
|
||||
.dispatch_last_save_and_update_tasks(
|
||||
&mut pending_running_tasks,
|
||||
&ctx,
|
||||
device_id,
|
||||
&dispatcher,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return res;
|
||||
}
|
||||
|
||||
if let Some(res) = self
|
||||
.index_pending_ancestors(&mut pending_running_tasks, &ctx, &dispatcher)
|
||||
.index_pending_ancestors(&mut pending_running_tasks, &ctx, device_id, &dispatcher)
|
||||
.await
|
||||
{
|
||||
return res;
|
||||
@@ -253,7 +273,7 @@ impl Job for Indexer {
|
||||
.await?;
|
||||
}
|
||||
|
||||
update_location_size(location.id, ctx.db(), &ctx).await?;
|
||||
update_location_size(location.id, location.pub_id.clone(), &ctx).await?;
|
||||
|
||||
metadata.mean_db_write_time += start_size_update_time.elapsed();
|
||||
}
|
||||
@@ -271,13 +291,23 @@ impl Job for Indexer {
|
||||
"all tasks must be completed here"
|
||||
);
|
||||
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(
|
||||
location::id::equals(location.id),
|
||||
vec![location::scan_state::set(LocationScanState::Indexed as i32)],
|
||||
let (sync_param, db_param) =
|
||||
sync_db_not_null_entry!(LocationScanState::Indexed as i32, location::scan_state);
|
||||
|
||||
ctx.sync()
|
||||
.write_op(
|
||||
ctx.db(),
|
||||
ctx.sync().shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location.pub_id.clone(),
|
||||
},
|
||||
[sync_param],
|
||||
),
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(location::id::equals(location.id), vec![db_param])
|
||||
.select(location::select!({ id })),
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.map_err(indexer::Error::from)?;
|
||||
|
||||
@@ -338,6 +368,7 @@ impl Indexer {
|
||||
task_id: TaskId,
|
||||
any_task_output: Box<dyn AnyTaskOutput>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Result<Vec<TaskHandle<Error>>, JobErrorOrDispatcherError<indexer::Error>> {
|
||||
self.metadata.completed_tasks += 1;
|
||||
@@ -349,6 +380,7 @@ impl Indexer {
|
||||
.downcast::<walker::Output<WalkerDBProxy, IsoFilePathFactory>>()
|
||||
.expect("just checked"),
|
||||
ctx,
|
||||
device_id,
|
||||
dispatcher,
|
||||
)
|
||||
.await;
|
||||
@@ -403,6 +435,7 @@ impl Indexer {
|
||||
..
|
||||
}: walker::Output<WalkerDBProxy, IsoFilePathFactory>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Result<Vec<TaskHandle<Error>>, JobErrorOrDispatcherError<indexer::Error>> {
|
||||
self.metadata.mean_scan_read_time += scan_time;
|
||||
@@ -465,7 +498,7 @@ impl Indexer {
|
||||
self.metadata.mean_db_write_time += db_delete_time.elapsed();
|
||||
}
|
||||
let (save_tasks, update_tasks) =
|
||||
self.prepare_save_and_update_tasks(to_create, to_update, ctx);
|
||||
self.prepare_save_and_update_tasks(to_create, to_update, ctx, device_id);
|
||||
|
||||
ctx.progress(vec![
|
||||
ProgressUpdate::TaskCount(self.metadata.total_tasks),
|
||||
@@ -552,13 +585,14 @@ impl Indexer {
|
||||
&mut self,
|
||||
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Option<Result<ReturnStatus, Error>> {
|
||||
while let Some(task) = pending_running_tasks.next().await {
|
||||
match task {
|
||||
Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => {
|
||||
match self
|
||||
.process_task_output(task_id, out, ctx, dispatcher)
|
||||
.process_task_output(task_id, out, ctx, device_id, dispatcher)
|
||||
.await
|
||||
{
|
||||
Ok(more_handles) => pending_running_tasks.extend(more_handles),
|
||||
@@ -666,6 +700,7 @@ impl Indexer {
|
||||
&mut self,
|
||||
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Option<Result<ReturnStatus, Error>> {
|
||||
if !self.to_create_buffer.is_empty() || !self.to_update_buffer.is_empty() {
|
||||
@@ -687,7 +722,8 @@ impl Indexer {
|
||||
self.location.pub_id.clone(),
|
||||
self.to_create_buffer.drain(..).collect(),
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
)
|
||||
.into_task(),
|
||||
);
|
||||
@@ -707,7 +743,7 @@ impl Indexer {
|
||||
tasks::Updater::new_deep(
|
||||
self.to_update_buffer.drain(..).collect(),
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
)
|
||||
.into_task(),
|
||||
);
|
||||
@@ -726,7 +762,7 @@ impl Indexer {
|
||||
}
|
||||
}
|
||||
|
||||
self.process_handles(pending_running_tasks, ctx, dispatcher)
|
||||
self.process_handles(pending_running_tasks, ctx, device_id, dispatcher)
|
||||
.await
|
||||
} else {
|
||||
None
|
||||
@@ -737,6 +773,7 @@ impl Indexer {
|
||||
&mut self,
|
||||
pending_running_tasks: &mut FuturesUnordered<TaskHandle<Error>>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &JobTaskDispatcher,
|
||||
) -> Option<Result<ReturnStatus, Error>> {
|
||||
if self.ancestors_needing_indexing.is_empty() {
|
||||
@@ -759,7 +796,8 @@ impl Indexer {
|
||||
self.location.pub_id.clone(),
|
||||
chunked_saves,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -776,7 +814,7 @@ impl Indexer {
|
||||
}
|
||||
}
|
||||
|
||||
self.process_handles(pending_running_tasks, ctx, dispatcher)
|
||||
self.process_handles(pending_running_tasks, ctx, device_id, dispatcher)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -785,6 +823,7 @@ impl Indexer {
|
||||
to_create: Vec<WalkedEntry>,
|
||||
to_update: Vec<WalkedEntry>,
|
||||
ctx: &impl JobContext<OuterCtx>,
|
||||
device_id: device::id::Type,
|
||||
) -> (Vec<tasks::Saver>, Vec<tasks::Updater>) {
|
||||
if self.processing_first_directory {
|
||||
// If we are processing the first directory, we dispatch shallow tasks with higher priority
|
||||
@@ -806,7 +845,8 @@ impl Indexer {
|
||||
self.location.pub_id.clone(),
|
||||
chunked_saves,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -824,7 +864,7 @@ impl Indexer {
|
||||
tasks::Updater::new_shallow(
|
||||
chunked_updates,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -851,7 +891,8 @@ impl Indexer {
|
||||
self.location.pub_id.clone(),
|
||||
chunked_saves,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
device_id,
|
||||
));
|
||||
}
|
||||
save_tasks
|
||||
@@ -878,7 +919,7 @@ impl Indexer {
|
||||
update_tasks.push(tasks::Updater::new_deep(
|
||||
chunked_updates,
|
||||
Arc::clone(ctx.db()),
|
||||
Arc::clone(ctx.sync()),
|
||||
ctx.sync().clone(),
|
||||
));
|
||||
}
|
||||
update_tasks
|
||||
|
||||
@@ -4,17 +4,17 @@ use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData};
|
||||
use sd_core_prisma_helpers::{
|
||||
file_path_pub_and_cas_ids, file_path_to_isolate_with_pub_id, file_path_walker,
|
||||
};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::{DevicePubId, SyncManager};
|
||||
|
||||
use sd_prisma::{
|
||||
prisma::{file_path, indexer_rule, location, PrismaClient, SortOrder},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::OperationFactory;
|
||||
use sd_sync::{sync_db_entry, OperationFactory};
|
||||
use sd_utils::{
|
||||
db::{size_in_bytes_from_db, size_in_bytes_to_db, MissingFieldError},
|
||||
error::{FileIOError, NonUtf8PathError},
|
||||
from_bytes_to_uuid, msgpack,
|
||||
from_bytes_to_uuid,
|
||||
};
|
||||
|
||||
use std::{
|
||||
@@ -50,6 +50,8 @@ pub enum Error {
|
||||
IndexerRuleNotFound(indexer_rule::id::Type),
|
||||
#[error(transparent)]
|
||||
SubPath(#[from] sub_path::Error),
|
||||
#[error("device not found: <device_pub_id='{0}'")]
|
||||
DeviceNotFound(DevicePubId),
|
||||
|
||||
// Internal Errors
|
||||
#[error("database error: {0}")]
|
||||
@@ -136,7 +138,7 @@ async fn update_directory_sizes(
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
) -> Result<(), Error> {
|
||||
let to_sync_and_update = db
|
||||
let (ops, queries) = db
|
||||
._batch(chunk_db_queries(iso_paths_and_sizes.keys(), db))
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -144,22 +146,20 @@ async fn update_directory_sizes(
|
||||
.map(|file_path| {
|
||||
let size_bytes = iso_paths_and_sizes
|
||||
.get(&IsolatedFilePathData::try_from(&file_path)?)
|
||||
.map(|size| size.to_be_bytes().to_vec())
|
||||
.map(|size| size_in_bytes_to_db(*size))
|
||||
.expect("must be here");
|
||||
|
||||
let (sync_param, db_param) = sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes);
|
||||
|
||||
Ok((
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: file_path.pub_id.clone(),
|
||||
},
|
||||
file_path::size_in_bytes_bytes::NAME,
|
||||
msgpack!(size_bytes),
|
||||
[sync_param],
|
||||
),
|
||||
db.file_path()
|
||||
.update(
|
||||
file_path::pub_id::equals(file_path.pub_id),
|
||||
vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))],
|
||||
)
|
||||
.update(file_path::pub_id::equals(file_path.pub_id), vec![db_param])
|
||||
.select(file_path::select!({ id })),
|
||||
))
|
||||
})
|
||||
@@ -167,42 +167,54 @@ async fn update_directory_sizes(
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
sync.write_ops(db, to_sync_and_update).await?;
|
||||
if !ops.is_empty() && !queries.is_empty() {
|
||||
sync.write_ops(db, (ops, queries)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_location_size(
|
||||
location_id: location::id::Type,
|
||||
db: &PrismaClient,
|
||||
location_pub_id: location::pub_id::Type,
|
||||
ctx: &impl OuterContext,
|
||||
) -> Result<(), Error> {
|
||||
let total_size = db
|
||||
.file_path()
|
||||
.find_many(vec![
|
||||
file_path::location_id::equals(Some(location_id)),
|
||||
file_path::materialized_path::equals(Some("/".to_string())),
|
||||
])
|
||||
.select(file_path::select!({ size_in_bytes_bytes }))
|
||||
.exec()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|file_path| {
|
||||
file_path
|
||||
.size_in_bytes_bytes
|
||||
.map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes))
|
||||
})
|
||||
.sum::<u64>();
|
||||
let db = ctx.db();
|
||||
let sync = ctx.sync();
|
||||
|
||||
db.location()
|
||||
.update(
|
||||
location::id::equals(location_id),
|
||||
vec![location::size_in_bytes::set(Some(
|
||||
total_size.to_be_bytes().to_vec(),
|
||||
))],
|
||||
)
|
||||
.exec()
|
||||
.await?;
|
||||
let total_size = size_in_bytes_to_db(
|
||||
db.file_path()
|
||||
.find_many(vec![
|
||||
file_path::location_id::equals(Some(location_id)),
|
||||
file_path::materialized_path::equals(Some("/".to_string())),
|
||||
])
|
||||
.select(file_path::select!({ size_in_bytes_bytes }))
|
||||
.exec()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|file_path| {
|
||||
file_path
|
||||
.size_in_bytes_bytes
|
||||
.map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes))
|
||||
})
|
||||
.sum::<u64>(),
|
||||
);
|
||||
|
||||
let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes);
|
||||
|
||||
sync.write_op(
|
||||
db,
|
||||
sync.shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location_pub_id,
|
||||
},
|
||||
[sync_param],
|
||||
),
|
||||
db.location()
|
||||
.update(location::id::equals(location_id), vec![db_param])
|
||||
.select(location::select!({ id })),
|
||||
)
|
||||
.await?;
|
||||
|
||||
ctx.invalidate_query("locations.list");
|
||||
ctx.invalidate_query("locations.get");
|
||||
@@ -213,7 +225,7 @@ async fn update_location_size(
|
||||
async fn remove_non_existing_file_paths(
|
||||
to_remove: Vec<file_path_pub_and_cas_ids::Data>,
|
||||
db: &PrismaClient,
|
||||
sync: &sd_core_sync::Manager,
|
||||
sync: &SyncManager,
|
||||
) -> Result<u64, Error> {
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
let (sync_params, db_params): (Vec<_>, Vec<_>) = to_remove
|
||||
@@ -228,6 +240,10 @@ async fn remove_non_existing_file_paths(
|
||||
})
|
||||
.unzip();
|
||||
|
||||
if sync_params.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
sync.write_ops(
|
||||
db,
|
||||
(
|
||||
@@ -318,7 +334,7 @@ pub async fn reverse_update_directories_sizes(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let to_sync_and_update = ancestors
|
||||
let (sync_ops, update_queries) = ancestors
|
||||
.into_values()
|
||||
.filter_map(|materialized_path| {
|
||||
if let Some((pub_id, size)) =
|
||||
@@ -326,18 +342,19 @@ pub async fn reverse_update_directories_sizes(
|
||||
{
|
||||
let size_bytes = size_in_bytes_to_db(size);
|
||||
|
||||
let (sync_param, db_param) =
|
||||
sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes);
|
||||
|
||||
Some((
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: pub_id.clone(),
|
||||
},
|
||||
file_path::size_in_bytes_bytes::NAME,
|
||||
msgpack!(size_bytes),
|
||||
),
|
||||
db.file_path().update(
|
||||
file_path::pub_id::equals(pub_id),
|
||||
vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))],
|
||||
[sync_param],
|
||||
),
|
||||
db.file_path()
|
||||
.update(file_path::pub_id::equals(pub_id), vec![db_param])
|
||||
.select(file_path::select!({ id })),
|
||||
))
|
||||
} else {
|
||||
warn!("Got a missing ancestor for a file_path in the database, ignoring...");
|
||||
@@ -346,7 +363,9 @@ pub async fn reverse_update_directories_sizes(
|
||||
})
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
sync.write_ops(db, to_sync_and_update).await?;
|
||||
if !sync_ops.is_empty() && !update_queries.is_empty() {
|
||||
sync.write_ops(db, (sync_ops, update_queries)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ use crate::{
|
||||
|
||||
use sd_core_indexer_rules::{IndexerRule, IndexerRuler};
|
||||
use sd_core_prisma_helpers::location_with_indexer_rules;
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::prisma::PrismaClient;
|
||||
use sd_prisma::prisma::{device, PrismaClient};
|
||||
use sd_task_system::{BaseTaskDispatcher, CancelTaskOnDrop, IntoTask, TaskDispatcher, TaskOutput};
|
||||
use sd_utils::db::maybe_missing;
|
||||
|
||||
@@ -62,6 +62,17 @@ pub async fn shallow(
|
||||
.await?,
|
||||
);
|
||||
|
||||
let device_pub_id = &ctx.sync().device_pub_id;
|
||||
let device_id = ctx
|
||||
.db()
|
||||
.device()
|
||||
.find_unique(device::pub_id::equals(device_pub_id.to_db()))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(indexer::Error::from)?
|
||||
.ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))?
|
||||
.id;
|
||||
|
||||
let Some(walker::Output {
|
||||
to_create,
|
||||
to_update,
|
||||
@@ -96,7 +107,8 @@ pub async fn shallow(
|
||||
to_create,
|
||||
to_update,
|
||||
Arc::clone(db),
|
||||
Arc::clone(sync),
|
||||
sync.clone(),
|
||||
device_id,
|
||||
dispatcher,
|
||||
)
|
||||
.await?
|
||||
@@ -124,7 +136,7 @@ pub async fn shallow(
|
||||
.await?;
|
||||
}
|
||||
|
||||
update_location_size(location.id, db, ctx).await?;
|
||||
update_location_size(location.id, location.pub_id, ctx).await?;
|
||||
}
|
||||
|
||||
if indexed_count > 0 || removed_count > 0 {
|
||||
@@ -203,7 +215,8 @@ async fn save_and_update(
|
||||
to_create: Vec<WalkedEntry>,
|
||||
to_update: Vec<WalkedEntry>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
device_id: device::id::Type,
|
||||
dispatcher: &BaseTaskDispatcher<Error>,
|
||||
) -> Result<Option<Metadata>, Error> {
|
||||
let save_and_update_tasks = to_create
|
||||
@@ -216,7 +229,8 @@ async fn save_and_update(
|
||||
location.pub_id.clone(),
|
||||
chunk.collect::<Vec<_>>(),
|
||||
Arc::clone(&db),
|
||||
Arc::clone(&sync),
|
||||
sync.clone(),
|
||||
device_id,
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task)
|
||||
@@ -229,7 +243,7 @@ async fn save_and_update(
|
||||
tasks::Updater::new_shallow(
|
||||
chunk.collect::<Vec<_>>(),
|
||||
Arc::clone(&db),
|
||||
Arc::clone(&sync),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task),
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
use crate::{indexer, Error};
|
||||
|
||||
use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::{
|
||||
prisma::{file_path, location, PrismaClient},
|
||||
prisma::{device, file_path, location, PrismaClient},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{sync_db_entry, OperationFactory};
|
||||
use sd_sync::{sync_db_entry, sync_entry, OperationFactory};
|
||||
use sd_task_system::{ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId};
|
||||
use sd_utils::{
|
||||
db::{inode_to_db, size_in_bytes_to_db},
|
||||
msgpack,
|
||||
};
|
||||
use sd_utils::db::{inode_to_db, size_in_bytes_to_db};
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
@@ -32,11 +29,12 @@ pub struct Saver {
|
||||
// Received input args
|
||||
location_id: location::id::Type,
|
||||
location_pub_id: location::pub_id::Type,
|
||||
device_id: device::id::Type,
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
|
||||
// Dependencies
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
}
|
||||
|
||||
/// [`Save`] Task output
|
||||
@@ -73,8 +71,9 @@ impl Task<Error> for Saver {
|
||||
#[allow(clippy::blocks_in_conditions)] // Due to `err` on `instrument` macro above
|
||||
async fn run(&mut self, _: &Interrupter) -> Result<ExecStatus, Error> {
|
||||
use file_path::{
|
||||
create_unchecked, date_created, date_indexed, date_modified, extension, hidden, inode,
|
||||
is_dir, location, location_id, materialized_path, name, size_in_bytes_bytes,
|
||||
create_unchecked, date_created, date_indexed, date_modified, device, device_id,
|
||||
extension, hidden, inode, is_dir, location, location_id, materialized_path, name,
|
||||
size_in_bytes_bytes,
|
||||
};
|
||||
|
||||
let start_time = Instant::now();
|
||||
@@ -82,13 +81,14 @@ impl Task<Error> for Saver {
|
||||
let Self {
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
db,
|
||||
sync,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let (sync_stuff, paths): (Vec<_>, Vec<_>) = walked_entries
|
||||
let (create_crdt_ops, paths): (Vec<_>, Vec<_>) = walked_entries
|
||||
.drain(..)
|
||||
.map(
|
||||
|WalkedEntry {
|
||||
@@ -118,13 +118,13 @@ impl Task<Error> for Saver {
|
||||
new file_paths and they were not identified yet"
|
||||
);
|
||||
|
||||
let (sync_params, db_params): (Vec<_>, Vec<_>) = [
|
||||
let (sync_params, db_params) = [
|
||||
(
|
||||
(
|
||||
location::NAME,
|
||||
msgpack!(prisma_sync::location::SyncId {
|
||||
sync_entry!(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location_pub_id.clone()
|
||||
}),
|
||||
},
|
||||
location
|
||||
),
|
||||
location_id::set(Some(*location_id)),
|
||||
),
|
||||
@@ -138,9 +138,18 @@ impl Task<Error> for Saver {
|
||||
sync_db_entry!(modified_at, date_modified),
|
||||
sync_db_entry!(Utc::now(), date_indexed),
|
||||
sync_db_entry!(hidden, hidden),
|
||||
(
|
||||
sync_entry!(
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: sync.device_pub_id.to_db(),
|
||||
},
|
||||
device
|
||||
),
|
||||
device_id::set(Some(*device_id)),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.unzip();
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
(
|
||||
sync.shared_create(
|
||||
@@ -155,12 +164,22 @@ impl Task<Error> for Saver {
|
||||
)
|
||||
.unzip();
|
||||
|
||||
if create_crdt_ops.is_empty() && paths.is_empty() {
|
||||
return Ok(ExecStatus::Done(
|
||||
Output {
|
||||
saved_count: 0,
|
||||
save_duration: Duration::ZERO,
|
||||
}
|
||||
.into_output(),
|
||||
));
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
let saved_count = sync
|
||||
.write_ops(
|
||||
db,
|
||||
(
|
||||
sync_stuff.into_iter().flatten().collect(),
|
||||
create_crdt_ops,
|
||||
db.file_path().create_many(paths).skip_duplicates(),
|
||||
),
|
||||
)
|
||||
@@ -188,12 +207,14 @@ impl Saver {
|
||||
location_pub_id: location::pub_id::Type,
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: TaskId::new_v4(),
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
db,
|
||||
sync,
|
||||
@@ -207,12 +228,14 @@ impl Saver {
|
||||
location_pub_id: location::pub_id::Type,
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: TaskId::new_v4(),
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
db,
|
||||
sync,
|
||||
@@ -228,6 +251,7 @@ struct SaveState {
|
||||
|
||||
location_id: location::id::Type,
|
||||
location_pub_id: location::pub_id::Type,
|
||||
device_id: device::id::Type,
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
}
|
||||
|
||||
@@ -236,7 +260,7 @@ impl SerializableTask<Error> for Saver {
|
||||
|
||||
type DeserializeError = rmp_serde::decode::Error;
|
||||
|
||||
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
|
||||
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
|
||||
|
||||
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
|
||||
let Self {
|
||||
@@ -244,6 +268,7 @@ impl SerializableTask<Error> for Saver {
|
||||
is_shallow,
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
..
|
||||
} = self;
|
||||
@@ -252,6 +277,7 @@ impl SerializableTask<Error> for Saver {
|
||||
is_shallow,
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
})
|
||||
}
|
||||
@@ -266,12 +292,14 @@ impl SerializableTask<Error> for Saver {
|
||||
is_shallow,
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
}| Self {
|
||||
id,
|
||||
is_shallow,
|
||||
location_id,
|
||||
location_pub_id,
|
||||
device_id,
|
||||
walked_entries,
|
||||
db,
|
||||
sync,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{indexer, Error};
|
||||
|
||||
use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::{
|
||||
prisma::{file_path, object, PrismaClient},
|
||||
@@ -39,7 +39,7 @@ pub struct Updater {
|
||||
|
||||
// Dependencies
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
}
|
||||
|
||||
/// [`Update`] Task output
|
||||
@@ -93,7 +93,7 @@ impl Task<Error> for Updater {
|
||||
|
||||
check_interruption!(interrupter);
|
||||
|
||||
let (sync_stuff, paths_to_update) = walked_entries
|
||||
let (crdt_ops, paths_to_update) = walked_entries
|
||||
.drain(..)
|
||||
.map(
|
||||
|WalkedEntry {
|
||||
@@ -138,18 +138,12 @@ impl Task<Error> for Updater {
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
(
|
||||
sync_params
|
||||
.into_iter()
|
||||
.map(|(field, value)| {
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: pub_id.to_db(),
|
||||
},
|
||||
field,
|
||||
value,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
sync.shared_update(
|
||||
prisma_sync::file_path::SyncId {
|
||||
pub_id: pub_id.to_db(),
|
||||
},
|
||||
sync_params,
|
||||
),
|
||||
db.file_path()
|
||||
.update(file_path::pub_id::equals(pub_id.into()), db_params)
|
||||
// selecting id to avoid fetching whole object from database
|
||||
@@ -159,11 +153,18 @@ impl Task<Error> for Updater {
|
||||
)
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
if crdt_ops.is_empty() && paths_to_update.is_empty() {
|
||||
return Ok(ExecStatus::Done(
|
||||
Output {
|
||||
updated_count: 0,
|
||||
update_duration: Duration::ZERO,
|
||||
}
|
||||
.into_output(),
|
||||
));
|
||||
}
|
||||
|
||||
let updated = sync
|
||||
.write_ops(
|
||||
db,
|
||||
(sync_stuff.into_iter().flatten().collect(), paths_to_update),
|
||||
)
|
||||
.write_ops(db, (crdt_ops, paths_to_update))
|
||||
.await
|
||||
.map_err(indexer::Error::from)?;
|
||||
|
||||
@@ -186,7 +187,7 @@ impl Updater {
|
||||
pub fn new_deep(
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: TaskId::new_v4(),
|
||||
@@ -202,7 +203,7 @@ impl Updater {
|
||||
pub fn new_shallow(
|
||||
walked_entries: Vec<WalkedEntry>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: TaskId::new_v4(),
|
||||
@@ -264,7 +265,7 @@ impl SerializableTask<Error> for Updater {
|
||||
|
||||
type DeserializeError = rmp_serde::decode::Error;
|
||||
|
||||
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
|
||||
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
|
||||
|
||||
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
|
||||
let Self {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{Error, NonCriticalError, UpdateEvent};
|
||||
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::prisma::PrismaClient;
|
||||
use sd_task_system::{
|
||||
@@ -98,7 +98,7 @@ impl ProgressUpdate {
|
||||
pub trait OuterContext: Send + Sync + Clone + 'static {
|
||||
fn id(&self) -> Uuid;
|
||||
fn db(&self) -> &Arc<PrismaClient>;
|
||||
fn sync(&self) -> &Arc<SyncManager>;
|
||||
fn sync(&self) -> &SyncManager;
|
||||
fn invalidate_query(&self, query: &'static str);
|
||||
fn query_invalidator(&self) -> impl Fn(&'static str) + Send + Sync;
|
||||
fn report_update(&self, update: UpdateEvent);
|
||||
@@ -158,7 +158,7 @@ where
|
||||
JobCtx: JobContext<OuterCtx>,
|
||||
{
|
||||
fn into_job(self) -> Box<dyn DynJob<OuterCtx, JobCtx>> {
|
||||
let id = JobId::new_v4();
|
||||
let id = JobId::now_v7();
|
||||
|
||||
Box::new(JobHolder {
|
||||
id,
|
||||
@@ -333,7 +333,7 @@ where
|
||||
}
|
||||
|
||||
pub fn new(job: J) -> Self {
|
||||
let id = JobId::new_v4();
|
||||
let id = JobId::now_v7();
|
||||
Self {
|
||||
id,
|
||||
job,
|
||||
|
||||
@@ -290,6 +290,7 @@ impl Report {
|
||||
.map(|id| job::parent::connect(job::id::equals(id.as_bytes().to_vec())))],
|
||||
),
|
||||
)
|
||||
.select(job::select!({ id }))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(ReportError::Create)?;
|
||||
@@ -300,7 +301,7 @@ impl Report {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update(&mut self, db: &PrismaClient) -> Result<(), ReportError> {
|
||||
pub async fn update(&self, db: &PrismaClient) -> Result<(), ReportError> {
|
||||
db.job()
|
||||
.update(
|
||||
job::id::equals(self.id.as_bytes().to_vec()),
|
||||
@@ -318,6 +319,7 @@ impl Report {
|
||||
job::date_completed::set(self.completed_at.map(Into::into)),
|
||||
],
|
||||
)
|
||||
.select(job::select!({ id }))
|
||||
.exec()
|
||||
.await
|
||||
.map_err(ReportError::Update)?;
|
||||
|
||||
@@ -313,7 +313,7 @@ impl<OuterCtx: OuterContext, JobCtx: JobContext<OuterCtx>> JobSystemRunner<Outer
|
||||
Ok(Some(serialized_job)) => {
|
||||
let name = {
|
||||
let db = handle.ctx.db();
|
||||
let mut report = handle.ctx.report_mut().await;
|
||||
let report = handle.ctx.report().await;
|
||||
if let Err(e) = report.update(db).await {
|
||||
error!(?e, "Failed to update report on job shutdown;");
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use crate::media_processor::{self, media_data_extractor};
|
||||
|
||||
use sd_core_prisma_helpers::ObjectPubId;
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::{DevicePubId, SyncManager};
|
||||
|
||||
use sd_file_ext::extensions::{Extension, ImageExtension, ALL_IMAGE_EXTENSIONS};
|
||||
use sd_media_metadata::ExifMetadata;
|
||||
use sd_prisma::{
|
||||
prisma::{exif_data, object, PrismaClient},
|
||||
prisma::{device, exif_data, object, PrismaClient},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{option_sync_db_entry, OperationFactory};
|
||||
use sd_sync::{option_sync_db_entry, sync_entry, OperationFactory};
|
||||
use sd_utils::chain_optional_iter;
|
||||
|
||||
use std::{path::Path, sync::LazyLock};
|
||||
@@ -51,9 +51,20 @@ fn to_query(
|
||||
exif_version,
|
||||
}: ExifMetadata,
|
||||
object_id: exif_data::object_id::Type,
|
||||
device_pub_id: &DevicePubId,
|
||||
) -> (Vec<(&'static str, rmpv::Value)>, exif_data::Create) {
|
||||
let device_pub_id = device_pub_id.to_db();
|
||||
|
||||
let (sync_params, db_params) = chain_optional_iter(
|
||||
[],
|
||||
[(
|
||||
sync_entry!(
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device_pub_id.clone()
|
||||
},
|
||||
exif_data::device
|
||||
),
|
||||
exif_data::device::connect(device::pub_id::equals(device_pub_id)),
|
||||
)],
|
||||
[
|
||||
option_sync_db_entry!(
|
||||
serde_json::to_vec(&camera_data).ok(),
|
||||
@@ -109,24 +120,22 @@ pub async fn save(
|
||||
exif_datas
|
||||
.into_iter()
|
||||
.map(|(exif_data, object_id, object_pub_id)| async move {
|
||||
let (sync_params, create) = to_query(exif_data, object_id);
|
||||
let (sync_params, create) = to_query(exif_data, object_id, &sync.device_pub_id);
|
||||
let db_params = create._params.clone();
|
||||
|
||||
sync.write_ops(
|
||||
sync.write_op(
|
||||
db,
|
||||
(
|
||||
sync.shared_create(
|
||||
prisma_sync::exif_data::SyncId {
|
||||
object: prisma_sync::object::SyncId {
|
||||
pub_id: object_pub_id.into(),
|
||||
},
|
||||
sync.shared_create(
|
||||
prisma_sync::exif_data::SyncId {
|
||||
object: prisma_sync::object::SyncId {
|
||||
pub_id: object_pub_id.into(),
|
||||
},
|
||||
sync_params,
|
||||
),
|
||||
db.exif_data()
|
||||
.upsert(exif_data::object_id::equals(object_id), create, db_params)
|
||||
.select(exif_data::select!({ id })),
|
||||
},
|
||||
sync_params,
|
||||
),
|
||||
db.exif_data()
|
||||
.upsert(exif_data::object_id::equals(object_id), create, db_params)
|
||||
.select(exif_data::select!({ id })),
|
||||
)
|
||||
.await
|
||||
})
|
||||
|
||||
@@ -14,7 +14,11 @@ use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_prisma_helpers::file_path_for_media_processor;
|
||||
|
||||
use sd_file_ext::extensions::Extension;
|
||||
use sd_prisma::prisma::{location, PrismaClient};
|
||||
use sd_prisma::{
|
||||
prisma::{location, PrismaClient},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{sync_db_not_null_entry, OperationFactory};
|
||||
use sd_task_system::{
|
||||
AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId,
|
||||
TaskOutput, TaskStatus, TaskSystemError,
|
||||
@@ -125,7 +129,7 @@ impl Job for MediaProcessor {
|
||||
TaskKind::MediaDataExtractor => {
|
||||
tasks::MediaDataExtractor::deserialize(
|
||||
&task_bytes,
|
||||
(Arc::clone(ctx.db()), Arc::clone(ctx.sync())),
|
||||
(Arc::clone(ctx.db()), ctx.sync().clone()),
|
||||
)
|
||||
.await
|
||||
.map(IntoTask::into_task)
|
||||
@@ -214,15 +218,23 @@ impl Job for MediaProcessor {
|
||||
..
|
||||
} = self;
|
||||
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(
|
||||
location::id::equals(location.id),
|
||||
vec![location::scan_state::set(
|
||||
LocationScanState::Completed as i32,
|
||||
)],
|
||||
let (sync_param, db_param) =
|
||||
sync_db_not_null_entry!(LocationScanState::Completed as i32, location::scan_state);
|
||||
|
||||
ctx.sync()
|
||||
.write_op(
|
||||
ctx.db(),
|
||||
ctx.sync().shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location.pub_id.clone(),
|
||||
},
|
||||
[sync_param],
|
||||
),
|
||||
ctx.db()
|
||||
.location()
|
||||
.update(location::id::equals(location.id), vec![db_param])
|
||||
.select(location::select!({ id })),
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.map_err(media_processor::Error::from)?;
|
||||
|
||||
@@ -632,7 +644,7 @@ impl MediaProcessor {
|
||||
parent_iso_file_path.location_id(),
|
||||
Arc::clone(&self.location_path),
|
||||
Arc::clone(db),
|
||||
Arc::clone(sync),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task)
|
||||
@@ -648,7 +660,7 @@ impl MediaProcessor {
|
||||
parent_iso_file_path.location_id(),
|
||||
Arc::clone(&self.location_path),
|
||||
Arc::clone(db),
|
||||
Arc::clone(sync),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task),
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
};
|
||||
|
||||
use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_prisma::prisma::{location, PrismaClient};
|
||||
use sd_task_system::{
|
||||
@@ -154,7 +154,7 @@ pub async fn shallow(
|
||||
|
||||
async fn dispatch_media_data_extractor_tasks(
|
||||
db: &Arc<PrismaClient>,
|
||||
sync: &Arc<SyncManager>,
|
||||
sync: &SyncManager,
|
||||
parent_iso_file_path: &IsolatedFilePathData<'_>,
|
||||
location_path: &Arc<PathBuf>,
|
||||
dispatcher: &BaseTaskDispatcher<Error>,
|
||||
@@ -185,7 +185,7 @@ async fn dispatch_media_data_extractor_tasks(
|
||||
parent_iso_file_path.location_id(),
|
||||
Arc::clone(location_path),
|
||||
Arc::clone(db),
|
||||
Arc::clone(sync),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task)
|
||||
@@ -201,7 +201,7 @@ async fn dispatch_media_data_extractor_tasks(
|
||||
parent_iso_file_path.location_id(),
|
||||
Arc::clone(location_path),
|
||||
Arc::clone(db),
|
||||
Arc::clone(sync),
|
||||
sync.clone(),
|
||||
)
|
||||
})
|
||||
.map(IntoTask::into_task),
|
||||
@@ -220,7 +220,7 @@ async fn dispatch_media_data_extractor_tasks(
|
||||
async fn dispatch_thumbnailer_tasks(
|
||||
parent_iso_file_path: &IsolatedFilePathData<'_>,
|
||||
should_regenerate: bool,
|
||||
location_path: &PathBuf,
|
||||
location_path: &Path,
|
||||
dispatcher: &BaseTaskDispatcher<Error>,
|
||||
ctx: &impl OuterContext,
|
||||
) -> Result<Vec<TaskHandle<Error>>, Error> {
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
|
||||
use sd_core_file_path_helper::IsolatedFilePathData;
|
||||
use sd_core_prisma_helpers::{file_path_for_media_processor, ObjectPubId};
|
||||
use sd_core_sync::Manager as SyncManager;
|
||||
use sd_core_sync::SyncManager;
|
||||
|
||||
use sd_media_metadata::{ExifMetadata, FFmpegMetadata};
|
||||
use sd_prisma::prisma::{exif_data, ffmpeg_data, file_path, location, object, PrismaClient};
|
||||
@@ -69,7 +69,7 @@ pub struct MediaDataExtractor {
|
||||
|
||||
// Dependencies
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -275,7 +275,7 @@ impl MediaDataExtractor {
|
||||
location_id: location::id::Type,
|
||||
location_path: Arc<PathBuf>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
) -> Self {
|
||||
let mut output = Output::default();
|
||||
|
||||
@@ -316,7 +316,7 @@ impl MediaDataExtractor {
|
||||
location_id: location::id::Type,
|
||||
location_path: Arc<PathBuf>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
) -> Self {
|
||||
Self::new(Kind::Exif, file_paths, location_id, location_path, db, sync)
|
||||
}
|
||||
@@ -327,7 +327,7 @@ impl MediaDataExtractor {
|
||||
location_id: location::id::Type,
|
||||
location_path: Arc<PathBuf>,
|
||||
db: Arc<PrismaClient>,
|
||||
sync: Arc<SyncManager>,
|
||||
sync: SyncManager,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
Kind::FFmpeg,
|
||||
@@ -550,7 +550,7 @@ impl SerializableTask<Error> for MediaDataExtractor {
|
||||
|
||||
type DeserializeError = rmp_serde::decode::Error;
|
||||
|
||||
type DeserializeCtx = (Arc<PrismaClient>, Arc<SyncManager>);
|
||||
type DeserializeCtx = (Arc<PrismaClient>, SyncManager);
|
||||
|
||||
async fn serialize(self) -> Result<Vec<u8>, Self::SerializeError> {
|
||||
let Self {
|
||||
|
||||
@@ -60,7 +60,7 @@ impl<'de> Deserialize<'de> for RulePerKind {
|
||||
|
||||
struct FieldsVisitor;
|
||||
|
||||
impl<'de> de::Visitor<'de> for FieldsVisitor {
|
||||
impl de::Visitor<'_> for FieldsVisitor {
|
||||
type Value = Fields;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
||||
@@ -9,8 +9,9 @@ repository.workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Spacedrive Sub-crates
|
||||
sd-prisma = { path = "../../../crates/prisma" }
|
||||
sd-utils = { path = "../../../crates/utils" }
|
||||
sd-cloud-schema = { workspace = true }
|
||||
sd-prisma = { path = "../../../crates/prisma" }
|
||||
sd-utils = { path = "../../../crates/utils" }
|
||||
|
||||
# Workspace dependencies
|
||||
prisma-client-rust = { workspace = true }
|
||||
|
||||
@@ -34,7 +34,6 @@ use sd_utils::{from_bytes_to_uuid, uuid_to_bytes};
|
||||
use std::{borrow::Cow, fmt};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use specta::Type;
|
||||
use uuid::Uuid;
|
||||
|
||||
// File Path selectables!
|
||||
@@ -75,6 +74,20 @@ file_path::select!(file_path_for_media_processor {
|
||||
pub_id
|
||||
}
|
||||
});
|
||||
file_path::select!(file_path_watcher_remove {
|
||||
id
|
||||
pub_id
|
||||
location_id
|
||||
materialized_path
|
||||
is_dir
|
||||
name
|
||||
extension
|
||||
object: select {
|
||||
id
|
||||
pub_id
|
||||
}
|
||||
|
||||
});
|
||||
file_path::select!(file_path_to_isolate {
|
||||
location_id
|
||||
materialized_path
|
||||
@@ -244,7 +257,7 @@ job::select!(job_without_data {
|
||||
location::select!(location_ids_and_path {
|
||||
id
|
||||
pub_id
|
||||
instance_id
|
||||
device: select { pub_id }
|
||||
path
|
||||
});
|
||||
|
||||
@@ -259,6 +272,7 @@ impl From<location_with_indexer_rules::Data> for location::Data {
|
||||
id: data.id,
|
||||
pub_id: data.pub_id,
|
||||
path: data.path,
|
||||
device_id: data.device_id,
|
||||
instance_id: data.instance_id,
|
||||
name: data.name,
|
||||
total_capacity: data.total_capacity,
|
||||
@@ -272,6 +286,7 @@ impl From<location_with_indexer_rules::Data> for location::Data {
|
||||
scan_state: data.scan_state,
|
||||
file_paths: None,
|
||||
indexer_rules: None,
|
||||
device: None,
|
||||
instance: None,
|
||||
}
|
||||
}
|
||||
@@ -283,6 +298,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data {
|
||||
id: data.id,
|
||||
pub_id: data.pub_id.clone(),
|
||||
path: data.path.clone(),
|
||||
device_id: data.device_id,
|
||||
instance_id: data.instance_id,
|
||||
name: data.name.clone(),
|
||||
total_capacity: data.total_capacity,
|
||||
@@ -296,6 +312,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data {
|
||||
scan_state: data.scan_state,
|
||||
file_paths: None,
|
||||
indexer_rules: None,
|
||||
device: None,
|
||||
instance: None,
|
||||
}
|
||||
}
|
||||
@@ -311,7 +328,7 @@ label::include!((take: i64) => label_with_objects {
|
||||
}
|
||||
});
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Type)]
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, specta::Type)]
|
||||
#[serde(transparent)]
|
||||
pub struct CasId<'cas_id>(Cow<'cas_id, str>);
|
||||
|
||||
@@ -321,7 +338,7 @@ impl Clone for CasId<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'cas_id> CasId<'cas_id> {
|
||||
impl CasId<'_> {
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.0.as_ref()
|
||||
@@ -374,17 +391,32 @@ impl From<&CasId<'_>> for String {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
|
||||
#[serde(transparent)]
|
||||
#[repr(transparent)]
|
||||
#[specta(rename = "CoreDevicePubId")]
|
||||
pub struct DevicePubId(PubId);
|
||||
|
||||
impl From<DevicePubId> for sd_cloud_schema::devices::PubId {
|
||||
fn from(DevicePubId(pub_id): DevicePubId) -> Self {
|
||||
Self(pub_id.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
|
||||
#[serde(transparent)]
|
||||
#[repr(transparent)]
|
||||
#[specta(rename = "CoreFilePathPubId")]
|
||||
pub struct FilePathPubId(PubId);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
|
||||
#[serde(transparent)]
|
||||
#[repr(transparent)]
|
||||
#[specta(rename = "CoreObjectPubId")]
|
||||
pub struct ObjectPubId(PubId);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)]
|
||||
#[specta(rename = "CorePubId")]
|
||||
enum PubId {
|
||||
Uuid(Uuid),
|
||||
Vec(Vec<u8>),
|
||||
@@ -392,7 +424,7 @@ enum PubId {
|
||||
|
||||
impl PubId {
|
||||
fn new() -> Self {
|
||||
Self::Uuid(Uuid::new_v4())
|
||||
Self::Uuid(Uuid::now_v7())
|
||||
}
|
||||
|
||||
fn to_db(&self) -> Vec<u8> {
|
||||
@@ -451,6 +483,15 @@ impl From<PubId> for Uuid {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&PubId> for Uuid {
|
||||
fn from(pub_id: &PubId) -> Self {
|
||||
match pub_id {
|
||||
PubId::Uuid(uuid) => *uuid,
|
||||
PubId::Vec(bytes) => from_bytes_to_uuid(bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PubId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
@@ -499,6 +540,12 @@ macro_rules! delegate_pub_id {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&$type_name> for ::uuid::Uuid {
|
||||
fn from(pub_id: &$type_name) -> Self {
|
||||
(&pub_id.0).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl ::std::fmt::Display for $type_name {
|
||||
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
@@ -526,4 +573,4 @@ macro_rules! delegate_pub_id {
|
||||
};
|
||||
}
|
||||
|
||||
delegate_pub_id!(FilePathPubId, ObjectPubId);
|
||||
delegate_pub_id!(FilePathPubId, ObjectPubId, DevicePubId);
|
||||
|
||||
@@ -9,6 +9,8 @@ default = []
|
||||
|
||||
[dependencies]
|
||||
# Spacedrive Sub-crates
|
||||
sd-core-prisma-helpers = { path = "../prisma-helpers" }
|
||||
|
||||
sd-actors = { path = "../../../crates/actors" }
|
||||
sd-prisma = { path = "../../../crates/prisma" }
|
||||
sd-sync = { path = "../../../crates/sync" }
|
||||
@@ -16,8 +18,11 @@ sd-utils = { path = "../../../crates/utils" }
|
||||
|
||||
# Workspace dependencies
|
||||
async-channel = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-concurrency = { workspace = true }
|
||||
itertools = { workspace = true }
|
||||
prisma-client-rust = { workspace = true, features = ["rspc"] }
|
||||
rmp-serde = { workspace = true }
|
||||
rmpv = { workspace = true }
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
use async_channel as chan;
|
||||
|
||||
pub trait ActorTypes {
|
||||
type Event: Send;
|
||||
type Request: Send;
|
||||
type Handler;
|
||||
}
|
||||
|
||||
pub struct ActorIO<T: ActorTypes> {
|
||||
pub event_rx: chan::Receiver<T::Event>,
|
||||
pub req_tx: chan::Sender<T::Request>,
|
||||
}
|
||||
|
||||
impl<T: ActorTypes> Clone for ActorIO<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
event_rx: self.event_rx.clone(),
|
||||
req_tx: self.req_tx.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ActorTypes> ActorIO<T> {
|
||||
pub async fn send(&self, value: T::Request) -> Result<(), chan::SendError<T::Request>> {
|
||||
self.req_tx.send(value).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HandlerIO<T: ActorTypes> {
|
||||
pub event_tx: chan::Sender<T::Event>,
|
||||
pub req_rx: chan::Receiver<T::Request>,
|
||||
}
|
||||
|
||||
pub fn create_actor_io<T: ActorTypes>() -> (ActorIO<T>, HandlerIO<T>) {
|
||||
let (req_tx, req_rx) = chan::bounded(32);
|
||||
let (event_tx, event_rx) = chan::bounded(32);
|
||||
|
||||
(ActorIO { event_rx, req_tx }, HandlerIO { event_tx, req_rx })
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use sd_prisma::{
|
||||
prisma::{
|
||||
crdt_operation, exif_data, file_path, instance, label, label_on_object, location, object,
|
||||
tag, tag_on_object, PrismaClient, SortOrder,
|
||||
crdt_operation, device, exif_data, file_path, label, label_on_object, location, object,
|
||||
storage_statistics, tag, tag_on_object, PrismaClient, SortOrder,
|
||||
},
|
||||
prisma_sync,
|
||||
};
|
||||
@@ -10,49 +10,141 @@ use sd_utils::chain_optional_iter;
|
||||
|
||||
use std::future::Future;
|
||||
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::{crdt_op_unchecked_db, Error};
|
||||
use super::{crdt_op_unchecked_db, Error, SyncManager};
|
||||
|
||||
/// Takes all the syncable data in the database and generates [`CRDTOperations`] for it.
|
||||
/// This is a requirement before the library can sync.
|
||||
pub async fn backfill_operations(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
let lock = sync.timestamp_lock.lock().await;
|
||||
pub async fn backfill_operations(sync: &SyncManager) -> Result<(), Error> {
|
||||
let _lock_guard = sync.sync_lock.lock().await;
|
||||
|
||||
let res = db
|
||||
._transaction()
|
||||
let db = &sync.db;
|
||||
|
||||
let local_device = db
|
||||
.device()
|
||||
.find_unique(device::pub_id::equals(sync.device_pub_id.to_db()))
|
||||
.exec()
|
||||
.await?
|
||||
.ok_or(Error::DeviceNotFound(sync.device_pub_id.clone()))?;
|
||||
|
||||
let local_device_id = local_device.id;
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(9_999_999_999)
|
||||
.run(|db| async move {
|
||||
debug!("backfill started");
|
||||
let start = Instant::now();
|
||||
db.crdt_operation()
|
||||
.delete_many(vec![crdt_operation::instance_id::equals(instance_id)])
|
||||
.delete_many(vec![crdt_operation::device_pub_id::equals(
|
||||
sync.device_pub_id.to_db(),
|
||||
)])
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
paginate_tags(&db, sync, instance_id).await?;
|
||||
paginate_locations(&db, sync, instance_id).await?;
|
||||
paginate_objects(&db, sync, instance_id).await?;
|
||||
paginate_exif_datas(&db, sync, instance_id).await?;
|
||||
paginate_file_paths(&db, sync, instance_id).await?;
|
||||
paginate_tags_on_objects(&db, sync, instance_id).await?;
|
||||
paginate_labels(&db, sync, instance_id).await?;
|
||||
paginate_labels_on_objects(&db, sync, instance_id).await?;
|
||||
backfill_device(&db, sync, local_device).await?;
|
||||
|
||||
(
|
||||
backfill_storage_statistics(&db, sync, local_device_id),
|
||||
paginate_tags(&db, sync),
|
||||
paginate_locations(&db, sync, local_device_id),
|
||||
paginate_objects(&db, sync, local_device_id),
|
||||
paginate_labels(&db, sync),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
(
|
||||
paginate_exif_datas(&db, sync, local_device_id),
|
||||
paginate_file_paths(&db, sync, local_device_id),
|
||||
paginate_tags_on_objects(&db, sync, local_device_id),
|
||||
paginate_labels_on_objects(&db, sync, local_device_id),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
debug!(elapsed = ?start.elapsed(), "backfill ended");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
.await
|
||||
}
|
||||
|
||||
drop(lock);
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn backfill_device(
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
local_device: device::Data,
|
||||
) -> Result<(), Error> {
|
||||
db.crdt_operation()
|
||||
.create_many(vec![crdt_op_unchecked_db(&sync.shared_create(
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: local_device.pub_id,
|
||||
},
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(local_device.name, device::name),
|
||||
option_sync_entry!(local_device.os, device::os),
|
||||
option_sync_entry!(local_device.hardware_model, device::hardware_model),
|
||||
option_sync_entry!(local_device.timestamp, device::timestamp),
|
||||
option_sync_entry!(local_device.date_created, device::date_created),
|
||||
option_sync_entry!(local_device.date_deleted, device::date_deleted),
|
||||
],
|
||||
),
|
||||
))?])
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
res
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn backfill_storage_statistics(
|
||||
db: &PrismaClient,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
let Some(stats) = db
|
||||
.storage_statistics()
|
||||
.find_first(vec![storage_statistics::device_id::equals(Some(device_id))])
|
||||
.include(storage_statistics::include!({device: select { pub_id }}))
|
||||
.exec()
|
||||
.await?
|
||||
else {
|
||||
// Nothing to do
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
db.crdt_operation()
|
||||
.create_many(vec![crdt_op_unchecked_db(&sync.shared_create(
|
||||
prisma_sync::storage_statistics::SyncId {
|
||||
pub_id: stats.pub_id,
|
||||
},
|
||||
chain_optional_iter(
|
||||
[
|
||||
sync_entry!(stats.total_capacity, storage_statistics::total_capacity),
|
||||
sync_entry!(
|
||||
stats.available_capacity,
|
||||
storage_statistics::available_capacity
|
||||
),
|
||||
],
|
||||
[option_sync_entry!(
|
||||
stats.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
storage_statistics::device
|
||||
)],
|
||||
),
|
||||
))?])
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn paginate<T, E1, E2, E3, GetterFut, OperationsFut>(
|
||||
@@ -112,38 +204,32 @@ where
|
||||
}
|
||||
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_tags(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use tag::{color, date_created, date_modified, id, name};
|
||||
|
||||
async fn paginate_tags(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> {
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.tag()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.find_many(vec![tag::id::gt(cursor)])
|
||||
.order_by(tag::id::order(SortOrder::Asc))
|
||||
.exec()
|
||||
},
|
||||
|tag| tag.id,
|
||||
|tags| {
|
||||
tags.into_iter()
|
||||
.flat_map(|t| {
|
||||
.map(|t| {
|
||||
sync.shared_create(
|
||||
prisma_sync::tag::SyncId { pub_id: t.pub_id },
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(t.name, name),
|
||||
option_sync_entry!(t.color, color),
|
||||
option_sync_entry!(t.date_created, date_created),
|
||||
option_sync_entry!(t.date_modified, date_modified),
|
||||
option_sync_entry!(t.name, tag::name),
|
||||
option_sync_entry!(t.color, tag::color),
|
||||
option_sync_entry!(t.date_created, tag::date_created),
|
||||
option_sync_entry!(t.date_modified, tag::date_modified),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -154,25 +240,20 @@ async fn paginate_tags(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_locations(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use location::{
|
||||
available_capacity, date_created, generate_preview_media, hidden, id, include, instance,
|
||||
is_archived, name, path, size_in_bytes, sync_preview_media, total_capacity,
|
||||
};
|
||||
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.location()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.find_many(vec![
|
||||
location::id::gt(cursor),
|
||||
location::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(location::id::order(SortOrder::Asc))
|
||||
.take(1000)
|
||||
.include(include!({
|
||||
instance: select {
|
||||
id
|
||||
pub_id
|
||||
}
|
||||
.include(location::include!({
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
@@ -180,36 +261,44 @@ async fn paginate_locations(
|
||||
|locations| {
|
||||
locations
|
||||
.into_iter()
|
||||
.flat_map(|l| {
|
||||
.map(|l| {
|
||||
sync.shared_create(
|
||||
prisma_sync::location::SyncId { pub_id: l.pub_id },
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(l.name, name),
|
||||
option_sync_entry!(l.path, path),
|
||||
option_sync_entry!(l.total_capacity, total_capacity),
|
||||
option_sync_entry!(l.available_capacity, available_capacity),
|
||||
option_sync_entry!(l.size_in_bytes, size_in_bytes),
|
||||
option_sync_entry!(l.is_archived, is_archived),
|
||||
option_sync_entry!(l.name, location::name),
|
||||
option_sync_entry!(l.path, location::path),
|
||||
option_sync_entry!(l.total_capacity, location::total_capacity),
|
||||
option_sync_entry!(
|
||||
l.available_capacity,
|
||||
location::available_capacity
|
||||
),
|
||||
option_sync_entry!(l.size_in_bytes, location::size_in_bytes),
|
||||
option_sync_entry!(l.is_archived, location::is_archived),
|
||||
option_sync_entry!(
|
||||
l.generate_preview_media,
|
||||
generate_preview_media
|
||||
location::generate_preview_media
|
||||
),
|
||||
option_sync_entry!(l.sync_preview_media, sync_preview_media),
|
||||
option_sync_entry!(l.hidden, hidden),
|
||||
option_sync_entry!(l.date_created, date_created),
|
||||
option_sync_entry!(
|
||||
l.instance.map(|i| {
|
||||
prisma_sync::instance::SyncId { pub_id: i.pub_id }
|
||||
l.sync_preview_media,
|
||||
location::sync_preview_media
|
||||
),
|
||||
option_sync_entry!(l.hidden, location::hidden),
|
||||
option_sync_entry!(l.date_created, location::date_created),
|
||||
option_sync_entry!(
|
||||
l.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
instance
|
||||
location::device
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -220,41 +309,53 @@ async fn paginate_locations(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_objects(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use object::{date_accessed, date_created, favorite, hidden, id, important, kind, note};
|
||||
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.object()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.find_many(vec![
|
||||
object::id::gt(cursor),
|
||||
object::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(object::id::order(SortOrder::Asc))
|
||||
.take(1000)
|
||||
.include(object::include!({
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
|object| object.id,
|
||||
|objects| {
|
||||
objects
|
||||
.into_iter()
|
||||
.flat_map(|o| {
|
||||
.map(|o| {
|
||||
sync.shared_create(
|
||||
prisma_sync::object::SyncId { pub_id: o.pub_id },
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(o.kind, kind),
|
||||
option_sync_entry!(o.hidden, hidden),
|
||||
option_sync_entry!(o.favorite, favorite),
|
||||
option_sync_entry!(o.important, important),
|
||||
option_sync_entry!(o.note, note),
|
||||
option_sync_entry!(o.date_created, date_created),
|
||||
option_sync_entry!(o.date_accessed, date_accessed),
|
||||
option_sync_entry!(o.kind, object::kind),
|
||||
option_sync_entry!(o.hidden, object::hidden),
|
||||
option_sync_entry!(o.favorite, object::favorite),
|
||||
option_sync_entry!(o.important, object::important),
|
||||
option_sync_entry!(o.note, object::note),
|
||||
option_sync_entry!(o.date_created, object::date_created),
|
||||
option_sync_entry!(o.date_accessed, object::date_accessed),
|
||||
option_sync_entry!(
|
||||
o.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
object::device
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -265,22 +366,21 @@ async fn paginate_objects(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_exif_datas(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use exif_data::{
|
||||
artist, camera_data, copyright, description, epoch_time, exif_version, id, include,
|
||||
media_date, media_location, resolution,
|
||||
};
|
||||
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.exif_data()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.find_many(vec![
|
||||
exif_data::id::gt(cursor),
|
||||
exif_data::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(exif_data::id::order(SortOrder::Asc))
|
||||
.take(1000)
|
||||
.include(include!({
|
||||
.include(exif_data::include!({
|
||||
object: select { pub_id }
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
@@ -288,7 +388,7 @@ async fn paginate_exif_datas(
|
||||
|exif_datas| {
|
||||
exif_datas
|
||||
.into_iter()
|
||||
.flat_map(|ed| {
|
||||
.map(|ed| {
|
||||
sync.shared_create(
|
||||
prisma_sync::exif_data::SyncId {
|
||||
object: prisma_sync::object::SyncId {
|
||||
@@ -298,20 +398,28 @@ async fn paginate_exif_datas(
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(ed.resolution, resolution),
|
||||
option_sync_entry!(ed.media_date, media_date),
|
||||
option_sync_entry!(ed.media_location, media_location),
|
||||
option_sync_entry!(ed.camera_data, camera_data),
|
||||
option_sync_entry!(ed.artist, artist),
|
||||
option_sync_entry!(ed.description, description),
|
||||
option_sync_entry!(ed.copyright, copyright),
|
||||
option_sync_entry!(ed.exif_version, exif_version),
|
||||
option_sync_entry!(ed.epoch_time, epoch_time),
|
||||
option_sync_entry!(ed.resolution, exif_data::resolution),
|
||||
option_sync_entry!(ed.media_date, exif_data::media_date),
|
||||
option_sync_entry!(ed.media_location, exif_data::media_location),
|
||||
option_sync_entry!(ed.camera_data, exif_data::camera_data),
|
||||
option_sync_entry!(ed.artist, exif_data::artist),
|
||||
option_sync_entry!(ed.description, exif_data::description),
|
||||
option_sync_entry!(ed.copyright, exif_data::copyright),
|
||||
option_sync_entry!(ed.exif_version, exif_data::exif_version),
|
||||
option_sync_entry!(ed.epoch_time, exif_data::epoch_time),
|
||||
option_sync_entry!(
|
||||
ed.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
exif_data::device
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -322,22 +430,21 @@ async fn paginate_exif_datas(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_file_paths(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use file_path::{
|
||||
cas_id, date_created, date_indexed, date_modified, extension, hidden, id, include, inode,
|
||||
integrity_checksum, is_dir, location, materialized_path, name, object, size_in_bytes_bytes,
|
||||
};
|
||||
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.file_path()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.include(include!({
|
||||
.find_many(vec![
|
||||
file_path::id::gt(cursor),
|
||||
file_path::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(file_path::id::order(SortOrder::Asc))
|
||||
.include(file_path::include!({
|
||||
location: select { pub_id }
|
||||
object: select { pub_id }
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
@@ -345,41 +452,58 @@ async fn paginate_file_paths(
|
||||
|file_paths| {
|
||||
file_paths
|
||||
.into_iter()
|
||||
.flat_map(|fp| {
|
||||
.map(|fp| {
|
||||
sync.shared_create(
|
||||
prisma_sync::file_path::SyncId { pub_id: fp.pub_id },
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(fp.is_dir, is_dir),
|
||||
option_sync_entry!(fp.cas_id, cas_id),
|
||||
option_sync_entry!(fp.integrity_checksum, integrity_checksum),
|
||||
option_sync_entry!(fp.is_dir, file_path::is_dir),
|
||||
option_sync_entry!(fp.cas_id, file_path::cas_id),
|
||||
option_sync_entry!(
|
||||
fp.integrity_checksum,
|
||||
file_path::integrity_checksum
|
||||
),
|
||||
option_sync_entry!(
|
||||
fp.location.map(|l| {
|
||||
prisma_sync::location::SyncId { pub_id: l.pub_id }
|
||||
}),
|
||||
location
|
||||
file_path::location
|
||||
),
|
||||
option_sync_entry!(
|
||||
fp.object.map(|o| {
|
||||
prisma_sync::object::SyncId { pub_id: o.pub_id }
|
||||
}),
|
||||
object
|
||||
file_path::object
|
||||
),
|
||||
option_sync_entry!(
|
||||
fp.materialized_path,
|
||||
file_path::materialized_path
|
||||
),
|
||||
option_sync_entry!(fp.name, file_path::name),
|
||||
option_sync_entry!(fp.extension, file_path::extension),
|
||||
option_sync_entry!(fp.hidden, file_path::hidden),
|
||||
option_sync_entry!(
|
||||
fp.size_in_bytes_bytes,
|
||||
file_path::size_in_bytes_bytes
|
||||
),
|
||||
option_sync_entry!(fp.inode, file_path::inode),
|
||||
option_sync_entry!(fp.date_created, file_path::date_created),
|
||||
option_sync_entry!(fp.date_modified, file_path::date_modified),
|
||||
option_sync_entry!(fp.date_indexed, file_path::date_indexed),
|
||||
option_sync_entry!(
|
||||
fp.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
file_path::device
|
||||
),
|
||||
option_sync_entry!(fp.materialized_path, materialized_path),
|
||||
option_sync_entry!(fp.name, name),
|
||||
option_sync_entry!(fp.extension, extension),
|
||||
option_sync_entry!(fp.hidden, hidden),
|
||||
option_sync_entry!(fp.size_in_bytes_bytes, size_in_bytes_bytes),
|
||||
option_sync_entry!(fp.inode, inode),
|
||||
option_sync_entry!(fp.date_created, date_created),
|
||||
option_sync_entry!(fp.date_modified, date_modified),
|
||||
option_sync_entry!(fp.date_indexed, date_indexed),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -390,20 +514,23 @@ async fn paginate_file_paths(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_tags_on_objects(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use tag_on_object::{date_created, include, object_id, tag_id};
|
||||
|
||||
paginate_relation(
|
||||
|group_id, item_id| {
|
||||
db.tag_on_object()
|
||||
.find_many(vec![tag_id::gt(group_id), object_id::gt(item_id)])
|
||||
.order_by(tag_id::order(SortOrder::Asc))
|
||||
.order_by(object_id::order(SortOrder::Asc))
|
||||
.include(include!({
|
||||
.find_many(vec![
|
||||
tag_on_object::tag_id::gt(group_id),
|
||||
tag_on_object::object_id::gt(item_id),
|
||||
tag_on_object::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(tag_on_object::tag_id::order(SortOrder::Asc))
|
||||
.order_by(tag_on_object::object_id::order(SortOrder::Asc))
|
||||
.include(tag_on_object::include!({
|
||||
tag: select { pub_id }
|
||||
object: select { pub_id }
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
@@ -411,7 +538,7 @@ async fn paginate_tags_on_objects(
|
||||
|tag_on_objects| {
|
||||
tag_on_objects
|
||||
.into_iter()
|
||||
.flat_map(|t_o| {
|
||||
.map(|t_o| {
|
||||
sync.relation_create(
|
||||
prisma_sync::tag_on_object::SyncId {
|
||||
tag: prisma_sync::tag::SyncId {
|
||||
@@ -423,11 +550,21 @@ async fn paginate_tags_on_objects(
|
||||
},
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[option_sync_entry!(t_o.date_created, date_created)],
|
||||
[
|
||||
option_sync_entry!(t_o.date_created, tag_on_object::date_created),
|
||||
option_sync_entry!(
|
||||
t_o.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
tag_on_object::device
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -436,37 +573,31 @@ async fn paginate_tags_on_objects(
|
||||
}
|
||||
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_labels(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use label::{date_created, date_modified, id};
|
||||
|
||||
async fn paginate_labels(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> {
|
||||
paginate(
|
||||
|cursor| {
|
||||
db.label()
|
||||
.find_many(vec![id::gt(cursor)])
|
||||
.order_by(id::order(SortOrder::Asc))
|
||||
.find_many(vec![label::id::gt(cursor)])
|
||||
.order_by(label::id::order(SortOrder::Asc))
|
||||
.exec()
|
||||
},
|
||||
|label| label.id,
|
||||
|labels| {
|
||||
labels
|
||||
.into_iter()
|
||||
.flat_map(|l| {
|
||||
.map(|l| {
|
||||
sync.shared_create(
|
||||
prisma_sync::label::SyncId { name: l.name },
|
||||
chain_optional_iter(
|
||||
[],
|
||||
[
|
||||
option_sync_entry!(l.date_created, date_created),
|
||||
option_sync_entry!(l.date_modified, date_modified),
|
||||
option_sync_entry!(l.date_created, label::date_created),
|
||||
option_sync_entry!(l.date_modified, label::date_modified),
|
||||
],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
@@ -477,20 +608,23 @@ async fn paginate_labels(
|
||||
#[instrument(skip(db, sync), err)]
|
||||
async fn paginate_labels_on_objects(
|
||||
db: &PrismaClient,
|
||||
sync: &crate::Manager,
|
||||
instance_id: instance::id::Type,
|
||||
sync: &SyncManager,
|
||||
device_id: device::id::Type,
|
||||
) -> Result<(), Error> {
|
||||
use label_on_object::{date_created, include, label_id, object_id};
|
||||
|
||||
paginate_relation(
|
||||
|group_id, item_id| {
|
||||
db.label_on_object()
|
||||
.find_many(vec![label_id::gt(group_id), object_id::gt(item_id)])
|
||||
.order_by(label_id::order(SortOrder::Asc))
|
||||
.order_by(object_id::order(SortOrder::Asc))
|
||||
.include(include!({
|
||||
.find_many(vec![
|
||||
label_on_object::label_id::gt(group_id),
|
||||
label_on_object::object_id::gt(item_id),
|
||||
label_on_object::device_id::equals(Some(device_id)),
|
||||
])
|
||||
.order_by(label_on_object::label_id::order(SortOrder::Asc))
|
||||
.order_by(label_on_object::object_id::order(SortOrder::Asc))
|
||||
.include(label_on_object::include!({
|
||||
object: select { pub_id }
|
||||
label: select { name }
|
||||
device: select { pub_id }
|
||||
}))
|
||||
.exec()
|
||||
},
|
||||
@@ -498,7 +632,7 @@ async fn paginate_labels_on_objects(
|
||||
|label_on_objects| {
|
||||
label_on_objects
|
||||
.into_iter()
|
||||
.flat_map(|l_o| {
|
||||
.map(|l_o| {
|
||||
sync.relation_create(
|
||||
prisma_sync::label_on_object::SyncId {
|
||||
label: prisma_sync::label::SyncId {
|
||||
@@ -508,10 +642,20 @@ async fn paginate_labels_on_objects(
|
||||
pub_id: l_o.object.pub_id,
|
||||
},
|
||||
},
|
||||
[sync_entry!(l_o.date_created, date_created)],
|
||||
chain_optional_iter(
|
||||
[sync_entry!(l_o.date_created, label_on_object::date_created)],
|
||||
[option_sync_entry!(
|
||||
l_o.device.map(|device| {
|
||||
prisma_sync::device::SyncId {
|
||||
pub_id: device.pub_id,
|
||||
}
|
||||
}),
|
||||
label_on_object::device
|
||||
)],
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|o| crdt_op_unchecked_db(&o, instance_id))
|
||||
.map(|o| crdt_op_unchecked_db(&o))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|creates| db.crdt_operation().create_many(creates).exec())
|
||||
},
|
||||
|
||||
@@ -1,82 +1,14 @@
|
||||
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient};
|
||||
use sd_core_prisma_helpers::DevicePubId;
|
||||
|
||||
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, PrismaClient};
|
||||
use sd_sync::CRDTOperation;
|
||||
use sd_utils::from_bytes_to_uuid;
|
||||
use sd_utils::uuid_to_bytes;
|
||||
|
||||
use tracing::instrument;
|
||||
use uhlc::NTP64;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::Error;
|
||||
|
||||
crdt_operation::include!(crdt_with_instance {
|
||||
instance: select { pub_id }
|
||||
});
|
||||
|
||||
cloud_crdt_operation::include!(cloud_crdt_with_instance {
|
||||
instance: select { pub_id }
|
||||
});
|
||||
|
||||
impl crdt_with_instance::Data {
|
||||
#[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations
|
||||
pub const fn timestamp(&self) -> NTP64 {
|
||||
NTP64(self.timestamp as u64)
|
||||
}
|
||||
|
||||
pub fn instance(&self) -> Uuid {
|
||||
from_bytes_to_uuid(&self.instance.pub_id)
|
||||
}
|
||||
|
||||
pub fn into_operation(self) -> Result<CRDTOperation, Error> {
|
||||
Ok(CRDTOperation {
|
||||
instance: self.instance(),
|
||||
timestamp: self.timestamp(),
|
||||
record_id: rmp_serde::from_slice(&self.record_id)?,
|
||||
|
||||
model: {
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
// SAFETY: we will not have more than 2^16 models and we had to store using signed
|
||||
// integers due to SQLite limitations
|
||||
{
|
||||
self.model as u16
|
||||
}
|
||||
},
|
||||
data: rmp_serde::from_slice(&self.data)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl cloud_crdt_with_instance::Data {
|
||||
#[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations
|
||||
pub const fn timestamp(&self) -> NTP64 {
|
||||
NTP64(self.timestamp as u64)
|
||||
}
|
||||
|
||||
pub fn instance(&self) -> Uuid {
|
||||
from_bytes_to_uuid(&self.instance.pub_id)
|
||||
}
|
||||
|
||||
#[instrument(skip(self), err)]
|
||||
pub fn into_operation(self) -> Result<(i32, CRDTOperation), Error> {
|
||||
Ok((
|
||||
self.id,
|
||||
CRDTOperation {
|
||||
instance: self.instance(),
|
||||
timestamp: self.timestamp(),
|
||||
record_id: rmp_serde::from_slice(&self.record_id)?,
|
||||
model: {
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
// SAFETY: we will not have more than 2^16 models and we had to store using signed
|
||||
// integers due to SQLite limitations
|
||||
{
|
||||
self.model as u16
|
||||
}
|
||||
},
|
||||
data: rmp_serde::from_slice(&self.data)?,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(op, db), err)]
|
||||
pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Result<(), Error> {
|
||||
crdt_operation::Create {
|
||||
@@ -87,16 +19,85 @@ pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Resul
|
||||
op.timestamp.0 as i64
|
||||
}
|
||||
},
|
||||
instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()),
|
||||
device_pub_id: uuid_to_bytes(&op.device_pub_id),
|
||||
kind: op.kind().to_string(),
|
||||
data: rmp_serde::to_vec(&op.data)?,
|
||||
model: i32::from(op.model),
|
||||
model: i32::from(op.model_id),
|
||||
record_id: rmp_serde::to_vec(&op.record_id)?,
|
||||
_params: vec![],
|
||||
}
|
||||
.to_query(db)
|
||||
.select(crdt_operation::select!({ id })) // To don't fetch the whole object for nothing
|
||||
.exec()
|
||||
.await
|
||||
.map_or_else(|e| Err(e.into()), |_| Ok(()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn from_crdt_ops(
|
||||
crdt_operation::Data {
|
||||
timestamp,
|
||||
model,
|
||||
record_id,
|
||||
data,
|
||||
device_pub_id,
|
||||
..
|
||||
}: crdt_operation::Data,
|
||||
) -> Result<CRDTOperation, Error> {
|
||||
Ok(CRDTOperation {
|
||||
device_pub_id: DevicePubId::from(device_pub_id).into(),
|
||||
timestamp: {
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
{
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
NTP64(timestamp as u64)
|
||||
}
|
||||
},
|
||||
model_id: {
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
{
|
||||
// SAFETY: we will not have more than 2^16 models and we had to store using signed
|
||||
// integers due to SQLite limitations
|
||||
model as u16
|
||||
}
|
||||
},
|
||||
record_id: rmp_serde::from_slice(&record_id)?,
|
||||
data: rmp_serde::from_slice(&data)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_cloud_crdt_ops(
|
||||
cloud_crdt_operation::Data {
|
||||
id,
|
||||
timestamp,
|
||||
model,
|
||||
record_id,
|
||||
data,
|
||||
device_pub_id,
|
||||
..
|
||||
}: cloud_crdt_operation::Data,
|
||||
) -> Result<(cloud_crdt_operation::id::Type, CRDTOperation), Error> {
|
||||
Ok((
|
||||
id,
|
||||
CRDTOperation {
|
||||
device_pub_id: DevicePubId::from(device_pub_id).into(),
|
||||
timestamp: {
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
{
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
NTP64(timestamp as u64)
|
||||
}
|
||||
},
|
||||
model_id: {
|
||||
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
|
||||
{
|
||||
// SAFETY: we will not have more than 2^16 models and we had to store using signed
|
||||
// integers due to SQLite limitations
|
||||
model as u16
|
||||
}
|
||||
},
|
||||
record_id: rmp_serde::from_slice(&record_id)?,
|
||||
data: rmp_serde::from_slice(&data)?,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1,641 +0,0 @@
|
||||
use sd_prisma::{
|
||||
prisma::{crdt_operation, PrismaClient, SortOrder},
|
||||
prisma_sync::ModelSyncData,
|
||||
};
|
||||
use sd_sync::{
|
||||
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, CompressedCRDTOperations,
|
||||
OperationKind,
|
||||
};
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
future::IntoFuture,
|
||||
num::NonZeroU128,
|
||||
ops::Deref,
|
||||
pin::pin,
|
||||
sync::{atomic::Ordering, Arc},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
use async_channel as chan;
|
||||
use futures::{stream, FutureExt, StreamExt};
|
||||
use futures_concurrency::{
|
||||
future::{Race, TryJoin},
|
||||
stream::Merge,
|
||||
};
|
||||
use prisma_client_rust::chrono::{DateTime, Utc};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{debug, error, instrument, trace, warn};
|
||||
use uhlc::{Timestamp, NTP64};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
actor::{create_actor_io, ActorIO, ActorTypes, HandlerIO},
|
||||
db_operation::write_crdt_op_to_db,
|
||||
Error, SharedState,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[must_use]
|
||||
/// Stuff that can be handled outside the actor
|
||||
pub enum Request {
|
||||
Messages {
|
||||
timestamps: Vec<(Uuid, NTP64)>,
|
||||
tx: oneshot::Sender<()>,
|
||||
},
|
||||
FinishedIngesting,
|
||||
}
|
||||
|
||||
/// Stuff that the actor consumes
|
||||
#[derive(Debug)]
|
||||
pub enum Event {
|
||||
Notification,
|
||||
Messages(MessagesEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub enum State {
|
||||
#[default]
|
||||
WaitingForNotification,
|
||||
RetrievingMessages,
|
||||
Ingesting(MessagesEvent),
|
||||
}
|
||||
|
||||
/// The single entrypoint for sync operation ingestion.
|
||||
/// Requests sync operations in a given timestamp range,
|
||||
/// and attempts to write them to the sync operations table along with
|
||||
/// the actual cell that the operation points to.
|
||||
///
|
||||
/// If this actor stops running, no sync operations will
|
||||
/// be applied to the database, independent of whether systems like p2p
|
||||
/// or cloud are exchanging messages.
|
||||
pub struct Actor {
|
||||
state: Option<State>,
|
||||
shared: Arc<SharedState>,
|
||||
io: ActorIO<Self>,
|
||||
}
|
||||
|
||||
impl Actor {
|
||||
#[instrument(skip(self), fields(old_state = ?self.state))]
|
||||
async fn tick(&mut self) {
|
||||
let state = match self
|
||||
.state
|
||||
.take()
|
||||
.expect("ingest actor in inconsistent state")
|
||||
{
|
||||
State::WaitingForNotification => self.waiting_for_notification_state_transition().await,
|
||||
State::RetrievingMessages => self.retrieving_messages_state_transition().await,
|
||||
State::Ingesting(event) => self.ingesting_state_transition(event).await,
|
||||
};
|
||||
|
||||
trace!(?state, "Actor state transitioned;");
|
||||
|
||||
self.state = Some(state);
|
||||
}
|
||||
|
||||
async fn waiting_for_notification_state_transition(&self) -> State {
|
||||
self.shared.active.store(false, Ordering::Relaxed);
|
||||
self.shared.active_notify.notify_waiters();
|
||||
|
||||
loop {
|
||||
match self
|
||||
.io
|
||||
.event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("sync actor receiver unexpectedly closed")
|
||||
{
|
||||
Event::Notification => {
|
||||
trace!("Received notification");
|
||||
break;
|
||||
}
|
||||
Event::Messages(event) => {
|
||||
trace!(
|
||||
?event,
|
||||
"Ignored event message as we're waiting for a `Event::Notification`"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.shared.active.store(true, Ordering::Relaxed);
|
||||
self.shared.active_notify.notify_waiters();
|
||||
|
||||
State::RetrievingMessages
|
||||
}
|
||||
|
||||
async fn retrieving_messages_state_transition(&self) -> State {
|
||||
enum StreamMessage {
|
||||
NewEvent(Event),
|
||||
AckedRequest(Result<(), oneshot::error::RecvError>),
|
||||
}
|
||||
|
||||
let (tx, rx) = oneshot::channel::<()>();
|
||||
|
||||
let timestamps = self
|
||||
.timestamps
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(&uid, ×tamp)| (uid, timestamp))
|
||||
.collect();
|
||||
|
||||
if self
|
||||
.io
|
||||
.send(Request::Messages { timestamps, tx })
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
warn!("Failed to send messages request");
|
||||
}
|
||||
|
||||
let mut msg_stream = pin!((
|
||||
self.io.event_rx.clone().map(StreamMessage::NewEvent),
|
||||
stream::once(rx.map(StreamMessage::AckedRequest)),
|
||||
)
|
||||
.merge());
|
||||
|
||||
loop {
|
||||
if let Some(msg) = msg_stream.next().await {
|
||||
match msg {
|
||||
StreamMessage::NewEvent(event) => {
|
||||
if let Event::Messages(messages) = event {
|
||||
trace!(?messages, "Received messages;");
|
||||
break State::Ingesting(messages);
|
||||
}
|
||||
}
|
||||
StreamMessage::AckedRequest(res) => {
|
||||
if res.is_err() {
|
||||
debug!("messages request ignored");
|
||||
break State::WaitingForNotification;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break State::WaitingForNotification;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ingesting_state_transition(&mut self, event: MessagesEvent) -> State {
|
||||
debug!(
|
||||
messages_count = event.messages.len(),
|
||||
first_message = ?DateTime::<Utc>::from(
|
||||
event.messages
|
||||
.first()
|
||||
.map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time())
|
||||
),
|
||||
last_message = ?DateTime::<Utc>::from(
|
||||
event.messages
|
||||
.last()
|
||||
.map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time())
|
||||
),
|
||||
"Ingesting operations;",
|
||||
);
|
||||
|
||||
for (instance, data) in event.messages.0 {
|
||||
for (model, data) in data {
|
||||
for (record, ops) in data {
|
||||
if let Err(e) = self
|
||||
.process_crdt_operations(instance, model, record, ops)
|
||||
.await
|
||||
{
|
||||
error!(?e, "Failed to ingest CRDT operations;");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tx) = event.wait_tx {
|
||||
if tx.send(()).is_err() {
|
||||
warn!("Failed to send wait_tx signal");
|
||||
}
|
||||
}
|
||||
|
||||
if event.has_more {
|
||||
State::RetrievingMessages
|
||||
} else {
|
||||
{
|
||||
if self.io.send(Request::FinishedIngesting).await.is_err() {
|
||||
error!("Failed to send finished ingesting request");
|
||||
}
|
||||
|
||||
State::WaitingForNotification
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn declare(shared: Arc<SharedState>) -> Handler {
|
||||
let (io, HandlerIO { event_tx, req_rx }) = create_actor_io::<Self>();
|
||||
|
||||
shared
|
||||
.actors
|
||||
.declare(
|
||||
"Sync Ingest",
|
||||
{
|
||||
let shared = Arc::clone(&shared);
|
||||
move |stop| async move {
|
||||
enum Race {
|
||||
Ticked,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
let mut this = Self {
|
||||
state: Some(State::default()),
|
||||
io,
|
||||
shared,
|
||||
};
|
||||
|
||||
while matches!(
|
||||
(
|
||||
this.tick().map(|()| Race::Ticked),
|
||||
stop.into_future().map(|()| Race::Stopped),
|
||||
)
|
||||
.race()
|
||||
.await,
|
||||
Race::Ticked
|
||||
) { /* Everything is Awesome! */ }
|
||||
}
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
Handler { event_tx, req_rx }
|
||||
}
|
||||
|
||||
// where the magic happens
|
||||
#[instrument(skip(self, ops), fields(operations_count = %ops.len()), err)]
|
||||
async fn process_crdt_operations(
|
||||
&mut self,
|
||||
instance: Uuid,
|
||||
model: u16,
|
||||
record_id: rmpv::Value,
|
||||
mut ops: Vec<CompressedCRDTOperation>,
|
||||
) -> Result<(), Error> {
|
||||
let db = &self.db;
|
||||
|
||||
ops.sort_by_key(|op| op.timestamp);
|
||||
|
||||
let new_timestamp = ops.last().expect("Empty ops array").timestamp;
|
||||
|
||||
// first, we update the HLC's timestamp with the incoming one.
|
||||
// this involves a drift check + sets the last time of the clock
|
||||
self.clock
|
||||
.update_with_timestamp(&Timestamp::new(
|
||||
new_timestamp,
|
||||
uhlc::ID::from(NonZeroU128::new(instance.to_u128_le()).expect("Non zero id")),
|
||||
))
|
||||
.expect("timestamp has too much drift!");
|
||||
|
||||
// read the timestamp for the operation's instance, or insert one if it doesn't exist
|
||||
let timestamp = self.timestamps.read().await.get(&instance).copied();
|
||||
|
||||
// Delete - ignores all other messages
|
||||
if let Some(delete_op) = ops
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|op| matches!(op.data, CRDTOperationData::Delete))
|
||||
{
|
||||
trace!("Deleting operation");
|
||||
handle_crdt_deletion(db, instance, model, record_id, delete_op).await?;
|
||||
}
|
||||
// Create + > 0 Update - overwrites the create's data with the updates
|
||||
else if let Some(timestamp) = ops
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp))
|
||||
{
|
||||
trace!("Create + Updates operations");
|
||||
|
||||
// conflict resolution
|
||||
let delete = db
|
||||
.crdt_operation()
|
||||
.find_first(vec![
|
||||
crdt_operation::model::equals(i32::from(model)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
|
||||
])
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Desc))
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
if delete.is_some() {
|
||||
debug!("Found a previous delete operation with the same SyncId, will ignore these operations");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
handle_crdt_create_and_updates(db, instance, model, record_id, ops, timestamp).await?;
|
||||
}
|
||||
// > 0 Update - batches updates with a fake Create op
|
||||
else {
|
||||
trace!("Updates operation");
|
||||
|
||||
let mut data = BTreeMap::new();
|
||||
|
||||
for op in ops.into_iter().rev() {
|
||||
let CRDTOperationData::Update { field, value } = op.data else {
|
||||
unreachable!("Create + Delete should be filtered out!");
|
||||
};
|
||||
|
||||
data.insert(field, (value, op.timestamp));
|
||||
}
|
||||
|
||||
// conflict resolution
|
||||
let (create, updates) = db
|
||||
._batch((
|
||||
db.crdt_operation()
|
||||
.find_first(vec![
|
||||
crdt_operation::model::equals(i32::from(model)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
crdt_operation::kind::equals(OperationKind::Create.to_string()),
|
||||
])
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Desc)),
|
||||
data.iter()
|
||||
.map(|(k, (_, timestamp))| {
|
||||
Ok(db
|
||||
.crdt_operation()
|
||||
.find_first(vec![
|
||||
crdt_operation::timestamp::gt({
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
{
|
||||
timestamp.as_u64() as i64
|
||||
}
|
||||
}),
|
||||
crdt_operation::model::equals(i32::from(model)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(
|
||||
&record_id,
|
||||
)?),
|
||||
crdt_operation::kind::equals(
|
||||
OperationKind::Update(k).to_string(),
|
||||
),
|
||||
])
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Desc)))
|
||||
})
|
||||
.collect::<Result<Vec<_>, Error>>()?,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if create.is_none() {
|
||||
warn!("Failed to find a previous create operation with the same SyncId");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
handle_crdt_updates(db, instance, model, record_id, data, updates).await?;
|
||||
}
|
||||
|
||||
// update the stored timestamp for this instance - will be derived from the crdt operations table on restart
|
||||
let new_ts = NTP64::max(timestamp.unwrap_or_default(), new_timestamp);
|
||||
|
||||
self.timestamps.write().await.insert(instance, new_ts);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_crdt_updates(
|
||||
db: &PrismaClient,
|
||||
instance: Uuid,
|
||||
model: u16,
|
||||
record_id: rmpv::Value,
|
||||
mut data: BTreeMap<String, (rmpv::Value, NTP64)>,
|
||||
updates: Vec<Option<crdt_operation::Data>>,
|
||||
) -> Result<(), Error> {
|
||||
let keys = data.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
// does the same thing as processing ops one-by-one and returning early if a newer op was found
|
||||
for (update, key) in updates.into_iter().zip(keys) {
|
||||
if update.is_some() {
|
||||
data.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 1000)
|
||||
.run(|db| async move {
|
||||
// fake operation to batch them all at once
|
||||
ModelSyncData::from_op(CRDTOperation {
|
||||
instance,
|
||||
model,
|
||||
record_id: record_id.clone(),
|
||||
timestamp: NTP64(0),
|
||||
data: CRDTOperationData::Create(
|
||||
data.iter()
|
||||
.map(|(k, (data, _))| (k.clone(), data.clone()))
|
||||
.collect(),
|
||||
),
|
||||
})
|
||||
.ok_or(Error::InvalidModelId(model))?
|
||||
.exec(&db)
|
||||
.await?;
|
||||
|
||||
// need to only apply ops that haven't been filtered out
|
||||
data.into_iter()
|
||||
.map(|(field, (value, timestamp))| {
|
||||
let record_id = record_id.clone();
|
||||
let db = &db;
|
||||
|
||||
async move {
|
||||
write_crdt_op_to_db(
|
||||
&CRDTOperation {
|
||||
instance,
|
||||
model,
|
||||
record_id,
|
||||
timestamp,
|
||||
data: CRDTOperationData::Update { field, value },
|
||||
},
|
||||
db,
|
||||
)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_join()
|
||||
.await
|
||||
.map(|_| ())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_crdt_create_and_updates(
|
||||
db: &PrismaClient,
|
||||
instance: Uuid,
|
||||
model: u16,
|
||||
record_id: rmpv::Value,
|
||||
ops: Vec<CompressedCRDTOperation>,
|
||||
timestamp: NTP64,
|
||||
) -> Result<(), Error> {
|
||||
let mut data = BTreeMap::new();
|
||||
|
||||
let mut applied_ops = vec![];
|
||||
|
||||
// search for all Updates until a Create is found
|
||||
for op in ops.iter().rev() {
|
||||
match &op.data {
|
||||
CRDTOperationData::Delete => unreachable!("Delete can't exist here!"),
|
||||
CRDTOperationData::Create(create_data) => {
|
||||
for (k, v) in create_data {
|
||||
data.entry(k).or_insert(v);
|
||||
}
|
||||
|
||||
applied_ops.push(op);
|
||||
|
||||
break;
|
||||
}
|
||||
CRDTOperationData::Update { field, value } => {
|
||||
applied_ops.push(op);
|
||||
data.insert(field, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 1000)
|
||||
.run(|db| async move {
|
||||
// fake a create with a bunch of data rather than individual insert
|
||||
ModelSyncData::from_op(CRDTOperation {
|
||||
instance,
|
||||
model,
|
||||
record_id: record_id.clone(),
|
||||
timestamp,
|
||||
data: CRDTOperationData::Create(
|
||||
data.into_iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
),
|
||||
})
|
||||
.ok_or(Error::InvalidModelId(model))?
|
||||
.exec(&db)
|
||||
.await?;
|
||||
|
||||
applied_ops
|
||||
.into_iter()
|
||||
.map(|op| {
|
||||
let record_id = record_id.clone();
|
||||
let db = &db;
|
||||
async move {
|
||||
let operation = CRDTOperation {
|
||||
instance,
|
||||
model,
|
||||
record_id,
|
||||
timestamp: op.timestamp,
|
||||
data: op.data.clone(),
|
||||
};
|
||||
|
||||
write_crdt_op_to_db(&operation, db).await
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_join()
|
||||
.await
|
||||
.map(|_| ())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_crdt_deletion(
|
||||
db: &PrismaClient,
|
||||
instance: Uuid,
|
||||
model: u16,
|
||||
record_id: rmpv::Value,
|
||||
delete_op: &CompressedCRDTOperation,
|
||||
) -> Result<(), Error> {
|
||||
// deletes are the be all and end all, no need to check anything
|
||||
let op = CRDTOperation {
|
||||
instance,
|
||||
model,
|
||||
record_id,
|
||||
timestamp: delete_op.timestamp,
|
||||
data: CRDTOperationData::Delete,
|
||||
};
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 1000)
|
||||
.run(|db| async move {
|
||||
ModelSyncData::from_op(op.clone())
|
||||
.ok_or(Error::InvalidModelId(model))?
|
||||
.exec(&db)
|
||||
.await?;
|
||||
|
||||
write_crdt_op_to_db(&op, &db).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
impl Deref for Actor {
|
||||
type Target = SharedState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.shared
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Handler {
|
||||
pub event_tx: chan::Sender<Event>,
|
||||
pub req_rx: chan::Receiver<Request>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MessagesEvent {
|
||||
pub instance_id: Uuid,
|
||||
pub messages: CompressedCRDTOperations,
|
||||
pub has_more: bool,
|
||||
pub wait_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl ActorTypes for Actor {
|
||||
type Event = Event;
|
||||
type Request = Request;
|
||||
type Handler = Handler;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{sync::atomic::AtomicBool, time::Duration};
|
||||
|
||||
use tokio::sync::Notify;
|
||||
use uhlc::HLCBuilder;
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn new_actor() -> (Handler, Arc<SharedState>) {
|
||||
let instance = Uuid::new_v4();
|
||||
let shared = Arc::new(SharedState {
|
||||
db: sd_prisma::test_db().await,
|
||||
instance,
|
||||
clock: HLCBuilder::new()
|
||||
.with_id(uhlc::ID::from(
|
||||
NonZeroU128::new(instance.to_u128_le()).expect("Non zero id"),
|
||||
))
|
||||
.build(),
|
||||
timestamps: Arc::default(),
|
||||
emit_messages_flag: Arc::new(AtomicBool::new(true)),
|
||||
active: AtomicBool::default(),
|
||||
active_notify: Notify::default(),
|
||||
actors: Arc::default(),
|
||||
});
|
||||
|
||||
(Actor::declare(Arc::clone(&shared)).await, shared)
|
||||
}
|
||||
|
||||
/// If messages tx is dropped, actor should reset and assume no further messages
|
||||
/// will be sent
|
||||
#[tokio::test]
|
||||
async fn messages_request_drop() -> Result<(), ()> {
|
||||
let (ingest, _) = new_actor().await;
|
||||
|
||||
for _ in 0..10 {
|
||||
ingest.event_tx.send(Event::Notification).await.unwrap();
|
||||
|
||||
let Ok(Request::Messages { .. }) = ingest.req_rx.recv().await else {
|
||||
panic!("bruh")
|
||||
};
|
||||
|
||||
// without this the test hangs, idk
|
||||
tokio::time::sleep(Duration::from_millis(0)).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
517
core/crates/sync/src/ingest_utils.rs
Normal file
517
core/crates/sync/src/ingest_utils.rs
Normal file
@@ -0,0 +1,517 @@
|
||||
use sd_core_prisma_helpers::DevicePubId;
|
||||
|
||||
use sd_prisma::{
|
||||
prisma::{crdt_operation, PrismaClient},
|
||||
prisma_sync::ModelSyncData,
|
||||
};
|
||||
use sd_sync::{
|
||||
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationKind, RecordId,
|
||||
};
|
||||
|
||||
use std::{collections::BTreeMap, num::NonZeroU128, sync::Arc};
|
||||
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, instrument, trace, warn};
|
||||
use uhlc::{Timestamp, HLC, NTP64};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{db_operation::write_crdt_op_to_db, Error, TimestampPerDevice};
|
||||
|
||||
crdt_operation::select!(crdt_operation_id { id });
|
||||
|
||||
// where the magic happens
|
||||
#[instrument(skip(clock, ops), fields(operations_count = %ops.len()), err)]
|
||||
pub async fn process_crdt_operations(
|
||||
clock: &HLC,
|
||||
timestamp_per_device: &TimestampPerDevice,
|
||||
sync_lock: Arc<Mutex<()>>,
|
||||
db: &PrismaClient,
|
||||
device_pub_id: DevicePubId,
|
||||
model_id: ModelId,
|
||||
(record_id, mut ops): (RecordId, Vec<CompressedCRDTOperation>),
|
||||
) -> Result<(), Error> {
|
||||
ops.sort_by_key(|op| op.timestamp);
|
||||
|
||||
let new_timestamp = ops.last().expect("Empty ops array").timestamp;
|
||||
|
||||
update_clock(clock, new_timestamp, &device_pub_id);
|
||||
|
||||
// Delete - ignores all other messages
|
||||
if let Some(delete_op) = ops
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|op| matches!(op.data, CRDTOperationData::Delete))
|
||||
{
|
||||
trace!("Deleting operation");
|
||||
handle_crdt_deletion(
|
||||
db,
|
||||
&sync_lock,
|
||||
&device_pub_id,
|
||||
model_id,
|
||||
record_id,
|
||||
delete_op,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
// Create + > 0 Update - overwrites the create's data with the updates
|
||||
else if let Some(timestamp) = ops
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp))
|
||||
{
|
||||
trace!("Create + Updates operations");
|
||||
|
||||
// conflict resolution
|
||||
let delete_count = db
|
||||
.crdt_operation()
|
||||
.count(vec![
|
||||
crdt_operation::model::equals(i32::from(model_id)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
|
||||
])
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
if delete_count > 0 {
|
||||
debug!("Found a previous delete operation with the same SyncId, will ignore these operations");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
handle_crdt_create_and_updates(
|
||||
db,
|
||||
&sync_lock,
|
||||
&device_pub_id,
|
||||
model_id,
|
||||
record_id,
|
||||
ops,
|
||||
timestamp,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
// > 0 Update - batches updates with a fake Create op
|
||||
else {
|
||||
trace!("Updates operation");
|
||||
|
||||
let mut data = BTreeMap::new();
|
||||
|
||||
for op in ops.into_iter().rev() {
|
||||
let CRDTOperationData::Update(fields_and_values) = op.data else {
|
||||
unreachable!("Create + Delete should be filtered out!");
|
||||
};
|
||||
|
||||
for (field, value) in fields_and_values {
|
||||
data.insert(field, (value, op.timestamp));
|
||||
}
|
||||
}
|
||||
|
||||
let earlier_time = data.values().fold(
|
||||
NTP64(u64::from(u32::MAX)),
|
||||
|earlier_time, (_, timestamp)| {
|
||||
if timestamp.0 < earlier_time.0 {
|
||||
*timestamp
|
||||
} else {
|
||||
earlier_time
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// conflict resolution
|
||||
let (create, possible_newer_updates_count) = db
|
||||
._batch((
|
||||
db.crdt_operation().count(vec![
|
||||
crdt_operation::model::equals(i32::from(model_id)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
crdt_operation::kind::equals(OperationKind::Create.to_string()),
|
||||
]),
|
||||
// Fetching all update operations newer than our current earlier timestamp
|
||||
db.crdt_operation()
|
||||
.find_many(vec![
|
||||
crdt_operation::timestamp::gt({
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
{
|
||||
earlier_time.as_u64() as i64
|
||||
}
|
||||
}),
|
||||
crdt_operation::model::equals(i32::from(model_id)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
crdt_operation::kind::starts_with("u".to_string()),
|
||||
])
|
||||
.select(crdt_operation::select!({ kind timestamp })),
|
||||
))
|
||||
.await?;
|
||||
|
||||
if create == 0 {
|
||||
warn!("Failed to find a previous create operation with the same SyncId");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for candidate in possible_newer_updates_count {
|
||||
// The first element is "u" meaning that this is an update, so we skip it
|
||||
for key in candidate
|
||||
.kind
|
||||
.split(':')
|
||||
.filter(|field| !field.is_empty())
|
||||
.skip(1)
|
||||
{
|
||||
// remove entries if we possess locally more recent updates for this field
|
||||
if data.get(key).is_some_and(|(_, new_timestamp)| {
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
{
|
||||
// we need to store as i64 due to SQLite limitations
|
||||
*new_timestamp < NTP64(candidate.timestamp as u64)
|
||||
}
|
||||
}) {
|
||||
data.remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
handle_crdt_updates(db, &sync_lock, &device_pub_id, model_id, record_id, data).await?;
|
||||
}
|
||||
|
||||
update_timestamp_per_device(timestamp_per_device, device_pub_id, new_timestamp).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bulk_ingest_create_only_ops(
|
||||
clock: &HLC,
|
||||
timestamp_per_device: &TimestampPerDevice,
|
||||
db: &PrismaClient,
|
||||
device_pub_id: DevicePubId,
|
||||
model_id: ModelId,
|
||||
ops: Vec<(RecordId, CompressedCRDTOperation)>,
|
||||
sync_lock: Arc<Mutex<()>>,
|
||||
) -> Result<(), Error> {
|
||||
let latest_timestamp = ops.iter().fold(NTP64(0), |latest, (_, op)| {
|
||||
if latest < op.timestamp {
|
||||
op.timestamp
|
||||
} else {
|
||||
latest
|
||||
}
|
||||
});
|
||||
|
||||
update_clock(clock, latest_timestamp, &device_pub_id);
|
||||
|
||||
let ops = ops
|
||||
.into_iter()
|
||||
.map(|(record_id, op)| {
|
||||
rmp_serde::to_vec(&record_id)
|
||||
.map(|serialized_record_id| (record_id, serialized_record_id, op))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// conflict resolution
|
||||
let delete_counts = db
|
||||
._batch(
|
||||
ops.iter()
|
||||
.map(|(_, serialized_record_id, _)| {
|
||||
db.crdt_operation().count(vec![
|
||||
crdt_operation::model::equals(i32::from(model_id)),
|
||||
crdt_operation::record_id::equals(serialized_record_id.clone()),
|
||||
crdt_operation::kind::equals(OperationKind::Delete.to_string()),
|
||||
])
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let lock_guard = sync_lock.lock().await;
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 10000)
|
||||
.with_max_wait(30 * 10000)
|
||||
.run(|db| {
|
||||
let device_pub_id = device_pub_id.clone();
|
||||
|
||||
async move {
|
||||
// complying with borrowck
|
||||
let device_pub_id = &device_pub_id;
|
||||
|
||||
let (crdt_creates, model_sync_data) = ops
|
||||
.into_iter()
|
||||
.zip(delete_counts)
|
||||
.filter_map(|(data, delete_count)| (delete_count == 0).then_some(data))
|
||||
.map(
|
||||
|(
|
||||
record_id,
|
||||
serialized_record_id,
|
||||
CompressedCRDTOperation { timestamp, data },
|
||||
)| {
|
||||
let crdt_create = crdt_operation::CreateUnchecked {
|
||||
timestamp: {
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we have to store using i64 due to SQLite limitations
|
||||
{
|
||||
timestamp.0 as i64
|
||||
}
|
||||
},
|
||||
model: i32::from(model_id),
|
||||
record_id: serialized_record_id,
|
||||
kind: "c".to_string(),
|
||||
data: rmp_serde::to_vec(&data)?,
|
||||
device_pub_id: device_pub_id.to_db(),
|
||||
_params: vec![],
|
||||
};
|
||||
|
||||
// NOTE(@fogodev): I wish I could do a create many here instead of creating separately each
|
||||
// entry, but it's not supported by PCR
|
||||
let model_sync_data = ModelSyncData::from_op(CRDTOperation {
|
||||
device_pub_id: Uuid::from(device_pub_id),
|
||||
model_id,
|
||||
record_id,
|
||||
timestamp,
|
||||
data,
|
||||
})?
|
||||
.exec(&db);
|
||||
|
||||
Ok::<_, Error>((crdt_create, model_sync_data))
|
||||
},
|
||||
)
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
|
||||
model_sync_data.try_join().await?;
|
||||
|
||||
db.crdt_operation().create_many(crdt_creates).exec().await?;
|
||||
|
||||
Ok::<_, Error>(())
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
drop(lock_guard);
|
||||
|
||||
update_timestamp_per_device(timestamp_per_device, device_pub_id, latest_timestamp).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
async fn handle_crdt_updates(
|
||||
db: &PrismaClient,
|
||||
sync_lock: &Mutex<()>,
|
||||
device_pub_id: &DevicePubId,
|
||||
model_id: ModelId,
|
||||
record_id: rmpv::Value,
|
||||
data: BTreeMap<String, (rmpv::Value, NTP64)>,
|
||||
) -> Result<(), Error> {
|
||||
let device_pub_id = sd_sync::DevicePubId::from(device_pub_id);
|
||||
|
||||
let _lock_guard = sync_lock.lock().await;
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 10000)
|
||||
.with_max_wait(30 * 10000)
|
||||
.run(|db| async move {
|
||||
// fake operation to batch them all at once, inserting the latest data on appropriate table
|
||||
ModelSyncData::from_op(CRDTOperation {
|
||||
device_pub_id,
|
||||
model_id,
|
||||
record_id: record_id.clone(),
|
||||
timestamp: NTP64(0),
|
||||
data: CRDTOperationData::Create(
|
||||
data.iter()
|
||||
.map(|(k, (data, _))| (k.clone(), data.clone()))
|
||||
.collect(),
|
||||
),
|
||||
})?
|
||||
.exec(&db)
|
||||
.await?;
|
||||
|
||||
let (fields_and_values, latest_timestamp) = data.into_iter().fold(
|
||||
(BTreeMap::new(), NTP64::default()),
|
||||
|(mut fields_and_values, mut latest_time_stamp), (field, (value, timestamp))| {
|
||||
fields_and_values.insert(field, value);
|
||||
if timestamp > latest_time_stamp {
|
||||
latest_time_stamp = timestamp;
|
||||
}
|
||||
(fields_and_values, latest_time_stamp)
|
||||
},
|
||||
);
|
||||
|
||||
write_crdt_op_to_db(
|
||||
&CRDTOperation {
|
||||
device_pub_id,
|
||||
model_id,
|
||||
record_id,
|
||||
timestamp: latest_timestamp,
|
||||
data: CRDTOperationData::Update(fields_and_values),
|
||||
},
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
async fn handle_crdt_create_and_updates(
|
||||
db: &PrismaClient,
|
||||
sync_lock: &Mutex<()>,
|
||||
device_pub_id: &DevicePubId,
|
||||
model_id: ModelId,
|
||||
record_id: rmpv::Value,
|
||||
ops: Vec<CompressedCRDTOperation>,
|
||||
timestamp: NTP64,
|
||||
) -> Result<(), Error> {
|
||||
let mut data = BTreeMap::new();
|
||||
let device_pub_id = sd_sync::DevicePubId::from(device_pub_id);
|
||||
|
||||
let mut applied_ops = vec![];
|
||||
|
||||
// search for all Updates until a Create is found
|
||||
for op in ops.into_iter().rev() {
|
||||
match &op.data {
|
||||
CRDTOperationData::Delete => unreachable!("Delete can't exist here!"),
|
||||
CRDTOperationData::Create(create_data) => {
|
||||
for (k, v) in create_data {
|
||||
data.entry(k.clone()).or_insert_with(|| v.clone());
|
||||
}
|
||||
|
||||
applied_ops.push(op);
|
||||
|
||||
break;
|
||||
}
|
||||
CRDTOperationData::Update(fields_and_values) => {
|
||||
for (field, value) in fields_and_values {
|
||||
data.insert(field.clone(), value.clone());
|
||||
}
|
||||
|
||||
applied_ops.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _lock_guard = sync_lock.lock().await;
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 10000)
|
||||
.with_max_wait(30 * 10000)
|
||||
.run(|db| async move {
|
||||
// fake a create with a bunch of data rather than individual insert
|
||||
ModelSyncData::from_op(CRDTOperation {
|
||||
device_pub_id,
|
||||
model_id,
|
||||
record_id: record_id.clone(),
|
||||
timestamp,
|
||||
data: CRDTOperationData::Create(data),
|
||||
})?
|
||||
.exec(&db)
|
||||
.await?;
|
||||
|
||||
applied_ops
|
||||
.into_iter()
|
||||
.map(|CompressedCRDTOperation { timestamp, data }| {
|
||||
let record_id = record_id.clone();
|
||||
let db = &db;
|
||||
async move {
|
||||
write_crdt_op_to_db(
|
||||
&CRDTOperation {
|
||||
device_pub_id,
|
||||
timestamp,
|
||||
model_id,
|
||||
record_id,
|
||||
data,
|
||||
},
|
||||
db,
|
||||
)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_join()
|
||||
.await
|
||||
.map(|_| ())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(skip_all, err)]
|
||||
async fn handle_crdt_deletion(
|
||||
db: &PrismaClient,
|
||||
sync_lock: &Mutex<()>,
|
||||
device_pub_id: &DevicePubId,
|
||||
model: u16,
|
||||
record_id: rmpv::Value,
|
||||
delete_op: &CompressedCRDTOperation,
|
||||
) -> Result<(), Error> {
|
||||
// deletes are the be all and end all, except if we never created the object to begin with
|
||||
// in this case we don't need to delete anything
|
||||
|
||||
if db
|
||||
.crdt_operation()
|
||||
.count(vec![
|
||||
crdt_operation::model::equals(i32::from(model)),
|
||||
crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?),
|
||||
])
|
||||
.exec()
|
||||
.await?
|
||||
== 0
|
||||
{
|
||||
// This means that in the other device this entry was created and deleted, before this
|
||||
// device here could even take notice of it. So we don't need to do anything here.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let op = CRDTOperation {
|
||||
device_pub_id: device_pub_id.into(),
|
||||
model_id: model,
|
||||
record_id,
|
||||
timestamp: delete_op.timestamp,
|
||||
data: CRDTOperationData::Delete,
|
||||
};
|
||||
|
||||
let _lock_guard = sync_lock.lock().await;
|
||||
|
||||
db._transaction()
|
||||
.with_timeout(30 * 10000)
|
||||
.with_max_wait(30 * 10000)
|
||||
.run(|db| async move {
|
||||
ModelSyncData::from_op(op.clone())?.exec(&db).await?;
|
||||
|
||||
write_crdt_op_to_db(&op, &db).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
fn update_clock(clock: &HLC, latest_timestamp: NTP64, device_pub_id: &DevicePubId) {
|
||||
// first, we update the HLC's timestamp with the incoming one.
|
||||
// this involves a drift check + sets the last time of the clock
|
||||
clock
|
||||
.update_with_timestamp(&Timestamp::new(
|
||||
latest_timestamp,
|
||||
uhlc::ID::from(
|
||||
NonZeroU128::new(Uuid::from(device_pub_id).to_u128_le()).expect("Non zero id"),
|
||||
),
|
||||
))
|
||||
.expect("timestamp has too much drift!");
|
||||
}
|
||||
|
||||
async fn update_timestamp_per_device(
|
||||
timestamp_per_device: &TimestampPerDevice,
|
||||
device_pub_id: DevicePubId,
|
||||
latest_timestamp: NTP64,
|
||||
) {
|
||||
// read the timestamp for the operation's device, or insert one if it doesn't exist
|
||||
let current_last_timestamp = timestamp_per_device
|
||||
.read()
|
||||
.await
|
||||
.get(&device_pub_id)
|
||||
.copied();
|
||||
|
||||
// update the stored timestamp for this device - will be derived from the crdt operations table on restart
|
||||
let new_ts = NTP64::max(current_last_timestamp.unwrap_or_default(), latest_timestamp);
|
||||
|
||||
timestamp_per_device
|
||||
.write()
|
||||
.await
|
||||
.insert(device_pub_id, new_ts);
|
||||
}
|
||||
@@ -27,45 +27,39 @@
|
||||
#![forbid(deprecated_in_future)]
|
||||
#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)]
|
||||
|
||||
use sd_prisma::prisma::{crdt_operation, instance, PrismaClient};
|
||||
use sd_sync::CRDTOperation;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
use sd_prisma::{
|
||||
prisma::{cloud_crdt_operation, crdt_operation},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_utils::uuid_to_bytes;
|
||||
|
||||
use tokio::sync::{Notify, RwLock};
|
||||
use uuid::Uuid;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use tokio::{sync::RwLock, task::JoinError};
|
||||
|
||||
mod actor;
|
||||
pub mod backfill;
|
||||
mod db_operation;
|
||||
pub mod ingest;
|
||||
mod ingest_utils;
|
||||
mod manager;
|
||||
|
||||
pub use ingest::*;
|
||||
pub use manager::*;
|
||||
pub use db_operation::{from_cloud_crdt_ops, from_crdt_ops, write_crdt_op_to_db};
|
||||
pub use manager::Manager as SyncManager;
|
||||
pub use uhlc::NTP64;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum SyncMessage {
|
||||
pub enum SyncEvent {
|
||||
Ingested,
|
||||
Created,
|
||||
}
|
||||
|
||||
pub type Timestamps = Arc<RwLock<HashMap<Uuid, NTP64>>>;
|
||||
pub use sd_core_prisma_helpers::DevicePubId;
|
||||
pub use sd_sync::{
|
||||
CRDTOperation, CompressedCRDTOperation, CompressedCRDTOperationsPerModel,
|
||||
CompressedCRDTOperationsPerModelPerDevice, ModelId, OperationFactory, RecordId, RelationSyncId,
|
||||
RelationSyncModel, SharedSyncModel, SyncId, SyncModel,
|
||||
};
|
||||
|
||||
pub struct SharedState {
|
||||
pub db: Arc<PrismaClient>,
|
||||
pub emit_messages_flag: Arc<AtomicBool>,
|
||||
pub instance: Uuid,
|
||||
pub timestamps: Timestamps,
|
||||
pub clock: uhlc::HLC,
|
||||
pub active: AtomicBool,
|
||||
pub active_notify: Notify,
|
||||
pub actors: Arc<sd_actors::Actors>,
|
||||
}
|
||||
pub type TimestampPerDevice = Arc<RwLock<HashMap<DevicePubId, NTP64>>>;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
@@ -75,8 +69,16 @@ pub enum Error {
|
||||
Deserialization(#[from] rmp_serde::decode::Error),
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] prisma_client_rust::QueryError),
|
||||
#[error("PrismaSync error: {0}")]
|
||||
PrismaSync(#[from] prisma_sync::Error),
|
||||
#[error("invalid model id: {0}")]
|
||||
InvalidModelId(u16),
|
||||
InvalidModelId(ModelId),
|
||||
#[error("tried to write an empty operations list")]
|
||||
EmptyOperations,
|
||||
#[error("device not found: {0}")]
|
||||
DeviceNotFound(DevicePubId),
|
||||
#[error("processes crdt task panicked")]
|
||||
ProcessCrdtPanic(JoinError),
|
||||
}
|
||||
|
||||
impl From<Error> for rspc::Error {
|
||||
@@ -105,19 +107,16 @@ pub fn crdt_op_db(op: &CRDTOperation) -> Result<crdt_operation::Create, Error> {
|
||||
op.timestamp.as_u64() as i64
|
||||
}
|
||||
},
|
||||
instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()),
|
||||
device_pub_id: uuid_to_bytes(&op.device_pub_id),
|
||||
kind: op.kind().to_string(),
|
||||
data: rmp_serde::to_vec(&op.data)?,
|
||||
model: i32::from(op.model),
|
||||
model: i32::from(op.model_id),
|
||||
record_id: rmp_serde::to_vec(&op.record_id)?,
|
||||
_params: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn crdt_op_unchecked_db(
|
||||
op: &CRDTOperation,
|
||||
instance_id: i32,
|
||||
) -> Result<crdt_operation::CreateUnchecked, Error> {
|
||||
pub fn crdt_op_unchecked_db(op: &CRDTOperation) -> Result<crdt_operation::CreateUnchecked, Error> {
|
||||
Ok(crdt_operation::CreateUnchecked {
|
||||
timestamp: {
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
@@ -126,10 +125,28 @@ pub fn crdt_op_unchecked_db(
|
||||
op.timestamp.as_u64() as i64
|
||||
}
|
||||
},
|
||||
instance_id,
|
||||
device_pub_id: uuid_to_bytes(&op.device_pub_id),
|
||||
kind: op.kind().to_string(),
|
||||
data: rmp_serde::to_vec(&op.data)?,
|
||||
model: i32::from(op.model),
|
||||
model: i32::from(op.model_id),
|
||||
record_id: rmp_serde::to_vec(&op.record_id)?,
|
||||
_params: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cloud_crdt_op_db(op: &CRDTOperation) -> Result<cloud_crdt_operation::Create, Error> {
|
||||
Ok(cloud_crdt_operation::Create {
|
||||
timestamp: {
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
{
|
||||
op.timestamp.as_u64() as i64
|
||||
}
|
||||
},
|
||||
device_pub_id: uuid_to_bytes(&op.device_pub_id),
|
||||
kind: op.data.as_kind().to_string(),
|
||||
data: rmp_serde::to_vec(&op.data)?,
|
||||
model: i32::from(op.model_id),
|
||||
record_id: rmp_serde::to_vec(&op.record_id)?,
|
||||
_params: vec![],
|
||||
})
|
||||
|
||||
@@ -1,35 +1,60 @@
|
||||
use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient, SortOrder};
|
||||
use sd_sync::{CRDTOperation, OperationFactory};
|
||||
use sd_utils::{from_bytes_to_uuid, uuid_to_bytes};
|
||||
use tracing::warn;
|
||||
use sd_core_prisma_helpers::DevicePubId;
|
||||
|
||||
use sd_prisma::{
|
||||
prisma::{cloud_crdt_operation, crdt_operation, device, PrismaClient, SortOrder},
|
||||
prisma_sync,
|
||||
};
|
||||
use sd_sync::{
|
||||
CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationFactory, RecordId,
|
||||
};
|
||||
use sd_utils::timestamp_to_datetime;
|
||||
|
||||
use std::{
|
||||
cmp, fmt,
|
||||
collections::{hash_map::Entry, BTreeMap, HashMap},
|
||||
fmt, mem,
|
||||
num::NonZeroU128,
|
||||
ops::Deref,
|
||||
sync::{
|
||||
atomic::{self, AtomicBool},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use prisma_client_rust::{and, operator::or};
|
||||
use tokio::sync::{broadcast, Mutex, Notify, RwLock};
|
||||
use async_stream::stream;
|
||||
use futures::{stream::FuturesUnordered, Stream, TryStreamExt};
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use itertools::Itertools;
|
||||
use tokio::{
|
||||
spawn,
|
||||
sync::{broadcast, Mutex, Notify, RwLock},
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::{debug, instrument, warn};
|
||||
use uhlc::{HLCBuilder, HLC};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
crdt_op_db,
|
||||
db_operation::{cloud_crdt_with_instance, crdt_with_instance},
|
||||
ingest, Error, SharedState, SyncMessage, NTP64,
|
||||
db_operation::{from_cloud_crdt_ops, from_crdt_ops},
|
||||
ingest_utils::{bulk_ingest_create_only_ops, process_crdt_operations},
|
||||
Error, SyncEvent, TimestampPerDevice, NTP64,
|
||||
};
|
||||
|
||||
const INGESTION_BATCH_SIZE: i64 = 10_000;
|
||||
|
||||
/// Wrapper that spawns the ingest actor and provides utilities for reading and writing sync operations.
|
||||
#[derive(Clone)]
|
||||
pub struct Manager {
|
||||
pub tx: broadcast::Sender<SyncMessage>,
|
||||
pub ingest: ingest::Handler,
|
||||
pub shared: Arc<SharedState>,
|
||||
pub timestamp_lock: Mutex<()>,
|
||||
pub tx: broadcast::Sender<SyncEvent>,
|
||||
pub db: Arc<PrismaClient>,
|
||||
pub emit_messages_flag: Arc<AtomicBool>,
|
||||
pub device_pub_id: DevicePubId,
|
||||
pub timestamp_per_device: TimestampPerDevice,
|
||||
pub clock: Arc<HLC>,
|
||||
pub active: Arc<AtomicBool>,
|
||||
pub active_notify: Arc<Notify>,
|
||||
pub(crate) sync_lock: Arc<Mutex<()>>,
|
||||
pub(crate) available_parallelism: usize,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Manager {
|
||||
@@ -38,29 +63,21 @@ impl fmt::Debug for Manager {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
|
||||
pub struct GetOpsArgs {
|
||||
pub clocks: Vec<(Uuid, NTP64)>,
|
||||
pub count: u32,
|
||||
}
|
||||
|
||||
impl Manager {
|
||||
/// Creates a new manager that can be used to read and write CRDT operations.
|
||||
/// Sync messages are received on the returned [`broadcast::Receiver<SyncMessage>`].
|
||||
pub async fn new(
|
||||
db: Arc<PrismaClient>,
|
||||
current_instance_uuid: Uuid,
|
||||
current_device_pub_id: &DevicePubId,
|
||||
emit_messages_flag: Arc<AtomicBool>,
|
||||
actors: Arc<sd_actors::Actors>,
|
||||
) -> Result<(Self, broadcast::Receiver<SyncMessage>), Error> {
|
||||
let existing_instances = db.instance().find_many(vec![]).exec().await?;
|
||||
) -> Result<(Self, broadcast::Receiver<SyncEvent>), Error> {
|
||||
let existing_devices = db.device().find_many(vec![]).exec().await?;
|
||||
|
||||
Self::with_existing_instances(
|
||||
Self::with_existing_devices(
|
||||
db,
|
||||
current_instance_uuid,
|
||||
current_device_pub_id,
|
||||
emit_messages_flag,
|
||||
&existing_instances,
|
||||
actors,
|
||||
&existing_devices,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -69,33 +86,34 @@ impl Manager {
|
||||
/// Sync messages are received on the returned [`broadcast::Receiver<SyncMessage>`].
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the `current_instance_id` UUID is zeroed.
|
||||
pub async fn with_existing_instances(
|
||||
/// Panics if the `current_device_pub_id` UUID is zeroed, which will never happen as we use `UUIDv7` for the
|
||||
/// device pub id. As this version have a timestamp part, instead of being totally random. So the only
|
||||
/// possible way to get zero from a `UUIDv7` is to go back in time to 1970
|
||||
pub async fn with_existing_devices(
|
||||
db: Arc<PrismaClient>,
|
||||
current_instance_uuid: Uuid,
|
||||
current_device_pub_id: &DevicePubId,
|
||||
emit_messages_flag: Arc<AtomicBool>,
|
||||
existing_instances: &[instance::Data],
|
||||
actors: Arc<sd_actors::Actors>,
|
||||
) -> Result<(Self, broadcast::Receiver<SyncMessage>), Error> {
|
||||
let timestamps = db
|
||||
existing_devices: &[device::Data],
|
||||
) -> Result<(Self, broadcast::Receiver<SyncEvent>), Error> {
|
||||
let latest_timestamp_per_device = db
|
||||
._batch(
|
||||
existing_instances
|
||||
existing_devices
|
||||
.iter()
|
||||
.map(|i| {
|
||||
.map(|device| {
|
||||
db.crdt_operation()
|
||||
.find_first(vec![crdt_operation::instance::is(vec![
|
||||
instance::id::equals(i.id),
|
||||
])])
|
||||
.find_first(vec![crdt_operation::device_pub_id::equals(
|
||||
device.pub_id.clone(),
|
||||
)])
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Desc))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.await?
|
||||
.into_iter()
|
||||
.zip(existing_instances)
|
||||
.map(|(op, i)| {
|
||||
.zip(existing_devices)
|
||||
.map(|(op, device)| {
|
||||
(
|
||||
from_bytes_to_uuid(&i.pub_id),
|
||||
DevicePubId::from(&device.pub_id),
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
NTP64(op.map(|o| o.timestamp).unwrap_or_default() as u64),
|
||||
@@ -105,54 +123,303 @@ impl Manager {
|
||||
|
||||
let (tx, rx) = broadcast::channel(64);
|
||||
|
||||
let clock = HLCBuilder::new()
|
||||
.with_id(uhlc::ID::from(
|
||||
NonZeroU128::new(current_instance_uuid.to_u128_le()).expect("Non zero id"),
|
||||
))
|
||||
.build();
|
||||
|
||||
let shared = Arc::new(SharedState {
|
||||
db,
|
||||
instance: current_instance_uuid,
|
||||
clock,
|
||||
timestamps: Arc::new(RwLock::new(timestamps)),
|
||||
emit_messages_flag,
|
||||
active: AtomicBool::default(),
|
||||
active_notify: Notify::default(),
|
||||
actors,
|
||||
});
|
||||
|
||||
let ingest = ingest::Actor::declare(shared.clone()).await;
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
tx,
|
||||
ingest,
|
||||
shared,
|
||||
timestamp_lock: Mutex::default(),
|
||||
db,
|
||||
device_pub_id: current_device_pub_id.clone(),
|
||||
clock: Arc::new(
|
||||
HLCBuilder::new()
|
||||
.with_id(uhlc::ID::from(
|
||||
NonZeroU128::new(Uuid::from(current_device_pub_id).to_u128_le())
|
||||
.expect("Non zero id"),
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
timestamp_per_device: Arc::new(RwLock::new(latest_timestamp_per_device)),
|
||||
emit_messages_flag,
|
||||
active: Arc::default(),
|
||||
active_notify: Arc::default(),
|
||||
sync_lock: Arc::new(Mutex::default()),
|
||||
available_parallelism: std::thread::available_parallelism()
|
||||
.map_or(1, std::num::NonZero::get),
|
||||
},
|
||||
rx,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<SyncMessage> {
|
||||
async fn fetch_cloud_crdt_ops(
|
||||
&self,
|
||||
model_id: ModelId,
|
||||
batch_size: i64,
|
||||
) -> Result<(Vec<cloud_crdt_operation::id::Type>, Vec<CRDTOperation>), Error> {
|
||||
self.db
|
||||
.cloud_crdt_operation()
|
||||
.find_many(vec![cloud_crdt_operation::model::equals(i32::from(
|
||||
model_id,
|
||||
))])
|
||||
.take(batch_size)
|
||||
.order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
.exec()
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(from_cloud_crdt_ops)
|
||||
.collect::<Result<(Vec<_>, Vec<_>), _>>()
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn ingest_by_model(&self, model_id: ModelId) -> Result<usize, Error> {
|
||||
let mut total_count = 0;
|
||||
|
||||
let mut buckets = (0..self.available_parallelism)
|
||||
.map(|_| FuturesUnordered::new())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut total_fetch_time = Duration::ZERO;
|
||||
let mut total_compression_time = Duration::ZERO;
|
||||
let mut total_work_distribution_time = Duration::ZERO;
|
||||
let mut total_process_time = Duration::ZERO;
|
||||
|
||||
loop {
|
||||
let fetching_start = Instant::now();
|
||||
|
||||
let (ops_ids, ops) = self
|
||||
.fetch_cloud_crdt_ops(model_id, INGESTION_BATCH_SIZE)
|
||||
.await?;
|
||||
if ops_ids.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
total_fetch_time += fetching_start.elapsed();
|
||||
|
||||
let messages_count = ops.len();
|
||||
|
||||
debug!(
|
||||
messages_count,
|
||||
first_message = ?ops
|
||||
.first()
|
||||
.map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)),
|
||||
last_message = ?ops
|
||||
.last()
|
||||
.map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)),
|
||||
"Messages by model to ingest",
|
||||
);
|
||||
|
||||
let compression_start = Instant::now();
|
||||
|
||||
let mut compressed_map =
|
||||
BTreeMap::<Uuid, HashMap<Vec<u8>, (RecordId, Vec<CompressedCRDTOperation>)>>::new();
|
||||
|
||||
for CRDTOperation {
|
||||
device_pub_id,
|
||||
timestamp,
|
||||
model_id: _, // Ignoring model_id as we know it already
|
||||
record_id,
|
||||
data,
|
||||
} in ops
|
||||
{
|
||||
let records = compressed_map.entry(device_pub_id).or_default();
|
||||
|
||||
// Can't use RecordId as a key because rmpv::Value doesn't implement Hash + Eq.
|
||||
// So we use it's serialized bytes as a key.
|
||||
let record_id_bytes =
|
||||
rmp_serde::to_vec_named(&record_id).expect("already serialized to Value");
|
||||
|
||||
match records.entry(record_id_bytes) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
entry
|
||||
.get_mut()
|
||||
.1
|
||||
.push(CompressedCRDTOperation { timestamp, data });
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry
|
||||
.insert((record_id, vec![CompressedCRDTOperation { timestamp, data }]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we separated all operations by their record_ids, we can do an optimization
|
||||
// to process all records that only posses a single create operation, batching them together
|
||||
let mut create_only_ops: BTreeMap<Uuid, Vec<(RecordId, CompressedCRDTOperation)>> =
|
||||
BTreeMap::new();
|
||||
for (device_pub_id, records) in &mut compressed_map {
|
||||
for (record_id, ops) in records.values_mut() {
|
||||
if ops.len() == 1 && matches!(ops[0].data, CRDTOperationData::Create(_)) {
|
||||
create_only_ops
|
||||
.entry(*device_pub_id)
|
||||
.or_default()
|
||||
.push((mem::replace(record_id, rmpv::Value::Nil), ops.remove(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total_count += bulk_process_of_create_only_ops(
|
||||
self.available_parallelism,
|
||||
Arc::clone(&self.clock),
|
||||
Arc::clone(&self.timestamp_per_device),
|
||||
Arc::clone(&self.db),
|
||||
Arc::clone(&self.sync_lock),
|
||||
model_id,
|
||||
create_only_ops,
|
||||
)
|
||||
.await?;
|
||||
|
||||
total_compression_time += compression_start.elapsed();
|
||||
|
||||
let work_distribution_start = Instant::now();
|
||||
|
||||
compressed_map
|
||||
.into_iter()
|
||||
.flat_map(|(device_pub_id, records)| {
|
||||
records.into_values().filter_map(move |(record_id, ops)| {
|
||||
if record_id.is_nil() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// We can process each record in parallel as they are independent
|
||||
|
||||
let clock = Arc::clone(&self.clock);
|
||||
let timestamp_per_device = Arc::clone(&self.timestamp_per_device);
|
||||
let db = Arc::clone(&self.db);
|
||||
let device_pub_id = device_pub_id.into();
|
||||
let sync_lock = Arc::clone(&self.sync_lock);
|
||||
|
||||
Some(async move {
|
||||
let count = ops.len();
|
||||
|
||||
process_crdt_operations(
|
||||
&clock,
|
||||
×tamp_per_device,
|
||||
sync_lock,
|
||||
&db,
|
||||
device_pub_id,
|
||||
model_id,
|
||||
(record_id, ops),
|
||||
)
|
||||
.await
|
||||
.map(|()| count)
|
||||
})
|
||||
})
|
||||
})
|
||||
.enumerate()
|
||||
.for_each(|(idx, fut)| buckets[idx % self.available_parallelism].push(fut));
|
||||
|
||||
total_work_distribution_time += work_distribution_start.elapsed();
|
||||
|
||||
let processing_start = Instant::now();
|
||||
|
||||
let handles = buckets
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.filter(|(_idx, bucket)| !bucket.is_empty())
|
||||
.map(|(idx, bucket)| {
|
||||
let mut bucket = mem::take(bucket);
|
||||
|
||||
spawn(async move {
|
||||
let mut ops_count = 0;
|
||||
let processing_start = Instant::now();
|
||||
while let Some(count) = bucket.try_next().await? {
|
||||
ops_count += count;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Ingested {ops_count} operations in {:?}",
|
||||
processing_start.elapsed()
|
||||
);
|
||||
|
||||
Ok::<_, Error>((ops_count, idx, bucket))
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let results = handles.try_join().await.map_err(Error::ProcessCrdtPanic)?;
|
||||
|
||||
total_process_time += processing_start.elapsed();
|
||||
|
||||
for res in results {
|
||||
let (count, idx, bucket) = res?;
|
||||
|
||||
buckets[idx] = bucket;
|
||||
|
||||
total_count += count;
|
||||
}
|
||||
|
||||
self.db
|
||||
.cloud_crdt_operation()
|
||||
.delete_many(vec![cloud_crdt_operation::id::in_vec(ops_ids)])
|
||||
.exec()
|
||||
.await?;
|
||||
}
|
||||
|
||||
debug!(
|
||||
total_count,
|
||||
?total_fetch_time,
|
||||
?total_compression_time,
|
||||
?total_work_distribution_time,
|
||||
?total_process_time,
|
||||
"Ingested all operations of this model"
|
||||
);
|
||||
|
||||
Ok(total_count)
|
||||
}
|
||||
|
||||
pub async fn ingest_ops(&self) -> Result<usize, Error> {
|
||||
let mut total_count = 0;
|
||||
|
||||
// WARN: this order here exists because sync messages MUST be processed in this exact order
|
||||
// due to relationship dependencies between these tables.
|
||||
total_count += self.ingest_by_model(prisma_sync::device::MODEL_ID).await?;
|
||||
|
||||
total_count += [
|
||||
self.ingest_by_model(prisma_sync::storage_statistics::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::tag::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::location::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::object::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::label::MODEL_ID),
|
||||
]
|
||||
.try_join()
|
||||
.await?
|
||||
.into_iter()
|
||||
.sum::<usize>();
|
||||
|
||||
total_count += [
|
||||
self.ingest_by_model(prisma_sync::exif_data::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::file_path::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::tag_on_object::MODEL_ID),
|
||||
self.ingest_by_model(prisma_sync::label_on_object::MODEL_ID),
|
||||
]
|
||||
.try_join()
|
||||
.await?
|
||||
.into_iter()
|
||||
.sum::<usize>();
|
||||
|
||||
if self.tx.send(SyncEvent::Ingested).is_err() {
|
||||
warn!("failed to send ingested message on `ingest_ops`");
|
||||
}
|
||||
|
||||
Ok(total_count)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<SyncEvent> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
|
||||
pub async fn write_ops<'item, Q>(
|
||||
&self,
|
||||
tx: &PrismaClient,
|
||||
(mut ops, queries): (Vec<CRDTOperation>, Q),
|
||||
(ops, queries): (Vec<CRDTOperation>, Q),
|
||||
) -> Result<Q::ReturnValue, Error>
|
||||
where
|
||||
Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send,
|
||||
{
|
||||
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
|
||||
let lock = self.timestamp_lock.lock().await;
|
||||
if ops.is_empty() {
|
||||
return Err(Error::EmptyOperations);
|
||||
}
|
||||
|
||||
for op in &mut ops {
|
||||
op.timestamp = *self.get_clock().new_timestamp().get_time();
|
||||
}
|
||||
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
|
||||
let lock_guard = self.sync_lock.lock().await;
|
||||
|
||||
let (res, _) = tx
|
||||
._batch((
|
||||
@@ -164,18 +431,17 @@ impl Manager {
|
||||
.await?;
|
||||
|
||||
if let Some(last) = ops.last() {
|
||||
self.shared
|
||||
.timestamps
|
||||
self.timestamp_per_device
|
||||
.write()
|
||||
.await
|
||||
.insert(self.instance, last.timestamp);
|
||||
.insert(self.device_pub_id.clone(), last.timestamp);
|
||||
}
|
||||
|
||||
if self.tx.send(SyncMessage::Created).is_err() {
|
||||
if self.tx.send(SyncEvent::Created).is_err() {
|
||||
warn!("failed to send created message on `write_ops`");
|
||||
}
|
||||
|
||||
drop(lock);
|
||||
drop(lock_guard);
|
||||
|
||||
res
|
||||
} else {
|
||||
@@ -188,160 +454,289 @@ impl Manager {
|
||||
pub async fn write_op<'item, Q>(
|
||||
&self,
|
||||
tx: &PrismaClient,
|
||||
mut op: CRDTOperation,
|
||||
op: CRDTOperation,
|
||||
query: Q,
|
||||
) -> Result<Q::ReturnValue, Error>
|
||||
where
|
||||
Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send,
|
||||
{
|
||||
let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) {
|
||||
let lock = self.timestamp_lock.lock().await;
|
||||
|
||||
op.timestamp = *self.get_clock().new_timestamp().get_time();
|
||||
let lock_guard = self.sync_lock.lock().await;
|
||||
|
||||
let ret = tx._batch((crdt_op_db(&op)?.to_query(tx), query)).await?.1;
|
||||
|
||||
if self.tx.send(SyncMessage::Created).is_err() {
|
||||
if self.tx.send(SyncEvent::Created).is_err() {
|
||||
warn!("failed to send created message on `write_op`");
|
||||
}
|
||||
|
||||
drop(lock);
|
||||
drop(lock_guard);
|
||||
|
||||
ret
|
||||
} else {
|
||||
tx._batch(vec![query]).await?.remove(0)
|
||||
};
|
||||
|
||||
self.shared
|
||||
.timestamps
|
||||
self.timestamp_per_device
|
||||
.write()
|
||||
.await
|
||||
.insert(self.instance, op.timestamp);
|
||||
.insert(self.device_pub_id.clone(), op.timestamp);
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
pub async fn get_instance_ops(
|
||||
&self,
|
||||
count: u32,
|
||||
instance_uuid: Uuid,
|
||||
timestamp: NTP64,
|
||||
) -> Result<Vec<CRDTOperation>, Error> {
|
||||
self.db
|
||||
.crdt_operation()
|
||||
.find_many(vec![
|
||||
crdt_operation::instance::is(vec![instance::pub_id::equals(uuid_to_bytes(
|
||||
&instance_uuid,
|
||||
))]),
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
crdt_operation::timestamp::gt(timestamp.as_u64() as i64),
|
||||
])
|
||||
.take(i64::from(count))
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
.include(crdt_with_instance::include())
|
||||
.exec()
|
||||
.await?
|
||||
// pub async fn get_device_ops(
|
||||
// &self,
|
||||
// count: u32,
|
||||
// device_pub_id: DevicePubId,
|
||||
// timestamp: NTP64,
|
||||
// ) -> Result<Vec<CRDTOperation>, Error> {
|
||||
// self.db
|
||||
// .crdt_operation()
|
||||
// .find_many(vec![
|
||||
// crdt_operation::device_pub_id::equals(device_pub_id.into()),
|
||||
// #[allow(clippy::cast_possible_wrap)]
|
||||
// crdt_operation::timestamp::gt(timestamp.as_u64() as i64),
|
||||
// ])
|
||||
// .take(i64::from(count))
|
||||
// .order_by(crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
// .exec()
|
||||
// .await?
|
||||
// .into_iter()
|
||||
// .map(from_crdt_ops)
|
||||
// .collect()
|
||||
// }
|
||||
|
||||
pub fn stream_device_ops<'a>(
|
||||
&'a self,
|
||||
device_pub_id: &'a DevicePubId,
|
||||
chunk_size: u32,
|
||||
initial_timestamp: NTP64,
|
||||
) -> impl Stream<Item = Result<Vec<CRDTOperation>, Error>> + Send + 'a {
|
||||
stream! {
|
||||
let mut current_initial_timestamp = initial_timestamp;
|
||||
|
||||
loop {
|
||||
match self.db.crdt_operation()
|
||||
.find_many(vec![
|
||||
crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
crdt_operation::timestamp::gt(current_initial_timestamp.as_u64() as i64),
|
||||
])
|
||||
.take(i64::from(chunk_size))
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
.exec()
|
||||
.await
|
||||
{
|
||||
Ok(ops) if ops.is_empty() => break,
|
||||
|
||||
Ok(ops) => match ops
|
||||
.into_iter()
|
||||
.map(from_crdt_ops)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
{
|
||||
Ok(ops) => {
|
||||
debug!(
|
||||
start_datetime = ?ops
|
||||
.first()
|
||||
.map(|op| timestamp_to_datetime(op.timestamp)),
|
||||
end_datetime = ?ops
|
||||
.last()
|
||||
.map(|op| timestamp_to_datetime(op.timestamp)),
|
||||
count = ops.len(),
|
||||
"Streaming crdt ops",
|
||||
);
|
||||
|
||||
if let Some(last_op) = ops.last() {
|
||||
current_initial_timestamp = last_op.timestamp;
|
||||
}
|
||||
|
||||
yield Ok(ops);
|
||||
}
|
||||
|
||||
Err(e) => return yield Err(e),
|
||||
}
|
||||
|
||||
Err(e) => return yield Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pub async fn get_ops(
|
||||
// &self,
|
||||
// count: u32,
|
||||
// timestamp_per_device: Vec<(DevicePubId, NTP64)>,
|
||||
// ) -> Result<Vec<CRDTOperation>, Error> {
|
||||
// let mut ops = self
|
||||
// .db
|
||||
// .crdt_operation()
|
||||
// .find_many(vec![or(timestamp_per_device
|
||||
// .iter()
|
||||
// .map(|(device_pub_id, timestamp)| {
|
||||
// and![
|
||||
// crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
|
||||
// crdt_operation::timestamp::gt({
|
||||
// #[allow(clippy::cast_possible_wrap)]
|
||||
// // SAFETY: we had to store using i64 due to SQLite limitations
|
||||
// {
|
||||
// timestamp.as_u64() as i64
|
||||
// }
|
||||
// })
|
||||
// ]
|
||||
// })
|
||||
// .chain([crdt_operation::device_pub_id::not_in_vec(
|
||||
// timestamp_per_device
|
||||
// .iter()
|
||||
// .map(|(device_pub_id, _)| device_pub_id.to_db())
|
||||
// .collect(),
|
||||
// )])
|
||||
// .collect())])
|
||||
// .take(i64::from(count))
|
||||
// .order_by(crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
// .exec()
|
||||
// .await?;
|
||||
|
||||
// ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) {
|
||||
// cmp::Ordering::Equal => {
|
||||
// from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id))
|
||||
// }
|
||||
// o => o,
|
||||
// });
|
||||
|
||||
// ops.into_iter()
|
||||
// .take(count as usize)
|
||||
// .map(from_crdt_ops)
|
||||
// .collect()
|
||||
// }
|
||||
|
||||
// pub async fn get_cloud_ops(
|
||||
// &self,
|
||||
// count: u32,
|
||||
// timestamp_per_device: Vec<(DevicePubId, NTP64)>,
|
||||
// ) -> Result<Vec<(cloud_crdt_operation::id::Type, CRDTOperation)>, Error> {
|
||||
// let mut ops = self
|
||||
// .db
|
||||
// .cloud_crdt_operation()
|
||||
// .find_many(vec![or(timestamp_per_device
|
||||
// .iter()
|
||||
// .map(|(device_pub_id, timestamp)| {
|
||||
// and![
|
||||
// cloud_crdt_operation::device_pub_id::equals(device_pub_id.to_db()),
|
||||
// cloud_crdt_operation::timestamp::gt({
|
||||
// #[allow(clippy::cast_possible_wrap)]
|
||||
// // SAFETY: we had to store using i64 due to SQLite limitations
|
||||
// {
|
||||
// timestamp.as_u64() as i64
|
||||
// }
|
||||
// })
|
||||
// ]
|
||||
// })
|
||||
// .chain([cloud_crdt_operation::device_pub_id::not_in_vec(
|
||||
// timestamp_per_device
|
||||
// .iter()
|
||||
// .map(|(device_pub_id, _)| device_pub_id.to_db())
|
||||
// .collect(),
|
||||
// )])
|
||||
// .collect())])
|
||||
// .take(i64::from(count))
|
||||
// .order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
// .exec()
|
||||
// .await?;
|
||||
|
||||
// ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) {
|
||||
// cmp::Ordering::Equal => {
|
||||
// from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id))
|
||||
// }
|
||||
// o => o,
|
||||
// });
|
||||
|
||||
// ops.into_iter()
|
||||
// .take(count as usize)
|
||||
// .map(from_cloud_crdt_ops)
|
||||
// .collect()
|
||||
// }
|
||||
}
|
||||
|
||||
async fn bulk_process_of_create_only_ops(
|
||||
available_parallelism: usize,
|
||||
clock: Arc<HLC>,
|
||||
timestamp_per_device: TimestampPerDevice,
|
||||
db: Arc<PrismaClient>,
|
||||
sync_lock: Arc<Mutex<()>>,
|
||||
model_id: ModelId,
|
||||
create_only_ops: BTreeMap<Uuid, Vec<(RecordId, CompressedCRDTOperation)>>,
|
||||
) -> Result<usize, Error> {
|
||||
let buckets = (0..available_parallelism)
|
||||
.map(|_| FuturesUnordered::new())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut bucket_idx = 0;
|
||||
|
||||
for (device_pub_id, records) in create_only_ops {
|
||||
records
|
||||
.into_iter()
|
||||
.map(crdt_with_instance::Data::into_operation)
|
||||
.collect()
|
||||
.chunks(100)
|
||||
.into_iter()
|
||||
.for_each(|chunk| {
|
||||
let ops = chunk.collect::<Vec<_>>();
|
||||
|
||||
buckets[bucket_idx % available_parallelism].push({
|
||||
let clock = Arc::clone(&clock);
|
||||
let timestamp_per_device = Arc::clone(×tamp_per_device);
|
||||
let db = Arc::clone(&db);
|
||||
let device_pub_id = device_pub_id.into();
|
||||
let sync_lock = Arc::clone(&sync_lock);
|
||||
|
||||
async move {
|
||||
let count = ops.len();
|
||||
bulk_ingest_create_only_ops(
|
||||
&clock,
|
||||
×tamp_per_device,
|
||||
&db,
|
||||
device_pub_id,
|
||||
model_id,
|
||||
ops,
|
||||
sync_lock,
|
||||
)
|
||||
.await
|
||||
.map(|()| count)
|
||||
}
|
||||
});
|
||||
|
||||
bucket_idx += 1;
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn get_ops(&self, args: GetOpsArgs) -> Result<Vec<CRDTOperation>, Error> {
|
||||
let mut ops = self
|
||||
.db
|
||||
.crdt_operation()
|
||||
.find_many(vec![or(args
|
||||
.clocks
|
||||
.iter()
|
||||
.map(|(instance_id, timestamp)| {
|
||||
and![
|
||||
crdt_operation::instance::is(vec![instance::pub_id::equals(
|
||||
uuid_to_bytes(instance_id)
|
||||
)]),
|
||||
crdt_operation::timestamp::gt({
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
{
|
||||
timestamp.as_u64() as i64
|
||||
}
|
||||
})
|
||||
]
|
||||
})
|
||||
.chain([crdt_operation::instance::is_not(vec![
|
||||
instance::pub_id::in_vec(
|
||||
args.clocks
|
||||
.iter()
|
||||
.map(|(instance_id, _)| uuid_to_bytes(instance_id))
|
||||
.collect(),
|
||||
),
|
||||
])])
|
||||
.collect())])
|
||||
.take(i64::from(args.count))
|
||||
.order_by(crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
.include(crdt_with_instance::include())
|
||||
.exec()
|
||||
.await?;
|
||||
let handles = buckets
|
||||
.into_iter()
|
||||
.map(|mut bucket| {
|
||||
spawn(async move {
|
||||
let mut total_count = 0;
|
||||
|
||||
ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) {
|
||||
cmp::Ordering::Equal => a.instance().cmp(&b.instance()),
|
||||
o => o,
|
||||
});
|
||||
let process_creates_batch_start = Instant::now();
|
||||
|
||||
ops.into_iter()
|
||||
.take(args.count as usize)
|
||||
.map(crdt_with_instance::Data::into_operation)
|
||||
.collect()
|
||||
}
|
||||
while let Some(count) = bucket.try_next().await? {
|
||||
total_count += count;
|
||||
}
|
||||
|
||||
pub async fn get_cloud_ops(
|
||||
&self,
|
||||
args: GetOpsArgs,
|
||||
) -> Result<Vec<(i32, CRDTOperation)>, Error> {
|
||||
let mut ops = self
|
||||
.db
|
||||
.cloud_crdt_operation()
|
||||
.find_many(vec![or(args
|
||||
.clocks
|
||||
.iter()
|
||||
.map(|(instance_id, timestamp)| {
|
||||
and![
|
||||
cloud_crdt_operation::instance::is(vec![instance::pub_id::equals(
|
||||
uuid_to_bytes(instance_id)
|
||||
)]),
|
||||
cloud_crdt_operation::timestamp::gt({
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
// SAFETY: we had to store using i64 due to SQLite limitations
|
||||
{
|
||||
timestamp.as_u64() as i64
|
||||
}
|
||||
})
|
||||
]
|
||||
})
|
||||
.chain([cloud_crdt_operation::instance::is_not(vec![
|
||||
instance::pub_id::in_vec(
|
||||
args.clocks
|
||||
.iter()
|
||||
.map(|(instance_id, _)| uuid_to_bytes(instance_id))
|
||||
.collect(),
|
||||
),
|
||||
])])
|
||||
.collect())])
|
||||
.take(i64::from(args.count))
|
||||
.order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc))
|
||||
.include(cloud_crdt_with_instance::include())
|
||||
.exec()
|
||||
.await?;
|
||||
debug!(
|
||||
"Processed {total_count} creates in {:?}",
|
||||
process_creates_batch_start.elapsed()
|
||||
);
|
||||
|
||||
ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) {
|
||||
cmp::Ordering::Equal => a.instance().cmp(&b.instance()),
|
||||
o => o,
|
||||
});
|
||||
Ok::<_, Error>(total_count)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ops.into_iter()
|
||||
.take(args.count as usize)
|
||||
.map(cloud_crdt_with_instance::Data::into_operation)
|
||||
.collect()
|
||||
}
|
||||
Ok(handles
|
||||
.try_join()
|
||||
.await
|
||||
.map_err(Error::ProcessCrdtPanic)?
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.sum())
|
||||
}
|
||||
|
||||
impl OperationFactory for Manager {
|
||||
@@ -349,15 +744,7 @@ impl OperationFactory for Manager {
|
||||
&self.clock
|
||||
}
|
||||
|
||||
fn get_instance(&self) -> Uuid {
|
||||
self.instance
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Manager {
|
||||
type Target = SharedState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.shared
|
||||
fn get_device_pub_id(&self) -> sd_sync::DevicePubId {
|
||||
sd_sync::DevicePubId::from(&self.device_pub_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
mod mock_instance;
|
||||
|
||||
use sd_core_sync::*;
|
||||
|
||||
use sd_prisma::{prisma::location, prisma_sync};
|
||||
use sd_sync::*;
|
||||
use sd_utils::{msgpack, uuid_to_bytes};
|
||||
|
||||
use mock_instance::Instance;
|
||||
use tracing::info;
|
||||
use tracing_test::traced_test;
|
||||
use uuid::Uuid;
|
||||
|
||||
const MOCK_LOCATION_NAME: &str = "Location 0";
|
||||
const MOCK_LOCATION_PATH: &str = "/User/Anon/Documents";
|
||||
|
||||
async fn write_test_location(instance: &Instance) -> location::Data {
|
||||
let location_pub_id = Uuid::new_v4();
|
||||
|
||||
let location = instance
|
||||
.sync
|
||||
.write_ops(&instance.db, {
|
||||
let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [
|
||||
sync_db_entry!(MOCK_LOCATION_NAME, location::name),
|
||||
sync_db_entry!(MOCK_LOCATION_PATH, location::path),
|
||||
]
|
||||
.into_iter()
|
||||
.unzip();
|
||||
|
||||
(
|
||||
instance.sync.shared_create(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: uuid_to_bytes(&location_pub_id),
|
||||
},
|
||||
sync_ops,
|
||||
),
|
||||
instance
|
||||
.db
|
||||
.location()
|
||||
.create(uuid_to_bytes(&location_pub_id), db_ops),
|
||||
)
|
||||
})
|
||||
.await
|
||||
.expect("failed to create mock location");
|
||||
|
||||
instance
|
||||
.sync
|
||||
.write_ops(&instance.db, {
|
||||
let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [
|
||||
sync_db_entry!(1024, location::total_capacity),
|
||||
sync_db_entry!(512, location::available_capacity),
|
||||
]
|
||||
.into_iter()
|
||||
.unzip();
|
||||
|
||||
(
|
||||
sync_ops
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
instance.sync.shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: uuid_to_bytes(&location_pub_id),
|
||||
},
|
||||
k,
|
||||
v,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
instance
|
||||
.db
|
||||
.location()
|
||||
.update(location::id::equals(location.id), db_ops),
|
||||
)
|
||||
})
|
||||
.await
|
||||
.expect("failed to create mock location");
|
||||
|
||||
location
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn writes_operations_and_rows_together() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let instance = Instance::new(Uuid::new_v4()).await;
|
||||
|
||||
write_test_location(&instance).await;
|
||||
|
||||
let operations = instance
|
||||
.db
|
||||
.crdt_operation()
|
||||
.find_many(vec![])
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
// 1 create, 2 update
|
||||
assert_eq!(operations.len(), 3);
|
||||
assert_eq!(operations[0].model, prisma_sync::location::MODEL_ID as i32);
|
||||
|
||||
let out = instance
|
||||
.sync
|
||||
.get_ops(GetOpsArgs {
|
||||
clocks: vec![],
|
||||
count: 100,
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert_eq!(out.len(), 3);
|
||||
|
||||
let locations = instance.db.location().find_many(vec![]).exec().await?;
|
||||
|
||||
assert_eq!(locations.len(), 1);
|
||||
let location = locations.first().unwrap();
|
||||
assert_eq!(location.name.as_deref(), Some(MOCK_LOCATION_NAME));
|
||||
assert_eq!(location.path.as_deref(), Some(MOCK_LOCATION_PATH));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn operations_send_and_ingest() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let instance1 = Instance::new(Uuid::new_v4()).await;
|
||||
let instance2 = Instance::new(Uuid::new_v4()).await;
|
||||
|
||||
let mut instance2_sync_rx = instance2.sync_rx.resubscribe();
|
||||
|
||||
info!("Created instances!");
|
||||
|
||||
Instance::pair(&instance1, &instance2).await;
|
||||
|
||||
info!("Paired instances!");
|
||||
|
||||
write_test_location(&instance1).await;
|
||||
|
||||
info!("Created mock location!");
|
||||
|
||||
assert!(matches!(
|
||||
instance2_sync_rx.recv().await?,
|
||||
SyncMessage::Ingested
|
||||
));
|
||||
|
||||
let out = instance2
|
||||
.sync
|
||||
.get_ops(GetOpsArgs {
|
||||
clocks: vec![],
|
||||
count: 100,
|
||||
})
|
||||
.await?;
|
||||
|
||||
assert_locations_equality(
|
||||
&instance1.db.location().find_many(vec![]).exec().await?[0],
|
||||
&instance2.db.location().find_many(vec![]).exec().await?[0],
|
||||
);
|
||||
|
||||
assert_eq!(out.len(), 3);
|
||||
|
||||
instance1.teardown().await;
|
||||
instance2.teardown().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_update_after_delete() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let instance1 = Instance::new(Uuid::new_v4()).await;
|
||||
let instance2 = Instance::new(Uuid::new_v4()).await;
|
||||
|
||||
let mut instance2_sync_rx = instance2.sync_rx.resubscribe();
|
||||
|
||||
Instance::pair(&instance1, &instance2).await;
|
||||
|
||||
let location = write_test_location(&instance1).await;
|
||||
|
||||
assert!(matches!(
|
||||
instance2_sync_rx.recv().await?,
|
||||
SyncMessage::Ingested
|
||||
));
|
||||
|
||||
instance2
|
||||
.sync
|
||||
.write_op(
|
||||
&instance2.db,
|
||||
instance2.sync.shared_delete(prisma_sync::location::SyncId {
|
||||
pub_id: location.pub_id.clone(),
|
||||
}),
|
||||
instance2.db.location().delete_many(vec![]),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert!(matches!(
|
||||
instance1.sync_rx.resubscribe().recv().await?,
|
||||
SyncMessage::Ingested
|
||||
));
|
||||
|
||||
instance1
|
||||
.sync
|
||||
.write_op(
|
||||
&instance1.db,
|
||||
instance1.sync.shared_update(
|
||||
prisma_sync::location::SyncId {
|
||||
pub_id: location.pub_id.clone(),
|
||||
},
|
||||
"name",
|
||||
msgpack!("New Location"),
|
||||
),
|
||||
instance1.db.location().find_many(vec![]),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// one spare update operation that actually gets ignored by instance 2
|
||||
assert_eq!(instance1.db.crdt_operation().count(vec![]).exec().await?, 5);
|
||||
assert_eq!(instance2.db.crdt_operation().count(vec![]).exec().await?, 4);
|
||||
|
||||
assert_eq!(instance1.db.location().count(vec![]).exec().await?, 0);
|
||||
// the whole point of the test - the update (which is ingested as an upsert) should be ignored
|
||||
assert_eq!(instance2.db.location().count(vec![]).exec().await?, 0);
|
||||
|
||||
instance1.teardown().await;
|
||||
instance2.teardown().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_locations_equality(l1: &location::Data, l2: &location::Data) {
|
||||
assert_eq!(l1.pub_id, l2.pub_id, "pub id");
|
||||
assert_eq!(l1.name, l2.name, "name");
|
||||
assert_eq!(l1.path, l2.path, "path");
|
||||
assert_eq!(l1.total_capacity, l2.total_capacity, "total capacity");
|
||||
assert_eq!(
|
||||
l1.available_capacity, l2.available_capacity,
|
||||
"available capacity"
|
||||
);
|
||||
assert_eq!(l1.size_in_bytes, l2.size_in_bytes, "size in bytes");
|
||||
assert_eq!(l1.is_archived, l2.is_archived, "is archived");
|
||||
assert_eq!(
|
||||
l1.generate_preview_media, l2.generate_preview_media,
|
||||
"generate preview media"
|
||||
);
|
||||
assert_eq!(
|
||||
l1.sync_preview_media, l2.sync_preview_media,
|
||||
"sync preview media"
|
||||
);
|
||||
assert_eq!(l1.hidden, l2.hidden, "hidden");
|
||||
assert_eq!(l1.date_created, l2.date_created, "date created");
|
||||
assert_eq!(l1.scan_state, l2.scan_state, "scan state");
|
||||
assert_eq!(l1.instance_id, l2.instance_id, "instance id");
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
use sd_core_sync::*;
|
||||
|
||||
use sd_prisma::prisma;
|
||||
use sd_sync::CompressedCRDTOperations;
|
||||
use sd_utils::uuid_to_bytes;
|
||||
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
use prisma_client_rust::chrono::Utc;
|
||||
use tokio::{fs, spawn, sync::broadcast};
|
||||
use tracing::{info, instrument, warn, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn db_path(id: Uuid) -> String {
|
||||
format!("/tmp/test-{id}.db")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Instance {
|
||||
pub id: Uuid,
|
||||
pub db: Arc<prisma::PrismaClient>,
|
||||
pub sync: Arc<sd_core_sync::Manager>,
|
||||
pub sync_rx: Arc<broadcast::Receiver<SyncMessage>>,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub async fn new(id: Uuid) -> Arc<Self> {
|
||||
let url = format!("file:{}", db_path(id));
|
||||
|
||||
let db = Arc::new(
|
||||
prisma::PrismaClient::_builder()
|
||||
.with_url(url.to_string())
|
||||
.build()
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
db._db_push().await.unwrap();
|
||||
|
||||
db.instance()
|
||||
.create(
|
||||
uuid_to_bytes(&id),
|
||||
vec![],
|
||||
vec![],
|
||||
Utc::now().into(),
|
||||
Utc::now().into(),
|
||||
vec![],
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (sync, sync_rx) = sd_core_sync::Manager::new(
|
||||
Arc::clone(&db),
|
||||
id,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
.expect("failed to create sync manager");
|
||||
|
||||
Arc::new(Self {
|
||||
id,
|
||||
db,
|
||||
sync: Arc::new(sync),
|
||||
sync_rx: Arc::new(sync_rx),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn teardown(&self) {
|
||||
fs::remove_file(db_path(self.id)).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn pair(instance1: &Arc<Self>, instance2: &Arc<Self>) {
|
||||
#[instrument(skip(left, right))]
|
||||
async fn half(left: &Arc<Instance>, right: &Arc<Instance>, context: &'static str) {
|
||||
left.db
|
||||
.instance()
|
||||
.create(
|
||||
uuid_to_bytes(&right.id),
|
||||
vec![],
|
||||
vec![],
|
||||
Utc::now().into(),
|
||||
Utc::now().into(),
|
||||
vec![],
|
||||
)
|
||||
.exec()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
spawn({
|
||||
let mut sync_rx_left = left.sync_rx.resubscribe();
|
||||
let right = Arc::clone(right);
|
||||
|
||||
async move {
|
||||
while let Ok(msg) = sync_rx_left.recv().await {
|
||||
info!(?msg, "sync_rx_left received message");
|
||||
if matches!(msg, SyncMessage::Created) {
|
||||
right
|
||||
.sync
|
||||
.ingest
|
||||
.event_tx
|
||||
.send(ingest::Event::Notification)
|
||||
.await
|
||||
.unwrap();
|
||||
info!("sent notification to instance 2");
|
||||
}
|
||||
}
|
||||
}
|
||||
.in_current_span()
|
||||
});
|
||||
|
||||
spawn({
|
||||
let left = Arc::clone(left);
|
||||
let right = Arc::clone(right);
|
||||
|
||||
async move {
|
||||
while let Ok(msg) = right.sync.ingest.req_rx.recv().await {
|
||||
info!(?msg, "right instance received request");
|
||||
match msg {
|
||||
ingest::Request::Messages { timestamps, tx } => {
|
||||
let messages = left
|
||||
.sync
|
||||
.get_ops(GetOpsArgs {
|
||||
clocks: timestamps,
|
||||
count: 100,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ingest = &right.sync.ingest;
|
||||
|
||||
ingest
|
||||
.event_tx
|
||||
.send(ingest::Event::Messages(ingest::MessagesEvent {
|
||||
messages: CompressedCRDTOperations::new(messages),
|
||||
has_more: false,
|
||||
instance_id: left.id,
|
||||
wait_tx: None,
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if tx.send(()).is_err() {
|
||||
warn!("failed to send ack to instance 1");
|
||||
}
|
||||
}
|
||||
ingest::Request::FinishedIngesting => {
|
||||
right.sync.tx.send(SyncMessage::Ingested).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.in_current_span()
|
||||
});
|
||||
}
|
||||
|
||||
half(instance1, instance2, "instance1 -> instance2").await;
|
||||
half(instance2, instance1, "instance2 -> instance1").await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the `node` table. If the table is not empty, all the data it contains will be lost.
|
||||
- You are about to drop the column `instance_id` on the `cloud_crdt_operation` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `instance_id` on the `crdt_operation` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `instance_pub_id` on the `storage_statistics` table. All the data in the column will be lost.
|
||||
- Added the required column `device_pub_id` to the `cloud_crdt_operation` table without a default value. This is not possible if the table is not empty.
|
||||
- Added the required column `device_pub_id` to the `crdt_operation` table without a default value. This is not possible if the table is not empty.
|
||||
|
||||
*/
|
||||
-- DropIndex
|
||||
DROP INDEX "node_pub_id_key";
|
||||
|
||||
-- DropTable
|
||||
PRAGMA foreign_keys=off;
|
||||
DROP TABLE "node";
|
||||
PRAGMA foreign_keys=on;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "device" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"pub_id" BLOB NOT NULL,
|
||||
"name" TEXT,
|
||||
"os" INTEGER,
|
||||
"hardware_model" INTEGER,
|
||||
"timestamp" BIGINT,
|
||||
"date_created" DATETIME,
|
||||
"date_deleted" DATETIME
|
||||
);
|
||||
|
||||
-- RedefineTables
|
||||
PRAGMA defer_foreign_keys=ON;
|
||||
PRAGMA foreign_keys=OFF;
|
||||
CREATE TABLE "new_cloud_crdt_operation" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"timestamp" BIGINT NOT NULL,
|
||||
"model" INTEGER NOT NULL,
|
||||
"record_id" BLOB NOT NULL,
|
||||
"kind" TEXT NOT NULL,
|
||||
"data" BLOB NOT NULL,
|
||||
"device_pub_id" BLOB NOT NULL,
|
||||
CONSTRAINT "cloud_crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_cloud_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "cloud_crdt_operation";
|
||||
DROP TABLE "cloud_crdt_operation";
|
||||
ALTER TABLE "new_cloud_crdt_operation" RENAME TO "cloud_crdt_operation";
|
||||
CREATE INDEX "cloud_crdt_operation_timestamp_idx" ON "cloud_crdt_operation"("timestamp");
|
||||
CREATE TABLE "new_crdt_operation" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"timestamp" BIGINT NOT NULL,
|
||||
"model" INTEGER NOT NULL,
|
||||
"record_id" BLOB NOT NULL,
|
||||
"kind" TEXT NOT NULL,
|
||||
"data" BLOB NOT NULL,
|
||||
"device_pub_id" BLOB NOT NULL,
|
||||
CONSTRAINT "crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "crdt_operation";
|
||||
DROP TABLE "crdt_operation";
|
||||
ALTER TABLE "new_crdt_operation" RENAME TO "crdt_operation";
|
||||
CREATE INDEX "crdt_operation_timestamp_idx" ON "crdt_operation"("timestamp");
|
||||
CREATE TABLE "new_exif_data" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"resolution" BLOB,
|
||||
"media_date" BLOB,
|
||||
"media_location" BLOB,
|
||||
"camera_data" BLOB,
|
||||
"artist" TEXT,
|
||||
"description" TEXT,
|
||||
"copyright" TEXT,
|
||||
"exif_version" TEXT,
|
||||
"epoch_time" BIGINT,
|
||||
"object_id" INTEGER NOT NULL,
|
||||
"device_pub_id" BLOB,
|
||||
CONSTRAINT "exif_data_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
CONSTRAINT "exif_data_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_exif_data" ("artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution") SELECT "artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution" FROM "exif_data";
|
||||
DROP TABLE "exif_data";
|
||||
ALTER TABLE "new_exif_data" RENAME TO "exif_data";
|
||||
CREATE UNIQUE INDEX "exif_data_object_id_key" ON "exif_data"("object_id");
|
||||
CREATE TABLE "new_file_path" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"pub_id" BLOB NOT NULL,
|
||||
"is_dir" BOOLEAN,
|
||||
"cas_id" TEXT,
|
||||
"integrity_checksum" TEXT,
|
||||
"location_id" INTEGER,
|
||||
"materialized_path" TEXT,
|
||||
"name" TEXT,
|
||||
"extension" TEXT,
|
||||
"hidden" BOOLEAN,
|
||||
"size_in_bytes" TEXT,
|
||||
"size_in_bytes_bytes" BLOB,
|
||||
"inode" BLOB,
|
||||
"object_id" INTEGER,
|
||||
"key_id" INTEGER,
|
||||
"date_created" DATETIME,
|
||||
"date_modified" DATETIME,
|
||||
"date_indexed" DATETIME,
|
||||
"device_pub_id" BLOB,
|
||||
CONSTRAINT "file_path_location_id_fkey" FOREIGN KEY ("location_id") REFERENCES "location" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
CONSTRAINT "file_path_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
CONSTRAINT "file_path_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_file_path" ("cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes") SELECT "cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes" FROM "file_path";
|
||||
DROP TABLE "file_path";
|
||||
ALTER TABLE "new_file_path" RENAME TO "file_path";
|
||||
CREATE UNIQUE INDEX "file_path_pub_id_key" ON "file_path"("pub_id");
|
||||
CREATE INDEX "file_path_location_id_idx" ON "file_path"("location_id");
|
||||
CREATE INDEX "file_path_location_id_materialized_path_idx" ON "file_path"("location_id", "materialized_path");
|
||||
CREATE UNIQUE INDEX "file_path_location_id_materialized_path_name_extension_key" ON "file_path"("location_id", "materialized_path", "name", "extension");
|
||||
CREATE UNIQUE INDEX "file_path_location_id_inode_key" ON "file_path"("location_id", "inode");
|
||||
CREATE TABLE "new_label_on_object" (
|
||||
"date_created" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"object_id" INTEGER NOT NULL,
|
||||
"label_id" INTEGER NOT NULL,
|
||||
"device_pub_id" BLOB,
|
||||
|
||||
PRIMARY KEY ("label_id", "object_id"),
|
||||
CONSTRAINT "label_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "label_on_object_label_id_fkey" FOREIGN KEY ("label_id") REFERENCES "label" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "label_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_label_on_object" ("date_created", "label_id", "object_id") SELECT "date_created", "label_id", "object_id" FROM "label_on_object";
|
||||
DROP TABLE "label_on_object";
|
||||
ALTER TABLE "new_label_on_object" RENAME TO "label_on_object";
|
||||
CREATE TABLE "new_location" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"pub_id" BLOB NOT NULL,
|
||||
"name" TEXT,
|
||||
"path" TEXT,
|
||||
"total_capacity" INTEGER,
|
||||
"available_capacity" INTEGER,
|
||||
"size_in_bytes" BLOB,
|
||||
"is_archived" BOOLEAN,
|
||||
"generate_preview_media" BOOLEAN,
|
||||
"sync_preview_media" BOOLEAN,
|
||||
"hidden" BOOLEAN,
|
||||
"date_created" DATETIME,
|
||||
"scan_state" INTEGER NOT NULL DEFAULT 0,
|
||||
"device_pub_id" BLOB,
|
||||
"instance_id" INTEGER,
|
||||
CONSTRAINT "location_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
CONSTRAINT "location_instance_id_fkey" FOREIGN KEY ("instance_id") REFERENCES "instance" ("id") ON DELETE SET NULL ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_location" ("available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity") SELECT "available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity" FROM "location";
|
||||
DROP TABLE "location";
|
||||
ALTER TABLE "new_location" RENAME TO "location";
|
||||
CREATE UNIQUE INDEX "location_pub_id_key" ON "location"("pub_id");
|
||||
CREATE TABLE "new_object" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"pub_id" BLOB NOT NULL,
|
||||
"kind" INTEGER,
|
||||
"key_id" INTEGER,
|
||||
"hidden" BOOLEAN,
|
||||
"favorite" BOOLEAN,
|
||||
"important" BOOLEAN,
|
||||
"note" TEXT,
|
||||
"date_created" DATETIME,
|
||||
"date_accessed" DATETIME,
|
||||
"device_pub_id" BLOB,
|
||||
CONSTRAINT "object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_object" ("date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id") SELECT "date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id" FROM "object";
|
||||
DROP TABLE "object";
|
||||
ALTER TABLE "new_object" RENAME TO "object";
|
||||
CREATE UNIQUE INDEX "object_pub_id_key" ON "object"("pub_id");
|
||||
CREATE TABLE "new_storage_statistics" (
|
||||
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"pub_id" BLOB NOT NULL,
|
||||
"total_capacity" BIGINT NOT NULL DEFAULT 0,
|
||||
"available_capacity" BIGINT NOT NULL DEFAULT 0,
|
||||
"device_pub_id" BLOB,
|
||||
CONSTRAINT "storage_statistics_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_storage_statistics" ("available_capacity", "id", "pub_id", "total_capacity") SELECT "available_capacity", "id", "pub_id", "total_capacity" FROM "storage_statistics";
|
||||
DROP TABLE "storage_statistics";
|
||||
ALTER TABLE "new_storage_statistics" RENAME TO "storage_statistics";
|
||||
CREATE UNIQUE INDEX "storage_statistics_pub_id_key" ON "storage_statistics"("pub_id");
|
||||
CREATE UNIQUE INDEX "storage_statistics_device_pub_id_key" ON "storage_statistics"("device_pub_id");
|
||||
CREATE TABLE "new_tag_on_object" (
|
||||
"object_id" INTEGER NOT NULL,
|
||||
"tag_id" INTEGER NOT NULL,
|
||||
"date_created" DATETIME,
|
||||
"device_pub_id" BLOB,
|
||||
|
||||
PRIMARY KEY ("tag_id", "object_id"),
|
||||
CONSTRAINT "tag_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "tag_on_object_tag_id_fkey" FOREIGN KEY ("tag_id") REFERENCES "tag" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "tag_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_tag_on_object" ("date_created", "object_id", "tag_id") SELECT "date_created", "object_id", "tag_id" FROM "tag_on_object";
|
||||
DROP TABLE "tag_on_object";
|
||||
ALTER TABLE "new_tag_on_object" RENAME TO "tag_on_object";
|
||||
PRAGMA foreign_keys=ON;
|
||||
PRAGMA defer_foreign_keys=OFF;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "device_pub_id_key" ON "device"("pub_id");
|
||||
@@ -28,9 +28,10 @@ model CRDTOperation {
|
||||
kind String
|
||||
data Bytes
|
||||
|
||||
instance_id Int
|
||||
instance Instance @relation(fields: [instance_id], references: [id])
|
||||
// We just need the actual device_pub_id here, but we don't need as an actual relation
|
||||
device_pub_id Bytes
|
||||
|
||||
@@index([timestamp])
|
||||
@@map("crdt_operation")
|
||||
}
|
||||
|
||||
@@ -46,24 +47,41 @@ model CloudCRDTOperation {
|
||||
kind String
|
||||
data Bytes
|
||||
|
||||
instance_id Int
|
||||
instance Instance @relation(fields: [instance_id], references: [id])
|
||||
// We just need the actual device_pub_id here, but we don't need as an actual relation
|
||||
device_pub_id Bytes
|
||||
|
||||
@@index([timestamp])
|
||||
@@map("cloud_crdt_operation")
|
||||
}
|
||||
|
||||
/// @deprecated: This model has to exist solely for backwards compatibility.
|
||||
/// @local
|
||||
model Node {
|
||||
id Int @id @default(autoincrement())
|
||||
pub_id Bytes @unique
|
||||
name String
|
||||
// Enum: sd_core::node::Platform
|
||||
platform Int
|
||||
date_created DateTime
|
||||
identity Bytes? // TODO: Change to required field in future
|
||||
/// Devices are the owner machines connected to this library
|
||||
/// @shared(id: pub_id, modelId: 12)
|
||||
model Device {
|
||||
id Int @id @default(autoincrement())
|
||||
// uuid v7
|
||||
pub_id Bytes @unique
|
||||
name String? // Not actually NULLABLE, but we have to comply with current sync implementation BS
|
||||
|
||||
@@map("node")
|
||||
// Enum: sd_cloud_schema::device::DeviceOS
|
||||
os Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS
|
||||
// Enum: sd_cloud_schema::device::HardwareModel
|
||||
hardware_model Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS
|
||||
|
||||
// clock timestamp for sync
|
||||
timestamp BigInt?
|
||||
|
||||
date_created DateTime? // Not actually NULLABLE, but we have to comply with current sync implementation BS
|
||||
date_deleted DateTime?
|
||||
|
||||
StorageStatistics StorageStatistics?
|
||||
Location Location[]
|
||||
FilePath FilePath[]
|
||||
Object Object[]
|
||||
ExifData ExifData[]
|
||||
TagOnObject TagOnObject[]
|
||||
LabelOnObject LabelOnObject[]
|
||||
|
||||
@@map("device")
|
||||
}
|
||||
|
||||
// represents a single `.db` file (SQLite DB) that is paired to the current library.
|
||||
@@ -88,11 +106,7 @@ model Instance {
|
||||
|
||||
// clock timestamp for sync
|
||||
timestamp BigInt?
|
||||
|
||||
locations Location[]
|
||||
CRDTOperation CRDTOperation[]
|
||||
CloudCRDTOperation CloudCRDTOperation[]
|
||||
storage_statistics StorageStatistics?
|
||||
Location Location[]
|
||||
|
||||
@@map("instance")
|
||||
}
|
||||
@@ -158,6 +172,9 @@ model Location {
|
||||
|
||||
scan_state Int @default(0) // Enum: sd_core::location::ScanState
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
// this should just be a local-only cache but it's too much effort to broadcast online locations rn (@brendan)
|
||||
instance_id Int?
|
||||
instance Instance? @relation(fields: [instance_id], references: [id], onDelete: SetNull)
|
||||
@@ -208,6 +225,9 @@ model FilePath {
|
||||
date_modified DateTime?
|
||||
date_indexed DateTime?
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
// key Key? @relation(fields: [key_id], references: [id])
|
||||
|
||||
@@unique([location_id, materialized_path, name, extension])
|
||||
@@ -255,6 +275,9 @@ model Object {
|
||||
exif_data ExifData?
|
||||
ffmpeg_data FfmpegData?
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
// key Key? @relation(fields: [key_id], references: [id])
|
||||
|
||||
@@map("object")
|
||||
@@ -323,6 +346,9 @@ model ExifData {
|
||||
object_id Int @unique
|
||||
object Object @relation(fields: [object_id], references: [id], onDelete: Cascade)
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
@@map("exif_data")
|
||||
}
|
||||
|
||||
@@ -509,6 +535,9 @@ model TagOnObject {
|
||||
|
||||
date_created DateTime?
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
@@id([tag_id, object_id])
|
||||
@@map("tag_on_object")
|
||||
}
|
||||
@@ -537,6 +566,9 @@ model LabelOnObject {
|
||||
label_id Int
|
||||
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
|
||||
|
||||
device_id Int?
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
@@id([label_id, object_id])
|
||||
@@map("label_on_object")
|
||||
}
|
||||
@@ -576,8 +608,8 @@ model StorageStatistics {
|
||||
total_capacity BigInt @default(0)
|
||||
available_capacity BigInt @default(0)
|
||||
|
||||
instance_pub_id Bytes? @unique
|
||||
instance Instance? @relation(fields: [instance_pub_id], references: [pub_id], onDelete: Cascade)
|
||||
device_id Int? @unique
|
||||
device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade)
|
||||
|
||||
@@map("storage_statistics")
|
||||
}
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use specta::Type;
|
||||
|
||||
use super::{Ctx, R};
|
||||
|
||||
pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("loginSession", {
|
||||
#[derive(Serialize, Type)]
|
||||
#[specta(inline)]
|
||||
enum Response {
|
||||
Start {
|
||||
user_code: String,
|
||||
verification_url: String,
|
||||
verification_url_complete: String,
|
||||
},
|
||||
Complete,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
R.subscription(|node, _: ()| async move {
|
||||
#[derive(Deserialize, Type)]
|
||||
struct DeviceAuthorizationResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_url: String,
|
||||
verification_uri_complete: String,
|
||||
}
|
||||
|
||||
async_stream::stream! {
|
||||
let device_type = if cfg!(target_arch = "wasm32") {
|
||||
"web".to_string()
|
||||
} else if cfg!(target_os = "ios") || cfg!(target_os = "android") {
|
||||
"mobile".to_string()
|
||||
} else {
|
||||
"desktop".to_string()
|
||||
};
|
||||
|
||||
let auth_response = match match node
|
||||
.http
|
||||
.post(format!(
|
||||
"{}/login/device/code",
|
||||
&node.env.api_url.lock().await
|
||||
))
|
||||
.form(&[("client_id", &node.env.client_id), ("device", &device_type)])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
{
|
||||
Ok(r) => r.json::<DeviceAuthorizationResponse>().await.map_err(|e| e.to_string()),
|
||||
Err(e) => {
|
||||
yield Response::Error(e.to_string());
|
||||
return
|
||||
},
|
||||
} {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
yield Response::Error(e.to_string());
|
||||
return
|
||||
},
|
||||
};
|
||||
|
||||
yield Response::Start {
|
||||
user_code: auth_response.user_code.clone(),
|
||||
verification_url: auth_response.verification_url.clone(),
|
||||
verification_url_complete: auth_response.verification_uri_complete.clone(),
|
||||
};
|
||||
|
||||
yield loop {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
|
||||
let token_resp = match node.http
|
||||
.post(format!("{}/login/oauth/access_token", &node.env.api_url.lock().await))
|
||||
.form(&[
|
||||
("grant_type", sd_cloud_api::auth::DEVICE_CODE_URN),
|
||||
("device_code", &auth_response.device_code),
|
||||
("client_id", &node.env.client_id)
|
||||
])
|
||||
.send()
|
||||
.await {
|
||||
Ok(v) => v,
|
||||
Err(e) => break Response::Error(e.to_string())
|
||||
};
|
||||
|
||||
match token_resp.status() {
|
||||
StatusCode::OK => {
|
||||
let token = match token_resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => break Response::Error(e.to_string())
|
||||
};
|
||||
|
||||
if let Err(e) = node.config
|
||||
.write(|c| c.auth_token = Some(token))
|
||||
.await {
|
||||
break Response::Error(e.to_string());
|
||||
};
|
||||
|
||||
|
||||
break Response::Complete;
|
||||
},
|
||||
StatusCode::BAD_REQUEST => {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuth400 {
|
||||
error: String
|
||||
}
|
||||
|
||||
let resp = match token_resp.json::<OAuth400>().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => break Response::Error(e.to_string())
|
||||
};
|
||||
|
||||
match resp.error.as_str() {
|
||||
"authorization_pending" => continue,
|
||||
e => {
|
||||
break Response::Error(e.to_string())
|
||||
}
|
||||
}
|
||||
},
|
||||
s => {
|
||||
break Response::Error(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.procedure(
|
||||
"logout",
|
||||
R.mutation(|node, _: ()| async move {
|
||||
node.config
|
||||
.write(|c| c.auth_token = None)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|_| {
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to write config".to_string(),
|
||||
)
|
||||
})
|
||||
}),
|
||||
)
|
||||
.procedure("me", {
|
||||
R.query(|node, _: ()| async move {
|
||||
let resp = sd_cloud_api::user::me(node.cloud_api_config().await).await?;
|
||||
|
||||
Ok(resp)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -381,6 +381,7 @@ async fn restore_backup(node: &Arc<Node>, path: impl AsRef<Path>) -> Result<Head
|
||||
db_restored_path,
|
||||
library_config_restored_path,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
node,
|
||||
)
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
use crate::{api::libraries::LibraryConfigWrapped, invalidate_query, library::LibraryName};
|
||||
|
||||
use reqwest::Response;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{utils::library, Ctx, R};
|
||||
|
||||
#[allow(unused)]
|
||||
async fn parse_json_body<T: DeserializeOwned>(response: Response) -> Result<T, rspc::Error> {
|
||||
response.json().await.map_err(|_| {
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"JSON conversion failed".to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.merge("library.", library::mount())
|
||||
.merge("locations.", locations::mount())
|
||||
.procedure("getApiOrigin", {
|
||||
R.query(|node, _: ()| async move { Ok(node.env.api_url.lock().await.to_string()) })
|
||||
})
|
||||
.procedure("setApiOrigin", {
|
||||
R.mutation(|node, origin: String| async move {
|
||||
let mut origin_env = node.env.api_url.lock().await;
|
||||
origin_env.clone_from(&origin);
|
||||
|
||||
node.config
|
||||
.write(|c| {
|
||||
c.auth_token = None;
|
||||
c.sd_api_origin = Some(origin);
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
mod library {
|
||||
use std::str::FromStr;
|
||||
|
||||
use sd_p2p::RemoteIdentity;
|
||||
|
||||
use crate::util::MaybeUndefined;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("get", {
|
||||
R.with2(library())
|
||||
.query(|(node, library), _: ()| async move {
|
||||
Ok(
|
||||
sd_cloud_api::library::get(node.cloud_api_config().await, library.id)
|
||||
.await?,
|
||||
)
|
||||
})
|
||||
})
|
||||
.procedure("list", {
|
||||
R.query(|node, _: ()| async move {
|
||||
Ok(sd_cloud_api::library::list(node.cloud_api_config().await).await?)
|
||||
})
|
||||
})
|
||||
.procedure("create", {
|
||||
R.with2(library())
|
||||
.mutation(|(node, library), _: ()| async move {
|
||||
let node_config = node.config.get().await;
|
||||
let cloud_library = sd_cloud_api::library::create(
|
||||
node.cloud_api_config().await,
|
||||
library.id,
|
||||
&library.config().await.name,
|
||||
library.instance_uuid,
|
||||
library.identity.to_remote_identity(),
|
||||
node_config.id,
|
||||
node_config.identity.to_remote_identity(),
|
||||
&node.p2p.peer_metadata(),
|
||||
)
|
||||
.await?;
|
||||
node.libraries
|
||||
.edit(
|
||||
library.id,
|
||||
None,
|
||||
MaybeUndefined::Undefined,
|
||||
MaybeUndefined::Value(cloud_library.id),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
invalidate_query!(library, "cloud.library.get");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("join", {
|
||||
R.mutation(|node, library_id: Uuid| async move {
|
||||
let Some(cloud_library) =
|
||||
sd_cloud_api::library::get(node.cloud_api_config().await, library_id)
|
||||
.await?
|
||||
else {
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::NotFound,
|
||||
"Library not found".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let library = node
|
||||
.libraries
|
||||
.create_with_uuid(
|
||||
library_id,
|
||||
LibraryName::new(cloud_library.name).map_err(|e| {
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
e.to_string(),
|
||||
)
|
||||
})?,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
&node,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
node.libraries
|
||||
.edit(
|
||||
library.id,
|
||||
None,
|
||||
MaybeUndefined::Undefined,
|
||||
MaybeUndefined::Value(cloud_library.id),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let node_config = node.config.get().await;
|
||||
let instances = sd_cloud_api::library::join(
|
||||
node.cloud_api_config().await,
|
||||
library_id,
|
||||
library.instance_uuid,
|
||||
library.identity.to_remote_identity(),
|
||||
node_config.id,
|
||||
node_config.identity.to_remote_identity(),
|
||||
node.p2p.peer_metadata(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
for instance in instances {
|
||||
crate::cloud::sync::receive::upsert_instance(
|
||||
library.id,
|
||||
&library.db,
|
||||
&library.sync,
|
||||
&node.libraries,
|
||||
&instance.uuid,
|
||||
instance.identity,
|
||||
&instance.node_id,
|
||||
RemoteIdentity::from_str(&instance.node_remote_identity)
|
||||
.expect("malformed remote identity in the DB"),
|
||||
instance.metadata,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
invalidate_query!(library, "cloud.library.get");
|
||||
invalidate_query!(library, "cloud.library.list");
|
||||
|
||||
Ok(LibraryConfigWrapped::from_library(&library).await)
|
||||
})
|
||||
})
|
||||
.procedure("sync", {
|
||||
R.with2(library())
|
||||
.mutation(|(_, library), _: ()| async move {
|
||||
library.do_cloud_sync();
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
mod locations {
|
||||
use super::*;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use specta::Type;
|
||||
#[derive(Type, Serialize, Deserialize)]
|
||||
pub struct CloudLocation {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("list", {
|
||||
R.query(|node, _: ()| async move {
|
||||
sd_cloud_api::locations::list(node.cloud_api_config().await)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
})
|
||||
})
|
||||
.procedure("create", {
|
||||
R.mutation(|node, name: String| async move {
|
||||
sd_cloud_api::locations::create(node.cloud_api_config().await, name)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
})
|
||||
})
|
||||
.procedure("remove", {
|
||||
R.mutation(|node, id: String| async move {
|
||||
sd_cloud_api::locations::create(node.cloud_api_config().await, id)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
397
core/src/api/cloud/devices.rs
Normal file
397
core/src/api/cloud/devices.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
use crate::api::{Ctx, R};
|
||||
|
||||
use sd_core_cloud_services::QuinnConnection;
|
||||
|
||||
use sd_cloud_schema::{
|
||||
auth::AccessToken,
|
||||
devices::{self, DeviceOS, HardwareModel, PubId},
|
||||
opaque_ke::{
|
||||
ClientLogin, ClientLoginFinishParameters, ClientLoginFinishResult, ClientLoginStartResult,
|
||||
ClientRegistration, ClientRegistrationFinishParameters, ClientRegistrationFinishResult,
|
||||
ClientRegistrationStartResult,
|
||||
},
|
||||
Client, NodeId, Service, SpacedriveCipherSuite,
|
||||
};
|
||||
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng};
|
||||
|
||||
use blake3::Hash;
|
||||
use futures::{FutureExt, SinkExt, StreamExt};
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::Deserialize;
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("get", {
|
||||
R.query(|node, pub_id: devices::PubId| async move {
|
||||
use devices::get::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
let Response(device) = super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.get(Request {
|
||||
pub_id,
|
||||
access_token,
|
||||
})
|
||||
.await,
|
||||
"Failed to get device;",
|
||||
)??;
|
||||
|
||||
debug!(?device, "Got device");
|
||||
|
||||
Ok(device)
|
||||
})
|
||||
})
|
||||
.procedure("list", {
|
||||
R.query(|node, _: ()| async move {
|
||||
use devices::list::{Request, Response};
|
||||
|
||||
let ((client, access_token), pub_id) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let Response(mut devices) = super::handle_comm_error(
|
||||
client.devices().list(Request { access_token }).await,
|
||||
"Failed to list devices;",
|
||||
)??;
|
||||
|
||||
// Filter out the local device by matching pub_id
|
||||
devices.retain(|device| device.pub_id != pub_id);
|
||||
|
||||
debug!(?devices, "Listed devices");
|
||||
|
||||
Ok(devices)
|
||||
})
|
||||
})
|
||||
.procedure("get_current_device", {
|
||||
R.query(|node, _: ()| async move {
|
||||
use devices::get::{Request, Response};
|
||||
|
||||
let ((client, access_token), pub_id) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let Response(device) = super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.get(Request {
|
||||
pub_id,
|
||||
access_token,
|
||||
})
|
||||
.await,
|
||||
"Failed to get current device;",
|
||||
)??;
|
||||
Ok(device)
|
||||
})
|
||||
})
|
||||
.procedure("delete", {
|
||||
R.mutation(|node, pub_id: devices::PubId| async move {
|
||||
use devices::delete::Request;
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.delete(Request {
|
||||
pub_id,
|
||||
access_token,
|
||||
})
|
||||
.await,
|
||||
"Failed to delete device;",
|
||||
)??;
|
||||
|
||||
debug!("Deleted device");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("update", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudUpdateDeviceArgs {
|
||||
pub_id: devices::PubId,
|
||||
name: String,
|
||||
}
|
||||
|
||||
R.mutation(
|
||||
|node, CloudUpdateDeviceArgs { pub_id, name }: CloudUpdateDeviceArgs| async move {
|
||||
use devices::update::Request;
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.update(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
name,
|
||||
})
|
||||
.await,
|
||||
"Failed to update device;",
|
||||
)??;
|
||||
|
||||
debug!("Updated device");
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn hello(
|
||||
client: &Client<QuinnConnection<Service>, Service>,
|
||||
access_token: AccessToken,
|
||||
device_pub_id: PubId,
|
||||
hashed_pub_id: Hash,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<SecretKey, rspc::Error> {
|
||||
use devices::hello::{Request, RequestUpdate, Response, State};
|
||||
|
||||
let ClientLoginStartResult { message, state } =
|
||||
ClientLogin::<SpacedriveCipherSuite>::start(rng, hashed_pub_id.as_bytes().as_slice())
|
||||
.map_err(|e| {
|
||||
error!(?e, "OPAQUE error initializing device hello request;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to initialize device login".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (mut hello_continuation, mut res_stream) = super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.hello(Request {
|
||||
access_token,
|
||||
pub_id: device_pub_id,
|
||||
opaque_login_message: Box::new(message),
|
||||
})
|
||||
.await,
|
||||
"Failed to send device hello request;",
|
||||
)?;
|
||||
|
||||
let Some(res) = res_stream.next().await else {
|
||||
let message = "Server did not send a device hello response;";
|
||||
error!("{message}");
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
message.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let credential_response = match super::handle_comm_error(
|
||||
res,
|
||||
"Communication error on device hello response;",
|
||||
)? {
|
||||
Ok(Response(State::LoginResponse(credential_response))) => credential_response,
|
||||
|
||||
Ok(Response(State::End)) => {
|
||||
unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;");
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Device hello response error;");
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let ClientLoginFinishResult {
|
||||
message,
|
||||
export_key,
|
||||
..
|
||||
} = state
|
||||
.finish(
|
||||
hashed_pub_id.as_bytes().as_slice(),
|
||||
*credential_response,
|
||||
ClientLoginFinishParameters::default(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!(?e, "Device hello finish error;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to finish device login".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
hello_continuation
|
||||
.send(RequestUpdate {
|
||||
opaque_login_finish: Box::new(message),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(?e, "Failed to send device hello request continuation;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to finish device login procedure;".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let Some(res) = res_stream.next().await else {
|
||||
let message = "Server did not send a device hello END response;";
|
||||
error!("{message}");
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
message.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
match super::handle_comm_error(res, "Communication error on device hello response;")? {
|
||||
Ok(Response(State::LoginResponse(_))) => {
|
||||
unreachable!("Device hello final response MUST be End here, this is a serious bug and should crash;");
|
||||
}
|
||||
|
||||
Ok(Response(State::End)) => {
|
||||
// Protocol completed successfully
|
||||
Ok(SecretKey::from(export_key))
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Device hello final response error;");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DeviceRegisterData {
|
||||
pub pub_id: PubId,
|
||||
pub name: String,
|
||||
pub os: DeviceOS,
|
||||
pub hardware_model: HardwareModel,
|
||||
pub connection_id: NodeId,
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
client: &Client<QuinnConnection<Service>, Service>,
|
||||
access_token: AccessToken,
|
||||
DeviceRegisterData {
|
||||
pub_id,
|
||||
name,
|
||||
os,
|
||||
hardware_model,
|
||||
connection_id,
|
||||
}: DeviceRegisterData,
|
||||
hashed_pub_id: Hash,
|
||||
rng: &mut CryptoRng,
|
||||
) -> Result<SecretKey, rspc::Error> {
|
||||
use devices::register::{Request, RequestUpdate, Response, State};
|
||||
|
||||
let ClientRegistrationStartResult { message, state } =
|
||||
ClientRegistration::<SpacedriveCipherSuite>::start(
|
||||
rng,
|
||||
hashed_pub_id.as_bytes().as_slice(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!(?e, "OPAQUE error initializing device register request;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to initialize device register".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (mut register_continuation, mut res_stream) = super::handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.register(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
name,
|
||||
os,
|
||||
hardware_model,
|
||||
connection_id,
|
||||
opaque_register_message: Box::new(message),
|
||||
})
|
||||
.await,
|
||||
"Failed to send device register request;",
|
||||
)?;
|
||||
|
||||
let Some(res) = res_stream.next().await else {
|
||||
let message = "Server did not send a device register response;";
|
||||
error!("{message}");
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
message.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let registration_response = match super::handle_comm_error(
|
||||
res,
|
||||
"Communication error on device register response;",
|
||||
)? {
|
||||
Ok(Response(State::RegistrationResponse(res))) => res,
|
||||
|
||||
Ok(Response(State::End)) => {
|
||||
unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;");
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Device hello response error;");
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
let ClientRegistrationFinishResult {
|
||||
message,
|
||||
export_key,
|
||||
..
|
||||
} = state
|
||||
.finish(
|
||||
rng,
|
||||
hashed_pub_id.as_bytes().as_slice(),
|
||||
*registration_response,
|
||||
ClientRegistrationFinishParameters::default(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!(?e, "Device register finish error;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to finish device register".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
register_continuation
|
||||
.send(RequestUpdate {
|
||||
opaque_registration_finish: Box::new(message),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(?e, "Failed to send device register request continuation;");
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
"Failed to finish device register procedure;".into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let Some(res) = res_stream.next().await else {
|
||||
let message = "Server did not send a device register END response;";
|
||||
error!("{message}");
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::InternalServerError,
|
||||
message.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
match super::handle_comm_error(res, "Communication error on device register response;")? {
|
||||
Ok(Response(State::RegistrationResponse(_))) => {
|
||||
unreachable!("Device register final response MUST be End here, this is a serious bug and should crash;");
|
||||
}
|
||||
|
||||
Ok(Response(State::End)) => {
|
||||
// Protocol completed successfully
|
||||
Ok(SecretKey::from(export_key))
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
error!(?e, "Device register final response error;");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
140
core/src/api/cloud/libraries.rs
Normal file
140
core/src/api/cloud/libraries.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use crate::api::{utils::library, Ctx, R};
|
||||
|
||||
use sd_cloud_schema::libraries;
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("get", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudGetLibraryArgs {
|
||||
pub_id: libraries::PubId,
|
||||
with_device: bool,
|
||||
}
|
||||
|
||||
R.query(
|
||||
|node,
|
||||
CloudGetLibraryArgs {
|
||||
pub_id,
|
||||
with_device,
|
||||
}: CloudGetLibraryArgs| async move {
|
||||
use libraries::get::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
let Response(library) = super::handle_comm_error(
|
||||
client
|
||||
.libraries()
|
||||
.get(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
with_device,
|
||||
})
|
||||
.await,
|
||||
"Failed to get library;",
|
||||
)??;
|
||||
|
||||
debug!(?library, "Got library");
|
||||
|
||||
Ok(library)
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure("list", {
|
||||
R.query(|node, with_device: bool| async move {
|
||||
use libraries::list::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
let Response(libraries) = super::handle_comm_error(
|
||||
client
|
||||
.libraries()
|
||||
.list(Request {
|
||||
access_token,
|
||||
with_device,
|
||||
})
|
||||
.await,
|
||||
"Failed to list libraries;",
|
||||
)??;
|
||||
|
||||
debug!(?libraries, "Listed libraries");
|
||||
|
||||
Ok(libraries)
|
||||
})
|
||||
})
|
||||
.procedure("create", {
|
||||
R.with2(library())
|
||||
.mutation(|(node, library), _: ()| async move {
|
||||
let ((client, access_token), name, device_pub_id) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
library.config().map(|config| Ok(config.name.to_string())),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.libraries()
|
||||
.create(libraries::create::Request {
|
||||
name,
|
||||
access_token,
|
||||
pub_id: libraries::PubId(library.id),
|
||||
device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to create library;",
|
||||
)??;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("delete", {
|
||||
R.with2(library())
|
||||
.mutation(|(node, library), _: ()| async move {
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.libraries()
|
||||
.delete(libraries::delete::Request {
|
||||
access_token,
|
||||
pub_id: libraries::PubId(library.id),
|
||||
})
|
||||
.await,
|
||||
"Failed to delete library;",
|
||||
)??;
|
||||
|
||||
debug!("Deleted library");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("update", {
|
||||
R.with2(library())
|
||||
.mutation(|(node, library), name: String| async move {
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.libraries()
|
||||
.update(libraries::update::Request {
|
||||
access_token,
|
||||
pub_id: libraries::PubId(library.id),
|
||||
name,
|
||||
})
|
||||
.await,
|
||||
"Failed to update library;",
|
||||
)??;
|
||||
|
||||
debug!("Updated library");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
112
core/src/api/cloud/locations.rs
Normal file
112
core/src/api/cloud/locations.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::api::{Ctx, R};
|
||||
|
||||
use sd_cloud_schema::{devices, libraries, locations};
|
||||
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("list", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudListLocationsArgs {
|
||||
pub library_pub_id: libraries::PubId,
|
||||
pub with_library: bool,
|
||||
pub with_device: bool,
|
||||
}
|
||||
|
||||
R.query(
|
||||
|node,
|
||||
CloudListLocationsArgs {
|
||||
library_pub_id,
|
||||
with_library,
|
||||
with_device,
|
||||
}: CloudListLocationsArgs| async move {
|
||||
use locations::list::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
let Response(locations) = super::handle_comm_error(
|
||||
client
|
||||
.locations()
|
||||
.list(Request {
|
||||
access_token,
|
||||
library_pub_id,
|
||||
with_library,
|
||||
with_device,
|
||||
})
|
||||
.await,
|
||||
"Failed to list locations;",
|
||||
)??;
|
||||
|
||||
debug!(?locations, "Got locations");
|
||||
|
||||
Ok(locations)
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure("create", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudCreateLocationArgs {
|
||||
pub pub_id: locations::PubId,
|
||||
pub name: String,
|
||||
pub library_pub_id: libraries::PubId,
|
||||
pub device_pub_id: devices::PubId,
|
||||
}
|
||||
|
||||
R.mutation(
|
||||
|node,
|
||||
CloudCreateLocationArgs {
|
||||
pub_id,
|
||||
name,
|
||||
library_pub_id,
|
||||
device_pub_id,
|
||||
}: CloudCreateLocationArgs| async move {
|
||||
use locations::create::Request;
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.locations()
|
||||
.create(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
name,
|
||||
library_pub_id,
|
||||
device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to list locations;",
|
||||
)??;
|
||||
|
||||
debug!("Created cloud location");
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure("delete", {
|
||||
R.mutation(|node, pub_id: locations::PubId| async move {
|
||||
use locations::delete::Request;
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.locations()
|
||||
.delete(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to list locations;",
|
||||
)??;
|
||||
|
||||
debug!("Created cloud location");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}
|
||||
316
core/src/api/cloud/mod.rs
Normal file
316
core/src/api/cloud/mod.rs
Normal file
@@ -0,0 +1,316 @@
|
||||
use crate::{
|
||||
library::LibraryManagerError,
|
||||
node::{config::NodeConfig, HardwareModel},
|
||||
Node,
|
||||
};
|
||||
|
||||
use sd_core_cloud_services::{CloudP2P, KeyManager, QuinnConnection, UserResponse};
|
||||
|
||||
use sd_cloud_schema::{
|
||||
auth,
|
||||
error::{ClientSideError, Error},
|
||||
sync::groups,
|
||||
users, Client, SecretKey as IrohSecretKey, Service,
|
||||
};
|
||||
use sd_crypto::{CryptoRng, SeedableRng};
|
||||
use sd_utils::error::report_error;
|
||||
|
||||
use std::pin::pin;
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use tracing::{debug, error, instrument};
|
||||
|
||||
use super::{Ctx, R};
|
||||
|
||||
mod devices;
|
||||
mod libraries;
|
||||
mod locations;
|
||||
mod sync_groups;
|
||||
|
||||
async fn try_get_cloud_services_client(
|
||||
node: &Node,
|
||||
) -> Result<Client<QuinnConnection<Service>, Service>, sd_core_cloud_services::Error> {
|
||||
node.cloud_services
|
||||
.client()
|
||||
.await
|
||||
.map_err(report_error("Failed to get cloud services client"))
|
||||
}
|
||||
|
||||
pub(crate) fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.merge("libraries.", libraries::mount())
|
||||
.merge("locations.", locations::mount())
|
||||
.merge("devices.", devices::mount())
|
||||
.merge("syncGroups.", sync_groups::mount())
|
||||
.procedure("bootstrap", {
|
||||
R.mutation(
|
||||
|node, (access_token, refresh_token): (auth::AccessToken, auth::RefreshToken)| async move {
|
||||
use sd_cloud_schema::devices;
|
||||
|
||||
// Only allow a single bootstrap request in flight at a time
|
||||
let mut has_bootstrapped_lock = node
|
||||
.cloud_services
|
||||
.has_bootstrapped
|
||||
.try_lock()
|
||||
.map_err(|_| {
|
||||
rspc::Error::new(
|
||||
rspc::ErrorCode::Conflict,
|
||||
String::from("Bootstrap in progress"),
|
||||
)
|
||||
})?;
|
||||
|
||||
if *has_bootstrapped_lock {
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::Conflict,
|
||||
String::from("Already bootstrapped"),
|
||||
));
|
||||
}
|
||||
|
||||
node.cloud_services
|
||||
.token_refresher
|
||||
.init(access_token, refresh_token)
|
||||
.await?;
|
||||
|
||||
let client = try_get_cloud_services_client(&node).await?;
|
||||
let data_directory = node.config.data_directory();
|
||||
|
||||
let mut rng =
|
||||
CryptoRng::from_seed(node.master_rng.lock().await.generate_fixed());
|
||||
|
||||
// create user route is idempotent, so we can safely keep creating the same user over and over
|
||||
handle_comm_error(
|
||||
client
|
||||
.users()
|
||||
.create(users::create::Request {
|
||||
access_token: node
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
})
|
||||
.await,
|
||||
"Failed to create user;",
|
||||
)??;
|
||||
|
||||
let (device_pub_id, name, os) = {
|
||||
let NodeConfig { id, name, os, .. } = node.config.get().await;
|
||||
(devices::PubId(id.into()), name, os)
|
||||
};
|
||||
|
||||
let hashed_pub_id = blake3::hash(device_pub_id.0.as_bytes().as_slice());
|
||||
|
||||
let key_manager = match handle_comm_error(
|
||||
client
|
||||
.devices()
|
||||
.get(devices::get::Request {
|
||||
access_token: node
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
pub_id: device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to get device on cloud bootstrap;",
|
||||
)? {
|
||||
Ok(_) => {
|
||||
// Device registered, we execute a device hello flow
|
||||
let master_key = self::devices::hello(
|
||||
&client,
|
||||
node.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
device_pub_id,
|
||||
hashed_pub_id,
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Device hello successful");
|
||||
|
||||
KeyManager::load(master_key, data_directory).await?
|
||||
}
|
||||
Err(Error::Client(ClientSideError::NotFound(_))) => {
|
||||
// Device not registered, we execute a device register flow
|
||||
let iroh_secret_key = IrohSecretKey::generate_with_rng(&mut rng);
|
||||
let hardware_model = Into::into(
|
||||
HardwareModel::try_get().unwrap_or(HardwareModel::Other),
|
||||
);
|
||||
|
||||
let master_key = self::devices::register(
|
||||
&client,
|
||||
node.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
self::devices::DeviceRegisterData {
|
||||
pub_id: device_pub_id,
|
||||
name,
|
||||
os,
|
||||
hardware_model,
|
||||
connection_id: iroh_secret_key.public(),
|
||||
},
|
||||
hashed_pub_id,
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Device registered successfully");
|
||||
|
||||
KeyManager::new(master_key, iroh_secret_key, data_directory, &mut rng)
|
||||
.await?
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let iroh_secret_key = key_manager.iroh_secret_key().await;
|
||||
|
||||
node.cloud_services.set_key_manager(key_manager).await;
|
||||
|
||||
node.cloud_services
|
||||
.set_cloud_p2p(
|
||||
CloudP2P::new(
|
||||
device_pub_id,
|
||||
&node.cloud_services,
|
||||
rng,
|
||||
iroh_secret_key,
|
||||
node.cloud_services.cloud_p2p_dns_origin_name.clone(),
|
||||
node.cloud_services.cloud_p2p_dns_pkarr_url.clone(),
|
||||
node.cloud_services.cloud_p2p_relay_url.clone(),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
.await;
|
||||
|
||||
let groups::list::Response(groups) = handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.list(groups::list::Request {
|
||||
access_token: node
|
||||
.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.await?,
|
||||
})
|
||||
.await,
|
||||
"Failed to list sync groups on bootstrap",
|
||||
)??;
|
||||
|
||||
groups
|
||||
.into_iter()
|
||||
.map(
|
||||
|groups::GroupBaseData {
|
||||
pub_id,
|
||||
library,
|
||||
// TODO(@fogodev): We can use this latest key hash to check if we
|
||||
// already have the latest key hash for this group locally
|
||||
// issuing a ask for key hash request for other devices if we don't
|
||||
latest_key_hash: _latest_key_hash,
|
||||
..
|
||||
}| {
|
||||
let node = &node;
|
||||
|
||||
async move {
|
||||
match initialize_cloud_sync(pub_id, library, node).await {
|
||||
// If we don't have this library locally, we didn't joined this group yet
|
||||
Ok(()) | Err(LibraryManagerError::LibraryNotFound) => {
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
*has_bootstrapped_lock = true;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure(
|
||||
"listenCloudServicesNotifications",
|
||||
R.subscription(|node, _: ()| async move {
|
||||
stream! {
|
||||
let mut notifications_stream =
|
||||
pin!(node.cloud_services.stream_user_notifications());
|
||||
|
||||
while let Some(notification) = notifications_stream.next().await {
|
||||
yield notification;
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.procedure(
|
||||
"userResponse",
|
||||
R.mutation(|node, response: UserResponse| async move {
|
||||
node.cloud_services.send_user_response(response).await;
|
||||
|
||||
Ok(())
|
||||
}),
|
||||
)
|
||||
.procedure(
|
||||
"hasBootstrapped",
|
||||
R.query(|node, _: ()| async move {
|
||||
// If we can't lock immediately, it means that there is a bootstrap in progress
|
||||
// so we didn't bootstrapped yet
|
||||
Ok(node
|
||||
.cloud_services
|
||||
.has_bootstrapped
|
||||
.try_lock()
|
||||
.map(|lock| *lock)
|
||||
.unwrap_or(false))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn handle_comm_error<T, E: std::error::Error + std::fmt::Debug + Send + Sync + 'static>(
|
||||
res: Result<T, E>,
|
||||
message: &'static str,
|
||||
) -> Result<T, rspc::Error> {
|
||||
res.map_err(|e| {
|
||||
error!(?e, "Communication with cloud services error: {message}");
|
||||
rspc::Error::with_cause(rspc::ErrorCode::InternalServerError, message.into(), e)
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(%group_pub_id, %library_pub_id), err)]
|
||||
async fn initialize_cloud_sync(
|
||||
group_pub_id: groups::PubId,
|
||||
sd_cloud_schema::libraries::Library {
|
||||
pub_id: sd_cloud_schema::libraries::PubId(library_pub_id),
|
||||
..
|
||||
}: sd_cloud_schema::libraries::Library,
|
||||
node: &Node,
|
||||
) -> Result<(), LibraryManagerError> {
|
||||
let library = node
|
||||
.libraries
|
||||
.get_library(&library_pub_id)
|
||||
.await
|
||||
.ok_or(LibraryManagerError::LibraryNotFound)?;
|
||||
|
||||
library.init_cloud_sync(node, group_pub_id).await
|
||||
}
|
||||
|
||||
async fn get_client_and_access_token(
|
||||
node: &Node,
|
||||
) -> Result<(Client<QuinnConnection<Service>, Service>, auth::AccessToken), rspc::Error> {
|
||||
(
|
||||
try_get_cloud_services_client(node),
|
||||
node.cloud_services
|
||||
.token_refresher
|
||||
.get_access_token()
|
||||
.map(|res| res.map_err(Into::into)),
|
||||
)
|
||||
.try_join()
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
404
core/src/api/cloud/sync_groups.rs
Normal file
404
core/src/api/cloud/sync_groups.rs
Normal file
@@ -0,0 +1,404 @@
|
||||
use crate::{
|
||||
api::{utils::library, Ctx, R},
|
||||
library::LibraryName,
|
||||
Node,
|
||||
};
|
||||
|
||||
use sd_core_cloud_services::JoinedLibraryCreateArgs;
|
||||
|
||||
use sd_cloud_schema::{
|
||||
cloud_p2p, devices, libraries,
|
||||
sync::{groups, KeyHash},
|
||||
};
|
||||
use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng, SeedableRng};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures_concurrency::future::TryJoin;
|
||||
use rspc::alpha::AlphaRouter;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{spawn, sync::oneshot};
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub fn mount() -> AlphaRouter<Ctx> {
|
||||
R.router()
|
||||
.procedure("create", {
|
||||
R.with2(library())
|
||||
.mutation(|(node, library), _: ()| async move {
|
||||
use groups::create::{Request, Response};
|
||||
|
||||
let ((client, access_token), device_pub_id, mut rng, key_manager) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
node.master_rng
|
||||
.lock()
|
||||
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
|
||||
node.cloud_services
|
||||
.key_manager()
|
||||
.map(|res| res.map_err(Into::into)),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let new_key = SecretKey::generate(&mut rng);
|
||||
let key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string());
|
||||
|
||||
let Response(group_pub_id) = super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.create(Request {
|
||||
access_token: access_token.clone(),
|
||||
key_hash: key_hash.clone(),
|
||||
library_pub_id: libraries::PubId(library.id),
|
||||
device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to create sync group;",
|
||||
)??;
|
||||
|
||||
if let Err(e) = key_manager
|
||||
.add_key_with_hash(group_pub_id, new_key, key_hash.clone(), &mut rng)
|
||||
.await
|
||||
{
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.delete(groups::delete::Request {
|
||||
access_token,
|
||||
pub_id: group_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to delete sync group after we failed to store secret key in key manager;",
|
||||
)??;
|
||||
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
library.init_cloud_sync(&node, group_pub_id).await?;
|
||||
|
||||
debug!(%group_pub_id, ?key_hash, "Created sync group");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("delete", {
|
||||
R.mutation(|node, pub_id: groups::PubId| async move {
|
||||
use groups::delete::Request;
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.delete(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to delete sync group;",
|
||||
)??;
|
||||
|
||||
debug!(%pub_id, "Deleted sync group");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("get", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudGetSyncGroupArgs {
|
||||
pub pub_id: groups::PubId,
|
||||
pub kind: groups::get::RequestKind,
|
||||
}
|
||||
|
||||
// This is a compatibility layer because quic-rpc uses bincode for serialization
|
||||
// and bincode doesn't support serde's tagged enums, and we need them for serializing
|
||||
// to frontend
|
||||
#[derive(Debug, Serialize, specta::Type)]
|
||||
#[serde(tag = "kind", content = "data")]
|
||||
pub enum CloudSyncGroupGetResponseKind {
|
||||
WithDevices(groups::GroupWithDevices),
|
||||
FullData(groups::Group),
|
||||
}
|
||||
|
||||
impl From<groups::get::ResponseKind> for CloudSyncGroupGetResponseKind {
|
||||
fn from(kind: groups::get::ResponseKind) -> Self {
|
||||
match kind {
|
||||
groups::get::ResponseKind::WithDevices(data) => {
|
||||
CloudSyncGroupGetResponseKind::WithDevices(data)
|
||||
}
|
||||
|
||||
groups::get::ResponseKind::FullData(data) => {
|
||||
CloudSyncGroupGetResponseKind::FullData(data)
|
||||
}
|
||||
groups::get::ResponseKind::DevicesConnectionIds(_) => {
|
||||
unreachable!(
|
||||
"DevicesConnectionIds response is not expected, as we requested it"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
R.query(
|
||||
|node, CloudGetSyncGroupArgs { pub_id, kind }: CloudGetSyncGroupArgs| async move {
|
||||
use groups::get::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
if matches!(kind, groups::get::RequestKind::DevicesConnectionIds) {
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::PreconditionFailed,
|
||||
"This request isn't allowed here".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let Response(response_kind) = super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.get(Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
kind,
|
||||
})
|
||||
.await,
|
||||
"Failed to get sync group;",
|
||||
)??;
|
||||
|
||||
debug!(?response_kind, "Got sync group");
|
||||
|
||||
Ok(CloudSyncGroupGetResponseKind::from(response_kind))
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure("leave", {
|
||||
R.query(|node, pub_id: groups::PubId| async move {
|
||||
let ((client, access_token), current_device_pub_id, mut rng, key_manager) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
node.master_rng
|
||||
.lock()
|
||||
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
|
||||
node.cloud_services
|
||||
.key_manager()
|
||||
.map(|res| res.map_err(Into::into)),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.leave(groups::leave::Request {
|
||||
access_token,
|
||||
pub_id,
|
||||
current_device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to leave sync group;",
|
||||
)??;
|
||||
|
||||
key_manager.remove_group(pub_id, &mut rng).await?;
|
||||
|
||||
debug!(%pub_id, "Left sync group");
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.procedure("list", {
|
||||
R.query(|node, _: ()| async move {
|
||||
use groups::list::{Request, Response};
|
||||
|
||||
let (client, access_token) = super::get_client_and_access_token(&node).await?;
|
||||
|
||||
let Response(groups) = super::handle_comm_error(
|
||||
client.sync().groups().list(Request { access_token }).await,
|
||||
"Failed to list groups;",
|
||||
)??;
|
||||
|
||||
debug!(?groups, "Listed sync groups");
|
||||
|
||||
Ok(groups)
|
||||
})
|
||||
})
|
||||
.procedure("remove_device", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct CloudSyncGroupsRemoveDeviceArgs {
|
||||
group_pub_id: groups::PubId,
|
||||
to_remove_device_pub_id: devices::PubId,
|
||||
}
|
||||
R.query(
|
||||
|node,
|
||||
CloudSyncGroupsRemoveDeviceArgs {
|
||||
group_pub_id,
|
||||
to_remove_device_pub_id,
|
||||
}: CloudSyncGroupsRemoveDeviceArgs| async move {
|
||||
use groups::remove_device::Request;
|
||||
|
||||
let ((client, access_token), current_device_pub_id, mut rng, key_manager) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
node.master_rng
|
||||
.lock()
|
||||
.map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))),
|
||||
node.cloud_services
|
||||
.key_manager()
|
||||
.map(|res| res.map_err(Into::into)),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let new_key = SecretKey::generate(&mut rng);
|
||||
let new_key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string());
|
||||
|
||||
key_manager
|
||||
.add_key_with_hash(group_pub_id, new_key, new_key_hash.clone(), &mut rng)
|
||||
.await?;
|
||||
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.remove_device(Request {
|
||||
access_token,
|
||||
group_pub_id,
|
||||
new_key_hash,
|
||||
current_device_pub_id,
|
||||
to_remove_device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to remove device from sync group;",
|
||||
)??;
|
||||
|
||||
debug!(%to_remove_device_pub_id, %group_pub_id, "Removed device");
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
})
|
||||
.procedure("request_join", {
|
||||
#[derive(Deserialize, specta::Type)]
|
||||
struct SyncGroupsRequestJoinArgs {
|
||||
sync_group: groups::GroupWithDevices,
|
||||
asking_device: devices::Device,
|
||||
}
|
||||
|
||||
R.mutation(
|
||||
|node,
|
||||
SyncGroupsRequestJoinArgs {
|
||||
sync_group,
|
||||
asking_device,
|
||||
}: SyncGroupsRequestJoinArgs| async move {
|
||||
let ((client, access_token), current_device_pub_id, cloud_p2p) = (
|
||||
super::get_client_and_access_token(&node),
|
||||
node.config.get().map(|config| Ok(config.id.into())),
|
||||
node.cloud_services
|
||||
.cloud_p2p()
|
||||
.map(|res| res.map_err(Into::into)),
|
||||
)
|
||||
.try_join()
|
||||
.await?;
|
||||
|
||||
let group_pub_id = sync_group.pub_id;
|
||||
|
||||
debug!("My pub id: {:?}", current_device_pub_id);
|
||||
debug!("Asking device pub id: {:?}", asking_device.pub_id);
|
||||
if asking_device.pub_id != current_device_pub_id {
|
||||
return Err(rspc::Error::new(
|
||||
rspc::ErrorCode::BadRequest,
|
||||
String::from("Asking device must be the current device"),
|
||||
));
|
||||
}
|
||||
|
||||
let groups::request_join::Response(existing_devices) =
|
||||
super::handle_comm_error(
|
||||
client
|
||||
.sync()
|
||||
.groups()
|
||||
.request_join(groups::request_join::Request {
|
||||
access_token,
|
||||
group_pub_id,
|
||||
current_device_pub_id,
|
||||
})
|
||||
.await,
|
||||
"Failed to update library;",
|
||||
)??;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
cloud_p2p
|
||||
.request_join_sync_group(
|
||||
existing_devices,
|
||||
cloud_p2p::authorize_new_device_in_sync_group::Request {
|
||||
sync_group,
|
||||
asking_device,
|
||||
},
|
||||
tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
JoinedSyncGroupReceiver {
|
||||
node,
|
||||
group_pub_id,
|
||||
rx,
|
||||
}
|
||||
.dispatch();
|
||||
|
||||
debug!(%group_pub_id, "Requested to join sync group");
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
struct JoinedSyncGroupReceiver {
|
||||
node: Arc<Node>,
|
||||
group_pub_id: groups::PubId,
|
||||
rx: oneshot::Receiver<JoinedLibraryCreateArgs>,
|
||||
}
|
||||
|
||||
impl JoinedSyncGroupReceiver {
|
||||
fn dispatch(self) {
|
||||
spawn(async move {
|
||||
let Self {
|
||||
node,
|
||||
group_pub_id,
|
||||
rx,
|
||||
} = self;
|
||||
|
||||
if let Ok(JoinedLibraryCreateArgs {
|
||||
pub_id: libraries::PubId(pub_id),
|
||||
name,
|
||||
description,
|
||||
}) = rx.await
|
||||
{
|
||||
let Ok(name) =
|
||||
LibraryName::new(name).map_err(|e| error!(?e, "Invalid library name"))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(library) = node
|
||||
.libraries
|
||||
.create_with_uuid(pub_id, name, description, true, None, &node)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(?e, "Failed to create library from sync group join response")
|
||||
})
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(e) = library.init_cloud_sync(&node, group_pub_id).await {
|
||||
error!(?e, "Failed to initialize cloud sync for library");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user