diff --git a/Cargo.lock b/Cargo.lock index a453654f5..5395ee416 100644 Binary files a/Cargo.lock and b/Cargo.lock differ diff --git a/Cargo.toml b/Cargo.toml index 2a215dfbf..9ea80477d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,9 @@ repository = "https://github.com/spacedriveapp/spacedrive" rust-version = "1.81" [workspace.dependencies] +# First party dependencies +sd-cloud-schema = { git = "https://github.com/spacedriveapp/cloud-services-schema", rev = "bbc69c5cb2" } + # Third party dependencies used by one or more of our crates async-channel = "2.3" async-stream = "0.3.6" @@ -26,23 +29,25 @@ async-trait = "0.1.83" axum = "0.7.7" axum-extra = "0.9.4" base64 = "0.22.1" -blake3 = "1.5" +blake3 = "1.5.4" +bytes = "1.7.1" # Update blocked by hyper chrono = "0.4.38" ed25519-dalek = "2.1" +flume = "0.11.0" futures = "0.3.31" futures-concurrency = "7.6" globset = "0.4.15" http = "1.1" hyper = "1.5" -image = "0.24.9" # Update blocked due to https://github.com/image-rs/image/issues/2230 +image = "0.25.4" itertools = "0.13.0" lending-stream = "1.0" -libc = "0.2" +libc = "0.2.159" mimalloc = "0.1.43" normpath = "1.3" pin-project-lite = "0.2.14" rand = "0.9.0-alpha.2" -regex = "1" +regex = "1.11" reqwest = { version = "0.12.8", default-features = false } rmp = "0.8.14" rmp-serde = "1.3" @@ -62,7 +67,8 @@ tracing-subscriber = "0.3.18" tracing-test = "0.2.5" uhlc = "0.8.0" # Must follow version used by specta uuid = "1.10" # Must follow version used by specta -webp = "0.2.6" # Update blocked by image +webp = "0.3.0" +zeroize = "1.8" [workspace.dependencies.rspc] git = "https://github.com/spacedriveapp/rspc.git" diff --git a/apps/desktop/package.json b/apps/desktop/package.json index 999a28bae..31b5cdee0 100644 --- a/apps/desktop/package.json +++ b/apps/desktop/package.json @@ -20,26 +20,28 @@ "@sd/ui": "workspace:*", "@t3-oss/env-core": "^0.7.1", "@tanstack/react-query": "^5.59", - "@tauri-apps/api": "=2.0.2", - "@tauri-apps/plugin-dialog": "2.0.0", + "@tauri-apps/api": "=2.0.3", + "@tauri-apps/plugin-dialog": "2.0.1", + "@tauri-apps/plugin-http": "2.0.1", "@tauri-apps/plugin-os": "2.0.0", - "@tauri-apps/plugin-shell": "2.0.0", + "@tauri-apps/plugin-shell": "2.0.1", "consistent-hash": "^1.2.2", "immer": "^10.0.3", "react": "^18.2.0", "react-dom": "^18.2.0", "react-router-dom": "=6.20.1", - "sonner": "^1.0.3" + "sonner": "^1.0.3", + "supertokens-web-js": "^0.13.0" }, "devDependencies": { "@sd/config": "workspace:*", "@sentry/vite-plugin": "^2.16.0", - "@tauri-apps/cli": "2.0.1", + "@tauri-apps/cli": "2.0.4", "@types/react": "^18.2.67", "@types/react-dom": "^18.2.22", "sass": "^1.72.0", "typescript": "^5.6.2", - "vite": "^5.2.0", - "vite-tsconfig-paths": "^4.3.2" + "vite": "^5.4.9", + "vite-tsconfig-paths": "^5.0.1" } } diff --git a/apps/desktop/src-tauri/Cargo.toml b/apps/desktop/src-tauri/Cargo.toml index 2d0572fb0..6963addb8 100644 --- a/apps/desktop/src-tauri/Cargo.toml +++ b/apps/desktop/src-tauri/Cargo.toml @@ -35,12 +35,15 @@ uuid = { workspace = true, features = ["serde"] } # Specific Desktop dependencies # WARNING: Do NOT enable default features, as that vendors dbus (see below) -opener = { version = "0.7.1", features = ["reveal"], default-features = false } -specta-typescript = "=0.0.7" -tauri-plugin-dialog = "=2.0.2" -tauri-plugin-os = "=2.0.1" -tauri-plugin-shell = "=2.0.2" -tauri-plugin-updater = "=2.0.2" +opener = { version = "0.7.1", features = ["reveal"], default-features = false } +specta-typescript = "=0.0.7" +tauri-plugin-clipboard-manager = "=2.0.1" +tauri-plugin-deep-link = "=2.0.1" +tauri-plugin-dialog = "=2.0.3" +tauri-plugin-http = "=2.0.3" +tauri-plugin-os = "=2.0.1" +tauri-plugin-shell = "=2.0.2" +tauri-plugin-updater = "=2.0.2" # memory allocator mimalloc = { workspace = true } diff --git a/apps/desktop/src-tauri/capabilities/default.json b/apps/desktop/src-tauri/capabilities/default.json index cc710d277..5b68a580b 100644 --- a/apps/desktop/src-tauri/capabilities/default.json +++ b/apps/desktop/src-tauri/capabilities/default.json @@ -17,6 +17,7 @@ "dialog:allow-open", "dialog:allow-save", "dialog:allow-confirm", + "deep-link:default", "os:allow-os-type", "core:window:allow-close", "core:window:allow-create", @@ -24,6 +25,32 @@ "core:window:allow-minimize", "core:window:allow-toggle-maximize", "core:window:allow-start-dragging", - "core:webview:allow-internal-toggle-devtools" + "core:webview:allow-internal-toggle-devtools", + { + "identifier": "http:default", + "allow": [ + { + "url": "http://ipc.localhost" + }, + { + "url": "http://asset.localhost" + }, + { + "url": "http://localhost:8001" + }, + { + "url": "http://tauri.localhost" + }, + { + "url": "http://localhost:9420" + }, + { + "url": "https://auth.spacedrive.com" + }, + { + "url": "https://plausible.io" + } + ] + } ] } diff --git a/apps/desktop/src-tauri/src/main.rs b/apps/desktop/src-tauri/src/main.rs index a1701893e..ffa95903b 100644 --- a/apps/desktop/src-tauri/src/main.rs +++ b/apps/desktop/src-tauri/src/main.rs @@ -14,13 +14,13 @@ use sd_core::{Node, NodeError}; use sd_fda::DiskAccess; use serde::{Deserialize, Serialize}; use specta_typescript::Typescript; -use tauri::Emitter; use tauri::{async_runtime::block_on, webview::PlatformWebview, AppHandle, Manager, WindowEvent}; +use tauri::{Emitter, Listener}; use tauri_plugins::{sd_error_plugin, sd_server_plugin}; use tauri_specta::{collect_events, Builder}; use tokio::task::block_in_place; use tokio::time::sleep; -use tracing::error; +use tracing::{debug, error}; mod file; mod menu; @@ -179,7 +179,11 @@ pub enum DragAndDropEvent { Cancelled, } -const CLIENT_ID: &str = "2abb241e-40b8-4517-a3e3-5594375c8fbb"; +#[derive(Debug, Clone, Serialize, Deserialize, specta::Type, tauri_specta::Event)] +#[serde(rename_all = "camelCase")] +pub struct DeepLinkEvent { + data: String, +} #[tokio::main] async fn main() -> tauri::Result<()> { @@ -221,9 +225,20 @@ async fn main() -> tauri::Result<()> { tauri::Builder::default() .invoke_handler(builder.invoke_handler()) + .plugin(tauri_plugin_deep_link::init()) .setup(move |app| { // We need a the app handle to determine the data directory now. // This means all the setup code has to be within `setup`, however it doesn't support async so we `block_on`. + let handle = app.handle().clone(); + app.listen("deep-link://new-url", move |event| { + let deep_link_event = DeepLinkEvent { + data: event.payload().to_string(), + }; + println!("Deep link event={:?}", deep_link_event); + + handle.emit("deeplink", deep_link_event).unwrap(); + }); + block_in_place(|| { block_on(async move { builder.mount_events(app); @@ -239,10 +254,7 @@ async fn main() -> tauri::Result<()> { // The `_guard` must be assigned to variable for flushing remaining logs on main exit through Drop let (_guard, result) = match Node::init_logger(&data_dir) { - Ok(guard) => ( - Some(guard), - Node::new(data_dir, sd_core::Env::new(CLIENT_ID)).await, - ), + Ok(guard) => (Some(guard), Node::new(data_dir).await), Err(err) => (None, Err(NodeError::Logger(err))), }; @@ -256,7 +268,7 @@ async fn main() -> tauri::Result<()> { } }; - let should_clear_localstorage = node.libraries.get_all().await.is_empty(); + let should_clear_local_storage = node.libraries.get_all().await.is_empty(); handle.plugin(rspc::integrations::tauri::plugin(router, { let node = node.clone(); @@ -266,8 +278,8 @@ async fn main() -> tauri::Result<()> { handle.manage(node.clone()); handle.windows().iter().for_each(|(_, window)| { - if should_clear_localstorage { - println!("cleaning localStorage"); + if should_clear_local_storage { + debug!("cleaning localStorage"); for webview in window.webviews() { webview.eval("localStorage.clear();").ok(); } @@ -344,6 +356,7 @@ async fn main() -> tauri::Result<()> { .plugin(tauri_plugin_dialog::init()) .plugin(tauri_plugin_os::init()) .plugin(tauri_plugin_shell::init()) + .plugin(tauri_plugin_http::init()) // TODO: Bring back Tauri Plugin Window State - it was buggy so we removed it. .plugin(tauri_plugin_updater::Builder::new().build()) .plugin(updater::plugin()) diff --git a/apps/desktop/src-tauri/tauri.conf.json b/apps/desktop/src-tauri/tauri.conf.json index 85c60bc0c..1bf5fa3b9 100644 --- a/apps/desktop/src-tauri/tauri.conf.json +++ b/apps/desktop/src-tauri/tauri.conf.json @@ -1,5 +1,5 @@ { - "$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.2/core/tauri-config-schema/schema.json", + "$schema": "https://raw.githubusercontent.com/tauri-apps/tauri/tauri-v2.0.0-rc.8/crates/tauri-cli/tauri.config.schema.json", "productName": "Spacedrive", "identifier": "com.spacedrive.desktop", "build": { @@ -36,7 +36,12 @@ } ], "security": { - "csp": "default-src webkit-pdfjs-viewer: asset: https://asset.localhost blob: data: filesystem: ws: wss: http: https: tauri: 'unsafe-eval' 'unsafe-inline' 'self' img-src: 'self'" + "csp": { + "default-src": "'self' webkit-pdfjs-viewer: asset: http://asset.localhost blob: data: filesystem: http: https: tauri:", + "connect-src": "'self' ipc: http://ipc.localhost ws: wss: http: https: tauri:", + "img-src": "'self' asset: http://asset.localhost blob: data: filesystem: http: https: tauri:", + "style-src": "'self' 'unsafe-inline' http: https: tauri:" + } } }, "bundle": { @@ -100,6 +105,12 @@ "endpoints": [ "https://spacedrive.com/api/releases/tauri/{{version}}/{{target}}/{{arch}}" ] + }, + "deep-link": { + "mobile": [], + "desktop": { + "schemes": ["spacedrive"] + } } } } diff --git a/apps/desktop/src/App.tsx b/apps/desktop/src/App.tsx index 141458057..097240280 100644 --- a/apps/desktop/src/App.tsx +++ b/apps/desktop/src/App.tsx @@ -3,9 +3,10 @@ import { QueryClientProvider } from '@tanstack/react-query'; import { listen } from '@tauri-apps/api/event'; import { PropsWithChildren, startTransition, useEffect, useMemo, useRef, useState } from 'react'; import { createPortal } from 'react-dom'; -import { RspcProvider } from '@sd/client'; +import { RspcProvider, useBridgeMutation } from '@sd/client'; import { createRoutes, + DeeplinkEvent, ErrorPage, KeybindEvent, PlatformProvider, @@ -17,14 +18,11 @@ import { RouteTitleContext } from '@sd/interface/hooks/useRouteTitle'; import '@sd/ui/style/style.scss'; -import { useLocale } from '@sd/interface/hooks'; - -import { commands } from './commands'; -import { platform } from './platform'; -import { queryClient } from './query'; -import { createMemoryRouterWithHistory } from './router'; -import { createUpdater } from './updater'; - +import SuperTokens from 'supertokens-web-js'; +import EmailPassword from 'supertokens-web-js/recipe/emailpassword'; +import Passwordless from 'supertokens-web-js/recipe/passwordless'; +import Session from 'supertokens-web-js/recipe/session'; +import ThirdParty from 'supertokens-web-js/recipe/thirdparty'; // TODO: Bring this back once upstream is fixed up. // const client = hooks.createClient({ // links: [ @@ -34,6 +32,32 @@ import { createUpdater } from './updater'; // tauriLink() // ] // }); +import getCookieHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/cookieHandler'; +import getWindowHandler from '@sd/interface/app/$libraryId/settings/client/account/handlers/windowHandler'; +import { useLocale } from '@sd/interface/hooks'; +import { AUTH_SERVER_URL, getTokens } from '@sd/interface/util'; + +import { commands } from './commands'; +import { platform } from './platform'; +import { queryClient } from './query'; +import { createMemoryRouterWithHistory } from './router'; +import { createUpdater } from './updater'; + +SuperTokens.init({ + appInfo: { + apiDomain: AUTH_SERVER_URL, + apiBasePath: '/api/auth', + appName: 'Spacedrive Auth Service' + }, + cookieHandler: getCookieHandler, + windowHandler: getWindowHandler, + recipeList: [ + Session.init({ tokenTransferMethod: 'header' }), + EmailPassword.init(), + ThirdParty.init(), + Passwordless.init() + ] +}); const startupError = (window as any).__SD_ERROR__ as string | undefined; @@ -41,15 +65,31 @@ export default function App() { useEffect(() => { // This tells Tauri to show the current window because it's finished loading commands.appReady(); + // .then(() => { + // if (import.meta.env.PROD) window.fetch = fetch; + // }); }, []); useEffect(() => { const keybindListener = listen('keybind', (input) => { document.dispatchEvent(new KeybindEvent(input.payload as string)); }); + const deeplinkListener = listen('deeplink', async (data) => { + const payload = (data.payload as any).data as string; + if (!payload) return; + const json = JSON.parse(payload)[0]; + if (!json) return; + //json output: "spacedrive://-/URL" + if (typeof json !== 'string') return; + if (!json.startsWith('spacedrive://-')) return; + const url = (json as string).split('://-/')[1]; + if (!url) return; + document.dispatchEvent(new DeeplinkEvent(url)); + }); return () => { keybindListener.then((unlisten) => unlisten()); + deeplinkListener.then((unlisten) => unlisten()); }; }, []); @@ -79,6 +119,15 @@ type RedirectPath = { pathname: string; search: string | undefined }; function AppInner() { const [tabs, setTabs] = useState(() => [createTab()]); const [selectedTabIndex, setSelectedTabIndex] = useState(0); + const tokens = getTokens(); + const cloudBootstrap = useBridgeMutation('cloud.bootstrap'); + + useEffect(() => { + // If the access token and/or refresh token are missing, we need to skip the cloud bootstrap + if (tokens.accessToken.length === 0 || tokens.refreshToken.length === 0) return; + cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); const selectedTab = tabs[selectedTabIndex]!; diff --git a/apps/desktop/tsconfig.json b/apps/desktop/tsconfig.json index d3855c6fb..11d32a210 100644 --- a/apps/desktop/tsconfig.json +++ b/apps/desktop/tsconfig.json @@ -5,7 +5,8 @@ "declarationDir": "dist", "paths": { "~/*": ["./src/*"] - } + }, + "moduleResolution": "bundler" }, "include": ["src"], "references": [ diff --git a/apps/landing/src/components/mdx/Pre.tsx b/apps/landing/src/components/mdx/Pre.tsx index 9400737dc..b90328a70 100644 --- a/apps/landing/src/components/mdx/Pre.tsx +++ b/apps/landing/src/components/mdx/Pre.tsx @@ -19,18 +19,6 @@ const Pre: FC<{ children: React.ReactNode }> = ({ children }) => { return (
- {/* */} diff --git a/apps/mobile/src/components/modal/sync/JoinRequestModal.tsx b/apps/mobile/src/components/modal/sync/JoinRequestModal.tsx new file mode 100644 index 000000000..c15f6074a --- /dev/null +++ b/apps/mobile/src/components/modal/sync/JoinRequestModal.tsx @@ -0,0 +1,66 @@ +import { ArrowRight } from 'phosphor-react-native'; +import React, { forwardRef } from 'react'; +import { Text, View } from 'react-native'; +import { HardwareModel } from '@sd/client'; +import { Icon } from '~/components/icons/Icon'; +import { Modal, ModalRef } from '~/components/layout/Modal'; +import { hardwareModelToIcon } from '~/components/overview/Devices'; +import { Button } from '~/components/primitive/Button'; +import useForwardedRef from '~/hooks/useForwardedRef'; +import { twStyle } from '~/lib/tailwind'; + +interface Props { + device_name: string; + device_model: HardwareModel; + library_name: string; +} + +const JoinRequestModal = forwardRef((props, ref) => { + const modalRef = useForwardedRef(ref); + return ( + + + + A device is requesting to join one of your libraries. Please review the device + and the library it is requesting to join below. + + + + + + {props.device_name} + + + + {/* library */} + + + + {props.library_name} + + + + + + + + + + ); +}); + +export default JoinRequestModal; diff --git a/apps/mobile/src/components/overview/Devices.tsx b/apps/mobile/src/components/overview/Devices.tsx index 1f65407b9..fdd69ddb3 100644 --- a/apps/mobile/src/components/overview/Devices.tsx +++ b/apps/mobile/src/components/overview/Devices.tsx @@ -5,8 +5,9 @@ import React, { useEffect, useState } from 'react'; import { Platform, Text, View } from 'react-native'; import DeviceInfo from 'react-native-device-info'; import { ScrollView } from 'react-native-gesture-handler'; -import { HardwareModel, NodeState, StatisticsResponse } from '@sd/client'; +import { HardwareModel, NodeState, StatisticsResponse, useBridgeQuery } from '@sd/client'; import { tw, twStyle } from '~/lib/tailwind'; +import { getTokens } from '~/utils'; import Fade from '../layout/Fade'; import { Button } from '../primitive/Button'; @@ -44,6 +45,23 @@ const Devices = ({ node, stats }: Props) => { Omit >({ freeSpace: 0, totalSpace: 0 }); const [deviceName, setDeviceName] = useState(''); + const [accessToken, setAccessToken] = useState(''); + useEffect(() => { + (async () => { + const at = await getTokens(); + setAccessToken(at.accessToken); + })(); + }, []); + + const devices = useBridgeQuery(['cloud.devices.list']); + + // Refetch devices every 10 seconds + useEffect(() => { + const interval = setInterval(async () => { + await devices.refetch(); + }, 10000); + return () => clearInterval(interval); + }, []); useEffect(() => { const getFSInfo = async () => { @@ -74,7 +92,7 @@ const Devices = ({ node, stats }: Props) => { }, [node]); return ( - + { connectionType={null} /> )} + {devices.data?.map((device) => ( + + ))}
}} /> +
}} + /> +
}} + />
}} /> -
}} - /> -
}} - /> {/* SectionType[] = (debugState) => [ +const sections: ( + debugState: DebugState, + userInfo: ReturnType['userInfo'] +) => SectionType[] = (debugState, userInfo) => [ { title: 'Client', data: [ @@ -44,6 +48,21 @@ const sections: (debugState: DebugState) => SectionType[] = (debugState) => [ title: 'General', rounded: 'top' }, + ...(userInfo + ? ([ + { + icon: UserCircle, + navigateTo: 'AccountProfile', + title: 'Account' + } + ] as const) + : ([ + { + icon: UserCircle, + navigateTo: 'AccountLogin', + title: 'Account' + } + ] as const)), { icon: Books, navigateTo: 'LibrarySettings', @@ -158,12 +177,16 @@ function renderSectionHeader({ section }: { section: { title: string } }) { export default function SettingsScreen({ navigation }: SettingsStackScreenProps<'Settings'>) { const debugState = useDebugState(); const syncEnabled = useLibraryQuery(['sync.enabled']); + const userInfo = useUserStore().userInfo; + + // Enables the drawer from react-navigation useEnableDrawer(); + return ( ( ; +}; + +const SocialLogins: SocialLogin[] = [ + { name: 'Github', icon: GithubLogo }, + { name: 'Google', icon: GoogleLogo }, + { name: 'Apple', icon: AppleLogo } +]; + +const AccountLogin = () => { + const [activeTab, setActiveTab] = useState<'Login' | 'Register'>('Login'); + const userInfo = useUserStore().userInfo; + + useEffect(() => { + if (userInfo) return; //no need to check if user info is already present + async function _() { + const user_data = await fetch(`${AUTH_SERVER_URL}/api/user`, { + method: 'GET' + }); + const data = await user_data.json(); + if (data.message !== 'unauthorised') { + getUserStore().userInfo = data; + } + } + _(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + // FIXME: Currently opens in App. + // const socialLoginHandlers = (name: SocialLogin['name']) => { + // return { + // Github: async () => { + // try { + // const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + // thirdPartyId: 'github', + + // // This is where Github should redirect the user back after login or error. + // frontendRedirectURI: 'http://localhost:9420/api/auth/callback/github' + // }); + + // // we redirect the user to Github for auth. + // window.location.assign(authUrl); + // } catch (err: any) { + // if (err.isSuperTokensGeneralError === true) { + // // this may be a custom error message sent from the API by you. + // toast.error(err.message); + // } else { + // toast.error('Oops! Something went wrong.'); + // } + // } + // }, + // Google: async () => { + // try { + // const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + // thirdPartyId: 'google', + + // // This is where Google should redirect the user back after login or error. + // // This URL goes on the Google's dashboard as well. + // frontendRedirectURI: 'http://localhost:9420/api/auth/callback/google' + // }); + + // /* + // Example value of authUrl: https://accounts.google.com/o/oauth2/v2/auth/oauthchooseaccount?scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&access_type=offline&include_granted_scopes=true&response_type=code&client_id=1060725074195-kmeum4crr01uirfl2op9kd5acmi9jutn.apps.googleusercontent.com&state=5a489996a28cafc83ddff&redirect_uri=https%3A%2F%2Fsupertokens.io%2Fdev%2Foauth%2Fredirect-to-app&flowName=GeneralOAuthFlow + // */ + + // // we redirect the user to google for auth. + // window.location.assign(authUrl); + // } catch (err: any) { + // if (err.isSuperTokensGeneralError === true) { + // // this may be a custom error message sent from the API by you. + // toast.error(err.message); + // } else { + // toast.error('Oops! Something went wrong.'); + // } + // } + // }, + // Apple: async () => { + // try { + // const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + // thirdPartyId: 'apple', + + // // This is where Apple should redirect the user back after login or error. + // frontendRedirectURI: 'http://localhost:9420/api/auth/callback/apple' + // }); + + // // we redirect the user to Apple for auth. + // window.location.assign(authUrl); + // } catch (err: any) { + // if (err.isSuperTokensGeneralError === true) { + // // this may be a custom error message sent from the API by you. + // toast.error(err.message); + // } else { + // toast.error('Oops! Something went wrong.'); + // } + // } + // } + // }[name](); + // }; + + return ( + + + + + {AccountTabs.map((text) => ( + + ))} + + + {activeTab === 'Login' ? : } + {/* Disabled for now */} + {/* + + OR + + + + {SocialLogins.map((social) => ( + + ))} + */} + + + + + ); +}; +export default AccountLogin; diff --git a/apps/mobile/src/screens/settings/client/AccountSettings/AccountProfile.tsx b/apps/mobile/src/screens/settings/client/AccountSettings/AccountProfile.tsx new file mode 100644 index 000000000..3734e63a8 --- /dev/null +++ b/apps/mobile/src/screens/settings/client/AccountSettings/AccountProfile.tsx @@ -0,0 +1,179 @@ +import { useNavigation } from '@react-navigation/native'; +import { Envelope } from 'phosphor-react-native'; +import { useEffect, useState } from 'react'; +import { Text, View } from 'react-native'; +import { + SyncStatus, + useBridgeMutation, + useBridgeQuery, + useLibraryMutation, + useLibrarySubscription +} from '@sd/client'; +import Card from '~/components/layout/Card'; +import ScreenContainer from '~/components/layout/ScreenContainer'; +import { Button } from '~/components/primitive/Button'; +import { tw, twStyle } from '~/lib/tailwind'; +import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack'; +import { getUserStore, useUserStore } from '~/stores/userStore'; +import { AUTH_SERVER_URL, getTokens } from '~/utils'; + +const AccountProfile = () => { + const userInfo = useUserStore().userInfo; + + const emailName = userInfo ? userInfo.email.split('@')[0] : ''; + const capitalizedEmailName = (emailName?.charAt(0).toUpperCase() ?? '') + emailName?.slice(1); + const navigator = useNavigation['navigation']>(); + + const cloudBootstrap = useBridgeMutation('cloud.bootstrap'); + const devices = useBridgeQuery(['cloud.devices.list']); + const addLibraryToCloud = useLibraryMutation('cloud.libraries.create'); + const listLibraries = useBridgeQuery(['cloud.libraries.list', true]); + const createSyncGroup = useLibraryMutation('cloud.syncGroups.create'); + const listSyncGroups = useBridgeQuery(['cloud.syncGroups.list']); + const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join'); + const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']); + const [{ accessToken, refreshToken }, setTokens] = useState<{ + accessToken: string; + refreshToken: string; + }>({ + accessToken: '', + refreshToken: '' + }); + useEffect(() => { + (async () => { + const { accessToken, refreshToken } = await getTokens(); + setTokens({ accessToken, refreshToken }); + })(); + }, []); + const [syncStatus, setSyncStatus] = useState(null); + useLibrarySubscription(['sync.active'], { + onData: (data) => { + console.log('sync activity', data); + setSyncStatus(data); + } + }); + + async function signOut() { + await fetch(`${AUTH_SERVER_URL}/api/auth/signout`, { + method: 'POST' + }); + navigator.navigate('AccountLogin'); + getUserStore().userInfo = undefined; + } + + return ( + + + + + + Welcome{' '} + {capitalizedEmailName} + + + + + {userInfo ? userInfo.email : ''} + + + + + + + {/* Sync activity */} + + Sync Activity + + {Object.keys(syncStatus ?? {}).map((status, index) => ( + + + {status} + + ))} + + + + {/* Automatically list libraries */} + + Cloud Libraries + {listLibraries.data?.map((library) => ( + + {library.name} + + )) || No libraries found.} + + + {/* Debug buttons */} + + + + + + + + Library Sync Groups + {listSyncGroups.data?.map((group) => ( + + + {group.library.name} + + + + )) || No sync groups found.} + + + + ); +}; + +export default AccountProfile; diff --git a/apps/mobile/src/screens/settings/client/AccountSettings/Login.tsx b/apps/mobile/src/screens/settings/client/AccountSettings/Login.tsx new file mode 100644 index 000000000..fddd07041 --- /dev/null +++ b/apps/mobile/src/screens/settings/client/AccountSettings/Login.tsx @@ -0,0 +1,195 @@ +import AsyncStorage from '@react-native-async-storage/async-storage'; +import { useNavigation } from '@react-navigation/native'; +import { RSPCError } from '@spacedrive/rspc-client'; +import { UseMutationResult } from '@tanstack/react-query'; +import { useState } from 'react'; +import { Controller } from 'react-hook-form'; +import { Text, View } from 'react-native'; +import { z } from 'zod'; +import { useBridgeMutation, useZodForm } from '@sd/client'; +import { Button } from '~/components/primitive/Button'; +import { Input } from '~/components/primitive/Input'; +import { toast } from '~/components/primitive/Toast'; +import { tw, twStyle } from '~/lib/tailwind'; +import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack'; +import { getUserStore } from '~/stores/userStore'; +import { AUTH_SERVER_URL } from '~/utils'; + +import ShowPassword from './ShowPassword'; + +const LoginSchema = z.object({ + email: z.string().email({ + message: 'Email is required' + }), + password: z.string().min(6, { + message: 'Password must be at least 6 characters' + }) +}); + +const Login = () => { + const [showPassword, setShowPassword] = useState(false); + const form = useZodForm({ + schema: LoginSchema, + defaultValues: { + email: '', + password: '' + } + }); + const updateUserStore = getUserStore(); + const navigator = useNavigation['navigation']>(); + const cloudBootstrap = useBridgeMutation('cloud.bootstrap'); + + return ( + + + ( + + + {form.formState.errors.email && ( + + {form.formState.errors.email.message} + + )} + + )} + /> + ( + + + {form.formState.errors.password && ( + + {form.formState.errors.password.message} + + )} + + + )} + /> + + + + ); +}; + +async function signInClicked( + email: string, + password: string, + navigator: SettingsStackScreenProps<'AccountProfile'>['navigation'], + cloudBootstrap: UseMutationResult, // Cloud bootstrap mutation + updateUserStore: ReturnType +) { + try { + const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signin`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json; charset=utf-8' + }, + body: JSON.stringify({ + formFields: [ + { + id: 'email', + value: email + }, + { + id: 'password', + value: password + } + ] + }) + }); + + const response: { + status: string; + reason?: string; + user?: { + id: string; + email: string; + timeJoined: number; + tenantIds: string[]; + }; + } = await req.json(); + + if (response.status === 'FIELD_ERROR') { + // response.reason?.forEach((formField) => { + // if (formField.id === 'email') { + // // Email validation failed (for example incorrect email syntax). + // toast.error(formField.error); + // } + // }); + console.error('Field error: ', response.reason); + } else if (response.status === 'WRONG_CREDENTIALS_ERROR') { + toast.error('Email & password combination is incorrect.'); + } else if (response.status === 'SIGN_IN_NOT_ALLOWED') { + // the reason string is a user friendly message + // about what went wrong. It can also contain a support code which users + // can tell you so you know why their sign in was not allowed. + toast.error(response.reason!); + } else { + // sign in successful. The session tokens are automatically handled by + // the frontend SDK. + cloudBootstrap.mutate([ + req.headers.get('st-access-token')!, + req.headers.get('st-refresh-token')! + ]); + toast.success('Sign in successful'); + // Update the user store with the user info + updateUserStore.userInfo = response.user; + // Save the access token to AsyncStorage, because SuperTokens doesn't store it correctly. Thanks to the React Native SDK. + await AsyncStorage.setItem('access_token', req.headers.get('st-access-token')!); + await AsyncStorage.setItem('refresh_token', req.headers.get('st-refresh-token')!); + // Refresh the page to show the user is logged in + navigator.navigate('AccountProfile'); + } + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + // this may be a custom error message sent from the API by you. + toast.error(err.message); + } else { + console.error(err); + toast.error('Oops! Something went wrong.'); + } + } +} + +export default Login; diff --git a/apps/mobile/src/screens/settings/client/AccountSettings/Register.tsx b/apps/mobile/src/screens/settings/client/AccountSettings/Register.tsx new file mode 100644 index 000000000..00ee4eb27 --- /dev/null +++ b/apps/mobile/src/screens/settings/client/AccountSettings/Register.tsx @@ -0,0 +1,191 @@ +import { zodResolver } from '@hookform/resolvers/zod'; +import { useNavigation } from '@react-navigation/native'; +import { useState } from 'react'; +import { Controller, useForm } from 'react-hook-form'; +import { Text, View } from 'react-native'; +import { z } from 'zod'; +import { Button } from '~/components/primitive/Button'; +import { Input } from '~/components/primitive/Input'; +import { toast } from '~/components/primitive/Toast'; +import { tw, twStyle } from '~/lib/tailwind'; +import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack'; +import { AUTH_SERVER_URL } from '~/utils'; + +import ShowPassword from './ShowPassword'; + +const RegisterSchema = z + .object({ + email: z.string().email({ + message: 'Email is required' + }), + password: z.string().min(6, { + message: 'Password must be at least 6 characters' + }), + confirmPassword: z.string().min(6, { + message: 'Password must be at least 6 characters' + }) + }) + .refine((data) => data.password === data.confirmPassword, { + message: 'Passwords do not match', + path: ['confirmPassword'] + }); +type RegisterType = z.infer; + +const Register = () => { + const [showPassword, setShowPassword] = useState(false); + // useZodForm seems to be out-dated or needs + //fixing as it does not support the schema using zod.refine + const form = useForm({ + resolver: zodResolver(RegisterSchema), + defaultValues: { + email: '', + password: '', + confirmPassword: '' + } + }); + + const navigator = useNavigation['navigation']>(); + return ( + + ( + + )} + /> + {form.formState.errors.email && ( + {form.formState.errors.email.message} + )} + ( + + + + )} + /> + {form.formState.errors.password && ( + + {form.formState.errors.password.message} + + )} + ( + + + {form.formState.errors.confirmPassword && ( + + {form.formState.errors.confirmPassword.message} + + )} + + + )} + /> + + + ); +}; + +async function signUpClicked( + email: string, + password: string, + navigator: SettingsStackScreenProps<'AccountProfile'>['navigation'] +) { + try { + const req = await fetch(`${AUTH_SERVER_URL}/api/auth/signup`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json; charset=utf-8' + }, + body: JSON.stringify({ + formFields: [ + { + id: 'email', + value: email + }, + { + id: 'password', + value: password + } + ] + }) + }); + + const response: { + status: string; + reason?: string; + user?: { + id: string; + email: string; + timeJoined: number; + tenantIds: string[]; + }; + } = await req.json(); + + if (response.status === 'FIELD_ERROR') { + // one of the input formFields failed validaiton + console.error('Field error: ', response.reason); + } else if (response.status === 'SIGN_UP_NOT_ALLOWED') { + // the reason string is a user friendly message + // about what went wrong. It can also contain a support code which users + // can tell you so you know why their sign up was not allowed. + toast.error(response.reason!); + } else { + // sign up successful. The session tokens are automatically handled by + // the frontend SDK. + toast.success('Sign up successful'); + navigator.navigate('AccountProfile'); + } + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + // this may be a custom error message sent from the API by you. + toast.error(err.message); + } else { + console.error(err); + toast.error('Oops! Something went wrong.'); + } + } +} + +export default Register; diff --git a/apps/mobile/src/screens/settings/client/AccountSettings/ShowPassword.tsx b/apps/mobile/src/screens/settings/client/AccountSettings/ShowPassword.tsx new file mode 100644 index 000000000..ea33539dd --- /dev/null +++ b/apps/mobile/src/screens/settings/client/AccountSettings/ShowPassword.tsx @@ -0,0 +1,29 @@ +import { Eye, EyeClosed } from 'phosphor-react-native'; +import { Text } from 'react-native'; +import { Button } from '~/components/primitive/Button'; +import { tw } from '~/lib/tailwind'; + +interface Props { + showPassword: boolean; + setShowPassword: (value: boolean) => void; + plural?: boolean; +} + +const ShowPassword = ({ showPassword, setShowPassword, plural }: Props) => { + return ( + + ); +}; + +export default ShowPassword; diff --git a/apps/mobile/src/screens/settings/info/Debug.tsx b/apps/mobile/src/screens/settings/info/Debug.tsx index 61afef43f..a83285f74 100644 --- a/apps/mobile/src/screens/settings/info/Debug.tsx +++ b/apps/mobile/src/screens/settings/info/Debug.tsx @@ -1,26 +1,52 @@ -import { useQueryClient } from '@tanstack/react-query'; import React from 'react'; import { Text, View } from 'react-native'; import { - auth, - toggleFeatureFlag, useBridgeMutation, useBridgeQuery, useDebugState, - useFeatureFlags + useFeatureFlags, + useLibraryMutation } from '@sd/client'; import Card from '~/components/layout/Card'; import { Button } from '~/components/primitive/Button'; import { tw } from '~/lib/tailwind'; import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack'; +import { getTokens } from '~/utils'; const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => { const debugState = useDebugState(); const featureFlags = useFeatureFlags(); - const origin = useBridgeQuery(['cloud.getApiOrigin']); - const setOrigin = useBridgeMutation(['cloud.setApiOrigin']); + const [tokens, setTokens] = React.useState({ accessToken: '', refreshToken: '' }); + const accessToken = tokens.accessToken; + const refreshToken = tokens.refreshToken; + // const origin = useBridgeQuery(['cloud.getApiOrigin']); + // const setOrigin = useBridgeMutation(['cloud.setApiOrigin']); - const queryClient = useQueryClient(); + React.useEffect(() => { + async function _() { + const _a = await getTokens(); + setTokens({ accessToken: _a.accessToken, refreshToken: _a.refreshToken }); + } + _(); + }, []); + + const cloudBootstrap = useBridgeMutation(['cloud.bootstrap']); + const addLibraryToCloud = useLibraryMutation('cloud.libraries.create'); + const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join'); + const getGroup = useBridgeQuery([ + 'cloud.syncGroups.get', + { + pub_id: '01924497-a1be-76e3-b62f-9582ea15463a', + // pub_id: '01924a25-966b-7c00-a582-9eed3aadd2cd', + kind: 'WithDevices' + } + ]); + // console.log(getGroup.data); + const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']); + // console.log('Current Device: ', currentDevice.data); + const createSyncGroup = useLibraryMutation('cloud.syncGroups.create'); + + // const queryClient = useQueryClient(); return ( @@ -31,7 +57,7 @@ const DebugScreen = ({ navigation }: SettingsStackScreenProps<'Debug'>) => { {JSON.stringify(featureFlags)} {JSON.stringify(debugState)} - - + */} - */} + + + + diff --git a/apps/mobile/src/screens/settings/library/CloudSettings/CloudSettings.tsx b/apps/mobile/src/screens/settings/library/CloudSettings/CloudSettings.tsx deleted file mode 100644 index 892556d0e..000000000 --- a/apps/mobile/src/screens/settings/library/CloudSettings/CloudSettings.tsx +++ /dev/null @@ -1,130 +0,0 @@ -import { useMemo } from 'react'; -import { ActivityIndicator, FlatList, Text, View } from 'react-native'; -import { useLibraryContext, useLibraryMutation, useLibraryQuery } from '@sd/client'; -import { Icon } from '~/components/icons/Icon'; -import Card from '~/components/layout/Card'; -import Empty from '~/components/layout/Empty'; -import ScreenContainer from '~/components/layout/ScreenContainer'; -import VirtualizedListWrapper from '~/components/layout/VirtualizedListWrapper'; -import { Button } from '~/components/primitive/Button'; -import { Divider } from '~/components/primitive/Divider'; -import { styled, tw, twStyle } from '~/lib/tailwind'; -import { useAuthStateSnapshot } from '~/stores/auth'; - -import Instance from './Instance'; -import Library from './Library'; -import Login from './Login'; -import ThisInstance from './ThisInstance'; - -export const InfoBox = styled(View, 'rounded-md border gap-1 border-app bg-transparent p-2'); - -const CloudSettings = () => { - return ( - - - - ); -}; - -const AuthSensitiveChild = () => { - const authState = useAuthStateSnapshot(); - if (authState.status === 'loggedIn') return ; - if (authState.status === 'notLoggedIn' || authState.status === 'loggingIn') return ; - - return null; -}; - -const Authenticated = () => { - const { library } = useLibraryContext(); - const cloudLibrary = useLibraryQuery(['cloud.library.get'], { retry: false }); - const createLibrary = useLibraryMutation(['cloud.library.create']); - - const cloudInstances = useMemo( - () => - cloudLibrary.data?.instances.filter( - (instance) => instance.uuid !== library.instance_id - ), - [cloudLibrary.data, library.instance_id] - ); - - if (cloudLibrary.isLoading) { - return ( - - - - ); - } - - return ( - - {cloudLibrary.data ? ( - - - - - - - - {cloudInstances?.length} - - - Instances - - - - - } - contentContainerStyle={twStyle( - cloudInstances?.length === 0 && 'flex-row' - )} - showsHorizontalScrollIndicator={false} - ItemSeparatorComponent={() => } - renderItem={({ item }) => } - keyExtractor={(item) => item.id} - numColumns={1} - /> - - - - ) : ( - - - - - Uploading your library to the cloud will allow you to access your - library from other devices using your account & importing. - - - - - )} - - ); -}; - -export default CloudSettings; diff --git a/apps/mobile/src/screens/settings/library/CloudSettings/Instance.tsx b/apps/mobile/src/screens/settings/library/CloudSettings/Instance.tsx deleted file mode 100644 index dbac4a60a..000000000 --- a/apps/mobile/src/screens/settings/library/CloudSettings/Instance.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import { Text, View } from 'react-native'; -import { CloudInstance, HardwareModel } from '@sd/client'; -import { Icon } from '~/components/icons/Icon'; -import { hardwareModelToIcon } from '~/components/overview/Devices'; -import { tw } from '~/lib/tailwind'; - -import { InfoBox } from './CloudSettings'; - -interface Props { - data: CloudInstance; -} - -const Instance = ({ data }: Props) => { - return ( - - - - - - - {data.metadata.name} - - - - Id: - - {data.id} - - - - - - - - UUID: - - {data.uuid} - - - - - - - - Public key: - - {data.identity} - - - - - - ); -}; - -export default Instance; diff --git a/apps/mobile/src/screens/settings/library/CloudSettings/Library.tsx b/apps/mobile/src/screens/settings/library/CloudSettings/Library.tsx deleted file mode 100644 index 9f848173f..000000000 --- a/apps/mobile/src/screens/settings/library/CloudSettings/Library.tsx +++ /dev/null @@ -1,66 +0,0 @@ -import { CheckCircle, XCircle } from 'phosphor-react-native'; -import { useMemo } from 'react'; -import { Text, View } from 'react-native'; -import { CloudLibrary, useLibraryContext, useLibraryMutation } from '@sd/client'; -import Card from '~/components/layout/Card'; -import { Button } from '~/components/primitive/Button'; -import { Divider } from '~/components/primitive/Divider'; -import { SettingsTitle } from '~/components/settings/SettingsContainer'; -import { tw } from '~/lib/tailwind'; -import { logout, useAuthStateSnapshot } from '~/stores/auth'; - -import { InfoBox } from './CloudSettings'; - -interface LibraryProps { - cloudLibrary?: CloudLibrary; -} - -const Library = ({ cloudLibrary }: LibraryProps) => { - const authState = useAuthStateSnapshot(); - const { library } = useLibraryContext(); - const syncLibrary = useLibraryMutation(['cloud.library.sync']); - const thisInstance = useMemo( - () => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id), - [cloudLibrary, library.instance_id] - ); - - return ( - - - Library - {authState.status === 'loggedIn' && ( - - )} - - - Name - - {cloudLibrary?.name} - - - - ); -}; - -export default Library; diff --git a/apps/mobile/src/screens/settings/library/CloudSettings/Login.tsx b/apps/mobile/src/screens/settings/library/CloudSettings/Login.tsx deleted file mode 100644 index 88738c329..000000000 --- a/apps/mobile/src/screens/settings/library/CloudSettings/Login.tsx +++ /dev/null @@ -1,45 +0,0 @@ -import { Text, View } from 'react-native'; -import { Icon } from '~/components/icons/Icon'; -import Card from '~/components/layout/Card'; -import { Button } from '~/components/primitive/Button'; -import { tw } from '~/lib/tailwind'; -import { cancel, login, useAuthStateSnapshot } from '~/stores/auth'; - -const Login = () => { - const authState = useAuthStateSnapshot(); - const buttonText = { - notLoggedIn: 'Login', - loggingIn: 'Cancel' - }; - return ( - - - - - - Cloud Sync will upload your library to the cloud so you can access your - library from other devices by importing it from the cloud. - - - {(authState.status === 'notLoggedIn' || authState.status === 'loggingIn') && ( - - )} - - - ); -}; - -export default Login; diff --git a/apps/mobile/src/screens/settings/library/CloudSettings/ThisInstance.tsx b/apps/mobile/src/screens/settings/library/CloudSettings/ThisInstance.tsx deleted file mode 100644 index 041d6591c..000000000 --- a/apps/mobile/src/screens/settings/library/CloudSettings/ThisInstance.tsx +++ /dev/null @@ -1,76 +0,0 @@ -import { useMemo } from 'react'; -import { Text, View } from 'react-native'; -import { CloudLibrary, HardwareModel, useLibraryContext } from '@sd/client'; -import { Icon } from '~/components/icons/Icon'; -import Card from '~/components/layout/Card'; -import { hardwareModelToIcon } from '~/components/overview/Devices'; -import { Divider } from '~/components/primitive/Divider'; -import { tw } from '~/lib/tailwind'; - -import { InfoBox } from './CloudSettings'; - -interface ThisInstanceProps { - cloudLibrary?: CloudLibrary; -} - -const ThisInstance = ({ cloudLibrary }: ThisInstanceProps) => { - const { library } = useLibraryContext(); - const thisInstance = useMemo( - () => cloudLibrary?.instances.find((instance) => instance.uuid === library.instance_id), - [cloudLibrary, library.instance_id] - ); - - if (!thisInstance) return null; - - return ( - - - This Instance - - - - - - {thisInstance.metadata.name} - - - - - - Id: - {thisInstance.id} - - - - - - - UUID: - - {thisInstance.uuid} - - - - - - - - Publc Key: - - {thisInstance.identity} - - - - - - ); -}; - -export default ThisInstance; diff --git a/apps/mobile/src/screens/settings/library/SyncSettings.tsx b/apps/mobile/src/screens/settings/library/SyncSettings.tsx deleted file mode 100644 index 0097625c2..000000000 --- a/apps/mobile/src/screens/settings/library/SyncSettings.tsx +++ /dev/null @@ -1,158 +0,0 @@ -import { useIsFocused } from '@react-navigation/native'; -import { inferSubscriptionResult } from '@spacedrive/rspc-client'; -import { MotiView } from 'moti'; -import { Circle } from 'phosphor-react-native'; -import React, { useEffect, useRef, useState } from 'react'; -import { Text, View } from 'react-native'; -import { - Procedures, - useLibraryMutation, - useLibraryQuery, - useLibrarySubscription -} from '@sd/client'; -import { Icon } from '~/components/icons/Icon'; -import Card from '~/components/layout/Card'; -import { ModalRef } from '~/components/layout/Modal'; -import ScreenContainer from '~/components/layout/ScreenContainer'; -import CloudModal from '~/components/modal/cloud/CloudModal'; -import { Button } from '~/components/primitive/Button'; -import { tw } from '~/lib/tailwind'; -import { SettingsStackScreenProps } from '~/navigation/tabs/SettingsStack'; - -const SyncSettingsScreen = ({ navigation }: SettingsStackScreenProps<'SyncSettings'>) => { - const syncEnabled = useLibraryQuery(['sync.enabled']); - const [data, setData] = useState>({}); - const modalRef = useRef(null); - - const [startBackfill, setStart] = useState(false); - const pageFocused = useIsFocused(); - const [showCloudModal, setShowCloudModal] = useState(false); - - useLibrarySubscription(['library.actors'], { onData: setData }); - - useEffect(() => { - if (startBackfill === true) { - navigation.navigate('BackfillWaitingStack', { - screen: 'BackfillWaiting' - }); - setTimeout(() => setShowCloudModal(true), 1000); - } - }, [startBackfill, navigation]); - - useEffect(() => { - if (pageFocused && showCloudModal) modalRef.current?.present(); - return () => { - if (showCloudModal) setShowCloudModal(false); - }; - }, [pageFocused, showCloudModal]); - - return ( - - {syncEnabled.data === false ? ( - - - - - - With Sync, you can share your library with other devices using P2P - technology. - - - Additionally, allowing you to enable Cloud services to upload your - library to the cloud, making it accessible on any of your devices. - - - - - - ) : ( - - {Object.keys(data).map((key) => { - return ( - - - - {key} - - {data[key] ? : } - - ); - })} - - )} - - - ); -}; - -export default SyncSettingsScreen; - -function OnlineIndicator({ online }: { online: boolean }) { - const size = 6; - return ( - - {online ? ( - - - - - ) : ( - - )} - - ); -} - -function StartButton({ name }: { name: string }) { - const startActor = useLibraryMutation(['library.startActor']); - return ( - - ); -} - -function StopButton({ name }: { name: string }) { - const stopActor = useLibraryMutation(['library.stopActor']); - return ( - - ); -} diff --git a/apps/mobile/src/stores/auth.ts b/apps/mobile/src/stores/auth.ts index 336b3ff22..3f99024dd 100644 --- a/apps/mobile/src/stores/auth.ts +++ b/apps/mobile/src/stores/auth.ts @@ -18,16 +18,16 @@ export function useAuthStateSnapshot() { return useSolidStore(store).state; } -nonLibraryClient - .query(['auth.me']) - .then(() => (store.state = { status: 'loggedIn' })) - .catch((e) => { - if (e instanceof RSPCError && e.code === 401) { - // TODO: handle error? - console.error('error', e); - } - store.state = { status: 'notLoggedIn' }; - }); +// nonLibraryClient +// .query(['auth.me']) +// .then(() => (store.state = { status: 'loggedIn' })) +// .catch((e) => { +// if (e instanceof RSPCError && e.code === 401) { +// // TODO: handle error? +// console.error('error', e); +// } +// store.state = { status: 'notLoggedIn' }; +// }); type CallbackStatus = 'success' | { error: string } | 'cancel'; const loginCallbacks = new Set<(status: CallbackStatus) => void>(); @@ -41,29 +41,29 @@ export function login() { store.state = { status: 'loggingIn' }; - let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], { - onData(data) { - if (data === 'Complete') { - loginCallbacks.forEach((cb) => cb('success')); - } else if ('Error' in data) { - console.error('[auth] error: ', data.Error); - onError(data.Error); - } else { - console.log('[auth] verification url: ', data.Start.verification_url_complete); - Promise.resolve() - .then(() => Linking.openURL(data.Start.verification_url_complete)) - .then( - (res) => { - authCleanup = res; - }, - (e) => onError(e.message) - ); - } - }, - onError(e) { - onError(e.message); - } - }); + // let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], { + // onData(data) { + // if (data === 'Complete') { + // loginCallbacks.forEach((cb) => cb('success')); + // } else if ('Error' in data) { + // console.error('[auth] error: ', data.Error); + // onError(data.Error); + // } else { + // console.log('[auth] verification url: ', data.Start.verification_url_complete); + // Promise.resolve() + // .then(() => Linking.openURL(data.Start.verification_url_complete)) + // .then( + // (res) => { + // authCleanup = res; + // }, + // (e) => onError(e.message) + // ); + // } + // }, + // onError(e) { + // onError(e.message); + // } + // }); return new Promise((res, rej) => { const cb = async (status: CallbackStatus) => { @@ -71,7 +71,7 @@ export function login() { if (status === 'success') { store.state = { status: 'loggedIn' }; - nonLibraryClient.query(['auth.me']); + // nonLibraryClient.query(['auth.me']); res(); } else { store.state = { status: 'notLoggedIn' }; @@ -88,8 +88,8 @@ export function set_logged_in() { export function logout() { store.state = { status: 'loggingOut' }; - nonLibraryClient.mutation(['auth.logout']); - nonLibraryClient.query(['auth.me']); + // nonLibraryClient.mutation(['auth.logout']); + // nonLibraryClient.query(['auth.me']); store.state = { status: 'notLoggedIn' }; } diff --git a/apps/mobile/src/stores/userStore.ts b/apps/mobile/src/stores/userStore.ts new file mode 100644 index 000000000..8a1bfe6a7 --- /dev/null +++ b/apps/mobile/src/stores/userStore.ts @@ -0,0 +1,22 @@ +import { proxy, useSnapshot } from 'valtio'; + +export type User = { + id: string; + email: string; + timeJoined: number; + tenantIds: string[]; +}; + +const state = { + userInfo: undefined as User | undefined +}; + +const store = proxy({ + ...state +}); + +// for reading +export const useUserStore = () => useSnapshot(store); + +// for writing +export const getUserStore = () => store; diff --git a/apps/mobile/src/utils/index.ts b/apps/mobile/src/utils/index.ts new file mode 100644 index 000000000..8d4e4077b --- /dev/null +++ b/apps/mobile/src/utils/index.ts @@ -0,0 +1,13 @@ +import AsyncStorage from '@react-native-async-storage/async-storage'; + +export async function getTokens() { + const fetchedToken = await AsyncStorage.getItem('access_token'); + const fetchedRefreshToken = await AsyncStorage.getItem('refresh_token'); + return { + accessToken: fetchedToken ?? '', + refreshToken: fetchedRefreshToken ?? '' + }; +} + +// export const AUTH_SERVER_URL = __DEV__ ? 'http://localhost:9420' : 'https://auth.spacedrive.com'; +export const AUTH_SERVER_URL = 'https://auth.spacedrive.com'; diff --git a/apps/server/src/main.rs b/apps/server/src/main.rs index 5a6304bf3..4a4246312 100644 --- a/apps/server/src/main.rs +++ b/apps/server/src/main.rs @@ -144,19 +144,7 @@ async fn main() { let state = AppState { auth }; - let (node, router) = match Node::new( - data_dir, - sd_core::Env { - api_url: tokio::sync::Mutex::new( - std::env::var("SD_API_URL") - .unwrap_or_else(|_| "https://app.spacedrive.com".to_string()), - ), - client_id: std::env::var("SD_CLIENT_ID") - .unwrap_or_else(|_| "04701823-a498-406e-aef9-22081c1dae34".to_string()), - }, - ) - .await - { + let (node, router) = match Node::new(data_dir).await { Ok(d) => d, Err(e) => { panic!("{}", e.to_string()) diff --git a/apps/storybook/package.json b/apps/storybook/package.json index 0f5786227..5dfe9f15f 100644 --- a/apps/storybook/package.json +++ b/apps/storybook/package.json @@ -30,6 +30,6 @@ "storybook": "^8.0.1", "tailwindcss": "^3.4.10", "typescript": "^5.6.2", - "vite": "^5.2.0" + "vite": "^5.4.9" } } diff --git a/apps/web/package.json b/apps/web/package.json index ad2a51ebf..85ab05a77 100644 --- a/apps/web/package.json +++ b/apps/web/package.json @@ -41,7 +41,7 @@ "rollup-plugin-visualizer": "^5.12.0", "start-server-and-test": "^2.0.3", "typescript": "^5.6.2", - "vite": "^5.2.0", - "vite-tsconfig-paths": "^4.3.2" + "vite": "^5.4.9", + "vite-tsconfig-paths": "^5.0.1" } } diff --git a/core/Cargo.toml b/core/Cargo.toml index 3d3762464..d2880a0e8 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -20,6 +20,7 @@ heif = ["sd-images/heif"] [dependencies] # Inner Core Sub-crates +sd-core-cloud-services = { path = "./crates/cloud-services" } sd-core-file-path-helper = { path = "./crates/file-path-helper" } sd-core-heavy-lifting = { path = "./crates/heavy-lifting" } sd-core-indexer-rules = { path = "./crates/indexer-rules" } @@ -29,7 +30,8 @@ sd-core-sync = { path = "./crates/sync" } # Spacedrive Sub-crates sd-actors = { path = "../crates/actors" } sd-ai = { path = "../crates/ai", optional = true } -sd-cloud-api = { path = "../crates/cloud-api" } +sd-crypto = { path = "../crates/crypto" } +sd-ffmpeg = { path = "../crates/ffmpeg", optional = true } sd-file-ext = { path = "../crates/file-ext" } sd-images = { path = "../crates/images", features = ["rspc", "serde", "specta"] } sd-media-metadata = { path = "../crates/media-metadata" } @@ -49,6 +51,7 @@ async-trait = { workspace = true } axum = { workspace = true, features = ["ws"] } base64 = { workspace = true } blake3 = { workspace = true } +bytes = { workspace = true } chrono = { workspace = true, features = ["serde"] } futures = { workspace = true } futures-concurrency = { workspace = true } @@ -64,6 +67,7 @@ reqwest = { workspace = true, features = ["json", "native-tls-vendor rmp-serde = { workspace = true } rmpv = { workspace = true } rspc = { workspace = true, features = ["alpha", "axum", "chrono", "unstable", "uuid"] } +sd-cloud-schema = { workspace = true } serde = { workspace = true, features = ["derive", "rc"] } serde_json = { workspace = true } specta = { workspace = true } @@ -75,12 +79,11 @@ tokio-stream = { workspace = true, features = ["fs"] } tokio-util = { workspace = true, features = ["io"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } -uuid = { workspace = true, features = ["serde", "v4"] } +uuid = { workspace = true, features = ["serde", "v4", "v7"] } # Specific Core dependencies async-recursion = "1.1" base91 = "0.1.0" -bytes = "1.6" ctor = "0.2.8" directories = "5.0" flate2 = "1.0" @@ -98,6 +101,7 @@ sysinfo = "0.29.11" # Update blocked tar = "0.4.41" tower-service = "0.3.2" tracing-appender = "0.2.3" +whoami = "1.5.2" [dependencies.tokio] features = ["io-util", "macros", "process", "rt-multi-thread", "sync", "time"] diff --git a/core/crates/cloud-services/Cargo.toml b/core/crates/cloud-services/Cargo.toml new file mode 100644 index 000000000..baffe812d --- /dev/null +++ b/core/crates/cloud-services/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "sd-core-cloud-services" +version = "0.1.0" + +edition = "2021" + +[dependencies] +# Core Spacedrive Sub-crates +sd-core-sync = { path = "../sync" } + +# Spacedrive Sub-crates +sd-actors = { path = "../../../crates/actors" } +sd-cloud-schema = { workspace = true } +sd-crypto = { path = "../../../crates/crypto" } +sd-prisma = { path = "../../../crates/prisma" } +sd-utils = { path = "../../../crates/utils" } + +# Workspace dependencies +async-stream = { workspace = true } +base64 = { workspace = true } +blake3 = { workspace = true } +chrono = { workspace = true, features = ["serde"] } +flume = { workspace = true } +futures = { workspace = true } +futures-concurrency = { workspace = true } +rmp-serde = { workspace = true } +rspc = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +specta = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync", "time"] } +tokio-stream = { workspace = true } +tokio-util = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true, features = ["serde"] } +zeroize = { workspace = true } + +# External dependencies +anyhow = "1.0.86" +dashmap = "6.1.0" +iroh-net = { version = "0.27", features = ["discovery-local-network", "iroh-relay"] } +paste = "=1.0.15" +quic-rpc = { version = "0.12.1", features = ["quinn-transport"] } +quinn = { package = "iroh-quinn", version = "0.11" } +# Using whatever version of reqwest that reqwest-middleware uses, just putting here to enable some features +reqwest = { version = "0.12", features = ["json", "native-tls-vendored", "stream"] } +reqwest-middleware = { version = "0.3", features = ["json"] } +reqwest-retry = "0.6" +rustls = { version = "=0.23.15", default-features = false, features = ["brotli", "ring", "std"] } +rustls-platform-verifier = "0.3.3" + + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "sync", "time"] } diff --git a/core/crates/cloud-services/src/client.rs b/core/crates/cloud-services/src/client.rs new file mode 100644 index 000000000..d9ec361e1 --- /dev/null +++ b/core/crates/cloud-services/src/client.rs @@ -0,0 +1,358 @@ +use crate::p2p::{NotifyUser, UserResponse}; + +use sd_cloud_schema::{Client, Service, ServicesALPN}; + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use futures::Stream; +use iroh_net::relay::RelayUrl; +use quic_rpc::{transport::quinn::QuinnConnection, RpcClient}; +use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint}; +use reqwest::{IntoUrl, Url}; +use reqwest_middleware::{reqwest, ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::sync::{Mutex, RwLock}; +use tracing::warn; + +use super::{ + error::Error, key_manager::KeyManager, p2p::CloudP2P, token_refresher::TokenRefresher, +}; + +#[derive(Debug, Default, Clone)] +enum ClientState { + #[default] + NotConnected, + Connected(Client, Service>), +} + +/// Cloud services are a optional feature that allows you to interact with the cloud services +/// of Spacedrive. +/// They're optional in two different ways: +/// - The cloud services depends on a user being logged in with our server. +/// - The user being connected to the internet to begin with. +/// +/// As we don't want to force the user to be connected to the internet, we have to make sure +/// that core can always operate without the cloud services. +#[derive(Debug)] +pub struct CloudServices { + client_state: Arc>, + get_cloud_api_address: Url, + http_client: ClientWithMiddleware, + domain_name: String, + pub cloud_p2p_dns_origin_name: String, + pub cloud_p2p_relay_url: RelayUrl, + pub cloud_p2p_dns_pkarr_url: Url, + pub token_refresher: TokenRefresher, + key_manager: Arc>>>, + cloud_p2p: Arc>>>, + pub(crate) notify_user_tx: flume::Sender, + notify_user_rx: flume::Receiver, + user_response_tx: flume::Sender, + pub(crate) user_response_rx: flume::Receiver, + pub has_bootstrapped: Arc>, +} + +impl CloudServices { + /// Creates a new cloud services client that can be used to interact with the cloud services. + /// The client will try to connect to the cloud services on a best effort basis, as the user + /// might not be connected to the internet. + /// If the client fails to connect, it will try again the next time it's used. + pub async fn new( + get_cloud_api_address: impl IntoUrl + Send, + cloud_p2p_relay_url: impl IntoUrl + Send, + cloud_p2p_dns_pkarr_url: impl IntoUrl + Send, + cloud_p2p_dns_origin_name: String, + domain_name: String, + ) -> Result { + let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3)); + + #[cfg(not(debug_assertions))] + { + http_client_builder = http_client_builder.https_only(true); + } + + let cloud_p2p_relay_url = cloud_p2p_relay_url + .into_url() + .map_err(Error::InvalidUrl)? + .into(); + + let cloud_p2p_dns_pkarr_url = cloud_p2p_dns_pkarr_url + .into_url() + .map_err(Error::InvalidUrl)?; + + let http_client = + ClientBuilder::new(http_client_builder.build().map_err(Error::HttpClientInit)?) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff::builder().build_with_max_retries(3), + )) + .build(); + let get_cloud_api_address = get_cloud_api_address + .into_url() + .map_err(Error::InvalidUrl)?; + + let client_state = match Self::init_client( + &http_client, + get_cloud_api_address.clone(), + domain_name.clone(), + ) + .await + { + Ok(client) => Arc::new(RwLock::new(ClientState::Connected(client))), + Err(e) => { + warn!( + ?e, + "Failed to initialize cloud services client; \ + This is a best effort and we will continue in Not Connected mode" + ); + Arc::new(RwLock::new(ClientState::NotConnected)) + } + }; + + let (notify_user_tx, notify_user_rx) = flume::bounded(16); + let (user_response_tx, user_response_rx) = flume::bounded(16); + + Ok(Self { + client_state, + token_refresher: TokenRefresher::new( + http_client.clone(), + get_cloud_api_address.clone(), + ), + get_cloud_api_address, + http_client, + cloud_p2p_dns_origin_name, + cloud_p2p_relay_url, + cloud_p2p_dns_pkarr_url, + domain_name, + key_manager: Arc::default(), + cloud_p2p: Arc::default(), + notify_user_tx, + notify_user_rx, + user_response_tx, + user_response_rx, + has_bootstrapped: Arc::default(), + }) + } + + pub fn stream_user_notifications(&self) -> impl Stream + '_ { + self.notify_user_rx.stream() + } + + #[must_use] + pub const fn http_client(&self) -> &ClientWithMiddleware { + &self.http_client + } + + /// Send back a user response to the Cloud P2P actor + /// + /// # Panics + /// Will panic if the channel is closed, which should never happen + pub async fn send_user_response(&self, response: UserResponse) { + self.user_response_tx + .send_async(response) + .await + .expect("user response channel must never close"); + } + + async fn init_client( + http_client: &ClientWithMiddleware, + get_cloud_api_address: Url, + domain_name: String, + ) -> Result, Service>, Error> { + let cloud_api_address = http_client + .get(get_cloud_api_address) + .send() + .await + .map_err(Error::FailedToRequestApiAddress)? + .error_for_status() + .map_err(Error::AuthServerError)? + .text() + .await + .map_err(Error::FailedToExtractApiAddress)? + .parse::()?; + + let mut crypto_config = { + #[cfg(debug_assertions)] + { + #[derive(Debug)] + struct SkipServerVerification; + impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result + { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA1, + rustls::SignatureScheme::ECDSA_SHA1_Legacy, + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP521_SHA512, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::ED448, + ] + } + } + + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(SkipServerVerification)) + .with_no_client_auth() + } + + #[cfg(not(debug_assertions))] + { + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new( + rustls_platform_verifier::Verifier::new(), + )) + .with_no_client_auth() + } + }; + + crypto_config + .alpn_protocols + .extend([ServicesALPN::LATEST.to_vec()]); + + let client_config = ClientConfig::new(Arc::new( + QuicClientConfig::try_from(crypto_config) + .expect("misconfigured TLS client config, this is a bug and should crash"), + )); + + let mut endpoint = Endpoint::client("[::]:0".parse().expect("hardcoded address")) + .map_err(Error::FailedToCreateEndpoint)?; + endpoint.set_default_client_config(client_config); + + // TODO(@fogodev): It's possible that we can't keep the connection alive all the time, + // and need to use single shot connections. I will only be sure when we have + // actually battle-tested the cloud services in core. + Ok(Client::new(RpcClient::new(QuinnConnection::new( + endpoint, + cloud_api_address, + domain_name, + )))) + } + + /// Returns a client to the cloud services. + /// + /// If the client is not connected, it will try to connect to the cloud services. + /// Available routes documented in + /// [`sd_cloud_schema::Service`](https://github.com/spacedriveapp/cloud-services-schema). + pub async fn client(&self) -> Result, Service>, Error> { + if let ClientState::Connected(client) = { self.client_state.read().await.clone() } { + return Ok(client); + } + + // If we're not connected, we need to try to connect. + let client = Self::init_client( + &self.http_client, + self.get_cloud_api_address.clone(), + self.domain_name.clone(), + ) + .await?; + *self.client_state.write().await = ClientState::Connected(client.clone()); + + Ok(client) + } + + pub async fn set_key_manager(&self, key_manager: KeyManager) { + self.key_manager + .write() + .await + .replace(Arc::new(key_manager)); + } + + pub async fn key_manager(&self) -> Result, Error> { + self.key_manager + .read() + .await + .as_ref() + .map_or(Err(Error::KeyManagerNotInitialized), |key_manager| { + Ok(Arc::clone(key_manager)) + }) + } + + pub async fn set_cloud_p2p(&self, cloud_p2p: CloudP2P) { + self.cloud_p2p.write().await.replace(Arc::new(cloud_p2p)); + } + + pub async fn cloud_p2p(&self) -> Result, Error> { + self.cloud_p2p + .read() + .await + .as_ref() + .map_or(Err(Error::CloudP2PNotInitialized), |cloud_p2p| { + Ok(Arc::clone(cloud_p2p)) + }) + } +} + +#[cfg(test)] +mod tests { + use sd_cloud_schema::{auth, devices}; + + use super::*; + + #[ignore] + #[tokio::test] + async fn test_client() { + let response = CloudServices::new( + "http://localhost:9420/cloud-api-address", + "http://relay.localhost:9999/", + "http://pkarr.localhost:9999/", + "dns.localhost:9999".to_string(), + "localhost".to_string(), + ) + .await + .unwrap() + .client() + .await + .unwrap() + .devices() + .list(devices::list::Request { + access_token: auth::AccessToken("invalid".to_string()), + }) + .await + .unwrap(); + + assert!(matches!( + response, + Err(sd_cloud_schema::Error::Client( + sd_cloud_schema::error::ClientSideError::Unauthorized + )) + )); + } +} diff --git a/core/crates/cloud-services/src/error.rs b/core/crates/cloud-services/src/error.rs new file mode 100644 index 000000000..f90ee028e --- /dev/null +++ b/core/crates/cloud-services/src/error.rs @@ -0,0 +1,170 @@ +use sd_cloud_schema::{cloud_p2p, sync::groups, Service}; +use sd_utils::error::FileIOError; + +use std::{io, net::AddrParseError}; + +use quic_rpc::{ + pattern::{bidi_streaming, rpc, server_streaming}, + transport::quinn::QuinnConnection, +}; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + // Setup errors + #[error("Couldn't parse Cloud Services API address URL: {0}")] + InvalidUrl(reqwest::Error), + #[error("Failed to parse Cloud Services API address URL")] + FailedToParseRelayUrl, + #[error("Failed to initialize http client: {0}")] + HttpClientInit(reqwest::Error), + #[error("Failed to request Cloud Services API address from Auth Server route: {0}")] + FailedToRequestApiAddress(reqwest_middleware::Error), + #[error("Auth Server's Cloud Services API address route returned an error: {0}")] + AuthServerError(reqwest::Error), + #[error( + "Failed to extract response body from Auth Server's Cloud Services API address route: {0}" + )] + FailedToExtractApiAddress(reqwest::Error), + #[error("Failed to parse auth server's Cloud Services API address: {0}")] + FailedToParseApiAddress(#[from] AddrParseError), + #[error("Failed to create endpoint: {0}")] + FailedToCreateEndpoint(io::Error), + + // Token refresher errors + #[error("Invalid token format, missing claims")] + MissingClaims, + #[error("Failed to decode access token data: {0}")] + DecodeAccessTokenData(#[from] base64::DecodeError), + #[error("Failed to deserialize access token json data: {0}")] + DeserializeAccessTokenData(#[from] serde_json::Error), + #[error("Token expired")] + TokenExpired, + #[error("Failed to request refresh token: {0}")] + RefreshTokenRequest(reqwest_middleware::Error), + #[error("Missing tokens on refresh response")] + MissingTokensOnRefreshResponse, + #[error("Failed to parse token header value to string: {0}")] + FailedToParseTokenHeaderValueToString(#[from] reqwest::header::ToStrError), + + // Key Manager errors + #[error("Failed to handle File on KeyManager: {0}")] + FileIO(#[from] FileIOError), + #[error("Failed to handle key store serialization: {0}")] + KeyStoreSerialization(rmp_serde::encode::Error), + #[error("Failed to handle key store deserialization: {0}")] + KeyStoreDeserialization(rmp_serde::decode::Error), + #[error("Key store encryption related error: {{context: \"{context}\", source: {source}}}")] + KeyStoreCrypto { + #[source] + source: sd_crypto::Error, + context: &'static str, + }, + #[error("Key manager not initialized")] + KeyManagerNotInitialized, + + // Cloud P2P errors + #[error("Failed to create Cloud P2P endpoint: {0}")] + CreateCloudP2PEndpoint(anyhow::Error), + #[error("Failed to connect to Cloud P2P node: {0}")] + ConnectToCloudP2PNode(anyhow::Error), + #[error("Communication error with Cloud P2P node: {0}")] + CloudP2PRpcCommunication(#[from] rpc::Error>), + #[error("Cloud P2P not initialized")] + CloudP2PNotInitialized, + #[error("Failed to initialize LocalSwarmDiscovery: {0}")] + LocalSwarmDiscoveryInit(anyhow::Error), + #[error("Failed to initialize DhtDiscovery: {0}")] + DhtDiscoveryInit(anyhow::Error), + + // Communication errors + #[error("Failed to communicate with RPC backend: {0}")] + RpcCommunication(#[from] rpc::Error>), + #[error("Failed to communicate with Server Streaming RPC backend: {0}")] + ServerStreamCommunication(#[from] server_streaming::Error>), + #[error("Failed to receive next response from Server Streaming RPC backend: {0}")] + ServerStreamRecv(#[from] server_streaming::ItemError>), + #[error("Failed to communicate with Bidi Streaming RPC backend: {0}")] + BidiStreamCommunication(#[from] bidi_streaming::Error>), + #[error("Failed to receive next response from Bidi Streaming RPC backend: {0}")] + BidiStreamRecv(#[from] bidi_streaming::ItemError>), + #[error("Error from backend: {0}")] + Backend(#[from] sd_cloud_schema::Error), + #[error("Failed to get access token from refresher: {0}")] + GetToken(#[from] GetTokenError), + #[error("Unexpected empty response from backend, context: {0}")] + EmptyResponse(&'static str), + #[error("Unexpected response from backend, context: {0}")] + UnexpectedResponse(&'static str), + + // Sync error + #[error("Sync error: {0}")] + Sync(#[from] sd_core_sync::Error), + #[error("Tried to sync messages with a group without having needed key")] + MissingSyncGroupKey(groups::PubId), + #[error("Failed to encrypt sync messages: {0}")] + Encrypt(sd_crypto::Error), + #[error("Failed to decrypt sync messages: {0}")] + Decrypt(sd_crypto::Error), + #[error("Failed to upload sync messages: {0}")] + UploadSyncMessages(reqwest_middleware::Error), + #[error("Failed to download sync messages: {0}")] + DownloadSyncMessages(reqwest_middleware::Error), + #[error("Received an error response from uploading sync messages: {0}")] + ErrorResponseUploadSyncMessages(reqwest::Error), + #[error("Received an error response from downloading sync messages: {0}")] + ErrorResponseDownloadSyncMessages(reqwest::Error), + #[error( + "Received an error response from downloading sync messages while reading its bytes: {0}" + )] + ErrorResponseDownloadReadBytesSyncMessages(reqwest::Error), + #[error("Critical error while uploading sync messages")] + CriticalErrorWhileUploadingSyncMessages, + #[error("Failed to send End update to push sync messages")] + EndUpdatePushSyncMessages(io::Error), + #[error("Unexpected end of stream while encrypting sync messages")] + UnexpectedEndOfStream, + #[error("Failed to create directory to store timestamp keeper files")] + FailedToCreateTimestampKeepersDirectory(io::Error), + #[error("Failed to read last timestamp keeper for pulling sync messages: {0}")] + FailedToReadLastTimestampKeeper(io::Error), + #[error("Failed to handle last timestamp keeper serialization: {0}")] + LastTimestampKeeperSerialization(rmp_serde::encode::Error), + #[error("Failed to handle last timestamp keeper deserialization: {0}")] + LastTimestampKeeperDeserialization(rmp_serde::decode::Error), + #[error("Failed to write last timestamp keeper for pulling sync messages: {0}")] + FailedToWriteLastTimestampKeeper(io::Error), + #[error("Sync messages download and decrypt task panicked")] + SyncMessagesDownloadAndDecryptTaskPanicked, + #[error("Serialization failure to push sync messages: {0}")] + SerializationFailureToPushSyncMessages(rmp_serde::encode::Error), + #[error("Deserialization failure to pull sync messages: {0}")] + DeserializationFailureToPullSyncMessages(rmp_serde::decode::Error), + #[error("Read nonce stream decryption: {0}")] + ReadNonceStreamDecryption(io::Error), + #[error("Incomplete download bytes sync messages")] + IncompleteDownloadBytesSyncMessages, + + // Temporary errors + #[error("Device missing secret key for decrypting sync messages")] + MissingKeyHash, +} + +#[derive(thiserror::Error, Debug)] +pub enum GetTokenError { + #[error("Token refresher not initialized")] + RefresherNotInitialized, + #[error("Token refresher failed to refresh and need to be initialized again")] + FailedToRefresh, +} + +impl From for rspc::Error { + fn from(e: Error) -> Self { + Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e) + } +} + +impl From for rspc::Error { + fn from(e: GetTokenError) -> Self { + Self::with_cause(rspc::ErrorCode::InternalServerError, e.to_string(), e) + } +} diff --git a/core/crates/cloud-services/src/key_manager/key_store.rs b/core/crates/cloud-services/src/key_manager/key_store.rs new file mode 100644 index 000000000..acf97dad9 --- /dev/null +++ b/core/crates/cloud-services/src/key_manager/key_store.rs @@ -0,0 +1,331 @@ +use crate::Error; + +use sd_cloud_schema::{ + sync::{groups, KeyHash}, + NodeId, SecretKey as IrohSecretKey, +}; +use sd_crypto::{ + cloud::{decrypt, encrypt, secret_key::SecretKey}, + primitives::{EncryptedBlock, OneShotNonce, StreamNonce}, + CryptoRng, +}; +use sd_utils::error::FileIOError; +use tracing::debug; + +use std::{ + collections::{BTreeMap, VecDeque}, + fs::Metadata, + path::PathBuf, + pin::pin, +}; + +use futures::StreamExt; +use serde::{Deserialize, Serialize}; +use tokio::{ + fs, + io::{AsyncReadExt, AsyncWriteExt, BufWriter}, +}; +use zeroize::{Zeroize, ZeroizeOnDrop}; + +type KeyStack = VecDeque<(KeyHash, SecretKey)>; + +#[derive(Serialize, Deserialize)] +pub struct KeyStore { + iroh_secret_key: IrohSecretKey, + keys: BTreeMap, +} + +impl KeyStore { + pub const fn new(iroh_secret_key: IrohSecretKey) -> Self { + Self { + iroh_secret_key, + keys: BTreeMap::new(), + } + } + + pub fn add_key(&mut self, group_pub_id: groups::PubId, key: SecretKey) { + self.keys.entry(group_pub_id).or_default().push_front(( + KeyHash(blake3::hash(key.as_ref()).to_hex().to_string()), + key, + )); + } + + pub fn add_key_with_hash( + &mut self, + group_pub_id: groups::PubId, + key: SecretKey, + key_hash: KeyHash, + ) { + debug!( + key_hash = key_hash.0, + ?group_pub_id, + "Added single cloud sync key to key manager" + ); + + self.keys + .entry(group_pub_id) + .or_default() + .push_front((key_hash, key)); + } + + pub fn add_many_keys( + &mut self, + group_pub_id: groups::PubId, + keys: impl IntoIterator>, + ) { + let group_entry = self.keys.entry(group_pub_id).or_default(); + + // We reverse the secret keys as a implementation detail to + // keep the keys in the same order as they were added as a stack + for key in keys.into_iter().rev() { + let key_hash = blake3::hash(key.as_ref()).to_hex().to_string(); + + debug!( + key_hash, + ?group_pub_id, + "Added cloud sync key to key manager" + ); + + group_entry.push_front((KeyHash(key_hash), key)); + } + } + + pub fn remove_group(&mut self, group_pub_id: groups::PubId) { + self.keys.remove(&group_pub_id); + } + + pub fn iroh_secret_key(&self) -> IrohSecretKey { + self.iroh_secret_key.clone() + } + + pub fn node_id(&self) -> NodeId { + self.iroh_secret_key.public() + } + + pub fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option { + self.keys.get(&group_pub_id).and_then(|group| { + group + .iter() + .find_map(|(key_hash, key)| (key_hash == hash).then(|| key.clone())) + }) + } + + pub fn get_latest_key(&self, group_pub_id: groups::PubId) -> Option<(KeyHash, SecretKey)> { + self.keys + .get(&group_pub_id) + .and_then(|group| group.front().cloned()) + } + + pub fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec { + self.keys + .get(&group_pub_id) + .map(|group| group.iter().map(|(_key_hash, key)| key.clone()).collect()) + .unwrap_or_default() + } + + pub async fn encrypt( + &self, + key: &SecretKey, + rng: &mut CryptoRng, + keys_file_path: &PathBuf, + ) -> Result<(), Error> { + let plain_text_bytes = + rmp_serde::to_vec_named(self).map_err(Error::KeyStoreSerialization)?; + let mut file = BufWriter::with_capacity( + EncryptedBlock::CIPHER_TEXT_SIZE, + fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&keys_file_path) + .await + .map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to open space keys file to encrypt", + )) + })?, + ); + + if plain_text_bytes.len() < EncryptedBlock::PLAIN_TEXT_SIZE { + use encrypt::OneShotEncryption; + + let EncryptedBlock { nonce, cipher_text } = key + .encrypt(&plain_text_bytes, rng) + .map_err(|e| Error::KeyStoreCrypto { + source: e, + context: "Failed to oneshot encrypt key store", + })?; + + file.write_all(nonce.as_slice()).await.map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to write space keys file oneshot nonce", + )) + })?; + + file.write_all(cipher_text.as_slice()).await.map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to write space keys file oneshot cipher text", + )) + })?; + } else { + use encrypt::StreamEncryption; + + let (nonce, stream) = key.encrypt(plain_text_bytes.as_slice(), rng); + + file.write_all(nonce.as_slice()).await.map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to write space keys file stream nonce", + )) + })?; + + let mut stream = pin!(stream); + while let Some(res) = stream.next().await { + file.write_all(&res.map_err(|e| Error::KeyStoreCrypto { + source: e, + context: "Failed to stream encrypt key store", + })?) + .await + .map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to write space keys file stream cipher text", + )) + })?; + } + }; + + file.flush().await.map_err(|e| { + FileIOError::from((&keys_file_path, e, "Failed to flush space keys file")).into() + }) + } + + pub async fn decrypt( + key: &SecretKey, + metadata: Metadata, + keys_file_path: &PathBuf, + ) -> Result { + let mut file = fs::File::open(&keys_file_path).await.map_err(|e| { + FileIOError::from(( + keys_file_path, + e, + "Failed to open space keys file to decrypt", + )) + })?; + + let usize_file_len = + usize::try_from(metadata.len()).expect("Failed to convert metadata length to usize"); + + let key_store_bytes = + if usize_file_len <= EncryptedBlock::CIPHER_TEXT_SIZE + size_of::() { + use decrypt::OneShotDecryption; + + let mut nonce = OneShotNonce::default(); + + file.read_exact(&mut nonce).await.map_err(|e| { + FileIOError::from(( + keys_file_path, + e, + "Failed to read space keys file oneshot nonce", + )) + })?; + + let mut cipher_text = vec![0u8; usize_file_len - size_of::()]; + + file.read_exact(&mut cipher_text).await.map_err(|e| { + FileIOError::from(( + keys_file_path, + e, + "Failed to read space keys file oneshot cipher text", + )) + })?; + + key.decrypt_owned(&EncryptedBlock { nonce, cipher_text }) + .map_err(|e| Error::KeyStoreCrypto { + source: e, + context: "Failed to oneshot decrypt space keys file", + })? + } else { + use decrypt::StreamDecryption; + + let mut nonce = StreamNonce::default(); + + let mut key_store_bytes = Vec::with_capacity( + (usize_file_len - size_of::()) / EncryptedBlock::CIPHER_TEXT_SIZE + * EncryptedBlock::PLAIN_TEXT_SIZE, + ); + + file.read_exact(&mut nonce).await.map_err(|e| { + FileIOError::from(( + keys_file_path, + e, + "Failed to read space keys file stream nonce", + )) + })?; + + key.decrypt(&nonce, &mut file, &mut key_store_bytes) + .await + .map_err(|e| Error::KeyStoreCrypto { + source: e, + context: "Failed to stream decrypt space keys file", + })?; + + key_store_bytes + }; + + let this = rmp_serde::from_slice::(&key_store_bytes) + .map_err(Error::KeyStoreDeserialization)?; + + #[cfg(debug_assertions)] + { + use std::fmt::Write; + let mut key_hashes_log = String::new(); + + this.keys.iter().for_each(|(group_pub_id, key_stack)| { + writeln!( + key_hashes_log, + "Group: {group_pub_id:?} => KeyHashes: {:?}", + key_stack + .iter() + .map(|(KeyHash(key_hash), _)| key_hash) + .collect::>() + ) + .expect("Failed to write to key hashes log"); + }); + + tracing::info!("Loaded key hashes: {key_hashes_log}"); + } + + Ok(this) + } +} + +/// Zeroize our secret keys and scrambles up iroh's secret key that doesn't implement zeroize +impl Zeroize for KeyStore { + fn zeroize(&mut self) { + self.iroh_secret_key = IrohSecretKey::generate(); + self.keys.values_mut().for_each(|group| { + group + .iter_mut() + .map(|(_key_hash, key)| key) + .for_each(Zeroize::zeroize); + }); + self.keys = BTreeMap::new(); + } +} + +impl Drop for KeyStore { + fn drop(&mut self) { + self.zeroize(); + } +} + +impl ZeroizeOnDrop for KeyStore {} diff --git a/core/crates/cloud-services/src/key_manager/mod.rs b/core/crates/cloud-services/src/key_manager/mod.rs new file mode 100644 index 000000000..64007a190 --- /dev/null +++ b/core/crates/cloud-services/src/key_manager/mod.rs @@ -0,0 +1,183 @@ +use crate::Error; + +use sd_cloud_schema::{ + sync::{groups, KeyHash}, + NodeId, SecretKey as IrohSecretKey, +}; +use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng}; +use sd_utils::error::FileIOError; + +use std::{ + fmt, + path::{Path, PathBuf}, +}; + +use tokio::{fs, sync::RwLock}; + +mod key_store; + +use key_store::KeyStore; + +const KEY_FILE_NAME: &str = "space.keys"; + +pub struct KeyManager { + master_key: SecretKey, + keys_file_path: PathBuf, + store: RwLock, +} + +impl KeyManager { + pub async fn new( + master_key: SecretKey, + iroh_secret_key: IrohSecretKey, + data_directory: impl AsRef + Send, + rng: &mut CryptoRng, + ) -> Result { + async fn inner( + master_key: SecretKey, + iroh_secret_key: IrohSecretKey, + keys_file_path: PathBuf, + rng: &mut CryptoRng, + ) -> Result { + let store = KeyStore::new(iroh_secret_key); + store.encrypt(&master_key, rng, &keys_file_path).await?; + + Ok(KeyManager { + master_key, + keys_file_path, + store: RwLock::new(store), + }) + } + + inner( + master_key, + iroh_secret_key, + data_directory.as_ref().join(KEY_FILE_NAME), + rng, + ) + .await + } + + pub async fn load( + master_key: SecretKey, + data_directory: impl AsRef + Send, + ) -> Result { + async fn inner( + master_key: SecretKey, + keys_file_path: PathBuf, + ) -> Result { + Ok(KeyManager { + store: RwLock::new( + KeyStore::decrypt( + &master_key, + fs::metadata(&keys_file_path).await.map_err(|e| { + FileIOError::from(( + &keys_file_path, + e, + "Failed to read space keys file", + )) + })?, + &keys_file_path, + ) + .await?, + ), + master_key, + keys_file_path, + }) + } + + inner(master_key, data_directory.as_ref().join(KEY_FILE_NAME)).await + } + + pub async fn iroh_secret_key(&self) -> IrohSecretKey { + self.store.read().await.iroh_secret_key() + } + + pub async fn node_id(&self) -> NodeId { + self.store.read().await.node_id() + } + + pub async fn add_key( + &self, + group_pub_id: groups::PubId, + key: SecretKey, + rng: &mut CryptoRng, + ) -> Result<(), Error> { + let mut store = self.store.write().await; + store.add_key(group_pub_id, key); + // Keeping the write lock here, this way we ensure that we can't corrupt the file + store + .encrypt(&self.master_key, rng, &self.keys_file_path) + .await + } + + pub async fn add_key_with_hash( + &self, + group_pub_id: groups::PubId, + key: SecretKey, + key_hash: KeyHash, + rng: &mut CryptoRng, + ) -> Result<(), Error> { + let mut store = self.store.write().await; + store.add_key_with_hash(group_pub_id, key, key_hash); + // Keeping the write lock here, this way we ensure that we can't corrupt the file + store + .encrypt(&self.master_key, rng, &self.keys_file_path) + .await + } + + pub async fn remove_group( + &self, + group_pub_id: groups::PubId, + rng: &mut CryptoRng, + ) -> Result<(), Error> { + let mut store = self.store.write().await; + store.remove_group(group_pub_id); + // Keeping the write lock here, this way we ensure that we can't corrupt the file + store + .encrypt(&self.master_key, rng, &self.keys_file_path) + .await + } + + pub async fn add_many_keys( + &self, + group_pub_id: groups::PubId, + keys: impl IntoIterator< + Item = SecretKey, + IntoIter = impl DoubleEndedIterator + Send, + > + Send, + rng: &mut CryptoRng, + ) -> Result<(), Error> { + let mut store = self.store.write().await; + store.add_many_keys(group_pub_id, keys); + // Keeping the write lock here, this way we ensure that we can't corrupt the file + store + .encrypt(&self.master_key, rng, &self.keys_file_path) + .await + } + + pub async fn get_latest_key( + &self, + group_pub_id: groups::PubId, + ) -> Option<(KeyHash, SecretKey)> { + self.store.read().await.get_latest_key(group_pub_id) + } + + pub async fn get_key(&self, group_pub_id: groups::PubId, hash: &KeyHash) -> Option { + self.store.read().await.get_key(group_pub_id, hash) + } + + pub async fn get_group_keys(&self, group_pub_id: groups::PubId) -> Vec { + self.store.read().await.get_group_keys(group_pub_id) + } +} + +impl fmt::Debug for KeyManager { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("KeyManager") + .field("master_key", &"[REDACTED]") + .field("keys_file_path", &self.keys_file_path) + .field("store", &"[REDACTED]") + .finish() + } +} diff --git a/core/crates/cloud-services/src/lib.rs b/core/crates/cloud-services/src/lib.rs new file mode 100644 index 000000000..615d5397d --- /dev/null +++ b/core/crates/cloud-services/src/lib.rs @@ -0,0 +1,55 @@ +#![recursion_limit = "256"] +#![warn( + clippy::all, + clippy::pedantic, + clippy::correctness, + clippy::perf, + clippy::style, + clippy::suspicious, + clippy::complexity, + clippy::nursery, + clippy::unwrap_used, + unused_qualifications, + rust_2018_idioms, + trivial_casts, + trivial_numeric_casts, + unused_allocation, + clippy::unnecessary_cast, + clippy::cast_lossless, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_precision_loss, + clippy::cast_sign_loss, + clippy::dbg_macro, + clippy::deprecated_cfg_attr, + clippy::separated_literal_suffix, + deprecated +)] +#![forbid(deprecated_in_future)] +#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)] + +mod error; + +mod client; +mod key_manager; +mod p2p; +mod sync; +mod token_refresher; + +pub use client::CloudServices; +pub use error::{Error, GetTokenError}; +pub use key_manager::KeyManager; +pub use p2p::{ + CloudP2P, JoinSyncGroupResponse, JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse, +}; +pub use sync::{ + declare_actors as declare_cloud_sync, SyncActors as CloudSyncActors, + SyncActorsState as CloudSyncActorsState, +}; + +// Re-exports +pub use quic_rpc::transport::quinn::QuinnConnection; + +// Export URL for the auth server +pub const AUTH_SERVER_URL: &str = "https://auth.spacedrive.com"; +// pub const AUTH_SERVER_URL: &str = "http://localhost:9420"; diff --git a/core/crates/cloud-services/src/p2p/mod.rs b/core/crates/cloud-services/src/p2p/mod.rs new file mode 100644 index 000000000..0f31f977c --- /dev/null +++ b/core/crates/cloud-services/src/p2p/mod.rs @@ -0,0 +1,242 @@ +use crate::{sync::ReceiveAndIngestNotifiers, CloudServices, Error}; + +use sd_cloud_schema::{ + cloud_p2p::{authorize_new_device_in_sync_group, CloudP2PALPN, CloudP2PError}, + devices::{self, Device}, + libraries, + sync::groups::{self, GroupWithDevices}, + SecretKey as IrohSecretKey, +}; +use sd_crypto::{CryptoRng, SeedableRng}; + +use std::{sync::Arc, time::Duration}; + +use iroh_net::{ + discovery::{ + dns::DnsDiscovery, local_swarm_discovery::LocalSwarmDiscovery, pkarr::dht::DhtDiscovery, + ConcurrentDiscovery, Discovery, + }, + relay::{RelayMap, RelayMode, RelayUrl}, + Endpoint, NodeId, +}; +use reqwest::Url; +use serde::{Deserialize, Serialize}; +use tokio::{spawn, sync::oneshot, time::sleep}; +use tracing::{debug, error, warn}; + +mod new_sync_messages_notifier; +mod runner; + +use runner::Runner; + +#[derive(Debug)] +pub struct JoinedLibraryCreateArgs { + pub pub_id: libraries::PubId, + pub name: String, + pub description: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)] +#[serde(transparent)] +#[repr(transparent)] +#[specta(rename = "CloudP2PTicket")] +pub struct Ticket(u64); + +#[derive(Debug, Serialize, specta::Type)] +#[serde(tag = "kind", content = "data")] +#[specta(rename = "CloudP2PNotifyUser")] +pub enum NotifyUser { + ReceivedJoinSyncGroupRequest { + ticket: Ticket, + asking_device: Device, + sync_group: GroupWithDevices, + }, + ReceivedJoinSyncGroupResponse { + response: JoinSyncGroupResponse, + sync_group: GroupWithDevices, + }, + SendingJoinSyncGroupResponseError { + error: JoinSyncGroupError, + sync_group: GroupWithDevices, + }, + TimedOutJoinRequest { + device: Device, + succeeded: bool, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, specta::Type)] +pub enum JoinSyncGroupError { + Communication, + InternalServer, + Auth, +} + +#[derive(Debug, Serialize, specta::Type)] +pub enum JoinSyncGroupResponse { + Accepted { authorizor_device: Device }, + Failed(CloudP2PError), + CriticalError, +} + +#[derive(Debug, Clone, Serialize, Deserialize, specta::Type)] +pub struct BasicLibraryCreationArgs { + pub id: libraries::PubId, + pub name: String, + pub description: Option, +} + +#[derive(Debug, Deserialize, specta::Type)] +#[serde(tag = "kind", content = "data")] +#[specta(rename = "CloudP2PUserResponse")] +pub enum UserResponse { + AcceptDeviceInSyncGroup { + ticket: Ticket, + accepted: Option, + }, +} +#[derive(Debug, Clone)] +pub struct CloudP2P { + msgs_tx: flume::Sender, +} + +impl CloudP2P { + pub async fn new( + current_device_pub_id: devices::PubId, + cloud_services: &CloudServices, + mut rng: CryptoRng, + iroh_secret_key: IrohSecretKey, + dns_origin_domain: String, + dns_pkarr_url: Url, + relay_url: RelayUrl, + ) -> Result { + let dht_discovery = DhtDiscovery::builder() + .secret_key(iroh_secret_key.clone()) + .pkarr_relay(dns_pkarr_url) + .build() + .map_err(Error::DhtDiscoveryInit)?; + + let endpoint = Endpoint::builder() + .alpns(vec![CloudP2PALPN::LATEST.to_vec()]) + .discovery(Box::new(ConcurrentDiscovery::from_services(vec![ + Box::new(DnsDiscovery::new(dns_origin_domain)), + Box::new( + LocalSwarmDiscovery::new(iroh_secret_key.public()) + .map_err(Error::LocalSwarmDiscoveryInit)?, + ), + Box::new(dht_discovery.clone()), + ]))) + .secret_key(iroh_secret_key) + .relay_mode(RelayMode::Custom(RelayMap::from_url(relay_url))) + .bind() + .await + .map_err(Error::CreateCloudP2PEndpoint)?; + + spawn({ + let endpoint = endpoint.clone(); + async move { + loop { + let Ok(node_addr) = endpoint.node_addr().await.map_err(|e| { + warn!(?e, "Failed to get direct addresses to force publish on DHT"); + }) else { + sleep(Duration::from_secs(5)).await; + continue; + }; + + debug!("Force publishing peer on DHT"); + return dht_discovery.publish(&node_addr.info); + } + } + }); + + let (msgs_tx, msgs_rx) = flume::bounded(16); + + spawn({ + let runner = Runner::new( + current_device_pub_id, + cloud_services, + msgs_tx.clone(), + endpoint, + ) + .await?; + let user_response_rx = cloud_services.user_response_rx.clone(); + + async move { + // All cloned runners share a single state with internal mutability + while let Err(e) = spawn(runner.clone().run( + msgs_rx.clone(), + user_response_rx.clone(), + CryptoRng::from_seed(rng.generate_fixed()), + )) + .await + { + if e.is_panic() { + error!("Cloud P2P runner panicked"); + } else { + break; + } + } + } + }); + + Ok(Self { msgs_tx }) + } + + /// Requests the device with the given connection ID asking for permission to the current device + /// to join the sync group + /// + /// # Panics + /// Will panic if the actor channel is closed, which should never happen + pub async fn request_join_sync_group( + &self, + devices_in_group: Vec<(devices::PubId, NodeId)>, + req: authorize_new_device_in_sync_group::Request, + tx: oneshot::Sender, + ) { + self.msgs_tx + .send_async(runner::Message::Request(runner::Request::JoinSyncGroup { + req, + devices_in_group, + tx, + })) + .await + .expect("Channel closed"); + } + + /// Register a notifier for the desired sync group, which will notify the receiver actor when + /// new sync messages arrive through cloud p2p notification requests. + /// + /// # Panics + /// Will panic if the actor channel is closed, which should never happen + pub async fn register_sync_messages_receiver_notifier( + &self, + sync_group_pub_id: groups::PubId, + notifier: Arc, + ) { + self.msgs_tx + .send_async(runner::Message::RegisterSyncMessageNotifier(( + sync_group_pub_id, + notifier, + ))) + .await + .expect("Channel closed"); + } + + /// Emit a notification that new sync messages were sent to cloud, so other devices should pull + /// them as soon as possible. + /// + /// # Panics + /// Will panic if the actor channel is closed, which should never happen + pub async fn notify_new_sync_messages(&self, group_pub_id: groups::PubId) { + self.msgs_tx + .send_async(runner::Message::NotifyPeersSyncMessages(group_pub_id)) + .await + .expect("Channel closed"); + } +} + +impl Drop for CloudP2P { + fn drop(&mut self) { + self.msgs_tx.send(runner::Message::Stop).ok(); + } +} diff --git a/core/crates/cloud-services/src/p2p/new_sync_messages_notifier.rs b/core/crates/cloud-services/src/p2p/new_sync_messages_notifier.rs new file mode 100644 index 000000000..f4d0a3751 --- /dev/null +++ b/core/crates/cloud-services/src/p2p/new_sync_messages_notifier.rs @@ -0,0 +1,156 @@ +use crate::{token_refresher::TokenRefresher, Error}; + +use sd_cloud_schema::{ + cloud_p2p::{Client, CloudP2PALPN, Service}, + devices, + sync::groups, +}; + +use std::time::Duration; + +use futures_concurrency::future::Join; +use iroh_net::{Endpoint, NodeId}; +use quic_rpc::{transport::quinn::QuinnConnection, RpcClient}; +use tokio::time::Instant; +use tracing::{debug, error, instrument, warn}; + +use super::runner::Message; + +const CACHED_MAX_DURATION: Duration = Duration::from_secs(60 * 5); + +pub async fn dispatch_notifier( + group_pub_id: groups::PubId, + device_pub_id: devices::PubId, + devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>, + msgs_tx: flume::Sender, + cloud_services: sd_cloud_schema::Client< + QuinnConnection, + sd_cloud_schema::Service, + >, + token_refresher: TokenRefresher, + endpoint: Endpoint, +) { + match notify_peers( + group_pub_id, + device_pub_id, + devices, + cloud_services, + token_refresher, + endpoint, + ) + .await + { + Ok((true, devices)) => { + if msgs_tx + .send_async(Message::UpdateCachedDevices((group_pub_id, devices))) + .await + .is_err() + { + warn!("Failed to send update cached devices message to update cached devices"); + } + } + + Ok((false, _)) => {} + + Err(e) => { + error!(?e, "Failed to notify peers"); + } + } +} + +#[instrument(skip(cloud_services, token_refresher, endpoint))] +async fn notify_peers( + group_pub_id: groups::PubId, + device_pub_id: devices::PubId, + devices: Option<(Instant, Vec<(devices::PubId, NodeId)>)>, + cloud_services: sd_cloud_schema::Client< + QuinnConnection, + sd_cloud_schema::Service, + >, + token_refresher: TokenRefresher, + endpoint: Endpoint, +) -> Result<(bool, Vec<(devices::PubId, NodeId)>), Error> { + let (devices, update_cache) = match devices { + Some((when, devices)) if when.elapsed() < CACHED_MAX_DURATION => (devices, false), + _ => { + debug!("Fetching devices connection ids for group"); + let groups::get::Response(groups::get::ResponseKind::DevicesConnectionIds(devices)) = + cloud_services + .sync() + .groups() + .get(groups::get::Request { + access_token: token_refresher.get_access_token().await?, + pub_id: group_pub_id, + kind: groups::get::RequestKind::DevicesConnectionIds, + }) + .await?? + else { + unreachable!("Only DevicesConnectionIds response is expected, as we requested it"); + }; + + (devices, true) + } + }; + + send_notifications(group_pub_id, device_pub_id, &devices, &endpoint).await; + + Ok((update_cache, devices)) +} + +async fn send_notifications( + group_pub_id: groups::PubId, + device_pub_id: devices::PubId, + devices: &[(devices::PubId, NodeId)], + endpoint: &Endpoint, +) { + devices + .iter() + .filter(|(peer_device_pub_id, _)| *peer_device_pub_id != device_pub_id) + .map(|(peer_device_pub_id, connection_id)| async move { + if let Err(e) = + connect_and_send_notification(group_pub_id, device_pub_id, connection_id, endpoint) + .await + { + // Using just a debug log here because we don't want to spam the logs with + // every single notification failure, as this is more a nice to have feature than a + // critical one + debug!(?e, %peer_device_pub_id, "Failed to send new sync messages notification to peer"); + } else { + debug!(%peer_device_pub_id, "Sent new sync messages notification to peer"); + } + }) + .collect::>() + .join() + .await; +} + +async fn connect_and_send_notification( + group_pub_id: groups::PubId, + device_pub_id: devices::PubId, + connection_id: &NodeId, + endpoint: &Endpoint, +) -> Result<(), Error> { + let client = Client::new(RpcClient::new(QuinnConnection::::from_connection( + endpoint + .connect(*connection_id, CloudP2PALPN::LATEST) + .await + .map_err(Error::ConnectToCloudP2PNode)?, + ))); + + if let Err(e) = client + .notify_new_sync_messages( + sd_cloud_schema::cloud_p2p::notify_new_sync_messages::Request { + sync_group_pub_id: group_pub_id, + device_pub_id, + }, + ) + .await? + { + warn!( + ?e, + "This route shouldn't return an error, it's just a notification", + ); + }; + + Ok(()) +} diff --git a/core/crates/cloud-services/src/p2p/runner.rs b/core/crates/cloud-services/src/p2p/runner.rs new file mode 100644 index 000000000..3dfc33be2 --- /dev/null +++ b/core/crates/cloud-services/src/p2p/runner.rs @@ -0,0 +1,651 @@ +use crate::{ + p2p::JoinSyncGroupError, sync::ReceiveAndIngestNotifiers, token_refresher::TokenRefresher, + CloudServices, Error, KeyManager, +}; + +use sd_cloud_schema::{ + cloud_p2p::{ + self, authorize_new_device_in_sync_group, notify_new_sync_messages, Client, CloudP2PALPN, + CloudP2PError, Service, + }, + devices::{self, Device}, + sync::groups, +}; +use sd_crypto::{CryptoRng, SeedableRng}; + +use std::{ + collections::HashMap, + pin::pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use dashmap::DashMap; +use flume::SendError; +use futures::StreamExt; +use futures_concurrency::stream::Merge; +use iroh_net::{Endpoint, NodeId}; +use quic_rpc::{ + server::{Accepting, RpcChannel, RpcServerError}, + transport::quinn::{QuinnConnection, QuinnServerEndpoint}, + RpcClient, RpcServer, +}; +use tokio::{ + spawn, + sync::{oneshot, Mutex}, + task::JoinHandle, + time::{interval, Instant, MissedTickBehavior}, +}; +use tokio_stream::wrappers::IntervalStream; +use tracing::{debug, error, warn}; + +use super::{ + new_sync_messages_notifier::dispatch_notifier, BasicLibraryCreationArgs, JoinSyncGroupResponse, + JoinedLibraryCreateArgs, NotifyUser, Ticket, UserResponse, +}; + +const TEN_SECONDS: Duration = Duration::from_secs(10); +const FIVE_MINUTES: Duration = Duration::from_secs(60 * 5); + +#[allow(clippy::large_enum_variant)] // Ignoring because the enum Stop variant will only happen a single time ever +pub enum Message { + Request(Request), + RegisterSyncMessageNotifier((groups::PubId, Arc)), + NotifyPeersSyncMessages(groups::PubId), + UpdateCachedDevices((groups::PubId, Vec<(devices::PubId, NodeId)>)), + Stop, +} + +pub enum Request { + JoinSyncGroup { + req: authorize_new_device_in_sync_group::Request, + devices_in_group: Vec<(devices::PubId, NodeId)>, + tx: oneshot::Sender, + }, +} + +/// We use internal mutability here, but don't worry because there will always be a single +/// [`Runner`] running at a time, so the lock is never contended +pub struct Runner { + current_device_pub_id: devices::PubId, + token_refresher: TokenRefresher, + cloud_services: sd_cloud_schema::Client< + QuinnConnection, + sd_cloud_schema::Service, + >, + msgs_tx: flume::Sender, + endpoint: Endpoint, + key_manager: Arc, + ticketer: Arc, + notify_user_tx: flume::Sender, + sync_messages_receiver_notifiers_map: + Arc>>, + pending_sync_group_join_requests: Arc>>, + cached_devices_per_group: HashMap)>, + timeout_checker_buffer: Vec<(Ticket, PendingSyncGroupJoin)>, +} + +impl Clone for Runner { + fn clone(&self) -> Self { + Self { + current_device_pub_id: self.current_device_pub_id, + token_refresher: self.token_refresher.clone(), + cloud_services: self.cloud_services.clone(), + msgs_tx: self.msgs_tx.clone(), + endpoint: self.endpoint.clone(), + key_manager: Arc::clone(&self.key_manager), + ticketer: Arc::clone(&self.ticketer), + notify_user_tx: self.notify_user_tx.clone(), + sync_messages_receiver_notifiers_map: Arc::clone( + &self.sync_messages_receiver_notifiers_map, + ), + pending_sync_group_join_requests: Arc::clone(&self.pending_sync_group_join_requests), + // Just cache the devices and their node_ids per group + cached_devices_per_group: HashMap::new(), + // This one is a temporary buffer only used for timeout checker + timeout_checker_buffer: vec![], + } + } +} + +struct PendingSyncGroupJoin { + channel: RpcChannel>, + request: authorize_new_device_in_sync_group::Request, + this_device: Device, + since: Instant, +} + +impl Runner { + pub async fn new( + current_device_pub_id: devices::PubId, + cloud_services: &CloudServices, + msgs_tx: flume::Sender, + endpoint: Endpoint, + ) -> Result { + Ok(Self { + current_device_pub_id, + token_refresher: cloud_services.token_refresher.clone(), + cloud_services: cloud_services.client().await?, + msgs_tx, + endpoint, + key_manager: cloud_services.key_manager().await?, + ticketer: Arc::default(), + notify_user_tx: cloud_services.notify_user_tx.clone(), + sync_messages_receiver_notifiers_map: Arc::default(), + pending_sync_group_join_requests: Arc::default(), + cached_devices_per_group: HashMap::new(), + timeout_checker_buffer: vec![], + }) + } + + pub async fn run( + mut self, + msgs_rx: flume::Receiver, + user_response_rx: flume::Receiver, + mut rng: CryptoRng, + ) { + // Ignoring because this is only used internally and I think that boxing will be more expensive than wasting + // some extra bytes for smaller variants + #[allow(clippy::large_enum_variant)] + enum StreamMessage { + AcceptResult( + Result< + Accepting>, + RpcServerError>, + >, + ), + Message(Message), + UserResponse(UserResponse), + Tick, + } + + let mut ticker = interval(TEN_SECONDS); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + // FIXME(@fogodev): Update this function to use iroh-net transport instead of quinn + // when it's implemented + let (server, server_handle) = setup_server_endpoint(self.endpoint.clone()); + + let mut msg_stream = pin!(( + async_stream::stream! { + loop { + yield StreamMessage::AcceptResult(server.accept().await); + } + }, + msgs_rx.stream().map(StreamMessage::Message), + user_response_rx.stream().map(StreamMessage::UserResponse), + IntervalStream::new(ticker).map(|_| StreamMessage::Tick), + ) + .merge()); + + while let Some(msg) = msg_stream.next().await { + match msg { + StreamMessage::AcceptResult(Ok(accepting)) => { + let Ok((request, channel)) = accepting.read_first().await.map_err(|e| { + error!(?e, "Failed to read first request from a new connection;"); + }) else { + continue; + }; + + self.handle_request(request, channel).await; + } + + StreamMessage::AcceptResult(Err(e)) => { + // TODO(@fogodev): Maybe report this error to the user on a toast? + error!(?e, "Error accepting connection;"); + } + + StreamMessage::Message(Message::Request(Request::JoinSyncGroup { + req, + devices_in_group, + tx, + })) => self.dispatch_join_requests(req, devices_in_group, &mut rng, tx), + + StreamMessage::Message(Message::RegisterSyncMessageNotifier(( + group_pub_id, + notifier, + ))) => { + self.sync_messages_receiver_notifiers_map + .insert(group_pub_id, notifier); + } + + StreamMessage::Message(Message::NotifyPeersSyncMessages(group_pub_id)) => { + spawn(dispatch_notifier( + group_pub_id, + self.current_device_pub_id, + self.cached_devices_per_group.get(&group_pub_id).cloned(), + self.msgs_tx.clone(), + self.cloud_services.clone(), + self.token_refresher.clone(), + self.endpoint.clone(), + )); + } + + StreamMessage::Message(Message::UpdateCachedDevices(( + group_pub_id, + devices_connections_ids, + ))) => { + self.cached_devices_per_group + .insert(group_pub_id, (Instant::now(), devices_connections_ids)); + } + + StreamMessage::UserResponse(UserResponse::AcceptDeviceInSyncGroup { + ticket, + accepted, + }) => { + self.handle_join_response(ticket, accepted).await; + } + + StreamMessage::Tick => self.tick().await, + + StreamMessage::Message(Message::Stop) => { + server_handle.abort(); + break; + } + } + } + } + + fn dispatch_join_requests( + &self, + req: authorize_new_device_in_sync_group::Request, + devices_in_group: Vec<(devices::PubId, NodeId)>, + rng: &mut CryptoRng, + tx: oneshot::Sender, + ) { + async fn inner( + key_manager: Arc, + endpoint: Endpoint, + mut rng: CryptoRng, + req: authorize_new_device_in_sync_group::Request, + devices_in_group: Vec<(devices::PubId, NodeId)>, + tx: oneshot::Sender, + ) -> Result { + let group_pub_id = req.sync_group.pub_id; + loop { + let client = + match connect_to_first_available_client(&endpoint, &devices_in_group).await { + Ok(client) => client, + Err(e) => { + return Ok(JoinSyncGroupResponse::Failed(e)); + } + }; + + match client + .authorize_new_device_in_sync_group(req.clone()) + .await? + { + Ok(authorize_new_device_in_sync_group::Response { + authorizor_device, + keys, + library_pub_id, + library_name, + library_description, + }) => { + debug!( + device_pub_id = %authorizor_device.pub_id, + %group_pub_id, + keys_count = keys.len(), + %library_pub_id, + library_name, + "Received join sync group response" + ); + + key_manager + .add_many_keys( + group_pub_id, + keys.into_iter().map(|key| { + key.as_slice() + .try_into() + .expect("critical error, backend has invalid secret keys") + }), + &mut rng, + ) + .await?; + + if tx + .send(JoinedLibraryCreateArgs { + pub_id: library_pub_id, + name: library_name, + description: library_description, + }) + .is_err() + { + error!("Failed to handle library creation locally from received library data"); + return Ok(JoinSyncGroupResponse::CriticalError); + } + + return Ok(JoinSyncGroupResponse::Accepted { authorizor_device }); + } + + // In case of timeout, we will try again + Err(CloudP2PError::TimedOut) => continue, + + Err(e) => return Ok(JoinSyncGroupResponse::Failed(e)), + } + } + } + + spawn({ + let endpoint = self.endpoint.clone(); + let notify_user_tx = self.notify_user_tx.clone(); + let key_manager = Arc::clone(&self.key_manager); + let rng = CryptoRng::from_seed(rng.generate_fixed()); + async move { + let sync_group = req.sync_group.clone(); + + if let Err(SendError(response)) = notify_user_tx + .send_async(NotifyUser::ReceivedJoinSyncGroupResponse { + response: inner(key_manager, endpoint, rng, req, devices_in_group, tx) + .await + .unwrap_or_else(|e| { + error!( + ?e, + "Failed to issue authorize new device in sync group request;" + ); + JoinSyncGroupResponse::CriticalError + }), + sync_group, + }) + .await + { + error!(?response, "Failed to send response to user;"); + } + } + }); + } + + async fn handle_request( + &self, + request: cloud_p2p::Request, + channel: RpcChannel>, + ) { + match request { + cloud_p2p::Request::AuthorizeNewDeviceInSyncGroup( + authorize_new_device_in_sync_group::Request { + sync_group, + asking_device, + }, + ) => { + let ticket = Ticket(self.ticketer.fetch_add(1, Ordering::Relaxed)); + let this_device = sync_group + .devices + .iter() + .find(|device| device.pub_id == self.current_device_pub_id) + .expect( + "current device must be in the sync group, otherwise we wouldn't be here", + ) + .clone(); + + self.notify_user_tx + .send_async(NotifyUser::ReceivedJoinSyncGroupRequest { + ticket, + asking_device: asking_device.clone(), + sync_group: sync_group.clone(), + }) + .await + .expect("notify_user_tx must never closes!"); + + self.pending_sync_group_join_requests.lock().await.insert( + ticket, + PendingSyncGroupJoin { + channel, + request: authorize_new_device_in_sync_group::Request { + sync_group, + asking_device, + }, + this_device, + since: Instant::now(), + }, + ); + } + + cloud_p2p::Request::NotifyNewSyncMessages(req) => { + if let Err(e) = channel + .rpc( + req, + (), + |(), + notify_new_sync_messages::Request { + sync_group_pub_id, + device_pub_id, + }| async move { + debug!(%sync_group_pub_id, %device_pub_id, "Received new sync messages notification"); + if let Some(notifier) = self + .sync_messages_receiver_notifiers_map + .get(&sync_group_pub_id) + { + notifier.notify_receiver(); + } else { + warn!("Received new sync messages notification for unknown sync group"); + } + + Ok(notify_new_sync_messages::Response) + }, + ) + .await + { + error!( + ?e, + "Failed to reply to new sync messages notification request" + ); + } + } + } + } + + async fn handle_join_response( + &self, + ticket: Ticket, + accepted: Option, + ) { + let Some(PendingSyncGroupJoin { + channel, + request, + this_device, + .. + }) = self + .pending_sync_group_join_requests + .lock() + .await + .remove(&ticket) + else { + warn!("Received join response for unknown ticket; We probably timed out this request already"); + return; + }; + + let sync_group = request.sync_group.clone(); + let asking_device_pub_id = request.asking_device.pub_id; + + let was_accepted = accepted.is_some(); + + let response = if let Some(BasicLibraryCreationArgs { + id: library_pub_id, + name: library_name, + description: library_description, + }) = accepted + { + Ok(authorize_new_device_in_sync_group::Response { + authorizor_device: this_device, + keys: self + .key_manager + .get_group_keys(request.sync_group.pub_id) + .await + .into_iter() + .map(Into::into) + .collect(), + library_pub_id, + library_name, + library_description, + }) + } else { + Err(CloudP2PError::Rejected) + }; + + if let Err(e) = channel + .rpc(request, (), |(), _req| async move { response }) + .await + { + error!(?e, "Failed to send response to user;"); + self.notify_join_error(sync_group, JoinSyncGroupError::Communication) + .await; + + return; + } + + if was_accepted { + let Ok(access_token) = self + .token_refresher + .get_access_token() + .await + .map_err(|e| error!(?e, "Failed to get access token;")) + else { + self.notify_join_error(sync_group, JoinSyncGroupError::Auth) + .await; + return; + }; + + match self + .cloud_services + .sync() + .groups() + .reply_join_request(groups::reply_join_request::Request { + access_token, + group_pub_id: sync_group.pub_id, + authorized_device_pub_id: asking_device_pub_id, + authorizor_device_pub_id: self.current_device_pub_id, + }) + .await + { + Ok(Ok(groups::reply_join_request::Response)) => { + // Everything is Awesome! + } + Ok(Err(e)) => { + error!(?e, "Failed to reply to join request"); + self.notify_join_error(sync_group, JoinSyncGroupError::InternalServer) + .await; + } + Err(e) => { + error!(?e, "Failed to send reply to join request"); + self.notify_join_error(sync_group, JoinSyncGroupError::Communication) + .await; + } + } + } + } + + async fn notify_join_error( + &self, + sync_group: groups::GroupWithDevices, + error: JoinSyncGroupError, + ) { + self.notify_user_tx + .send_async(NotifyUser::SendingJoinSyncGroupResponseError { error, sync_group }) + .await + .expect("notify_user_tx must never closes!"); + } + + async fn tick(&mut self) { + self.timeout_checker_buffer.clear(); + + let mut pending_sync_group_join_requests = + self.pending_sync_group_join_requests.lock().await; + + for (ticket, pending_sync_group_join) in pending_sync_group_join_requests.drain() { + if pending_sync_group_join.since.elapsed() > FIVE_MINUTES { + let PendingSyncGroupJoin { + channel, request, .. + } = pending_sync_group_join; + + let asking_device = request.asking_device.clone(); + + let notify_message = match channel + .rpc(request, (), |(), _req| async move { + Err(CloudP2PError::TimedOut) + }) + .await + { + Ok(()) => NotifyUser::TimedOutJoinRequest { + device: asking_device, + succeeded: true, + }, + Err(e) => { + error!(?e, "Failed to send timed out response to user;"); + NotifyUser::TimedOutJoinRequest { + device: asking_device, + succeeded: false, + } + } + }; + + self.notify_user_tx + .send_async(notify_message) + .await + .expect("notify_user_tx must never closes!"); + } else { + self.timeout_checker_buffer + .push((ticket, pending_sync_group_join)); + } + } + + pending_sync_group_join_requests.extend(self.timeout_checker_buffer.drain(..)); + } +} + +async fn connect_to_first_available_client( + endpoint: &Endpoint, + devices_in_group: &[(devices::PubId, NodeId)], +) -> Result, Service>, CloudP2PError> { + for (device_pub_id, device_connection_id) in devices_in_group { + if let Ok(connection) = endpoint + .connect(*device_connection_id, CloudP2PALPN::LATEST) + .await + .map_err( + |e| error!(?e, %device_pub_id, "Failed to connect to authorizor device candidate"), + ) { + debug!(%device_pub_id, "Connected to authorizor device candidate"); + return Ok(Client::new(RpcClient::new( + QuinnConnection::::from_connection(connection), + ))); + } + } + + Err(CloudP2PError::UnableToConnect) +} + +fn setup_server_endpoint( + endpoint: Endpoint, +) -> ( + RpcServer>, + JoinHandle<()>, +) { + let local_addr = { + let (ipv4_addr, maybe_ipv6_addr) = endpoint.bound_sockets(); + // Trying to give preference to IPv6 addresses because it's 2024 + maybe_ipv6_addr.unwrap_or(ipv4_addr) + }; + + let (connections_tx, connections_rx) = flume::bounded(16); + + ( + RpcServer::new(QuinnServerEndpoint::::handle_connections( + connections_rx, + local_addr, + )), + spawn(async move { + while let Some(connecting) = endpoint.accept().await { + if let Ok(connection) = connecting.await.map_err(|e| { + warn!(?e, "Cloud P2P failed to accept connection"); + }) { + if connections_tx.send_async(connection).await.is_err() { + warn!("Connection receiver dropped"); + break; + } + } + } + }), + ) +} diff --git a/core/crates/cloud-services/src/sync/ingest.rs b/core/crates/cloud-services/src/sync/ingest.rs new file mode 100644 index 000000000..a7dd65af3 --- /dev/null +++ b/core/crates/cloud-services/src/sync/ingest.rs @@ -0,0 +1,121 @@ +use crate::Error; + +use sd_core_sync::SyncManager; + +use sd_actors::{Actor, Stopper}; + +use std::{ + future::IntoFuture, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use futures::FutureExt; +use futures_concurrency::future::Race; +use tokio::{ + sync::Notify, + time::{sleep, Instant}, +}; +use tracing::{debug, error}; + +use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE}; + +/// Responsible for taking sync operations received from the cloud, +/// and applying them to the local database via the sync system's ingest actor. + +pub struct Ingester { + sync: SyncManager, + notifiers: Arc, + active: Arc, + active_notify: Arc, +} + +impl Actor for Ingester { + const IDENTIFIER: SyncActors = SyncActors::Ingester; + + async fn run(&mut self, stop: Stopper) { + enum Race { + Notified, + Stopped, + } + + loop { + self.active.store(true, Ordering::Relaxed); + self.active_notify.notify_waiters(); + + if let Err(e) = self.run_loop_iteration().await { + error!(?e, "Error during cloud sync ingester actor iteration"); + sleep(ONE_MINUTE).await; + continue; + } + + self.active.store(false, Ordering::Relaxed); + self.active_notify.notify_waiters(); + + if matches!( + ( + self.notifiers + .wait_notification_to_ingest() + .map(|()| Race::Notified), + stop.into_future().map(|()| Race::Stopped), + ) + .race() + .await, + Race::Stopped + ) { + break; + } + } + } +} + +impl Ingester { + pub const fn new( + sync: SyncManager, + notifiers: Arc, + active: Arc, + active_notify: Arc, + ) -> Self { + Self { + sync, + notifiers, + active, + active_notify, + } + } + + async fn run_loop_iteration(&self) -> Result<(), Error> { + let start = Instant::now(); + + let operations_to_ingest_count = self + .sync + .db + .cloud_crdt_operation() + .count(vec![]) + .exec() + .await + .map_err(sd_core_sync::Error::from)?; + + if operations_to_ingest_count == 0 { + debug!("Nothing to ingest, early finishing ingester loop"); + return Ok(()); + } + + debug!( + operations_to_ingest_count, + "Starting sync messages cloud ingestion loop" + ); + + let ingested_count = self.sync.ingest_ops().await?; + + debug!( + ingested_count, + "Finished sync messages cloud ingestion loop in {:?}", + start.elapsed() + ); + + Ok(()) + } +} diff --git a/core/crates/cloud-services/src/sync/mod.rs b/core/crates/cloud-services/src/sync/mod.rs new file mode 100644 index 000000000..b694befb4 --- /dev/null +++ b/core/crates/cloud-services/src/sync/mod.rs @@ -0,0 +1,136 @@ +use crate::{CloudServices, Error}; + +use sd_core_sync::SyncManager; + +use sd_actors::{ActorsCollection, IntoActor}; +use sd_cloud_schema::sync::groups; +use sd_crypto::CryptoRng; + +use std::{ + fmt, + path::Path, + sync::{atomic::AtomicBool, Arc}, + time::Duration, +}; + +use futures_concurrency::future::TryJoin; +use tokio::sync::Notify; + +mod ingest; +mod receive; +mod send; + +use ingest::Ingester; +use receive::Receiver; +use send::Sender; + +const ONE_MINUTE: Duration = Duration::from_secs(60); + +#[derive(Default)] +pub struct SyncActorsState { + pub send_active: Arc, + pub receive_active: Arc, + pub ingest_active: Arc, + pub state_change_notifier: Arc, + receiver_and_ingester_notifiers: Arc, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, specta::Type)] +#[specta(rename = "CloudSyncActors")] +pub enum SyncActors { + Ingester, + Sender, + Receiver, +} + +impl fmt::Display for SyncActors { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Ingester => write!(f, "Cloud Sync Ingester"), + Self::Sender => write!(f, "Cloud Sync Sender"), + Self::Receiver => write!(f, "Cloud Sync Receiver"), + } + } +} + +#[derive(Debug, Default)] +pub struct ReceiveAndIngestNotifiers { + ingester: Notify, + receiver: Notify, +} + +impl ReceiveAndIngestNotifiers { + pub fn notify_receiver(&self) { + self.receiver.notify_one(); + } + + async fn wait_notification_to_receive(&self) { + self.receiver.notified().await; + } + + fn notify_ingester(&self) { + self.ingester.notify_one(); + } + + async fn wait_notification_to_ingest(&self) { + self.ingester.notified().await; + } +} + +pub async fn declare_actors( + data_dir: Box, + cloud_services: Arc, + actors: &ActorsCollection, + actors_state: &SyncActorsState, + sync_group_pub_id: groups::PubId, + sync: SyncManager, + rng: CryptoRng, +) -> Result, Error> { + let (sender, receiver) = ( + Sender::new( + sync_group_pub_id, + sync.clone(), + Arc::clone(&cloud_services), + Arc::clone(&actors_state.send_active), + Arc::clone(&actors_state.state_change_notifier), + rng, + ), + Receiver::new( + data_dir, + sync_group_pub_id, + cloud_services.clone(), + sync.clone(), + Arc::clone(&actors_state.receiver_and_ingester_notifiers), + Arc::clone(&actors_state.receive_active), + Arc::clone(&actors_state.state_change_notifier), + ), + ) + .try_join() + .await?; + + let ingester = Ingester::new( + sync, + Arc::clone(&actors_state.receiver_and_ingester_notifiers), + Arc::clone(&actors_state.ingest_active), + Arc::clone(&actors_state.state_change_notifier), + ); + + actors + .declare_many_boxed([ + sender.into_actor(), + receiver.into_actor(), + ingester.into_actor(), + ]) + .await; + + cloud_services + .cloud_p2p() + .await? + .register_sync_messages_receiver_notifier( + sync_group_pub_id, + Arc::clone(&actors_state.receiver_and_ingester_notifiers), + ) + .await; + + Ok(Arc::clone(&actors_state.receiver_and_ingester_notifiers)) +} diff --git a/core/crates/cloud-services/src/sync/receive.rs b/core/crates/cloud-services/src/sync/receive.rs new file mode 100644 index 000000000..f4db4b4c5 --- /dev/null +++ b/core/crates/cloud-services/src/sync/receive.rs @@ -0,0 +1,356 @@ +use crate::{CloudServices, Error, KeyManager}; + +use sd_cloud_schema::{ + devices, + sync::{ + groups, + messages::{pull, MessagesCollection}, + }, + Client, Service, +}; +use sd_core_sync::{ + cloud_crdt_op_db, CRDTOperation, CompressedCRDTOperationsPerModel, SyncManager, +}; + +use sd_actors::{Actor, Stopper}; +use sd_crypto::{ + cloud::{OneShotDecryption, SecretKey, StreamDecryption}, + primitives::{EncryptedBlock, StreamNonce}, +}; +use sd_prisma::prisma::PrismaClient; + +use std::{ + collections::{hash_map::Entry, HashMap}, + future::IntoFuture, + path::Path, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use chrono::{DateTime, Utc}; +use futures::{FutureExt, StreamExt}; +use futures_concurrency::future::{Race, TryJoin}; +use quic_rpc::transport::quinn::QuinnConnection; +use serde::{Deserialize, Serialize}; +use tokio::{fs, io, sync::Notify, time::sleep}; +use tracing::{debug, error, instrument, warn}; +use uuid::Uuid; + +use super::{ReceiveAndIngestNotifiers, SyncActors, ONE_MINUTE}; + +const CLOUD_SYNC_DATA_KEEPER_DIRECTORY: &str = "cloud_sync_data_keeper"; + +/// Responsible for downloading sync operations from the cloud to be processed by the ingester + +pub struct Receiver { + keeper: LastTimestampKeeper, + sync_group_pub_id: groups::PubId, + device_pub_id: devices::PubId, + cloud_services: Arc, + cloud_client: Client>, + key_manager: Arc, + sync: SyncManager, + notifiers: Arc, + active: Arc, + active_notifier: Arc, +} + +impl Actor for Receiver { + const IDENTIFIER: SyncActors = SyncActors::Receiver; + + async fn run(&mut self, stop: Stopper) { + enum Race { + Continue, + Stop, + } + + loop { + self.active.store(true, Ordering::Relaxed); + self.active_notifier.notify_waiters(); + + let res = self.run_loop_iteration().await; + + self.active.store(false, Ordering::Relaxed); + + if let Err(e) = res { + error!(?e, "Error during cloud sync receiver actor iteration"); + sleep(ONE_MINUTE).await; + continue; + } + + self.active_notifier.notify_waiters(); + + if matches!( + ( + sleep(ONE_MINUTE).map(|()| Race::Continue), + self.notifiers + .wait_notification_to_receive() + .map(|()| Race::Continue), + stop.into_future().map(|()| Race::Stop), + ) + .race() + .await, + Race::Stop + ) { + break; + } + } + } +} + +impl Receiver { + pub async fn new( + data_dir: impl AsRef + Send, + sync_group_pub_id: groups::PubId, + cloud_services: Arc, + sync: SyncManager, + notifiers: Arc, + active: Arc, + active_notify: Arc, + ) -> Result { + let (keeper, cloud_client, key_manager) = ( + LastTimestampKeeper::load(data_dir.as_ref(), sync_group_pub_id), + cloud_services.client(), + cloud_services.key_manager(), + ) + .try_join() + .await?; + + Ok(Self { + keeper, + sync_group_pub_id, + device_pub_id: devices::PubId(Uuid::from(&sync.device_pub_id)), + cloud_services, + cloud_client, + key_manager, + sync, + notifiers, + active, + active_notifier: active_notify, + }) + } + + async fn run_loop_iteration(&mut self) -> Result<(), Error> { + let mut responses_stream = self + .cloud_client + .sync() + .messages() + .pull(pull::Request { + access_token: self + .cloud_services + .token_refresher + .get_access_token() + .await?, + group_pub_id: self.sync_group_pub_id, + current_device_pub_id: self.device_pub_id, + start_time_per_device: self + .keeper + .timestamps + .iter() + .map(|(device_pub_id, timestamp)| (*device_pub_id, *timestamp)) + .collect(), + }) + .await?; + + while let Some(new_messages_res) = responses_stream.next().await { + let pull::Response(new_messages) = new_messages_res??; + if new_messages.is_empty() { + break; + } + + self.handle_new_messages(new_messages).await?; + } + + debug!("Finished sync messages receiver actor iteration"); + + self.keeper.save().await + } + + async fn handle_new_messages( + &mut self, + new_messages: Vec, + ) -> Result<(), Error> { + debug!( + new_messages_collections_count = new_messages.len(), + start_time = ?new_messages.first().map(|c| c.start_time), + end_time = ?new_messages.first().map(|c| c.end_time), + "Handling new sync messages collections", + ); + + for message in new_messages.into_iter().filter(|message| { + if message.original_device_pub_id == self.device_pub_id { + warn!("Received sync message from the current device, need to check backend, this is a bug!"); + false + } else { + true + } + }) { + debug!( + new_messages_count = message.operations_count, + start_time = ?message.start_time, + end_time = ?message.end_time, + "Handling new sync messages", + ); + + let (device_pub_id, timestamp) = handle_single_message( + self.sync_group_pub_id, + message, + &self.key_manager, + &self.sync, + ) + .await?; + + match self.keeper.timestamps.entry(device_pub_id) { + Entry::Occupied(mut entry) => { + if entry.get() < ×tamp { + *entry.get_mut() = timestamp; + } + } + + Entry::Vacant(entry) => { + entry.insert(timestamp); + } + } + + // To ingest after each sync message collection is received, we MUST download and + // store the messages SEQUENTIALLY, otherwise we might ingest messages out of order + // due to parallel downloads + self.notifiers.notify_ingester(); + } + + Ok(()) + } +} + +#[instrument( + skip_all, + fields(%sync_group_pub_id, %original_device_pub_id, operations_count, ?key_hash, %end_time), +)] +async fn handle_single_message( + sync_group_pub_id: groups::PubId, + MessagesCollection { + original_device_pub_id, + end_time, + operations_count, + key_hash, + encrypted_messages, + .. + }: MessagesCollection, + key_manager: &KeyManager, + sync: &SyncManager, +) -> Result<(devices::PubId, DateTime), Error> { + // FIXME(@fogodev): If we don't have the key hash, we need to fetch it from another device in the group if possible + let Some(secret_key) = key_manager.get_key(sync_group_pub_id, &key_hash).await else { + return Err(Error::MissingKeyHash); + }; + + debug!( + size = encrypted_messages.len(), + "Received encrypted sync messages collection" + ); + + let crdt_ops = decrypt_messages(encrypted_messages, secret_key, original_device_pub_id).await?; + + assert_eq!( + crdt_ops.len(), + operations_count as usize, + "Sync messages count mismatch" + ); + + write_cloud_ops_to_db(crdt_ops, &sync.db).await?; + + Ok((original_device_pub_id, end_time)) +} + +#[instrument(skip(encrypted_messages, secret_key), fields(messages_size = %encrypted_messages.len()), err)] +async fn decrypt_messages( + encrypted_messages: Vec, + secret_key: SecretKey, + devices::PubId(device_pub_id): devices::PubId, +) -> Result, Error> { + let plain_text = if encrypted_messages.len() <= EncryptedBlock::CIPHER_TEXT_SIZE { + OneShotDecryption::decrypt(&secret_key, encrypted_messages.as_slice().into()) + .map_err(Error::Decrypt)? + } else { + let (nonce, cipher_text) = encrypted_messages.split_at(size_of::()); + + let mut plain_text = Vec::with_capacity(cipher_text.len()); + + StreamDecryption::decrypt( + &secret_key, + nonce.try_into().expect("we split the correct amount"), + cipher_text, + &mut plain_text, + ) + .await + .map_err(Error::Decrypt)?; + + plain_text + }; + + rmp_serde::from_slice::(&plain_text) + .map(|compressed_ops| compressed_ops.into_ops(device_pub_id)) + .map_err(Error::DeserializationFailureToPullSyncMessages) +} + +#[instrument(skip_all, err)] +pub async fn write_cloud_ops_to_db( + ops: Vec, + db: &PrismaClient, +) -> Result<(), sd_core_sync::Error> { + db._batch( + ops.into_iter() + .map(|op| cloud_crdt_op_db(&op).map(|op| op.to_query(db))) + .collect::, _>>()?, + ) + .await?; + + Ok(()) +} + +#[derive(Serialize, Deserialize, Debug)] +struct LastTimestampKeeper { + timestamps: HashMap>, + file_path: Box, +} + +impl LastTimestampKeeper { + async fn load(data_dir: &Path, sync_group_pub_id: groups::PubId) -> Result { + let cloud_sync_data_directory = data_dir.join(CLOUD_SYNC_DATA_KEEPER_DIRECTORY); + + fs::create_dir_all(&cloud_sync_data_directory) + .await + .map_err(Error::FailedToCreateTimestampKeepersDirectory)?; + + let file_path = cloud_sync_data_directory + .join(format!("{sync_group_pub_id}.bin")) + .into_boxed_path(); + + match fs::read(&file_path).await { + Ok(bytes) => Ok(Self { + timestamps: rmp_serde::from_slice(&bytes) + .map_err(Error::LastTimestampKeeperDeserialization)?, + file_path, + }), + + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(Self { + timestamps: HashMap::new(), + file_path, + }), + + Err(e) => Err(Error::FailedToReadLastTimestampKeeper(e)), + } + } + + async fn save(&self) -> Result<(), Error> { + fs::write( + &self.file_path, + &rmp_serde::to_vec_named(&self.timestamps) + .map_err(Error::LastTimestampKeeperSerialization)?, + ) + .await + .map_err(Error::FailedToWriteLastTimestampKeeper) + } +} diff --git a/core/crates/cloud-services/src/sync/send.rs b/core/crates/cloud-services/src/sync/send.rs new file mode 100644 index 000000000..c0ab06e88 --- /dev/null +++ b/core/crates/cloud-services/src/sync/send.rs @@ -0,0 +1,337 @@ +use crate::{CloudServices, Error, KeyManager}; + +use sd_core_sync::{CompressedCRDTOperationsPerModelPerDevice, SyncEvent, SyncManager, NTP64}; + +use sd_actors::{Actor, Stopper}; +use sd_cloud_schema::{ + devices, + error::{ClientSideError, NotFoundError}, + sync::{groups, messages}, + Client, Service, +}; +use sd_crypto::{ + cloud::{OneShotEncryption, SecretKey, StreamEncryption}, + primitives::EncryptedBlock, + CryptoRng, SeedableRng, +}; +use sd_utils::{datetime_to_timestamp, timestamp_to_datetime}; + +use std::{ + future::IntoFuture, + pin::pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, UNIX_EPOCH}, +}; + +use chrono::{DateTime, Utc}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures_concurrency::future::{Race, TryJoin}; +use quic_rpc::transport::quinn::QuinnConnection; +use tokio::{ + sync::{broadcast, Notify}, + time::sleep, +}; +use tracing::{debug, error}; +use uuid::Uuid; + +use super::{SyncActors, ONE_MINUTE}; + +const TEN_SECONDS: Duration = Duration::from_secs(10); + +const MESSAGES_COLLECTION_SIZE: u32 = 10_000; + +enum RaceNotifiedOrStopped { + Notified, + Stopped, +} + +enum LoopStatus { + SentMessages, + Idle, +} + +type LatestTimestamp = NTP64; + +#[derive(Debug)] +pub struct Sender { + sync_group_pub_id: groups::PubId, + sync: SyncManager, + cloud_services: Arc, + cloud_client: Client>, + key_manager: Arc, + is_active: Arc, + state_notify: Arc, + rng: CryptoRng, + maybe_latest_timestamp: Option, +} + +impl Actor for Sender { + const IDENTIFIER: SyncActors = SyncActors::Sender; + + async fn run(&mut self, stop: Stopper) { + loop { + self.is_active.store(true, Ordering::Relaxed); + self.state_notify.notify_waiters(); + + let res = self.run_loop_iteration().await; + + self.is_active.store(false, Ordering::Relaxed); + + match res { + Ok(LoopStatus::SentMessages) => { + if let Ok(cloud_p2p) = self.cloud_services.cloud_p2p().await.map_err(|e| { + error!(?e, "Failed to get cloud p2p client on sender actor"); + }) { + cloud_p2p + .notify_new_sync_messages(self.sync_group_pub_id) + .await; + } + } + + Ok(LoopStatus::Idle) => {} + + Err(e) => { + error!(?e, "Error during cloud sync sender actor iteration"); + sleep(ONE_MINUTE).await; + continue; + } + } + + self.state_notify.notify_waiters(); + + if matches!( + ( + // recreate subscription each time so that existing messages are dropped + wait_notification(self.sync.subscribe()), + stop.into_future().map(|()| RaceNotifiedOrStopped::Stopped), + ) + .race() + .await, + RaceNotifiedOrStopped::Stopped + ) { + break; + } + + sleep(TEN_SECONDS).await; + } + } +} + +impl Sender { + pub async fn new( + sync_group_pub_id: groups::PubId, + sync: SyncManager, + cloud_services: Arc, + is_active: Arc, + state_notify: Arc, + rng: CryptoRng, + ) -> Result { + let (cloud_client, key_manager) = (cloud_services.client(), cloud_services.key_manager()) + .try_join() + .await?; + + Ok(Self { + sync_group_pub_id, + sync, + cloud_services, + cloud_client, + key_manager, + is_active, + state_notify, + rng, + maybe_latest_timestamp: None, + }) + } + + async fn run_loop_iteration(&mut self) -> Result { + debug!("Starting cloud sender actor loop iteration"); + + let current_device_pub_id = devices::PubId(Uuid::from(&self.sync.device_pub_id)); + + let (key_hash, secret_key) = self + .key_manager + .get_latest_key(self.sync_group_pub_id) + .await + .ok_or(Error::MissingSyncGroupKey(self.sync_group_pub_id))?; + + let current_latest_timestamp = self.get_latest_timestamp(current_device_pub_id).await?; + + let mut crdt_ops_stream = pin!(self.sync.stream_device_ops( + &self.sync.device_pub_id, + MESSAGES_COLLECTION_SIZE, + current_latest_timestamp + )); + + let mut status = LoopStatus::Idle; + + let mut new_latest_timestamp = current_latest_timestamp; + + debug!( + chunk_size = MESSAGES_COLLECTION_SIZE, + "Trying to fetch chunk of sync messages from the database" + ); + while let Some(ops_res) = crdt_ops_stream.next().await { + let ops = ops_res?; + + let (Some(first), Some(last)) = (ops.first(), ops.last()) else { + break; + }; + + debug!("Got first and last sync messages"); + + #[allow(clippy::cast_possible_truncation)] + let operations_count = ops.len() as u32; + + debug!(operations_count, "Got chunk of sync messages"); + + new_latest_timestamp = last.timestamp; + + let start_time = timestamp_to_datetime(first.timestamp); + let end_time = timestamp_to_datetime(last.timestamp); + + // Ignoring this device_pub_id here as we already know it + let (_device_pub_id, compressed_ops) = + CompressedCRDTOperationsPerModelPerDevice::new_single_device(ops); + + let messages_bytes = rmp_serde::to_vec_named(&compressed_ops) + .map_err(Error::SerializationFailureToPushSyncMessages)?; + + let encrypted_messages = + encrypt_messages(&secret_key, &mut self.rng, messages_bytes).await?; + + let encrypted_messages_size = encrypted_messages.len(); + + debug!( + operations_count, + encrypted_messages_size, "Sending sync messages to cloud", + ); + + self.cloud_client + .sync() + .messages() + .push(messages::push::Request { + access_token: self + .cloud_services + .token_refresher + .get_access_token() + .await?, + group_pub_id: self.sync_group_pub_id, + device_pub_id: current_device_pub_id, + key_hash: key_hash.clone(), + operations_count, + time_range: (start_time, end_time), + encrypted_messages, + }) + .await??; + + debug!( + operations_count, + encrypted_messages_size, "Sent sync messages to cloud", + ); + + status = LoopStatus::SentMessages; + } + + self.maybe_latest_timestamp = Some(new_latest_timestamp); + + debug!("Finished cloud sender actor loop iteration"); + + Ok(status) + } + + async fn get_latest_timestamp( + &self, + current_device_pub_id: devices::PubId, + ) -> Result { + if let Some(latest_timestamp) = &self.maybe_latest_timestamp { + Ok(*latest_timestamp) + } else { + let latest_time = match self + .cloud_client + .sync() + .messages() + .get_latest_time(messages::get_latest_time::Request { + access_token: self + .cloud_services + .token_refresher + .get_access_token() + .await?, + group_pub_id: self.sync_group_pub_id, + kind: messages::get_latest_time::Kind::ForCurrentDevice(current_device_pub_id), + }) + .await? + { + Ok(messages::get_latest_time::Response { + latest_time, + latest_device_pub_id, + }) => { + assert_eq!(latest_device_pub_id, current_device_pub_id); + latest_time + } + + Err(sd_cloud_schema::Error::Client(ClientSideError::NotFound( + NotFoundError::LatestSyncMessageTime, + ))) => DateTime::::from(UNIX_EPOCH), + + Err(e) => return Err(e.into()), + }; + + Ok(datetime_to_timestamp(latest_time)) + } + } +} + +async fn encrypt_messages( + secret_key: &SecretKey, + rng: &mut CryptoRng, + messages_bytes: Vec, +) -> Result, Error> { + if messages_bytes.len() <= EncryptedBlock::PLAIN_TEXT_SIZE { + let mut nonce_and_cipher_text = Vec::with_capacity(OneShotEncryption::cipher_text_size( + secret_key, + messages_bytes.len(), + )); + + let EncryptedBlock { nonce, cipher_text } = + OneShotEncryption::encrypt(secret_key, messages_bytes.as_slice(), rng) + .map_err(Error::Encrypt)?; + + nonce_and_cipher_text.extend_from_slice(nonce.as_slice()); + nonce_and_cipher_text.extend(&cipher_text); + + Ok(nonce_and_cipher_text) + } else { + let mut rng = CryptoRng::from_seed(rng.generate_fixed()); + let mut nonce_and_cipher_text = Vec::with_capacity(StreamEncryption::cipher_text_size( + secret_key, + messages_bytes.len(), + )); + + let (nonce, cipher_stream) = + StreamEncryption::encrypt(secret_key, messages_bytes.as_slice(), &mut rng); + + nonce_and_cipher_text.extend_from_slice(nonce.as_slice()); + + let mut cipher_stream = pin!(cipher_stream); + + while let Some(ciphered_chunk) = cipher_stream.try_next().await.map_err(Error::Encrypt)? { + nonce_and_cipher_text.extend(ciphered_chunk); + } + + Ok(nonce_and_cipher_text) + } +} + +async fn wait_notification(mut rx: broadcast::Receiver) -> RaceNotifiedOrStopped { + // wait until Created message comes in + loop { + if matches!(rx.recv().await, Ok(SyncEvent::Created)) { + break; + }; + } + + RaceNotifiedOrStopped::Notified +} diff --git a/core/crates/cloud-services/src/token_refresher.rs b/core/crates/cloud-services/src/token_refresher.rs new file mode 100644 index 000000000..ae11e15db --- /dev/null +++ b/core/crates/cloud-services/src/token_refresher.rs @@ -0,0 +1,468 @@ +use sd_cloud_schema::auth::{AccessToken, RefreshToken}; + +use std::{pin::pin, time::Duration}; + +use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD}; +use chrono::{DateTime, Utc}; +use futures::StreamExt; +use futures_concurrency::stream::Merge; +use reqwest::Url; +use reqwest_middleware::{reqwest::header, ClientWithMiddleware}; +use tokio::{ + spawn, + sync::oneshot, + time::{interval, sleep, MissedTickBehavior}, +}; +use tokio_stream::wrappers::IntervalStream; +use tracing::{error, warn}; + +use super::{Error, GetTokenError}; + +const ONE_MINUTE: Duration = Duration::from_secs(60); +const TEN_SECONDS: Duration = Duration::from_secs(10); + +enum Message { + Init( + ( + AccessToken, + RefreshToken, + oneshot::Sender>, + ), + ), + CheckInitialization(oneshot::Sender>), + RequestToken(oneshot::Sender>), + RefreshTime, + Tick, +} + +#[derive(Debug, Clone)] +pub struct TokenRefresher { + tx: flume::Sender, +} + +impl TokenRefresher { + pub(crate) fn new(http_client: ClientWithMiddleware, auth_server_url: Url) -> Self { + let (tx, rx) = flume::bounded(8); + + spawn(async move { + let refresh_url = auth_server_url + .join("/api/auth/session/refresh") + .expect("hardcoded refresh url path"); + + while let Err(e) = spawn(Runner::run( + http_client.clone(), + refresh_url.clone(), + rx.clone(), + )) + .await + { + if e.is_panic() { + if let Some(msg) = e.into_panic().downcast_ref::<&str>() { + error!(?msg, "Panic in request handler!"); + } else { + error!("Some unknown panic in request handler!"); + } + } + } + }); + + Self { tx } + } + + pub async fn init( + &self, + access_token: AccessToken, + refresh_token: RefreshToken, + ) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + self.tx + .send_async(Message::Init((access_token, refresh_token, tx))) + .await + .expect("Token refresher channel closed"); + + rx.await.expect("Token refresher channel closed") + } + + pub async fn check_initialization(&self) -> Result<(), GetTokenError> { + let (tx, rx) = oneshot::channel(); + self.tx + .send_async(Message::CheckInitialization(tx)) + .await + .expect("Token refresher channel closed"); + + rx.await.expect("Token refresher channel closed") + } + + pub async fn get_access_token(&self) -> Result { + let (tx, rx) = oneshot::channel(); + self.tx + .send_async(Message::RequestToken(tx)) + .await + .expect("Token refresher channel closed"); + + rx.await.expect("Token refresher channel closed") + } +} + +struct Runner { + initialized: bool, + http_client: ClientWithMiddleware, + refresh_url: Url, + current_token: Option, + current_refresh_token: Option, + token_decoding_buffer: Vec, + refresh_tx: flume::Sender, +} + +impl Runner { + async fn run( + http_client: ClientWithMiddleware, + refresh_url: Url, + msgs_rx: flume::Receiver, + ) { + let (refresh_tx, refresh_rx) = flume::bounded(1); + + let mut ticker = interval(TEN_SECONDS); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + let mut msg_stream = pin!(( + msgs_rx.into_stream(), + refresh_rx.into_stream(), + IntervalStream::new(ticker).map(|_| Message::Tick) + ) + .merge()); + + let mut runner = Self { + initialized: false, + http_client, + refresh_url, + current_token: None, + current_refresh_token: None, + token_decoding_buffer: Vec::new(), + refresh_tx, + }; + + while let Some(msg) = msg_stream.next().await { + match msg { + Message::Init((access_token, refresh_token, ack)) => { + if ack + .send(runner.init(access_token, refresh_token).await) + .is_err() + { + error!("Failed to send init token refresher response, receiver dropped;"); + } + } + + Message::CheckInitialization(ack) => runner.check_initialization(ack), + + Message::RequestToken(ack) => runner.reply_token(ack), + + Message::RefreshTime => { + if let Err(e) = runner.refresh().await { + error!(?e, "Failed to refresh token: {e}"); + } + } + + Message::Tick => runner.tick().await, + } + } + } + + async fn init( + &mut self, + access_token: AccessToken, + refresh_token: RefreshToken, + ) -> Result<(), Error> { + let access_token_duration = + Self::extract_access_token_duration(&mut self.token_decoding_buffer, &access_token)?; + + self.initialized = true; + self.current_token = Some(access_token); + self.current_refresh_token = Some(refresh_token); + + // If the token has an expiration smaller than a minute, we need to refresh it immediately. + if access_token_duration < ONE_MINUTE { + self.refresh_tx + .send_async(Message::RefreshTime) + .await + .expect("refresh channel never closes"); + } else { + // This task will be mostly parked waiting a sleep + spawn(Self::schedule_refresh( + self.refresh_tx.clone(), + access_token_duration - ONE_MINUTE, + )); + } + + Ok(()) + } + + fn reply_token(&self, ack: oneshot::Sender>) { + if ack + .send(self.current_token.clone().ok_or({ + if self.initialized { + GetTokenError::FailedToRefresh + } else { + GetTokenError::RefresherNotInitialized + } + })) + .is_err() + { + warn!("Failed to send access token response, receiver dropped;"); + } + } + + async fn refresh(&mut self) -> Result<(), Error> { + let RefreshToken(refresh_token) = self + .current_refresh_token + .clone() + .expect("refresh token is set otherwise we wouldn't be here"); + + let response = self + .http_client + .post(self.refresh_url.clone()) + .header("rid", "session") + .header(header::AUTHORIZATION, format!("Bearer {refresh_token}")) + .send() + .await + .map_err(Error::RefreshTokenRequest)? + .error_for_status() + .map_err(Error::AuthServerError)?; + + if let (Some(access_token), Some(refresh_token)) = ( + response.headers().get("st-access-token"), + response.headers().get("st-refresh-token"), + ) { + // Only set values if we can parse both of them to strings + let (access_token, refresh_token) = ( + Self::token_header_value_to_string(access_token)?, + Self::token_header_value_to_string(refresh_token)?, + ); + + self.current_token = Some(AccessToken(access_token)); + self.current_refresh_token = Some(RefreshToken(refresh_token)); + } else { + return Err(Error::MissingTokensOnRefreshResponse); + } + + Ok(()) + } + + fn extract_access_token_duration( + token_decoding_buffer: &mut Vec, + AccessToken(token): &AccessToken, + ) -> Result { + #[derive(serde::Deserialize)] + struct Token { + #[serde(with = "chrono::serde::ts_seconds")] + exp: DateTime, + } + + token_decoding_buffer.clear(); + + // The format of a JWT token is simple: + // ".." + BASE64_URL_SAFE_NO_PAD.decode_vec( + token.split('.').nth(1).ok_or(Error::MissingClaims)?, + token_decoding_buffer, + )?; + + serde_json::from_slice::(token_decoding_buffer)? + .exp + .signed_duration_since(Utc::now()) + .to_std() + .map_err(|_| Error::TokenExpired) + } + + async fn schedule_refresh(refresh_tx: flume::Sender, wait_time: Duration) { + sleep(wait_time).await; + refresh_tx + .send_async(Message::RefreshTime) + .await + .expect("Refresh channel closed"); + } + + fn token_header_value_to_string(token: &header::HeaderValue) -> Result { + token.to_str().map(str::to_string).map_err(Into::into) + } + + fn check_initialization(&self, ack: oneshot::Sender>) { + if ack + .send(if self.initialized { + Ok(()) + } else { + Err(GetTokenError::RefresherNotInitialized) + }) + .is_err() + { + warn!("Failed to send access token response, receiver dropped;"); + } + } + + /// This method is a safeguard to make sure we try to keep refreshing tokens even if they + /// already expired, as the refresh token has a bigger expiration than the access token. + async fn tick(&mut self) { + if let Some(access_token) = &self.current_token { + if matches!( + Self::extract_access_token_duration(&mut self.token_decoding_buffer, access_token), + Err(Error::TokenExpired) + ) { + if let Err(e) = self.refresh().await { + error!(?e, "Failed to refresh expired token on tick method;"); + } + } + } + } +} + +/// This test is here for documentation purposes only, they are not meant to be run. +/// They're just examples of how to sign-up/sign-in and refresh tokens +#[cfg(test)] +mod tests { + use reqwest::header; + use reqwest_middleware::ClientBuilder; + use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; + use serde_json::json; + + use crate::AUTH_SERVER_URL; + + use super::*; + + async fn get_tokens() -> (AccessToken, RefreshToken) { + let client = reqwest::Client::new(); + + let req_body = json!({ + "formFields": [ + { + "id": "email", + "value": "johndoe@gmail.com" + }, + { + "id": "password", + "value": "testPass123" + } + ] + }); + + let response = client + .post(format!("{AUTH_SERVER_URL}/api/auth/public/signup")) + .header("rid", "emailpassword") + .header("st-auth-mode", "header") + .json(&req_body) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + if let (Some(access_token), Some(refresh_token)) = ( + response.headers().get("st-access-token"), + response.headers().get("st-refresh-token"), + ) { + ( + AccessToken(access_token.to_str().unwrap().to_string()), + RefreshToken(refresh_token.to_str().unwrap().to_string()), + ) + } else { + let response = client + .post(format!("{AUTH_SERVER_URL}/api/auth/public/signin")) + .header("rid", "emailpassword") + .header("st-auth-mode", "header") + .json(&req_body) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + ( + AccessToken( + response + .headers() + .get("st-access-token") + .unwrap() + .to_str() + .unwrap() + .to_string(), + ), + RefreshToken( + response + .headers() + .get("st-refresh-token") + .unwrap() + .to_str() + .unwrap() + .to_string(), + ), + ) + } + } + + #[ignore = "Documentation only"] + #[tokio::test] + async fn test_refresh_token() { + let (AccessToken(access_token), RefreshToken(refresh_token)) = get_tokens().await; + + let client = reqwest::Client::new(); + let response = client + .post(format!("{AUTH_SERVER_URL}/api/auth/session/refresh")) + .header("rid", "session") + .header(header::AUTHORIZATION, format!("Bearer {refresh_token}")) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + + assert_ne!( + response + .headers() + .get("st-access-token") + .unwrap() + .to_str() + .unwrap(), + access_token.as_str() + ); + + assert_ne!( + response + .headers() + .get("st-refresh-token") + .unwrap() + .to_str() + .unwrap(), + refresh_token.as_str() + ); + } + + #[ignore = "Needs an actual SuperTokens auth server running"] + #[tokio::test] + async fn test_refresher_runner() { + let http_client_builder = reqwest::Client::builder().timeout(Duration::from_secs(3)); + + let http_client = ClientBuilder::new(http_client_builder.build().unwrap()) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff::builder().build_with_max_retries(3), + )) + .build(); + + let (refresh_tx, _refresh_rx) = flume::bounded(1); + + let mut runner = Runner { + initialized: false, + http_client, + refresh_url: Url::parse(&format!("{AUTH_SERVER_URL}/api/auth/session/refresh")) + .unwrap(), + current_token: None, + current_refresh_token: None, + token_decoding_buffer: Vec::new(), + refresh_tx, + }; + + let (access_token, refresh_token) = get_tokens().await; + + runner.init(access_token, refresh_token).await.unwrap(); + + runner.refresh().await.unwrap(); + } +} diff --git a/core/crates/file-path-helper/src/isolated_file_path_data.rs b/core/crates/file-path-helper/src/isolated_file_path_data.rs index 3e89cce0f..fe83bbee9 100644 --- a/core/crates/file-path-helper/src/isolated_file_path_data.rs +++ b/core/crates/file-path-helper/src/isolated_file_path_data.rs @@ -2,7 +2,7 @@ use sd_core_prisma_helpers::{ file_path_for_file_identifier, file_path_for_media_processor, file_path_for_object_validator, file_path_to_full_path, file_path_to_handle_custom_uri, file_path_to_handle_p2p_serve_file, file_path_to_isolate, file_path_to_isolate_with_id, file_path_to_isolate_with_pub_id, - file_path_walker, file_path_with_object, + file_path_walker, file_path_watcher_remove, file_path_with_object, }; use sd_prisma::prisma::{file_path, location}; @@ -506,7 +506,8 @@ impl_from_db!( file_path_to_isolate_with_pub_id, file_path_walker, file_path_to_isolate_with_id, - file_path_with_object + file_path_with_object, + file_path_watcher_remove ); impl_from_db_without_location_id!( diff --git a/core/crates/heavy-lifting/src/file_identifier/job.rs b/core/crates/heavy-lifting/src/file_identifier/job.rs index a90c2ea6a..dc2d6866c 100644 --- a/core/crates/heavy-lifting/src/file_identifier/job.rs +++ b/core/crates/heavy-lifting/src/file_identifier/job.rs @@ -14,7 +14,11 @@ use crate::{ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId}; -use sd_prisma::prisma::{file_path, location, SortOrder}; +use sd_prisma::{ + prisma::{device, file_path, location, SortOrder}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, @@ -128,14 +132,14 @@ impl Job for FileIdentifier { match task_kind { TaskKind::Identifier => tasks::Identifier::deserialize( &task_bytes, - (Arc::clone(ctx.db()), Arc::clone(ctx.sync())), + (Arc::clone(ctx.db()), ctx.sync().clone()), ) .await .map(IntoTask::into_task), TaskKind::ObjectProcessor => tasks::ObjectProcessor::deserialize( &task_bytes, - (Arc::clone(ctx.db()), Arc::clone(ctx.sync())), + (Arc::clone(ctx.db()), ctx.sync().clone()), ) .await .map(IntoTask::into_task), @@ -173,8 +177,21 @@ impl Job for FileIdentifier { ) -> Result { let mut pending_running_tasks = FuturesUnordered::new(); + let device_pub_id = &ctx.sync().device_pub_id; + let device_id = ctx + .db() + .device() + .find_unique(device::pub_id::equals(device_pub_id.to_db())) + .exec() + .await + .map_err(file_identifier::Error::from)? + .ok_or(file_identifier::Error::DeviceNotFound( + device_pub_id.clone(), + ))? + .id; + match self - .init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher) + .init_or_resume(&mut pending_running_tasks, &ctx, device_id, &dispatcher) .await { Ok(()) => { /* Everything is awesome! */ } @@ -201,7 +218,7 @@ impl Job for FileIdentifier { match task { Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => { match self - .process_task_output(task_id, out, &ctx, &dispatcher) + .process_task_output(task_id, out, &ctx, device_id, &dispatcher) .await { Ok(tasks) => pending_running_tasks.extend(tasks), @@ -254,15 +271,25 @@ impl Job for FileIdentifier { .. } = self; - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set( - LocationScanState::FilesIdentified as i32, - )], + let (sync_param, db_param) = sync_db_not_null_entry!( + LocationScanState::FilesIdentified as i32, + location::scan_state + ); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(file_identifier::Error::from)?; @@ -302,6 +329,7 @@ impl FileIdentifier { &mut self, pending_running_tasks: &mut FuturesUnordered>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Result<(), JobErrorOrDispatcherError> { // if we don't have any pending task, then this is a fresh job @@ -335,6 +363,7 @@ impl FileIdentifier { .as_ref() .unwrap_or(&location_root_iso_file_path), ctx, + device_id, dispatcher, pending_running_tasks, ) @@ -345,8 +374,9 @@ impl FileIdentifier { self.last_orphan_file_path_id = None; self.dispatch_deep_identifier_tasks( - &maybe_sub_iso_file_path, + maybe_sub_iso_file_path.as_ref(), ctx, + device_id, dispatcher, pending_running_tasks, ) @@ -378,6 +408,7 @@ impl FileIdentifier { .as_ref() .unwrap_or(&location_root_iso_file_path), ctx, + device_id, dispatcher, pending_running_tasks, ) @@ -388,8 +419,9 @@ impl FileIdentifier { self.last_orphan_file_path_id = None; self.dispatch_deep_identifier_tasks( - &maybe_sub_iso_file_path, + maybe_sub_iso_file_path.as_ref(), ctx, + device_id, dispatcher, pending_running_tasks, ) @@ -401,8 +433,9 @@ impl FileIdentifier { Phase::SearchingOrphans => { self.dispatch_deep_identifier_tasks( - &maybe_sub_iso_file_path, + maybe_sub_iso_file_path.as_ref(), ctx, + device_id, dispatcher, pending_running_tasks, ) @@ -447,6 +480,7 @@ impl FileIdentifier { task_id: TaskId, any_task_output: Box, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Result>, DispatcherError> { if any_task_output.is::() { @@ -457,6 +491,7 @@ impl FileIdentifier { .downcast::() .expect("just checked"), ctx, + device_id, dispatcher, ) .await; @@ -501,6 +536,7 @@ impl FileIdentifier { errors, }: identifier::Output, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Result>, DispatcherError> { self.metadata.mean_extract_metadata_time += extract_metadata_time; @@ -548,6 +584,7 @@ impl FileIdentifier { let (tasks_count, res) = match dispatch_object_processor_tasks( self.file_paths_accumulator.drain(), ctx, + device_id, dispatcher, false, ) @@ -636,6 +673,7 @@ impl FileIdentifier { &mut self, sub_iso_file_path: &IsolatedFilePathData<'static>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, pending_running_tasks: &FuturesUnordered>, ) -> Result<(), JobErrorOrDispatcherError> { @@ -702,7 +740,8 @@ impl FileIdentifier { orphan_paths, true, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, )) .await?, ); @@ -713,8 +752,9 @@ impl FileIdentifier { async fn dispatch_deep_identifier_tasks( &mut self, - maybe_sub_iso_file_path: &Option>, + maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'static>>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, pending_running_tasks: &FuturesUnordered>, ) -> Result<(), JobErrorOrDispatcherError> { @@ -785,7 +825,8 @@ impl FileIdentifier { orphan_paths, false, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, )) .await?, ); diff --git a/core/crates/heavy-lifting/src/file_identifier/mod.rs b/core/crates/heavy-lifting/src/file_identifier/mod.rs index 9d7d2833a..f777c118d 100644 --- a/core/crates/heavy-lifting/src/file_identifier/mod.rs +++ b/core/crates/heavy-lifting/src/file_identifier/mod.rs @@ -2,9 +2,10 @@ use crate::{utils::sub_path, OuterContext}; use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData}; use sd_core_prisma_helpers::CasId; +use sd_core_sync::DevicePubId; use sd_file_ext::{extensions::Extension, kind::ObjectKind}; -use sd_prisma::prisma::{file_path, location}; +use sd_prisma::prisma::{device, file_path, location}; use sd_task_system::{TaskDispatcher, TaskHandle}; use sd_utils::{db::MissingFieldError, error::FileIOError}; @@ -41,6 +42,8 @@ const CHUNK_SIZE: usize = 100; #[derive(thiserror::Error, Debug)] pub enum Error { + #[error("device not found: , - maybe_sub_iso_file_path: &Option>, + maybe_sub_iso_file_path: Option<&IsolatedFilePathData<'_>>, ) -> Vec { sd_utils::chain_optional_iter( [ @@ -197,6 +200,7 @@ fn orphan_path_filters_deep( async fn dispatch_object_processor_tasks( file_paths_by_cas_id: Iter, ctx: &impl OuterContext, + device_id: device::id::Type, dispatcher: &Dispatcher, with_priority: bool, ) -> Result>, Dispatcher::DispatchError> @@ -217,7 +221,8 @@ where .dispatch(tasks::ObjectProcessor::new( HashMap::from([(cas_id, objects_to_create_or_link)]), Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, with_priority, )) .await?, @@ -239,7 +244,8 @@ where .dispatch(tasks::ObjectProcessor::new( mem::take(&mut current_batch), Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, with_priority, )) .await?, @@ -256,7 +262,8 @@ where .dispatch(tasks::ObjectProcessor::new( current_batch, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, with_priority, )) .await?, diff --git a/core/crates/heavy-lifting/src/file_identifier/shallow.rs b/core/crates/heavy-lifting/src/file_identifier/shallow.rs index cd165867d..4c00882da 100644 --- a/core/crates/heavy-lifting/src/file_identifier/shallow.rs +++ b/core/crates/heavy-lifting/src/file_identifier/shallow.rs @@ -6,7 +6,7 @@ use crate::{ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::file_path_for_file_identifier; -use sd_prisma::prisma::{file_path, location, SortOrder}; +use sd_prisma::prisma::{device, file_path, location, SortOrder}; use sd_task_system::{ BaseTaskDispatcher, CancelTaskOnDrop, TaskDispatcher, TaskHandle, TaskOutput, TaskStatus, }; @@ -66,6 +66,19 @@ pub async fn shallow( Ok, )?; + let device_pub_id = &ctx.sync().device_pub_id; + let device_id = ctx + .db() + .device() + .find_unique(device::pub_id::equals(device_pub_id.to_db())) + .exec() + .await + .map_err(file_identifier::Error::from)? + .ok_or(file_identifier::Error::DeviceNotFound( + device_pub_id.clone(), + ))? + .id; + let mut orphans_count = 0; let mut last_orphan_file_path_id = None; @@ -103,7 +116,8 @@ pub async fn shallow( orphan_paths, true, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, )) .await else { @@ -119,13 +133,14 @@ pub async fn shallow( return Ok(vec![]); } - process_tasks(identifier_tasks, dispatcher, ctx).await + process_tasks(identifier_tasks, dispatcher, ctx, device_id).await } async fn process_tasks( identifier_tasks: Vec>, dispatcher: &BaseTaskDispatcher, ctx: &impl OuterContext, + device_id: device::id::Type, ) -> Result, Error> { let total_identifier_tasks = identifier_tasks.len(); @@ -169,6 +184,7 @@ async fn process_tasks( let Ok(tasks) = dispatch_object_processor_tasks( file_paths_accumulator.drain(), ctx, + device_id, dispatcher, true, ) diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs index 11fc8a753..125a72713 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/identifier.rs @@ -5,18 +5,18 @@ use crate::{ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::{file_path_for_file_identifier, CasId, FilePathPubId}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_file_ext::kind::ObjectKind; use sd_prisma::{ - prisma::{file_path, location, PrismaClient}, + prisma::{device, file_path, location, PrismaClient}, prisma_sync, }; -use sd_sync::OperationFactory; +use sd_sync::{sync_db_entry, OperationFactory}; use sd_task_system::{ ExecStatus, Interrupter, InterruptionKind, IntoAnyTaskOutput, SerializableTask, Task, TaskId, }; -use sd_utils::{error::FileIOError, msgpack}; +use sd_utils::error::FileIOError; use std::{ collections::HashMap, convert::identity, future::IntoFuture, mem, path::PathBuf, pin::pin, @@ -64,6 +64,7 @@ pub struct Identifier { file_paths_by_id: HashMap, // Inner state + device_id: device::id::Type, identified_files: HashMap, file_paths_without_cas_id: Vec, @@ -72,7 +73,7 @@ pub struct Identifier { // Dependencies db: Arc, - sync: Arc, + sync: SyncManager, } /// Output from the `[Identifier]` task @@ -135,6 +136,7 @@ impl Task for Identifier { let Self { location, location_path, + device_id, file_paths_by_id, file_paths_without_cas_id, identified_files, @@ -255,6 +257,7 @@ impl Task for Identifier { file_paths_without_cas_id.drain(..), &self.db, &self.sync, + *device_id, ), ) .try_join() @@ -301,6 +304,7 @@ impl Task for Identifier { file_paths_without_cas_id.drain(..), &self.db, &self.sync, + *device_id, ) .await?; @@ -324,7 +328,8 @@ impl Identifier { file_paths: Vec, with_priority: bool, db: Arc, - sync: Arc, + sync: SyncManager, + device_id: device::id::Type, ) -> Self { let mut output = Output::default(); @@ -377,6 +382,7 @@ impl Identifier { id: TaskId::new_v4(), location, location_path, + device_id, identified_files: HashMap::with_capacity(file_paths_count - directories_count), file_paths_without_cas_id, file_paths_by_id, @@ -394,33 +400,31 @@ async fn assign_cas_id_to_file_paths( db: &PrismaClient, sync: &SyncManager, ) -> Result<(), file_identifier::Error> { - // Assign cas_id to each file path - sync.write_ops( - db, - identified_files - .iter() - .map(|(pub_id, IdentifiedFile { cas_id, .. })| { - ( - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: pub_id.to_db(), - }, - file_path::cas_id::NAME, - msgpack!(cas_id), - ), - db.file_path() - .update( - file_path::pub_id::equals(pub_id.to_db()), - vec![file_path::cas_id::set(cas_id.into())], - ) - // We don't need any data here, just the id avoids receiving the entire object - // as we can't pass an empty select macro call - .select(file_path::select!({ id })), - ) - }) - .unzip::<_, _, _, Vec<_>>(), - ) - .await?; + let (ops, queries) = identified_files + .iter() + .map(|(pub_id, IdentifiedFile { cas_id, .. })| { + let (sync_param, db_param) = sync_db_entry!(cas_id, file_path::cas_id); + + ( + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: pub_id.to_db(), + }, + [sync_param], + ), + db.file_path() + .update(file_path::pub_id::equals(pub_id.to_db()), vec![db_param]) + // We don't need any data here, just the id avoids receiving the entire object + // as we can't pass an empty select macro call + .select(file_path::select!({ id })), + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + if !ops.is_empty() && !queries.is_empty() { + // Assign cas_id to each file path + sync.write_ops(db, (ops, queries)).await?; + } Ok(()) } @@ -500,6 +504,7 @@ struct SaveState { id: TaskId, location: Arc, location_path: Arc, + device_id: device::id::Type, file_paths_by_id: HashMap, identified_files: HashMap, file_paths_without_cas_id: Vec, @@ -512,13 +517,14 @@ impl SerializableTask for Identifier { type DeserializeError = rmp_serde::decode::Error; - type DeserializeCtx = (Arc, Arc); + type DeserializeCtx = (Arc, SyncManager); async fn serialize(self) -> Result, Self::SerializeError> { let Self { id, location, location_path, + device_id, file_paths_by_id, identified_files, file_paths_without_cas_id, @@ -530,6 +536,7 @@ impl SerializableTask for Identifier { id, location, location_path, + device_id, file_paths_by_id, identified_files, file_paths_without_cas_id, @@ -547,6 +554,7 @@ impl SerializableTask for Identifier { id, location, location_path, + device_id, file_paths_by_id, identified_files, file_paths_without_cas_id, @@ -558,6 +566,7 @@ impl SerializableTask for Identifier { location, location_path, file_paths_by_id, + device_id, identified_files, file_paths_without_cas_id, output, diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs index f74a03b4a..59f75d0a9 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/mod.rs @@ -1,15 +1,15 @@ use crate::file_identifier; use sd_core_prisma_helpers::{file_path_id, FilePathPubId, ObjectPubId}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_file_ext::kind::ObjectKind; use sd_prisma::{ - prisma::{file_path, object, PrismaClient}, + prisma::{device, file_path, object, PrismaClient}, prisma_sync, }; -use sd_sync::{CRDTOperation, OperationFactory}; -use sd_utils::msgpack; +use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, CRDTOperation, OperationFactory}; +use sd_utils::chain_optional_iter; use std::collections::{HashMap, HashSet}; @@ -47,10 +47,12 @@ fn connect_file_path_to_object<'db>( prisma_sync::file_path::SyncId { pub_id: file_path_pub_id.to_db(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: object_pub_id.to_db(), - }), + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: object_pub_id.to_db(), + }, + file_path::object + )], ), db.file_path() .update( @@ -69,6 +71,7 @@ async fn create_objects_and_update_file_paths( files_and_kinds: impl IntoIterator + Send, db: &PrismaClient, sync: &SyncManager, + device_id: device::id::Type, ) -> Result, file_identifier::Error> { trace!("Preparing objects"); let (object_create_args, file_path_args) = files_and_kinds @@ -84,16 +87,23 @@ async fn create_objects_and_update_file_paths( let kind = kind as i32; - let (sync_params, db_params) = [ - ( - (object::date_created::NAME, msgpack!(created_at)), - object::date_created::set(created_at), - ), - ( - (object::kind::NAME, msgpack!(kind)), - object::kind::set(Some(kind)), - ), - ] + let device_pub_id = sync.device_pub_id.to_db(); + + let (sync_params, db_params) = chain_optional_iter( + [ + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id, + }, + object::device + ), + object::device_id::set(Some(device_id)), + ), + sync_db_entry!(kind, object::kind), + ], + [option_sync_db_entry!(created_at, object::date_created)], + ) .into_iter() .unzip::<_, _, Vec<_>, Vec<_>>(); @@ -121,51 +131,57 @@ async fn create_objects_and_update_file_paths( .unzip::<_, _, HashMap<_, _>, Vec<_>>( ); - trace!( - new_objects_count = object_create_args.len(), - "Creating new Objects!;", - ); + let new_objects_count = object_create_args.len(); + if new_objects_count > 0 { + trace!(new_objects_count, "Creating new Objects!;",); - // create new object records with assembled values - let created_objects_count = sync - .write_ops(db, { - let (sync, db_params) = object_create_args - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - - ( - sync.into_iter().flatten().collect(), - db.object().create_many(db_params), - ) - }) - .await?; - - trace!(%created_objects_count, "Created new Objects;"); - - if created_objects_count > 0 { - trace!("Updating file paths with created objects"); - - let updated_file_path_ids = sync - .write_ops( - db, - file_path_update_args + // create new object records with assembled values + let created_objects_count = sync + .write_ops(db, { + let (sync, db_params) = object_create_args .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(), - ) - .await - .map(|file_paths| { - file_paths - .into_iter() - .map(|file_path_id::Data { id }| id) - .collect::>() - })?; + .unzip::<_, _, Vec<_>, Vec<_>>(); - object_pub_id_by_file_path_id - .retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id)); + (sync, db.object().create_many(db_params)) + }) + .await?; - Ok(object_pub_id_by_file_path_id) + trace!(%created_objects_count, "Created new Objects;"); + + if created_objects_count > 0 { + let file_paths_to_update_count = file_path_update_args.len(); + if file_paths_to_update_count > 0 { + trace!( + file_paths_to_update_count, + "Updating file paths with created objects" + ); + + let updated_file_path_ids = sync + .write_ops( + db, + file_path_update_args + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(), + ) + .await + .map(|file_paths| { + file_paths + .into_iter() + .map(|file_path_id::Data { id }| id) + .collect::>() + })?; + + object_pub_id_by_file_path_id + .retain(|file_path_id, _| updated_file_path_ids.contains(file_path_id)); + } + + Ok(object_pub_id_by_file_path_id) + } else { + trace!("No objects created, skipping file path updates"); + Ok(HashMap::new()) + } } else { - trace!("No objects created, skipping file path updates"); + trace!("No objects to create, skipping file path updates"); Ok(HashMap::new()) } } diff --git a/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs b/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs index 9569c1563..a99d89d8d 100644 --- a/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs +++ b/core/crates/heavy-lifting/src/file_identifier/tasks/object_processor.rs @@ -1,9 +1,9 @@ use crate::{file_identifier, Error}; use sd_core_prisma_helpers::{file_path_id, object_for_file_identifier, CasId, ObjectPubId}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; -use sd_prisma::prisma::{file_path, object, PrismaClient}; +use sd_prisma::prisma::{device, file_path, object, PrismaClient}; use sd_task_system::{ check_interruption, ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId, }; @@ -29,13 +29,14 @@ pub struct ObjectProcessor { // Inner state stage: Stage, + device_id: device::id::Type, // Out collector output: Output, // Dependencies db: Arc, - sync: Arc, + sync: SyncManager, } #[derive(Debug, Serialize, Deserialize)] @@ -93,6 +94,7 @@ impl Task for ObjectProcessor { let Self { db, sync, + device_id, file_paths_by_cas_id, stage, output: @@ -167,8 +169,13 @@ impl Task for ObjectProcessor { ); let start = Instant::now(); let (more_file_paths_with_new_object, more_linked_objects_count) = - assign_objects_to_duplicated_orphans(file_paths_by_cas_id, db, sync) - .await?; + assign_objects_to_duplicated_orphans( + file_paths_by_cas_id, + db, + sync, + *device_id, + ) + .await?; *create_object_time = start.elapsed(); file_path_ids_with_new_object.extend(more_file_paths_with_new_object); *linked_objects_count += more_linked_objects_count; @@ -194,7 +201,8 @@ impl ObjectProcessor { pub fn new( file_paths_by_cas_id: HashMap, Vec>, db: Arc, - sync: Arc, + sync: SyncManager, + device_id: device::id::Type, with_priority: bool, ) -> Self { Self { @@ -202,6 +210,7 @@ impl ObjectProcessor { db, sync, file_paths_by_cas_id, + device_id, stage: Stage::Starting, output: Output::default(), with_priority, @@ -270,45 +279,44 @@ async fn assign_existing_objects_to_file_paths( db: &PrismaClient, sync: &SyncManager, ) -> Result, file_identifier::Error> { - sync.write_ops( - db, - objects_by_cas_id - .iter() - .flat_map(|(cas_id, object_pub_id)| { - file_paths_by_cas_id - .remove(cas_id) - .map(|file_paths| { - file_paths.into_iter().map( - |FilePathToCreateOrLinkObject { - file_path_pub_id, .. - }| { - connect_file_path_to_object( - &file_path_pub_id, - object_pub_id, - db, - sync, - ) - }, - ) - }) - .expect("must be here") - }) - .unzip::<_, _, Vec<_>, Vec<_>>(), - ) - .await - .map(|file_paths| { - file_paths - .into_iter() - .map(|file_path_id::Data { id }| id) - .collect() - }) - .map_err(Into::into) + let (ops, queries) = objects_by_cas_id + .iter() + .flat_map(|(cas_id, object_pub_id)| { + file_paths_by_cas_id + .remove(cas_id) + .map(|file_paths| { + file_paths.into_iter().map( + |FilePathToCreateOrLinkObject { + file_path_pub_id, .. + }| { + connect_file_path_to_object(&file_path_pub_id, object_pub_id, db, sync) + }, + ) + }) + .expect("must be here") + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + if ops.is_empty() && queries.is_empty() { + return Ok(vec![]); + } + + sync.write_ops(db, (ops, queries)) + .await + .map(|file_paths| { + file_paths + .into_iter() + .map(|file_path_id::Data { id }| id) + .collect() + }) + .map_err(Into::into) } async fn assign_objects_to_duplicated_orphans( file_paths_by_cas_id: &mut HashMap, Vec>, db: &PrismaClient, sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(Vec, u64), file_identifier::Error> { // at least 1 file path per cas_id let mut selected_file_paths = Vec::with_capacity(file_paths_by_cas_id.len()); @@ -327,7 +335,7 @@ async fn assign_objects_to_duplicated_orphans( }); let (mut file_paths_with_new_object, objects_by_cas_id) = - create_objects_and_update_file_paths(selected_file_paths, db, sync) + create_objects_and_update_file_paths(selected_file_paths, db, sync, device_id) .await? .into_iter() .map(|(file_path_id, object_pub_id)| { @@ -365,6 +373,7 @@ async fn assign_objects_to_duplicated_orphans( pub struct SaveState { id: TaskId, file_paths_by_cas_id: HashMap, Vec>, + device_id: device::id::Type, stage: Stage, output: Output, with_priority: bool, @@ -375,12 +384,13 @@ impl SerializableTask for ObjectProcessor { type DeserializeError = rmp_serde::decode::Error; - type DeserializeCtx = (Arc, Arc); + type DeserializeCtx = (Arc, SyncManager); async fn serialize(self) -> Result, Self::SerializeError> { let Self { id, file_paths_by_cas_id, + device_id, stage, output, with_priority, @@ -390,6 +400,7 @@ impl SerializableTask for ObjectProcessor { rmp_serde::to_vec_named(&SaveState { id, file_paths_by_cas_id, + device_id, stage, output, with_priority, @@ -404,6 +415,7 @@ impl SerializableTask for ObjectProcessor { |SaveState { id, file_paths_by_cas_id, + device_id, stage, output, with_priority, @@ -412,6 +424,7 @@ impl SerializableTask for ObjectProcessor { with_priority, file_paths_by_cas_id, stage, + device_id, output, db, sync, diff --git a/core/crates/heavy-lifting/src/indexer/job.rs b/core/crates/heavy-lifting/src/indexer/job.rs index 22546950e..cf19fbb90 100644 --- a/core/crates/heavy-lifting/src/indexer/job.rs +++ b/core/crates/heavy-lifting/src/indexer/job.rs @@ -16,7 +16,11 @@ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::location_with_indexer_rules; -use sd_prisma::prisma::location; +use sd_prisma::{ + prisma::{device, location}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, @@ -116,13 +120,13 @@ impl Job for Indexer { TaskKind::Save => tasks::Saver::deserialize( &task_bytes, - (Arc::clone(ctx.db()), Arc::clone(ctx.sync())), + (Arc::clone(ctx.db()), ctx.sync().clone()), ) .await .map(IntoTask::into_task), TaskKind::Update => tasks::Updater::deserialize( &task_bytes, - (Arc::clone(ctx.db()), Arc::clone(ctx.sync())), + (Arc::clone(ctx.db()), ctx.sync().clone()), ) .await .map(IntoTask::into_task), @@ -161,6 +165,17 @@ impl Job for Indexer { ) -> Result { let mut pending_running_tasks = FuturesUnordered::new(); + let device_pub_id = &ctx.sync().device_pub_id; + let device_id = ctx + .db() + .device() + .find_unique(device::pub_id::equals(device_pub_id.to_db())) + .exec() + .await + .map_err(indexer::Error::from)? + .ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))? + .id; + match self .init_or_resume(&mut pending_running_tasks, &ctx, &dispatcher) .await @@ -191,21 +206,26 @@ impl Job for Indexer { } if let Some(res) = self - .process_handles(&mut pending_running_tasks, &ctx, &dispatcher) + .process_handles(&mut pending_running_tasks, &ctx, device_id, &dispatcher) .await { return res; } if let Some(res) = self - .dispatch_last_save_and_update_tasks(&mut pending_running_tasks, &ctx, &dispatcher) + .dispatch_last_save_and_update_tasks( + &mut pending_running_tasks, + &ctx, + device_id, + &dispatcher, + ) .await { return res; } if let Some(res) = self - .index_pending_ancestors(&mut pending_running_tasks, &ctx, &dispatcher) + .index_pending_ancestors(&mut pending_running_tasks, &ctx, device_id, &dispatcher) .await { return res; @@ -253,7 +273,7 @@ impl Job for Indexer { .await?; } - update_location_size(location.id, ctx.db(), &ctx).await?; + update_location_size(location.id, location.pub_id.clone(), &ctx).await?; metadata.mean_db_write_time += start_size_update_time.elapsed(); } @@ -271,13 +291,23 @@ impl Job for Indexer { "all tasks must be completed here" ); - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set(LocationScanState::Indexed as i32)], + let (sync_param, db_param) = + sync_db_not_null_entry!(LocationScanState::Indexed as i32, location::scan_state); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(indexer::Error::from)?; @@ -338,6 +368,7 @@ impl Indexer { task_id: TaskId, any_task_output: Box, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Result>, JobErrorOrDispatcherError> { self.metadata.completed_tasks += 1; @@ -349,6 +380,7 @@ impl Indexer { .downcast::>() .expect("just checked"), ctx, + device_id, dispatcher, ) .await; @@ -403,6 +435,7 @@ impl Indexer { .. }: walker::Output, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Result>, JobErrorOrDispatcherError> { self.metadata.mean_scan_read_time += scan_time; @@ -465,7 +498,7 @@ impl Indexer { self.metadata.mean_db_write_time += db_delete_time.elapsed(); } let (save_tasks, update_tasks) = - self.prepare_save_and_update_tasks(to_create, to_update, ctx); + self.prepare_save_and_update_tasks(to_create, to_update, ctx, device_id); ctx.progress(vec![ ProgressUpdate::TaskCount(self.metadata.total_tasks), @@ -552,13 +585,14 @@ impl Indexer { &mut self, pending_running_tasks: &mut FuturesUnordered>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Option> { while let Some(task) = pending_running_tasks.next().await { match task { Ok(TaskStatus::Done((task_id, TaskOutput::Out(out)))) => { match self - .process_task_output(task_id, out, ctx, dispatcher) + .process_task_output(task_id, out, ctx, device_id, dispatcher) .await { Ok(more_handles) => pending_running_tasks.extend(more_handles), @@ -666,6 +700,7 @@ impl Indexer { &mut self, pending_running_tasks: &mut FuturesUnordered>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Option> { if !self.to_create_buffer.is_empty() || !self.to_update_buffer.is_empty() { @@ -687,7 +722,8 @@ impl Indexer { self.location.pub_id.clone(), self.to_create_buffer.drain(..).collect(), Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, ) .into_task(), ); @@ -707,7 +743,7 @@ impl Indexer { tasks::Updater::new_deep( self.to_update_buffer.drain(..).collect(), Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), ) .into_task(), ); @@ -726,7 +762,7 @@ impl Indexer { } } - self.process_handles(pending_running_tasks, ctx, dispatcher) + self.process_handles(pending_running_tasks, ctx, device_id, dispatcher) .await } else { None @@ -737,6 +773,7 @@ impl Indexer { &mut self, pending_running_tasks: &mut FuturesUnordered>, ctx: &impl JobContext, + device_id: device::id::Type, dispatcher: &JobTaskDispatcher, ) -> Option> { if self.ancestors_needing_indexing.is_empty() { @@ -759,7 +796,8 @@ impl Indexer { self.location.pub_id.clone(), chunked_saves, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, ) }) .collect::>(); @@ -776,7 +814,7 @@ impl Indexer { } } - self.process_handles(pending_running_tasks, ctx, dispatcher) + self.process_handles(pending_running_tasks, ctx, device_id, dispatcher) .await } @@ -785,6 +823,7 @@ impl Indexer { to_create: Vec, to_update: Vec, ctx: &impl JobContext, + device_id: device::id::Type, ) -> (Vec, Vec) { if self.processing_first_directory { // If we are processing the first directory, we dispatch shallow tasks with higher priority @@ -806,7 +845,8 @@ impl Indexer { self.location.pub_id.clone(), chunked_saves, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, ) }) .collect::>(); @@ -824,7 +864,7 @@ impl Indexer { tasks::Updater::new_shallow( chunked_updates, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), ) }) .collect::>(); @@ -851,7 +891,8 @@ impl Indexer { self.location.pub_id.clone(), chunked_saves, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), + device_id, )); } save_tasks @@ -878,7 +919,7 @@ impl Indexer { update_tasks.push(tasks::Updater::new_deep( chunked_updates, Arc::clone(ctx.db()), - Arc::clone(ctx.sync()), + ctx.sync().clone(), )); } update_tasks diff --git a/core/crates/heavy-lifting/src/indexer/mod.rs b/core/crates/heavy-lifting/src/indexer/mod.rs index 0fa7ce732..6880e6d91 100644 --- a/core/crates/heavy-lifting/src/indexer/mod.rs +++ b/core/crates/heavy-lifting/src/indexer/mod.rs @@ -4,17 +4,17 @@ use sd_core_file_path_helper::{FilePathError, IsolatedFilePathData}; use sd_core_prisma_helpers::{ file_path_pub_and_cas_ids, file_path_to_isolate_with_pub_id, file_path_walker, }; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::{DevicePubId, SyncManager}; use sd_prisma::{ prisma::{file_path, indexer_rule, location, PrismaClient, SortOrder}, prisma_sync, }; -use sd_sync::OperationFactory; +use sd_sync::{sync_db_entry, OperationFactory}; use sd_utils::{ db::{size_in_bytes_from_db, size_in_bytes_to_db, MissingFieldError}, error::{FileIOError, NonUtf8PathError}, - from_bytes_to_uuid, msgpack, + from_bytes_to_uuid, }; use std::{ @@ -50,6 +50,8 @@ pub enum Error { IndexerRuleNotFound(indexer_rule::id::Type), #[error(transparent)] SubPath(#[from] sub_path::Error), + #[error("device not found: Result<(), Error> { - let to_sync_and_update = db + let (ops, queries) = db ._batch(chunk_db_queries(iso_paths_and_sizes.keys(), db)) .await? .into_iter() @@ -144,22 +146,20 @@ async fn update_directory_sizes( .map(|file_path| { let size_bytes = iso_paths_and_sizes .get(&IsolatedFilePathData::try_from(&file_path)?) - .map(|size| size.to_be_bytes().to_vec()) + .map(|size| size_in_bytes_to_db(*size)) .expect("must be here"); + let (sync_param, db_param) = sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes); + Ok(( sync.shared_update( prisma_sync::file_path::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::size_in_bytes_bytes::NAME, - msgpack!(size_bytes), + [sync_param], ), db.file_path() - .update( - file_path::pub_id::equals(file_path.pub_id), - vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))], - ) + .update(file_path::pub_id::equals(file_path.pub_id), vec![db_param]) .select(file_path::select!({ id })), )) }) @@ -167,42 +167,54 @@ async fn update_directory_sizes( .into_iter() .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops(db, to_sync_and_update).await?; + if !ops.is_empty() && !queries.is_empty() { + sync.write_ops(db, (ops, queries)).await?; + } Ok(()) } async fn update_location_size( location_id: location::id::Type, - db: &PrismaClient, + location_pub_id: location::pub_id::Type, ctx: &impl OuterContext, ) -> Result<(), Error> { - let total_size = db - .file_path() - .find_many(vec![ - file_path::location_id::equals(Some(location_id)), - file_path::materialized_path::equals(Some("/".to_string())), - ]) - .select(file_path::select!({ size_in_bytes_bytes })) - .exec() - .await? - .into_iter() - .filter_map(|file_path| { - file_path - .size_in_bytes_bytes - .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) - }) - .sum::(); + let db = ctx.db(); + let sync = ctx.sync(); - db.location() - .update( - location::id::equals(location_id), - vec![location::size_in_bytes::set(Some( - total_size.to_be_bytes().to_vec(), - ))], - ) - .exec() - .await?; + let total_size = size_in_bytes_to_db( + db.file_path() + .find_many(vec![ + file_path::location_id::equals(Some(location_id)), + file_path::materialized_path::equals(Some("/".to_string())), + ]) + .select(file_path::select!({ size_in_bytes_bytes })) + .exec() + .await? + .into_iter() + .filter_map(|file_path| { + file_path + .size_in_bytes_bytes + .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) + }) + .sum::(), + ); + + let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location_pub_id, + }, + [sync_param], + ), + db.location() + .update(location::id::equals(location_id), vec![db_param]) + .select(location::select!({ id })), + ) + .await?; ctx.invalidate_query("locations.list"); ctx.invalidate_query("locations.get"); @@ -213,7 +225,7 @@ async fn update_location_size( async fn remove_non_existing_file_paths( to_remove: Vec, db: &PrismaClient, - sync: &sd_core_sync::Manager, + sync: &SyncManager, ) -> Result { #[allow(clippy::cast_sign_loss)] let (sync_params, db_params): (Vec<_>, Vec<_>) = to_remove @@ -228,6 +240,10 @@ async fn remove_non_existing_file_paths( }) .unzip(); + if sync_params.is_empty() { + return Ok(0); + } + sync.write_ops( db, ( @@ -318,7 +334,7 @@ pub async fn reverse_update_directories_sizes( ) .await?; - let to_sync_and_update = ancestors + let (sync_ops, update_queries) = ancestors .into_values() .filter_map(|materialized_path| { if let Some((pub_id, size)) = @@ -326,18 +342,19 @@ pub async fn reverse_update_directories_sizes( { let size_bytes = size_in_bytes_to_db(size); + let (sync_param, db_param) = + sync_db_entry!(size_bytes, file_path::size_in_bytes_bytes); + Some(( sync.shared_update( prisma_sync::file_path::SyncId { pub_id: pub_id.clone(), }, - file_path::size_in_bytes_bytes::NAME, - msgpack!(size_bytes), - ), - db.file_path().update( - file_path::pub_id::equals(pub_id), - vec![file_path::size_in_bytes_bytes::set(Some(size_bytes))], + [sync_param], ), + db.file_path() + .update(file_path::pub_id::equals(pub_id), vec![db_param]) + .select(file_path::select!({ id })), )) } else { warn!("Got a missing ancestor for a file_path in the database, ignoring..."); @@ -346,7 +363,9 @@ pub async fn reverse_update_directories_sizes( }) .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops(db, to_sync_and_update).await?; + if !sync_ops.is_empty() && !update_queries.is_empty() { + sync.write_ops(db, (sync_ops, update_queries)).await?; + } Ok(()) } diff --git a/core/crates/heavy-lifting/src/indexer/shallow.rs b/core/crates/heavy-lifting/src/indexer/shallow.rs index c57993840..1bc55b556 100644 --- a/core/crates/heavy-lifting/src/indexer/shallow.rs +++ b/core/crates/heavy-lifting/src/indexer/shallow.rs @@ -4,9 +4,9 @@ use crate::{ use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::location_with_indexer_rules; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; -use sd_prisma::prisma::PrismaClient; +use sd_prisma::prisma::{device, PrismaClient}; use sd_task_system::{BaseTaskDispatcher, CancelTaskOnDrop, IntoTask, TaskDispatcher, TaskOutput}; use sd_utils::db::maybe_missing; @@ -62,6 +62,17 @@ pub async fn shallow( .await?, ); + let device_pub_id = &ctx.sync().device_pub_id; + let device_id = ctx + .db() + .device() + .find_unique(device::pub_id::equals(device_pub_id.to_db())) + .exec() + .await + .map_err(indexer::Error::from)? + .ok_or(indexer::Error::DeviceNotFound(device_pub_id.clone()))? + .id; + let Some(walker::Output { to_create, to_update, @@ -96,7 +107,8 @@ pub async fn shallow( to_create, to_update, Arc::clone(db), - Arc::clone(sync), + sync.clone(), + device_id, dispatcher, ) .await? @@ -124,7 +136,7 @@ pub async fn shallow( .await?; } - update_location_size(location.id, db, ctx).await?; + update_location_size(location.id, location.pub_id, ctx).await?; } if indexed_count > 0 || removed_count > 0 { @@ -203,7 +215,8 @@ async fn save_and_update( to_create: Vec, to_update: Vec, db: Arc, - sync: Arc, + sync: SyncManager, + device_id: device::id::Type, dispatcher: &BaseTaskDispatcher, ) -> Result, Error> { let save_and_update_tasks = to_create @@ -216,7 +229,8 @@ async fn save_and_update( location.pub_id.clone(), chunk.collect::>(), Arc::clone(&db), - Arc::clone(&sync), + sync.clone(), + device_id, ) }) .map(IntoTask::into_task) @@ -229,7 +243,7 @@ async fn save_and_update( tasks::Updater::new_shallow( chunk.collect::>(), Arc::clone(&db), - Arc::clone(&sync), + sync.clone(), ) }) .map(IntoTask::into_task), diff --git a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs index 31fdf8d9a..c5d0951d0 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/saver.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/saver.rs @@ -1,18 +1,15 @@ use crate::{indexer, Error}; use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_prisma::{ - prisma::{file_path, location, PrismaClient}, + prisma::{device, file_path, location, PrismaClient}, prisma_sync, }; -use sd_sync::{sync_db_entry, OperationFactory}; +use sd_sync::{sync_db_entry, sync_entry, OperationFactory}; use sd_task_system::{ExecStatus, Interrupter, IntoAnyTaskOutput, SerializableTask, Task, TaskId}; -use sd_utils::{ - db::{inode_to_db, size_in_bytes_to_db}, - msgpack, -}; +use sd_utils::db::{inode_to_db, size_in_bytes_to_db}; use std::{sync::Arc, time::Duration}; @@ -32,11 +29,12 @@ pub struct Saver { // Received input args location_id: location::id::Type, location_pub_id: location::pub_id::Type, + device_id: device::id::Type, walked_entries: Vec, // Dependencies db: Arc, - sync: Arc, + sync: SyncManager, } /// [`Save`] Task output @@ -73,8 +71,9 @@ impl Task for Saver { #[allow(clippy::blocks_in_conditions)] // Due to `err` on `instrument` macro above async fn run(&mut self, _: &Interrupter) -> Result { use file_path::{ - create_unchecked, date_created, date_indexed, date_modified, extension, hidden, inode, - is_dir, location, location_id, materialized_path, name, size_in_bytes_bytes, + create_unchecked, date_created, date_indexed, date_modified, device, device_id, + extension, hidden, inode, is_dir, location, location_id, materialized_path, name, + size_in_bytes_bytes, }; let start_time = Instant::now(); @@ -82,13 +81,14 @@ impl Task for Saver { let Self { location_id, location_pub_id, + device_id, walked_entries, db, sync, .. } = self; - let (sync_stuff, paths): (Vec<_>, Vec<_>) = walked_entries + let (create_crdt_ops, paths): (Vec<_>, Vec<_>) = walked_entries .drain(..) .map( |WalkedEntry { @@ -118,13 +118,13 @@ impl Task for Saver { new file_paths and they were not identified yet" ); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ + let (sync_params, db_params) = [ ( - ( - location::NAME, - msgpack!(prisma_sync::location::SyncId { + sync_entry!( + prisma_sync::location::SyncId { pub_id: location_pub_id.clone() - }), + }, + location ), location_id::set(Some(*location_id)), ), @@ -138,9 +138,18 @@ impl Task for Saver { sync_db_entry!(modified_at, date_modified), sync_db_entry!(Utc::now(), date_indexed), sync_db_entry!(hidden, hidden), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: sync.device_pub_id.to_db(), + }, + device + ), + device_id::set(Some(*device_id)), + ), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); ( sync.shared_create( @@ -155,12 +164,22 @@ impl Task for Saver { ) .unzip(); + if create_crdt_ops.is_empty() && paths.is_empty() { + return Ok(ExecStatus::Done( + Output { + saved_count: 0, + save_duration: Duration::ZERO, + } + .into_output(), + )); + } + #[allow(clippy::cast_sign_loss)] let saved_count = sync .write_ops( db, ( - sync_stuff.into_iter().flatten().collect(), + create_crdt_ops, db.file_path().create_many(paths).skip_duplicates(), ), ) @@ -188,12 +207,14 @@ impl Saver { location_pub_id: location::pub_id::Type, walked_entries: Vec, db: Arc, - sync: Arc, + sync: SyncManager, + device_id: device::id::Type, ) -> Self { Self { id: TaskId::new_v4(), location_id, location_pub_id, + device_id, walked_entries, db, sync, @@ -207,12 +228,14 @@ impl Saver { location_pub_id: location::pub_id::Type, walked_entries: Vec, db: Arc, - sync: Arc, + sync: SyncManager, + device_id: device::id::Type, ) -> Self { Self { id: TaskId::new_v4(), location_id, location_pub_id, + device_id, walked_entries, db, sync, @@ -228,6 +251,7 @@ struct SaveState { location_id: location::id::Type, location_pub_id: location::pub_id::Type, + device_id: device::id::Type, walked_entries: Vec, } @@ -236,7 +260,7 @@ impl SerializableTask for Saver { type DeserializeError = rmp_serde::decode::Error; - type DeserializeCtx = (Arc, Arc); + type DeserializeCtx = (Arc, SyncManager); async fn serialize(self) -> Result, Self::SerializeError> { let Self { @@ -244,6 +268,7 @@ impl SerializableTask for Saver { is_shallow, location_id, location_pub_id, + device_id, walked_entries, .. } = self; @@ -252,6 +277,7 @@ impl SerializableTask for Saver { is_shallow, location_id, location_pub_id, + device_id, walked_entries, }) } @@ -266,12 +292,14 @@ impl SerializableTask for Saver { is_shallow, location_id, location_pub_id, + device_id, walked_entries, }| Self { id, is_shallow, location_id, location_pub_id, + device_id, walked_entries, db, sync, diff --git a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs index c103397ec..80cf3d6f4 100644 --- a/core/crates/heavy-lifting/src/indexer/tasks/updater.rs +++ b/core/crates/heavy-lifting/src/indexer/tasks/updater.rs @@ -1,7 +1,7 @@ use crate::{indexer, Error}; use sd_core_file_path_helper::{FilePathMetadata, IsolatedFilePathDataParts}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_prisma::{ prisma::{file_path, object, PrismaClient}, @@ -39,7 +39,7 @@ pub struct Updater { // Dependencies db: Arc, - sync: Arc, + sync: SyncManager, } /// [`Update`] Task output @@ -93,7 +93,7 @@ impl Task for Updater { check_interruption!(interrupter); - let (sync_stuff, paths_to_update) = walked_entries + let (crdt_ops, paths_to_update) = walked_entries .drain(..) .map( |WalkedEntry { @@ -138,18 +138,12 @@ impl Task for Updater { .unzip::<_, _, Vec<_>, Vec<_>>(); ( - sync_params - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: pub_id.to_db(), - }, - field, - value, - ) - }) - .collect::>(), + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: pub_id.to_db(), + }, + sync_params, + ), db.file_path() .update(file_path::pub_id::equals(pub_id.into()), db_params) // selecting id to avoid fetching whole object from database @@ -159,11 +153,18 @@ impl Task for Updater { ) .unzip::<_, _, Vec<_>, Vec<_>>(); + if crdt_ops.is_empty() && paths_to_update.is_empty() { + return Ok(ExecStatus::Done( + Output { + updated_count: 0, + update_duration: Duration::ZERO, + } + .into_output(), + )); + } + let updated = sync - .write_ops( - db, - (sync_stuff.into_iter().flatten().collect(), paths_to_update), - ) + .write_ops(db, (crdt_ops, paths_to_update)) .await .map_err(indexer::Error::from)?; @@ -186,7 +187,7 @@ impl Updater { pub fn new_deep( walked_entries: Vec, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Self { Self { id: TaskId::new_v4(), @@ -202,7 +203,7 @@ impl Updater { pub fn new_shallow( walked_entries: Vec, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Self { Self { id: TaskId::new_v4(), @@ -264,7 +265,7 @@ impl SerializableTask for Updater { type DeserializeError = rmp_serde::decode::Error; - type DeserializeCtx = (Arc, Arc); + type DeserializeCtx = (Arc, SyncManager); async fn serialize(self) -> Result, Self::SerializeError> { let Self { diff --git a/core/crates/heavy-lifting/src/job_system/job.rs b/core/crates/heavy-lifting/src/job_system/job.rs index fe8694c72..ec664c327 100644 --- a/core/crates/heavy-lifting/src/job_system/job.rs +++ b/core/crates/heavy-lifting/src/job_system/job.rs @@ -1,6 +1,6 @@ use crate::{Error, NonCriticalError, UpdateEvent}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_prisma::prisma::PrismaClient; use sd_task_system::{ @@ -98,7 +98,7 @@ impl ProgressUpdate { pub trait OuterContext: Send + Sync + Clone + 'static { fn id(&self) -> Uuid; fn db(&self) -> &Arc; - fn sync(&self) -> &Arc; + fn sync(&self) -> &SyncManager; fn invalidate_query(&self, query: &'static str); fn query_invalidator(&self) -> impl Fn(&'static str) + Send + Sync; fn report_update(&self, update: UpdateEvent); @@ -158,7 +158,7 @@ where JobCtx: JobContext, { fn into_job(self) -> Box> { - let id = JobId::new_v4(); + let id = JobId::now_v7(); Box::new(JobHolder { id, @@ -333,7 +333,7 @@ where } pub fn new(job: J) -> Self { - let id = JobId::new_v4(); + let id = JobId::now_v7(); Self { id, job, diff --git a/core/crates/heavy-lifting/src/job_system/report.rs b/core/crates/heavy-lifting/src/job_system/report.rs index 9f87b1eb4..b747b8195 100644 --- a/core/crates/heavy-lifting/src/job_system/report.rs +++ b/core/crates/heavy-lifting/src/job_system/report.rs @@ -290,6 +290,7 @@ impl Report { .map(|id| job::parent::connect(job::id::equals(id.as_bytes().to_vec())))], ), ) + .select(job::select!({ id })) .exec() .await .map_err(ReportError::Create)?; @@ -300,7 +301,7 @@ impl Report { Ok(()) } - pub async fn update(&mut self, db: &PrismaClient) -> Result<(), ReportError> { + pub async fn update(&self, db: &PrismaClient) -> Result<(), ReportError> { db.job() .update( job::id::equals(self.id.as_bytes().to_vec()), @@ -318,6 +319,7 @@ impl Report { job::date_completed::set(self.completed_at.map(Into::into)), ], ) + .select(job::select!({ id })) .exec() .await .map_err(ReportError::Update)?; diff --git a/core/crates/heavy-lifting/src/job_system/runner.rs b/core/crates/heavy-lifting/src/job_system/runner.rs index ae067fb0b..57c237ead 100644 --- a/core/crates/heavy-lifting/src/job_system/runner.rs +++ b/core/crates/heavy-lifting/src/job_system/runner.rs @@ -313,7 +313,7 @@ impl> JobSystemRunner { let name = { let db = handle.ctx.db(); - let mut report = handle.ctx.report_mut().await; + let report = handle.ctx.report().await; if let Err(e) = report.update(db).await { error!(?e, "Failed to update report on job shutdown;"); } diff --git a/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs b/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs index 5e1ea5ce7..854ba314c 100644 --- a/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs +++ b/core/crates/heavy-lifting/src/media_processor/helpers/exif_media_data.rs @@ -1,15 +1,15 @@ use crate::media_processor::{self, media_data_extractor}; use sd_core_prisma_helpers::ObjectPubId; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::{DevicePubId, SyncManager}; use sd_file_ext::extensions::{Extension, ImageExtension, ALL_IMAGE_EXTENSIONS}; use sd_media_metadata::ExifMetadata; use sd_prisma::{ - prisma::{exif_data, object, PrismaClient}, + prisma::{device, exif_data, object, PrismaClient}, prisma_sync, }; -use sd_sync::{option_sync_db_entry, OperationFactory}; +use sd_sync::{option_sync_db_entry, sync_entry, OperationFactory}; use sd_utils::chain_optional_iter; use std::{path::Path, sync::LazyLock}; @@ -51,9 +51,20 @@ fn to_query( exif_version, }: ExifMetadata, object_id: exif_data::object_id::Type, + device_pub_id: &DevicePubId, ) -> (Vec<(&'static str, rmpv::Value)>, exif_data::Create) { + let device_pub_id = device_pub_id.to_db(); + let (sync_params, db_params) = chain_optional_iter( - [], + [( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + exif_data::device + ), + exif_data::device::connect(device::pub_id::equals(device_pub_id)), + )], [ option_sync_db_entry!( serde_json::to_vec(&camera_data).ok(), @@ -109,24 +120,22 @@ pub async fn save( exif_datas .into_iter() .map(|(exif_data, object_id, object_pub_id)| async move { - let (sync_params, create) = to_query(exif_data, object_id); + let (sync_params, create) = to_query(exif_data, object_id, &sync.device_pub_id); let db_params = create._params.clone(); - sync.write_ops( + sync.write_op( db, - ( - sync.shared_create( - prisma_sync::exif_data::SyncId { - object: prisma_sync::object::SyncId { - pub_id: object_pub_id.into(), - }, + sync.shared_create( + prisma_sync::exif_data::SyncId { + object: prisma_sync::object::SyncId { + pub_id: object_pub_id.into(), }, - sync_params, - ), - db.exif_data() - .upsert(exif_data::object_id::equals(object_id), create, db_params) - .select(exif_data::select!({ id })), + }, + sync_params, ), + db.exif_data() + .upsert(exif_data::object_id::equals(object_id), create, db_params) + .select(exif_data::select!({ id })), ) .await }) diff --git a/core/crates/heavy-lifting/src/media_processor/job.rs b/core/crates/heavy-lifting/src/media_processor/job.rs index bab8e506c..fb622e162 100644 --- a/core/crates/heavy-lifting/src/media_processor/job.rs +++ b/core/crates/heavy-lifting/src/media_processor/job.rs @@ -14,7 +14,11 @@ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::file_path_for_media_processor; use sd_file_ext::extensions::Extension; -use sd_prisma::prisma::{location, PrismaClient}; +use sd_prisma::{ + prisma::{location, PrismaClient}, + prisma_sync, +}; +use sd_sync::{sync_db_not_null_entry, OperationFactory}; use sd_task_system::{ AnyTaskOutput, IntoTask, SerializableTask, Task, TaskDispatcher, TaskHandle, TaskId, TaskOutput, TaskStatus, TaskSystemError, @@ -125,7 +129,7 @@ impl Job for MediaProcessor { TaskKind::MediaDataExtractor => { tasks::MediaDataExtractor::deserialize( &task_bytes, - (Arc::clone(ctx.db()), Arc::clone(ctx.sync())), + (Arc::clone(ctx.db()), ctx.sync().clone()), ) .await .map(IntoTask::into_task) @@ -214,15 +218,23 @@ impl Job for MediaProcessor { .. } = self; - ctx.db() - .location() - .update( - location::id::equals(location.id), - vec![location::scan_state::set( - LocationScanState::Completed as i32, - )], + let (sync_param, db_param) = + sync_db_not_null_entry!(LocationScanState::Completed as i32, location::scan_state); + + ctx.sync() + .write_op( + ctx.db(), + ctx.sync().shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + [sync_param], + ), + ctx.db() + .location() + .update(location::id::equals(location.id), vec![db_param]) + .select(location::select!({ id })), ) - .exec() .await .map_err(media_processor::Error::from)?; @@ -632,7 +644,7 @@ impl MediaProcessor { parent_iso_file_path.location_id(), Arc::clone(&self.location_path), Arc::clone(db), - Arc::clone(sync), + sync.clone(), ) }) .map(IntoTask::into_task) @@ -648,7 +660,7 @@ impl MediaProcessor { parent_iso_file_path.location_id(), Arc::clone(&self.location_path), Arc::clone(db), - Arc::clone(sync), + sync.clone(), ) }) .map(IntoTask::into_task), diff --git a/core/crates/heavy-lifting/src/media_processor/shallow.rs b/core/crates/heavy-lifting/src/media_processor/shallow.rs index b74c8c063..675dcd791 100644 --- a/core/crates/heavy-lifting/src/media_processor/shallow.rs +++ b/core/crates/heavy-lifting/src/media_processor/shallow.rs @@ -4,7 +4,7 @@ use crate::{ }; use sd_core_file_path_helper::IsolatedFilePathData; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_prisma::prisma::{location, PrismaClient}; use sd_task_system::{ @@ -154,7 +154,7 @@ pub async fn shallow( async fn dispatch_media_data_extractor_tasks( db: &Arc, - sync: &Arc, + sync: &SyncManager, parent_iso_file_path: &IsolatedFilePathData<'_>, location_path: &Arc, dispatcher: &BaseTaskDispatcher, @@ -185,7 +185,7 @@ async fn dispatch_media_data_extractor_tasks( parent_iso_file_path.location_id(), Arc::clone(location_path), Arc::clone(db), - Arc::clone(sync), + sync.clone(), ) }) .map(IntoTask::into_task) @@ -201,7 +201,7 @@ async fn dispatch_media_data_extractor_tasks( parent_iso_file_path.location_id(), Arc::clone(location_path), Arc::clone(db), - Arc::clone(sync), + sync.clone(), ) }) .map(IntoTask::into_task), @@ -220,7 +220,7 @@ async fn dispatch_media_data_extractor_tasks( async fn dispatch_thumbnailer_tasks( parent_iso_file_path: &IsolatedFilePathData<'_>, should_regenerate: bool, - location_path: &PathBuf, + location_path: &Path, dispatcher: &BaseTaskDispatcher, ctx: &impl OuterContext, ) -> Result>, Error> { diff --git a/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs b/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs index 30072b1c1..cd1c962da 100644 --- a/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs +++ b/core/crates/heavy-lifting/src/media_processor/tasks/media_data_extractor.rs @@ -8,7 +8,7 @@ use crate::{ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::{file_path_for_media_processor, ObjectPubId}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_media_metadata::{ExifMetadata, FFmpegMetadata}; use sd_prisma::prisma::{exif_data, ffmpeg_data, file_path, location, object, PrismaClient}; @@ -69,7 +69,7 @@ pub struct MediaDataExtractor { // Dependencies db: Arc, - sync: Arc, + sync: SyncManager, } #[derive(Debug, Serialize, Deserialize)] @@ -275,7 +275,7 @@ impl MediaDataExtractor { location_id: location::id::Type, location_path: Arc, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Self { let mut output = Output::default(); @@ -316,7 +316,7 @@ impl MediaDataExtractor { location_id: location::id::Type, location_path: Arc, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Self { Self::new(Kind::Exif, file_paths, location_id, location_path, db, sync) } @@ -327,7 +327,7 @@ impl MediaDataExtractor { location_id: location::id::Type, location_path: Arc, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Self { Self::new( Kind::FFmpeg, @@ -550,7 +550,7 @@ impl SerializableTask for MediaDataExtractor { type DeserializeError = rmp_serde::decode::Error; - type DeserializeCtx = (Arc, Arc); + type DeserializeCtx = (Arc, SyncManager); async fn serialize(self) -> Result, Self::SerializeError> { let Self { diff --git a/core/crates/indexer-rules/src/serde_impl.rs b/core/crates/indexer-rules/src/serde_impl.rs index a0b24dd23..461630669 100644 --- a/core/crates/indexer-rules/src/serde_impl.rs +++ b/core/crates/indexer-rules/src/serde_impl.rs @@ -60,7 +60,7 @@ impl<'de> Deserialize<'de> for RulePerKind { struct FieldsVisitor; - impl<'de> de::Visitor<'de> for FieldsVisitor { + impl de::Visitor<'_> for FieldsVisitor { type Value = Fields; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/core/crates/prisma-helpers/Cargo.toml b/core/crates/prisma-helpers/Cargo.toml index 6a3a47a4c..8a4b490ea 100644 --- a/core/crates/prisma-helpers/Cargo.toml +++ b/core/crates/prisma-helpers/Cargo.toml @@ -9,8 +9,9 @@ repository.workspace = true [dependencies] # Spacedrive Sub-crates -sd-prisma = { path = "../../../crates/prisma" } -sd-utils = { path = "../../../crates/utils" } +sd-cloud-schema = { workspace = true } +sd-prisma = { path = "../../../crates/prisma" } +sd-utils = { path = "../../../crates/utils" } # Workspace dependencies prisma-client-rust = { workspace = true } diff --git a/core/crates/prisma-helpers/src/lib.rs b/core/crates/prisma-helpers/src/lib.rs index 2d5abddd9..311d81947 100644 --- a/core/crates/prisma-helpers/src/lib.rs +++ b/core/crates/prisma-helpers/src/lib.rs @@ -34,7 +34,6 @@ use sd_utils::{from_bytes_to_uuid, uuid_to_bytes}; use std::{borrow::Cow, fmt}; use serde::{Deserialize, Serialize}; -use specta::Type; use uuid::Uuid; // File Path selectables! @@ -75,6 +74,20 @@ file_path::select!(file_path_for_media_processor { pub_id } }); +file_path::select!(file_path_watcher_remove { + id + pub_id + location_id + materialized_path + is_dir + name + extension + object: select { + id + pub_id + } + +}); file_path::select!(file_path_to_isolate { location_id materialized_path @@ -244,7 +257,7 @@ job::select!(job_without_data { location::select!(location_ids_and_path { id pub_id - instance_id + device: select { pub_id } path }); @@ -259,6 +272,7 @@ impl From for location::Data { id: data.id, pub_id: data.pub_id, path: data.path, + device_id: data.device_id, instance_id: data.instance_id, name: data.name, total_capacity: data.total_capacity, @@ -272,6 +286,7 @@ impl From for location::Data { scan_state: data.scan_state, file_paths: None, indexer_rules: None, + device: None, instance: None, } } @@ -283,6 +298,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data { id: data.id, pub_id: data.pub_id.clone(), path: data.path.clone(), + device_id: data.device_id, instance_id: data.instance_id, name: data.name.clone(), total_capacity: data.total_capacity, @@ -296,6 +312,7 @@ impl From<&location_with_indexer_rules::Data> for location::Data { scan_state: data.scan_state, file_paths: None, indexer_rules: None, + device: None, instance: None, } } @@ -311,7 +328,7 @@ label::include!((take: i64) => label_with_objects { } }); -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Type)] +#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, specta::Type)] #[serde(transparent)] pub struct CasId<'cas_id>(Cow<'cas_id, str>); @@ -321,7 +338,7 @@ impl Clone for CasId<'_> { } } -impl<'cas_id> CasId<'cas_id> { +impl CasId<'_> { #[must_use] pub fn as_str(&self) -> &str { self.0.as_ref() @@ -374,17 +391,32 @@ impl From<&CasId<'_>> for String { } } -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)] #[serde(transparent)] #[repr(transparent)] +#[specta(rename = "CoreDevicePubId")] +pub struct DevicePubId(PubId); + +impl From for sd_cloud_schema::devices::PubId { + fn from(DevicePubId(pub_id): DevicePubId) -> Self { + Self(pub_id.into()) + } +} + +#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)] +#[serde(transparent)] +#[repr(transparent)] +#[specta(rename = "CoreFilePathPubId")] pub struct FilePathPubId(PubId); -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)] #[serde(transparent)] #[repr(transparent)] +#[specta(rename = "CoreObjectPubId")] pub struct ObjectPubId(PubId); -#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, specta::Type)] +#[specta(rename = "CorePubId")] enum PubId { Uuid(Uuid), Vec(Vec), @@ -392,7 +424,7 @@ enum PubId { impl PubId { fn new() -> Self { - Self::Uuid(Uuid::new_v4()) + Self::Uuid(Uuid::now_v7()) } fn to_db(&self) -> Vec { @@ -451,6 +483,15 @@ impl From for Uuid { } } +impl From<&PubId> for Uuid { + fn from(pub_id: &PubId) -> Self { + match pub_id { + PubId::Uuid(uuid) => *uuid, + PubId::Vec(bytes) => from_bytes_to_uuid(bytes), + } + } +} + impl fmt::Display for PubId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -499,6 +540,12 @@ macro_rules! delegate_pub_id { } } + impl From<&$type_name> for ::uuid::Uuid { + fn from(pub_id: &$type_name) -> Self { + (&pub_id.0).into() + } + } + impl ::std::fmt::Display for $type_name { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { write!(f, "{}", self.0) @@ -526,4 +573,4 @@ macro_rules! delegate_pub_id { }; } -delegate_pub_id!(FilePathPubId, ObjectPubId); +delegate_pub_id!(FilePathPubId, ObjectPubId, DevicePubId); diff --git a/core/crates/sync/Cargo.toml b/core/crates/sync/Cargo.toml index 930c3cdd0..d2a7dfc1e 100644 --- a/core/crates/sync/Cargo.toml +++ b/core/crates/sync/Cargo.toml @@ -9,6 +9,8 @@ default = [] [dependencies] # Spacedrive Sub-crates +sd-core-prisma-helpers = { path = "../prisma-helpers" } + sd-actors = { path = "../../../crates/actors" } sd-prisma = { path = "../../../crates/prisma" } sd-sync = { path = "../../../crates/sync" } @@ -16,8 +18,11 @@ sd-utils = { path = "../../../crates/utils" } # Workspace dependencies async-channel = { workspace = true } +async-stream = { workspace = true } +chrono = { workspace = true } futures = { workspace = true } futures-concurrency = { workspace = true } +itertools = { workspace = true } prisma-client-rust = { workspace = true, features = ["rspc"] } rmp-serde = { workspace = true } rmpv = { workspace = true } diff --git a/core/crates/sync/src/actor.rs b/core/crates/sync/src/actor.rs deleted file mode 100644 index 27f8c7a9b..000000000 --- a/core/crates/sync/src/actor.rs +++ /dev/null @@ -1,39 +0,0 @@ -use async_channel as chan; - -pub trait ActorTypes { - type Event: Send; - type Request: Send; - type Handler; -} - -pub struct ActorIO { - pub event_rx: chan::Receiver, - pub req_tx: chan::Sender, -} - -impl Clone for ActorIO { - fn clone(&self) -> Self { - Self { - event_rx: self.event_rx.clone(), - req_tx: self.req_tx.clone(), - } - } -} - -impl ActorIO { - pub async fn send(&self, value: T::Request) -> Result<(), chan::SendError> { - self.req_tx.send(value).await - } -} - -pub struct HandlerIO { - pub event_tx: chan::Sender, - pub req_rx: chan::Receiver, -} - -pub fn create_actor_io() -> (ActorIO, HandlerIO) { - let (req_tx, req_rx) = chan::bounded(32); - let (event_tx, event_rx) = chan::bounded(32); - - (ActorIO { event_rx, req_tx }, HandlerIO { event_tx, req_rx }) -} diff --git a/core/crates/sync/src/backfill.rs b/core/crates/sync/src/backfill.rs index 77f16a575..970de0c0d 100644 --- a/core/crates/sync/src/backfill.rs +++ b/core/crates/sync/src/backfill.rs @@ -1,7 +1,7 @@ use sd_prisma::{ prisma::{ - crdt_operation, exif_data, file_path, instance, label, label_on_object, location, object, - tag, tag_on_object, PrismaClient, SortOrder, + crdt_operation, device, exif_data, file_path, label, label_on_object, location, object, + storage_statistics, tag, tag_on_object, PrismaClient, SortOrder, }, prisma_sync, }; @@ -10,49 +10,141 @@ use sd_utils::chain_optional_iter; use std::future::Future; +use futures_concurrency::future::TryJoin; use tokio::time::Instant; use tracing::{debug, instrument}; -use super::{crdt_op_unchecked_db, Error}; +use super::{crdt_op_unchecked_db, Error, SyncManager}; /// Takes all the syncable data in the database and generates [`CRDTOperations`] for it. /// This is a requirement before the library can sync. -pub async fn backfill_operations( - db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, -) -> Result<(), Error> { - let lock = sync.timestamp_lock.lock().await; +pub async fn backfill_operations(sync: &SyncManager) -> Result<(), Error> { + let _lock_guard = sync.sync_lock.lock().await; - let res = db - ._transaction() + let db = &sync.db; + + let local_device = db + .device() + .find_unique(device::pub_id::equals(sync.device_pub_id.to_db())) + .exec() + .await? + .ok_or(Error::DeviceNotFound(sync.device_pub_id.clone()))?; + + let local_device_id = local_device.id; + + db._transaction() .with_timeout(9_999_999_999) .run(|db| async move { debug!("backfill started"); let start = Instant::now(); db.crdt_operation() - .delete_many(vec![crdt_operation::instance_id::equals(instance_id)]) + .delete_many(vec![crdt_operation::device_pub_id::equals( + sync.device_pub_id.to_db(), + )]) .exec() .await?; - paginate_tags(&db, sync, instance_id).await?; - paginate_locations(&db, sync, instance_id).await?; - paginate_objects(&db, sync, instance_id).await?; - paginate_exif_datas(&db, sync, instance_id).await?; - paginate_file_paths(&db, sync, instance_id).await?; - paginate_tags_on_objects(&db, sync, instance_id).await?; - paginate_labels(&db, sync, instance_id).await?; - paginate_labels_on_objects(&db, sync, instance_id).await?; + backfill_device(&db, sync, local_device).await?; + + ( + backfill_storage_statistics(&db, sync, local_device_id), + paginate_tags(&db, sync), + paginate_locations(&db, sync, local_device_id), + paginate_objects(&db, sync, local_device_id), + paginate_labels(&db, sync), + ) + .try_join() + .await?; + + ( + paginate_exif_datas(&db, sync, local_device_id), + paginate_file_paths(&db, sync, local_device_id), + paginate_tags_on_objects(&db, sync, local_device_id), + paginate_labels_on_objects(&db, sync, local_device_id), + ) + .try_join() + .await?; debug!(elapsed = ?start.elapsed(), "backfill ended"); Ok(()) }) - .await; + .await +} - drop(lock); +#[instrument(skip(db, sync), err)] +async fn backfill_device( + db: &PrismaClient, + sync: &SyncManager, + local_device: device::Data, +) -> Result<(), Error> { + db.crdt_operation() + .create_many(vec![crdt_op_unchecked_db(&sync.shared_create( + prisma_sync::device::SyncId { + pub_id: local_device.pub_id, + }, + chain_optional_iter( + [], + [ + option_sync_entry!(local_device.name, device::name), + option_sync_entry!(local_device.os, device::os), + option_sync_entry!(local_device.hardware_model, device::hardware_model), + option_sync_entry!(local_device.timestamp, device::timestamp), + option_sync_entry!(local_device.date_created, device::date_created), + option_sync_entry!(local_device.date_deleted, device::date_deleted), + ], + ), + ))?]) + .exec() + .await?; - res + Ok(()) +} + +#[instrument(skip(db, sync), err)] +async fn backfill_storage_statistics( + db: &PrismaClient, + sync: &SyncManager, + device_id: device::id::Type, +) -> Result<(), Error> { + let Some(stats) = db + .storage_statistics() + .find_first(vec![storage_statistics::device_id::equals(Some(device_id))]) + .include(storage_statistics::include!({device: select { pub_id }})) + .exec() + .await? + else { + // Nothing to do + return Ok(()); + }; + + db.crdt_operation() + .create_many(vec![crdt_op_unchecked_db(&sync.shared_create( + prisma_sync::storage_statistics::SyncId { + pub_id: stats.pub_id, + }, + chain_optional_iter( + [ + sync_entry!(stats.total_capacity, storage_statistics::total_capacity), + sync_entry!( + stats.available_capacity, + storage_statistics::available_capacity + ), + ], + [option_sync_entry!( + stats.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + storage_statistics::device + )], + ), + ))?]) + .exec() + .await?; + + Ok(()) } async fn paginate( @@ -112,38 +204,32 @@ where } #[instrument(skip(db, sync), err)] -async fn paginate_tags( - db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, -) -> Result<(), Error> { - use tag::{color, date_created, date_modified, id, name}; - +async fn paginate_tags(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> { paginate( |cursor| { db.tag() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) + .find_many(vec![tag::id::gt(cursor)]) + .order_by(tag::id::order(SortOrder::Asc)) .exec() }, |tag| tag.id, |tags| { tags.into_iter() - .flat_map(|t| { + .map(|t| { sync.shared_create( prisma_sync::tag::SyncId { pub_id: t.pub_id }, chain_optional_iter( [], [ - option_sync_entry!(t.name, name), - option_sync_entry!(t.color, color), - option_sync_entry!(t.date_created, date_created), - option_sync_entry!(t.date_modified, date_modified), + option_sync_entry!(t.name, tag::name), + option_sync_entry!(t.color, tag::color), + option_sync_entry!(t.date_created, tag::date_created), + option_sync_entry!(t.date_modified, tag::date_modified), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -154,25 +240,20 @@ async fn paginate_tags( #[instrument(skip(db, sync), err)] async fn paginate_locations( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use location::{ - available_capacity, date_created, generate_preview_media, hidden, id, include, instance, - is_archived, name, path, size_in_bytes, sync_preview_media, total_capacity, - }; - paginate( |cursor| { db.location() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) + .find_many(vec![ + location::id::gt(cursor), + location::device_id::equals(Some(device_id)), + ]) + .order_by(location::id::order(SortOrder::Asc)) .take(1000) - .include(include!({ - instance: select { - id - pub_id - } + .include(location::include!({ + device: select { pub_id } })) .exec() }, @@ -180,36 +261,44 @@ async fn paginate_locations( |locations| { locations .into_iter() - .flat_map(|l| { + .map(|l| { sync.shared_create( prisma_sync::location::SyncId { pub_id: l.pub_id }, chain_optional_iter( [], [ - option_sync_entry!(l.name, name), - option_sync_entry!(l.path, path), - option_sync_entry!(l.total_capacity, total_capacity), - option_sync_entry!(l.available_capacity, available_capacity), - option_sync_entry!(l.size_in_bytes, size_in_bytes), - option_sync_entry!(l.is_archived, is_archived), + option_sync_entry!(l.name, location::name), + option_sync_entry!(l.path, location::path), + option_sync_entry!(l.total_capacity, location::total_capacity), + option_sync_entry!( + l.available_capacity, + location::available_capacity + ), + option_sync_entry!(l.size_in_bytes, location::size_in_bytes), + option_sync_entry!(l.is_archived, location::is_archived), option_sync_entry!( l.generate_preview_media, - generate_preview_media + location::generate_preview_media ), - option_sync_entry!(l.sync_preview_media, sync_preview_media), - option_sync_entry!(l.hidden, hidden), - option_sync_entry!(l.date_created, date_created), option_sync_entry!( - l.instance.map(|i| { - prisma_sync::instance::SyncId { pub_id: i.pub_id } + l.sync_preview_media, + location::sync_preview_media + ), + option_sync_entry!(l.hidden, location::hidden), + option_sync_entry!(l.date_created, location::date_created), + option_sync_entry!( + l.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } }), - instance + location::device ), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -220,41 +309,53 @@ async fn paginate_locations( #[instrument(skip(db, sync), err)] async fn paginate_objects( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use object::{date_accessed, date_created, favorite, hidden, id, important, kind, note}; - paginate( |cursor| { db.object() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) + .find_many(vec![ + object::id::gt(cursor), + object::device_id::equals(Some(device_id)), + ]) + .order_by(object::id::order(SortOrder::Asc)) .take(1000) + .include(object::include!({ + device: select { pub_id } + })) .exec() }, |object| object.id, |objects| { objects .into_iter() - .flat_map(|o| { + .map(|o| { sync.shared_create( prisma_sync::object::SyncId { pub_id: o.pub_id }, chain_optional_iter( [], [ - option_sync_entry!(o.kind, kind), - option_sync_entry!(o.hidden, hidden), - option_sync_entry!(o.favorite, favorite), - option_sync_entry!(o.important, important), - option_sync_entry!(o.note, note), - option_sync_entry!(o.date_created, date_created), - option_sync_entry!(o.date_accessed, date_accessed), + option_sync_entry!(o.kind, object::kind), + option_sync_entry!(o.hidden, object::hidden), + option_sync_entry!(o.favorite, object::favorite), + option_sync_entry!(o.important, object::important), + option_sync_entry!(o.note, object::note), + option_sync_entry!(o.date_created, object::date_created), + option_sync_entry!(o.date_accessed, object::date_accessed), + option_sync_entry!( + o.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + object::device + ), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -265,22 +366,21 @@ async fn paginate_objects( #[instrument(skip(db, sync), err)] async fn paginate_exif_datas( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use exif_data::{ - artist, camera_data, copyright, description, epoch_time, exif_version, id, include, - media_date, media_location, resolution, - }; - paginate( |cursor| { db.exif_data() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) + .find_many(vec![ + exif_data::id::gt(cursor), + exif_data::device_id::equals(Some(device_id)), + ]) + .order_by(exif_data::id::order(SortOrder::Asc)) .take(1000) - .include(include!({ + .include(exif_data::include!({ object: select { pub_id } + device: select { pub_id } })) .exec() }, @@ -288,7 +388,7 @@ async fn paginate_exif_datas( |exif_datas| { exif_datas .into_iter() - .flat_map(|ed| { + .map(|ed| { sync.shared_create( prisma_sync::exif_data::SyncId { object: prisma_sync::object::SyncId { @@ -298,20 +398,28 @@ async fn paginate_exif_datas( chain_optional_iter( [], [ - option_sync_entry!(ed.resolution, resolution), - option_sync_entry!(ed.media_date, media_date), - option_sync_entry!(ed.media_location, media_location), - option_sync_entry!(ed.camera_data, camera_data), - option_sync_entry!(ed.artist, artist), - option_sync_entry!(ed.description, description), - option_sync_entry!(ed.copyright, copyright), - option_sync_entry!(ed.exif_version, exif_version), - option_sync_entry!(ed.epoch_time, epoch_time), + option_sync_entry!(ed.resolution, exif_data::resolution), + option_sync_entry!(ed.media_date, exif_data::media_date), + option_sync_entry!(ed.media_location, exif_data::media_location), + option_sync_entry!(ed.camera_data, exif_data::camera_data), + option_sync_entry!(ed.artist, exif_data::artist), + option_sync_entry!(ed.description, exif_data::description), + option_sync_entry!(ed.copyright, exif_data::copyright), + option_sync_entry!(ed.exif_version, exif_data::exif_version), + option_sync_entry!(ed.epoch_time, exif_data::epoch_time), + option_sync_entry!( + ed.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + exif_data::device + ), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -322,22 +430,21 @@ async fn paginate_exif_datas( #[instrument(skip(db, sync), err)] async fn paginate_file_paths( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use file_path::{ - cas_id, date_created, date_indexed, date_modified, extension, hidden, id, include, inode, - integrity_checksum, is_dir, location, materialized_path, name, object, size_in_bytes_bytes, - }; - paginate( |cursor| { db.file_path() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) - .include(include!({ + .find_many(vec![ + file_path::id::gt(cursor), + file_path::device_id::equals(Some(device_id)), + ]) + .order_by(file_path::id::order(SortOrder::Asc)) + .include(file_path::include!({ location: select { pub_id } object: select { pub_id } + device: select { pub_id } })) .exec() }, @@ -345,41 +452,58 @@ async fn paginate_file_paths( |file_paths| { file_paths .into_iter() - .flat_map(|fp| { + .map(|fp| { sync.shared_create( prisma_sync::file_path::SyncId { pub_id: fp.pub_id }, chain_optional_iter( [], [ - option_sync_entry!(fp.is_dir, is_dir), - option_sync_entry!(fp.cas_id, cas_id), - option_sync_entry!(fp.integrity_checksum, integrity_checksum), + option_sync_entry!(fp.is_dir, file_path::is_dir), + option_sync_entry!(fp.cas_id, file_path::cas_id), + option_sync_entry!( + fp.integrity_checksum, + file_path::integrity_checksum + ), option_sync_entry!( fp.location.map(|l| { prisma_sync::location::SyncId { pub_id: l.pub_id } }), - location + file_path::location ), option_sync_entry!( fp.object.map(|o| { prisma_sync::object::SyncId { pub_id: o.pub_id } }), - object + file_path::object + ), + option_sync_entry!( + fp.materialized_path, + file_path::materialized_path + ), + option_sync_entry!(fp.name, file_path::name), + option_sync_entry!(fp.extension, file_path::extension), + option_sync_entry!(fp.hidden, file_path::hidden), + option_sync_entry!( + fp.size_in_bytes_bytes, + file_path::size_in_bytes_bytes + ), + option_sync_entry!(fp.inode, file_path::inode), + option_sync_entry!(fp.date_created, file_path::date_created), + option_sync_entry!(fp.date_modified, file_path::date_modified), + option_sync_entry!(fp.date_indexed, file_path::date_indexed), + option_sync_entry!( + fp.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + file_path::device ), - option_sync_entry!(fp.materialized_path, materialized_path), - option_sync_entry!(fp.name, name), - option_sync_entry!(fp.extension, extension), - option_sync_entry!(fp.hidden, hidden), - option_sync_entry!(fp.size_in_bytes_bytes, size_in_bytes_bytes), - option_sync_entry!(fp.inode, inode), - option_sync_entry!(fp.date_created, date_created), - option_sync_entry!(fp.date_modified, date_modified), - option_sync_entry!(fp.date_indexed, date_indexed), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -390,20 +514,23 @@ async fn paginate_file_paths( #[instrument(skip(db, sync), err)] async fn paginate_tags_on_objects( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use tag_on_object::{date_created, include, object_id, tag_id}; - paginate_relation( |group_id, item_id| { db.tag_on_object() - .find_many(vec![tag_id::gt(group_id), object_id::gt(item_id)]) - .order_by(tag_id::order(SortOrder::Asc)) - .order_by(object_id::order(SortOrder::Asc)) - .include(include!({ + .find_many(vec![ + tag_on_object::tag_id::gt(group_id), + tag_on_object::object_id::gt(item_id), + tag_on_object::device_id::equals(Some(device_id)), + ]) + .order_by(tag_on_object::tag_id::order(SortOrder::Asc)) + .order_by(tag_on_object::object_id::order(SortOrder::Asc)) + .include(tag_on_object::include!({ tag: select { pub_id } object: select { pub_id } + device: select { pub_id } })) .exec() }, @@ -411,7 +538,7 @@ async fn paginate_tags_on_objects( |tag_on_objects| { tag_on_objects .into_iter() - .flat_map(|t_o| { + .map(|t_o| { sync.relation_create( prisma_sync::tag_on_object::SyncId { tag: prisma_sync::tag::SyncId { @@ -423,11 +550,21 @@ async fn paginate_tags_on_objects( }, chain_optional_iter( [], - [option_sync_entry!(t_o.date_created, date_created)], + [ + option_sync_entry!(t_o.date_created, tag_on_object::date_created), + option_sync_entry!( + t_o.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + tag_on_object::device + ), + ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -436,37 +573,31 @@ async fn paginate_tags_on_objects( } #[instrument(skip(db, sync), err)] -async fn paginate_labels( - db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, -) -> Result<(), Error> { - use label::{date_created, date_modified, id}; - +async fn paginate_labels(db: &PrismaClient, sync: &SyncManager) -> Result<(), Error> { paginate( |cursor| { db.label() - .find_many(vec![id::gt(cursor)]) - .order_by(id::order(SortOrder::Asc)) + .find_many(vec![label::id::gt(cursor)]) + .order_by(label::id::order(SortOrder::Asc)) .exec() }, |label| label.id, |labels| { labels .into_iter() - .flat_map(|l| { + .map(|l| { sync.shared_create( prisma_sync::label::SyncId { name: l.name }, chain_optional_iter( [], [ - option_sync_entry!(l.date_created, date_created), - option_sync_entry!(l.date_modified, date_modified), + option_sync_entry!(l.date_created, label::date_created), + option_sync_entry!(l.date_modified, label::date_modified), ], ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, @@ -477,20 +608,23 @@ async fn paginate_labels( #[instrument(skip(db, sync), err)] async fn paginate_labels_on_objects( db: &PrismaClient, - sync: &crate::Manager, - instance_id: instance::id::Type, + sync: &SyncManager, + device_id: device::id::Type, ) -> Result<(), Error> { - use label_on_object::{date_created, include, label_id, object_id}; - paginate_relation( |group_id, item_id| { db.label_on_object() - .find_many(vec![label_id::gt(group_id), object_id::gt(item_id)]) - .order_by(label_id::order(SortOrder::Asc)) - .order_by(object_id::order(SortOrder::Asc)) - .include(include!({ + .find_many(vec![ + label_on_object::label_id::gt(group_id), + label_on_object::object_id::gt(item_id), + label_on_object::device_id::equals(Some(device_id)), + ]) + .order_by(label_on_object::label_id::order(SortOrder::Asc)) + .order_by(label_on_object::object_id::order(SortOrder::Asc)) + .include(label_on_object::include!({ object: select { pub_id } label: select { name } + device: select { pub_id } })) .exec() }, @@ -498,7 +632,7 @@ async fn paginate_labels_on_objects( |label_on_objects| { label_on_objects .into_iter() - .flat_map(|l_o| { + .map(|l_o| { sync.relation_create( prisma_sync::label_on_object::SyncId { label: prisma_sync::label::SyncId { @@ -508,10 +642,20 @@ async fn paginate_labels_on_objects( pub_id: l_o.object.pub_id, }, }, - [sync_entry!(l_o.date_created, date_created)], + chain_optional_iter( + [sync_entry!(l_o.date_created, label_on_object::date_created)], + [option_sync_entry!( + l_o.device.map(|device| { + prisma_sync::device::SyncId { + pub_id: device.pub_id, + } + }), + label_on_object::device + )], + ), ) }) - .map(|o| crdt_op_unchecked_db(&o, instance_id)) + .map(|o| crdt_op_unchecked_db(&o)) .collect::, _>>() .map(|creates| db.crdt_operation().create_many(creates).exec()) }, diff --git a/core/crates/sync/src/db_operation.rs b/core/crates/sync/src/db_operation.rs index ff49d32b3..1bdb2422c 100644 --- a/core/crates/sync/src/db_operation.rs +++ b/core/crates/sync/src/db_operation.rs @@ -1,82 +1,14 @@ -use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient}; +use sd_core_prisma_helpers::DevicePubId; + +use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, PrismaClient}; use sd_sync::CRDTOperation; -use sd_utils::from_bytes_to_uuid; +use sd_utils::uuid_to_bytes; use tracing::instrument; use uhlc::NTP64; -use uuid::Uuid; use super::Error; -crdt_operation::include!(crdt_with_instance { - instance: select { pub_id } -}); - -cloud_crdt_operation::include!(cloud_crdt_with_instance { - instance: select { pub_id } -}); - -impl crdt_with_instance::Data { - #[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations - pub const fn timestamp(&self) -> NTP64 { - NTP64(self.timestamp as u64) - } - - pub fn instance(&self) -> Uuid { - from_bytes_to_uuid(&self.instance.pub_id) - } - - pub fn into_operation(self) -> Result { - Ok(CRDTOperation { - instance: self.instance(), - timestamp: self.timestamp(), - record_id: rmp_serde::from_slice(&self.record_id)?, - - model: { - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - // SAFETY: we will not have more than 2^16 models and we had to store using signed - // integers due to SQLite limitations - { - self.model as u16 - } - }, - data: rmp_serde::from_slice(&self.data)?, - }) - } -} - -impl cloud_crdt_with_instance::Data { - #[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations - pub const fn timestamp(&self) -> NTP64 { - NTP64(self.timestamp as u64) - } - - pub fn instance(&self) -> Uuid { - from_bytes_to_uuid(&self.instance.pub_id) - } - - #[instrument(skip(self), err)] - pub fn into_operation(self) -> Result<(i32, CRDTOperation), Error> { - Ok(( - self.id, - CRDTOperation { - instance: self.instance(), - timestamp: self.timestamp(), - record_id: rmp_serde::from_slice(&self.record_id)?, - model: { - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - // SAFETY: we will not have more than 2^16 models and we had to store using signed - // integers due to SQLite limitations - { - self.model as u16 - } - }, - data: rmp_serde::from_slice(&self.data)?, - }, - )) - } -} - #[instrument(skip(op, db), err)] pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Result<(), Error> { crdt_operation::Create { @@ -87,16 +19,85 @@ pub async fn write_crdt_op_to_db(op: &CRDTOperation, db: &PrismaClient) -> Resul op.timestamp.0 as i64 } }, - instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()), + device_pub_id: uuid_to_bytes(&op.device_pub_id), kind: op.kind().to_string(), data: rmp_serde::to_vec(&op.data)?, - model: i32::from(op.model), + model: i32::from(op.model_id), record_id: rmp_serde::to_vec(&op.record_id)?, _params: vec![], } .to_query(db) .select(crdt_operation::select!({ id })) // To don't fetch the whole object for nothing .exec() - .await - .map_or_else(|e| Err(e.into()), |_| Ok(())) + .await?; + + Ok(()) +} + +pub fn from_crdt_ops( + crdt_operation::Data { + timestamp, + model, + record_id, + data, + device_pub_id, + .. + }: crdt_operation::Data, +) -> Result { + Ok(CRDTOperation { + device_pub_id: DevicePubId::from(device_pub_id).into(), + timestamp: { + #[allow(clippy::cast_sign_loss)] + { + // SAFETY: we had to store using i64 due to SQLite limitations + NTP64(timestamp as u64) + } + }, + model_id: { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + // SAFETY: we will not have more than 2^16 models and we had to store using signed + // integers due to SQLite limitations + model as u16 + } + }, + record_id: rmp_serde::from_slice(&record_id)?, + data: rmp_serde::from_slice(&data)?, + }) +} + +pub fn from_cloud_crdt_ops( + cloud_crdt_operation::Data { + id, + timestamp, + model, + record_id, + data, + device_pub_id, + .. + }: cloud_crdt_operation::Data, +) -> Result<(cloud_crdt_operation::id::Type, CRDTOperation), Error> { + Ok(( + id, + CRDTOperation { + device_pub_id: DevicePubId::from(device_pub_id).into(), + timestamp: { + #[allow(clippy::cast_sign_loss)] + { + // SAFETY: we had to store using i64 due to SQLite limitations + NTP64(timestamp as u64) + } + }, + model_id: { + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + { + // SAFETY: we will not have more than 2^16 models and we had to store using signed + // integers due to SQLite limitations + model as u16 + } + }, + record_id: rmp_serde::from_slice(&record_id)?, + data: rmp_serde::from_slice(&data)?, + }, + )) } diff --git a/core/crates/sync/src/ingest.rs b/core/crates/sync/src/ingest.rs deleted file mode 100644 index d868f685d..000000000 --- a/core/crates/sync/src/ingest.rs +++ /dev/null @@ -1,641 +0,0 @@ -use sd_prisma::{ - prisma::{crdt_operation, PrismaClient, SortOrder}, - prisma_sync::ModelSyncData, -}; -use sd_sync::{ - CRDTOperation, CRDTOperationData, CompressedCRDTOperation, CompressedCRDTOperations, - OperationKind, -}; - -use std::{ - collections::BTreeMap, - future::IntoFuture, - num::NonZeroU128, - ops::Deref, - pin::pin, - sync::{atomic::Ordering, Arc}, - time::SystemTime, -}; - -use async_channel as chan; -use futures::{stream, FutureExt, StreamExt}; -use futures_concurrency::{ - future::{Race, TryJoin}, - stream::Merge, -}; -use prisma_client_rust::chrono::{DateTime, Utc}; -use tokio::sync::oneshot; -use tracing::{debug, error, instrument, trace, warn}; -use uhlc::{Timestamp, NTP64}; -use uuid::Uuid; - -use super::{ - actor::{create_actor_io, ActorIO, ActorTypes, HandlerIO}, - db_operation::write_crdt_op_to_db, - Error, SharedState, -}; - -#[derive(Debug)] -#[must_use] -/// Stuff that can be handled outside the actor -pub enum Request { - Messages { - timestamps: Vec<(Uuid, NTP64)>, - tx: oneshot::Sender<()>, - }, - FinishedIngesting, -} - -/// Stuff that the actor consumes -#[derive(Debug)] -pub enum Event { - Notification, - Messages(MessagesEvent), -} - -#[derive(Debug, Default)] -pub enum State { - #[default] - WaitingForNotification, - RetrievingMessages, - Ingesting(MessagesEvent), -} - -/// The single entrypoint for sync operation ingestion. -/// Requests sync operations in a given timestamp range, -/// and attempts to write them to the sync operations table along with -/// the actual cell that the operation points to. -/// -/// If this actor stops running, no sync operations will -/// be applied to the database, independent of whether systems like p2p -/// or cloud are exchanging messages. -pub struct Actor { - state: Option, - shared: Arc, - io: ActorIO, -} - -impl Actor { - #[instrument(skip(self), fields(old_state = ?self.state))] - async fn tick(&mut self) { - let state = match self - .state - .take() - .expect("ingest actor in inconsistent state") - { - State::WaitingForNotification => self.waiting_for_notification_state_transition().await, - State::RetrievingMessages => self.retrieving_messages_state_transition().await, - State::Ingesting(event) => self.ingesting_state_transition(event).await, - }; - - trace!(?state, "Actor state transitioned;"); - - self.state = Some(state); - } - - async fn waiting_for_notification_state_transition(&self) -> State { - self.shared.active.store(false, Ordering::Relaxed); - self.shared.active_notify.notify_waiters(); - - loop { - match self - .io - .event_rx - .recv() - .await - .expect("sync actor receiver unexpectedly closed") - { - Event::Notification => { - trace!("Received notification"); - break; - } - Event::Messages(event) => { - trace!( - ?event, - "Ignored event message as we're waiting for a `Event::Notification`" - ); - } - } - } - - self.shared.active.store(true, Ordering::Relaxed); - self.shared.active_notify.notify_waiters(); - - State::RetrievingMessages - } - - async fn retrieving_messages_state_transition(&self) -> State { - enum StreamMessage { - NewEvent(Event), - AckedRequest(Result<(), oneshot::error::RecvError>), - } - - let (tx, rx) = oneshot::channel::<()>(); - - let timestamps = self - .timestamps - .read() - .await - .iter() - .map(|(&uid, ×tamp)| (uid, timestamp)) - .collect(); - - if self - .io - .send(Request::Messages { timestamps, tx }) - .await - .is_err() - { - warn!("Failed to send messages request"); - } - - let mut msg_stream = pin!(( - self.io.event_rx.clone().map(StreamMessage::NewEvent), - stream::once(rx.map(StreamMessage::AckedRequest)), - ) - .merge()); - - loop { - if let Some(msg) = msg_stream.next().await { - match msg { - StreamMessage::NewEvent(event) => { - if let Event::Messages(messages) = event { - trace!(?messages, "Received messages;"); - break State::Ingesting(messages); - } - } - StreamMessage::AckedRequest(res) => { - if res.is_err() { - debug!("messages request ignored"); - break State::WaitingForNotification; - } - } - } - } else { - break State::WaitingForNotification; - } - } - } - - async fn ingesting_state_transition(&mut self, event: MessagesEvent) -> State { - debug!( - messages_count = event.messages.len(), - first_message = ?DateTime::::from( - event.messages - .first() - .map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time()) - ), - last_message = ?DateTime::::from( - event.messages - .last() - .map_or(SystemTime::UNIX_EPOCH, |m| m.3.timestamp.to_system_time()) - ), - "Ingesting operations;", - ); - - for (instance, data) in event.messages.0 { - for (model, data) in data { - for (record, ops) in data { - if let Err(e) = self - .process_crdt_operations(instance, model, record, ops) - .await - { - error!(?e, "Failed to ingest CRDT operations;"); - } - } - } - } - - if let Some(tx) = event.wait_tx { - if tx.send(()).is_err() { - warn!("Failed to send wait_tx signal"); - } - } - - if event.has_more { - State::RetrievingMessages - } else { - { - if self.io.send(Request::FinishedIngesting).await.is_err() { - error!("Failed to send finished ingesting request"); - } - - State::WaitingForNotification - } - } - } - - pub async fn declare(shared: Arc) -> Handler { - let (io, HandlerIO { event_tx, req_rx }) = create_actor_io::(); - - shared - .actors - .declare( - "Sync Ingest", - { - let shared = Arc::clone(&shared); - move |stop| async move { - enum Race { - Ticked, - Stopped, - } - - let mut this = Self { - state: Some(State::default()), - io, - shared, - }; - - while matches!( - ( - this.tick().map(|()| Race::Ticked), - stop.into_future().map(|()| Race::Stopped), - ) - .race() - .await, - Race::Ticked - ) { /* Everything is Awesome! */ } - } - }, - true, - ) - .await; - - Handler { event_tx, req_rx } - } - - // where the magic happens - #[instrument(skip(self, ops), fields(operations_count = %ops.len()), err)] - async fn process_crdt_operations( - &mut self, - instance: Uuid, - model: u16, - record_id: rmpv::Value, - mut ops: Vec, - ) -> Result<(), Error> { - let db = &self.db; - - ops.sort_by_key(|op| op.timestamp); - - let new_timestamp = ops.last().expect("Empty ops array").timestamp; - - // first, we update the HLC's timestamp with the incoming one. - // this involves a drift check + sets the last time of the clock - self.clock - .update_with_timestamp(&Timestamp::new( - new_timestamp, - uhlc::ID::from(NonZeroU128::new(instance.to_u128_le()).expect("Non zero id")), - )) - .expect("timestamp has too much drift!"); - - // read the timestamp for the operation's instance, or insert one if it doesn't exist - let timestamp = self.timestamps.read().await.get(&instance).copied(); - - // Delete - ignores all other messages - if let Some(delete_op) = ops - .iter() - .rev() - .find(|op| matches!(op.data, CRDTOperationData::Delete)) - { - trace!("Deleting operation"); - handle_crdt_deletion(db, instance, model, record_id, delete_op).await?; - } - // Create + > 0 Update - overwrites the create's data with the updates - else if let Some(timestamp) = ops - .iter() - .rev() - .find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp)) - { - trace!("Create + Updates operations"); - - // conflict resolution - let delete = db - .crdt_operation() - .find_first(vec![ - crdt_operation::model::equals(i32::from(model)), - crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), - crdt_operation::kind::equals(OperationKind::Delete.to_string()), - ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc)) - .exec() - .await?; - - if delete.is_some() { - debug!("Found a previous delete operation with the same SyncId, will ignore these operations"); - return Ok(()); - } - - handle_crdt_create_and_updates(db, instance, model, record_id, ops, timestamp).await?; - } - // > 0 Update - batches updates with a fake Create op - else { - trace!("Updates operation"); - - let mut data = BTreeMap::new(); - - for op in ops.into_iter().rev() { - let CRDTOperationData::Update { field, value } = op.data else { - unreachable!("Create + Delete should be filtered out!"); - }; - - data.insert(field, (value, op.timestamp)); - } - - // conflict resolution - let (create, updates) = db - ._batch(( - db.crdt_operation() - .find_first(vec![ - crdt_operation::model::equals(i32::from(model)), - crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), - crdt_operation::kind::equals(OperationKind::Create.to_string()), - ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc)), - data.iter() - .map(|(k, (_, timestamp))| { - Ok(db - .crdt_operation() - .find_first(vec![ - crdt_operation::timestamp::gt({ - #[allow(clippy::cast_possible_wrap)] - // SAFETY: we had to store using i64 due to SQLite limitations - { - timestamp.as_u64() as i64 - } - }), - crdt_operation::model::equals(i32::from(model)), - crdt_operation::record_id::equals(rmp_serde::to_vec( - &record_id, - )?), - crdt_operation::kind::equals( - OperationKind::Update(k).to_string(), - ), - ]) - .order_by(crdt_operation::timestamp::order(SortOrder::Desc))) - }) - .collect::, Error>>()?, - )) - .await?; - - if create.is_none() { - warn!("Failed to find a previous create operation with the same SyncId"); - return Ok(()); - } - - handle_crdt_updates(db, instance, model, record_id, data, updates).await?; - } - - // update the stored timestamp for this instance - will be derived from the crdt operations table on restart - let new_ts = NTP64::max(timestamp.unwrap_or_default(), new_timestamp); - - self.timestamps.write().await.insert(instance, new_ts); - - Ok(()) - } -} - -async fn handle_crdt_updates( - db: &PrismaClient, - instance: Uuid, - model: u16, - record_id: rmpv::Value, - mut data: BTreeMap, - updates: Vec>, -) -> Result<(), Error> { - let keys = data.keys().cloned().collect::>(); - - // does the same thing as processing ops one-by-one and returning early if a newer op was found - for (update, key) in updates.into_iter().zip(keys) { - if update.is_some() { - data.remove(&key); - } - } - - db._transaction() - .with_timeout(30 * 1000) - .run(|db| async move { - // fake operation to batch them all at once - ModelSyncData::from_op(CRDTOperation { - instance, - model, - record_id: record_id.clone(), - timestamp: NTP64(0), - data: CRDTOperationData::Create( - data.iter() - .map(|(k, (data, _))| (k.clone(), data.clone())) - .collect(), - ), - }) - .ok_or(Error::InvalidModelId(model))? - .exec(&db) - .await?; - - // need to only apply ops that haven't been filtered out - data.into_iter() - .map(|(field, (value, timestamp))| { - let record_id = record_id.clone(); - let db = &db; - - async move { - write_crdt_op_to_db( - &CRDTOperation { - instance, - model, - record_id, - timestamp, - data: CRDTOperationData::Update { field, value }, - }, - db, - ) - .await - } - }) - .collect::>() - .try_join() - .await - .map(|_| ()) - }) - .await -} - -async fn handle_crdt_create_and_updates( - db: &PrismaClient, - instance: Uuid, - model: u16, - record_id: rmpv::Value, - ops: Vec, - timestamp: NTP64, -) -> Result<(), Error> { - let mut data = BTreeMap::new(); - - let mut applied_ops = vec![]; - - // search for all Updates until a Create is found - for op in ops.iter().rev() { - match &op.data { - CRDTOperationData::Delete => unreachable!("Delete can't exist here!"), - CRDTOperationData::Create(create_data) => { - for (k, v) in create_data { - data.entry(k).or_insert(v); - } - - applied_ops.push(op); - - break; - } - CRDTOperationData::Update { field, value } => { - applied_ops.push(op); - data.insert(field, value); - } - } - } - - db._transaction() - .with_timeout(30 * 1000) - .run(|db| async move { - // fake a create with a bunch of data rather than individual insert - ModelSyncData::from_op(CRDTOperation { - instance, - model, - record_id: record_id.clone(), - timestamp, - data: CRDTOperationData::Create( - data.into_iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(), - ), - }) - .ok_or(Error::InvalidModelId(model))? - .exec(&db) - .await?; - - applied_ops - .into_iter() - .map(|op| { - let record_id = record_id.clone(); - let db = &db; - async move { - let operation = CRDTOperation { - instance, - model, - record_id, - timestamp: op.timestamp, - data: op.data.clone(), - }; - - write_crdt_op_to_db(&operation, db).await - } - }) - .collect::>() - .try_join() - .await - .map(|_| ()) - }) - .await -} - -async fn handle_crdt_deletion( - db: &PrismaClient, - instance: Uuid, - model: u16, - record_id: rmpv::Value, - delete_op: &CompressedCRDTOperation, -) -> Result<(), Error> { - // deletes are the be all and end all, no need to check anything - let op = CRDTOperation { - instance, - model, - record_id, - timestamp: delete_op.timestamp, - data: CRDTOperationData::Delete, - }; - - db._transaction() - .with_timeout(30 * 1000) - .run(|db| async move { - ModelSyncData::from_op(op.clone()) - .ok_or(Error::InvalidModelId(model))? - .exec(&db) - .await?; - - write_crdt_op_to_db(&op, &db).await - }) - .await -} - -impl Deref for Actor { - type Target = SharedState; - - fn deref(&self) -> &Self::Target { - &self.shared - } -} - -pub struct Handler { - pub event_tx: chan::Sender, - pub req_rx: chan::Receiver, -} - -#[derive(Debug)] -pub struct MessagesEvent { - pub instance_id: Uuid, - pub messages: CompressedCRDTOperations, - pub has_more: bool, - pub wait_tx: Option>, -} - -impl ActorTypes for Actor { - type Event = Event; - type Request = Request; - type Handler = Handler; -} - -#[cfg(test)] -mod test { - use std::{sync::atomic::AtomicBool, time::Duration}; - - use tokio::sync::Notify; - use uhlc::HLCBuilder; - - use super::*; - - async fn new_actor() -> (Handler, Arc) { - let instance = Uuid::new_v4(); - let shared = Arc::new(SharedState { - db: sd_prisma::test_db().await, - instance, - clock: HLCBuilder::new() - .with_id(uhlc::ID::from( - NonZeroU128::new(instance.to_u128_le()).expect("Non zero id"), - )) - .build(), - timestamps: Arc::default(), - emit_messages_flag: Arc::new(AtomicBool::new(true)), - active: AtomicBool::default(), - active_notify: Notify::default(), - actors: Arc::default(), - }); - - (Actor::declare(Arc::clone(&shared)).await, shared) - } - - /// If messages tx is dropped, actor should reset and assume no further messages - /// will be sent - #[tokio::test] - async fn messages_request_drop() -> Result<(), ()> { - let (ingest, _) = new_actor().await; - - for _ in 0..10 { - ingest.event_tx.send(Event::Notification).await.unwrap(); - - let Ok(Request::Messages { .. }) = ingest.req_rx.recv().await else { - panic!("bruh") - }; - - // without this the test hangs, idk - tokio::time::sleep(Duration::from_millis(0)).await; - } - - Ok(()) - } -} diff --git a/core/crates/sync/src/ingest_utils.rs b/core/crates/sync/src/ingest_utils.rs new file mode 100644 index 000000000..5f60ccfdf --- /dev/null +++ b/core/crates/sync/src/ingest_utils.rs @@ -0,0 +1,517 @@ +use sd_core_prisma_helpers::DevicePubId; + +use sd_prisma::{ + prisma::{crdt_operation, PrismaClient}, + prisma_sync::ModelSyncData, +}; +use sd_sync::{ + CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationKind, RecordId, +}; + +use std::{collections::BTreeMap, num::NonZeroU128, sync::Arc}; + +use futures_concurrency::future::TryJoin; +use tokio::sync::Mutex; +use tracing::{debug, instrument, trace, warn}; +use uhlc::{Timestamp, HLC, NTP64}; +use uuid::Uuid; + +use super::{db_operation::write_crdt_op_to_db, Error, TimestampPerDevice}; + +crdt_operation::select!(crdt_operation_id { id }); + +// where the magic happens +#[instrument(skip(clock, ops), fields(operations_count = %ops.len()), err)] +pub async fn process_crdt_operations( + clock: &HLC, + timestamp_per_device: &TimestampPerDevice, + sync_lock: Arc>, + db: &PrismaClient, + device_pub_id: DevicePubId, + model_id: ModelId, + (record_id, mut ops): (RecordId, Vec), +) -> Result<(), Error> { + ops.sort_by_key(|op| op.timestamp); + + let new_timestamp = ops.last().expect("Empty ops array").timestamp; + + update_clock(clock, new_timestamp, &device_pub_id); + + // Delete - ignores all other messages + if let Some(delete_op) = ops + .iter() + .rev() + .find(|op| matches!(op.data, CRDTOperationData::Delete)) + { + trace!("Deleting operation"); + handle_crdt_deletion( + db, + &sync_lock, + &device_pub_id, + model_id, + record_id, + delete_op, + ) + .await?; + } + // Create + > 0 Update - overwrites the create's data with the updates + else if let Some(timestamp) = ops + .iter() + .rev() + .find_map(|op| matches!(&op.data, CRDTOperationData::Create(_)).then_some(op.timestamp)) + { + trace!("Create + Updates operations"); + + // conflict resolution + let delete_count = db + .crdt_operation() + .count(vec![ + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + crdt_operation::kind::equals(OperationKind::Delete.to_string()), + ]) + .exec() + .await?; + + if delete_count > 0 { + debug!("Found a previous delete operation with the same SyncId, will ignore these operations"); + return Ok(()); + } + + handle_crdt_create_and_updates( + db, + &sync_lock, + &device_pub_id, + model_id, + record_id, + ops, + timestamp, + ) + .await?; + } + // > 0 Update - batches updates with a fake Create op + else { + trace!("Updates operation"); + + let mut data = BTreeMap::new(); + + for op in ops.into_iter().rev() { + let CRDTOperationData::Update(fields_and_values) = op.data else { + unreachable!("Create + Delete should be filtered out!"); + }; + + for (field, value) in fields_and_values { + data.insert(field, (value, op.timestamp)); + } + } + + let earlier_time = data.values().fold( + NTP64(u64::from(u32::MAX)), + |earlier_time, (_, timestamp)| { + if timestamp.0 < earlier_time.0 { + *timestamp + } else { + earlier_time + } + }, + ); + + // conflict resolution + let (create, possible_newer_updates_count) = db + ._batch(( + db.crdt_operation().count(vec![ + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + crdt_operation::kind::equals(OperationKind::Create.to_string()), + ]), + // Fetching all update operations newer than our current earlier timestamp + db.crdt_operation() + .find_many(vec![ + crdt_operation::timestamp::gt({ + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we had to store using i64 due to SQLite limitations + { + earlier_time.as_u64() as i64 + } + }), + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + crdt_operation::kind::starts_with("u".to_string()), + ]) + .select(crdt_operation::select!({ kind timestamp })), + )) + .await?; + + if create == 0 { + warn!("Failed to find a previous create operation with the same SyncId"); + return Ok(()); + } + + for candidate in possible_newer_updates_count { + // The first element is "u" meaning that this is an update, so we skip it + for key in candidate + .kind + .split(':') + .filter(|field| !field.is_empty()) + .skip(1) + { + // remove entries if we possess locally more recent updates for this field + if data.get(key).is_some_and(|(_, new_timestamp)| { + #[allow(clippy::cast_sign_loss)] + { + // we need to store as i64 due to SQLite limitations + *new_timestamp < NTP64(candidate.timestamp as u64) + } + }) { + data.remove(key); + } + } + + if data.is_empty() { + break; + } + } + + handle_crdt_updates(db, &sync_lock, &device_pub_id, model_id, record_id, data).await?; + } + + update_timestamp_per_device(timestamp_per_device, device_pub_id, new_timestamp).await; + + Ok(()) +} + +pub async fn bulk_ingest_create_only_ops( + clock: &HLC, + timestamp_per_device: &TimestampPerDevice, + db: &PrismaClient, + device_pub_id: DevicePubId, + model_id: ModelId, + ops: Vec<(RecordId, CompressedCRDTOperation)>, + sync_lock: Arc>, +) -> Result<(), Error> { + let latest_timestamp = ops.iter().fold(NTP64(0), |latest, (_, op)| { + if latest < op.timestamp { + op.timestamp + } else { + latest + } + }); + + update_clock(clock, latest_timestamp, &device_pub_id); + + let ops = ops + .into_iter() + .map(|(record_id, op)| { + rmp_serde::to_vec(&record_id) + .map(|serialized_record_id| (record_id, serialized_record_id, op)) + }) + .collect::, _>>()?; + + // conflict resolution + let delete_counts = db + ._batch( + ops.iter() + .map(|(_, serialized_record_id, _)| { + db.crdt_operation().count(vec![ + crdt_operation::model::equals(i32::from(model_id)), + crdt_operation::record_id::equals(serialized_record_id.clone()), + crdt_operation::kind::equals(OperationKind::Delete.to_string()), + ]) + }) + .collect::>(), + ) + .await?; + + let lock_guard = sync_lock.lock().await; + + db._transaction() + .with_timeout(30 * 10000) + .with_max_wait(30 * 10000) + .run(|db| { + let device_pub_id = device_pub_id.clone(); + + async move { + // complying with borrowck + let device_pub_id = &device_pub_id; + + let (crdt_creates, model_sync_data) = ops + .into_iter() + .zip(delete_counts) + .filter_map(|(data, delete_count)| (delete_count == 0).then_some(data)) + .map( + |( + record_id, + serialized_record_id, + CompressedCRDTOperation { timestamp, data }, + )| { + let crdt_create = crdt_operation::CreateUnchecked { + timestamp: { + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we have to store using i64 due to SQLite limitations + { + timestamp.0 as i64 + } + }, + model: i32::from(model_id), + record_id: serialized_record_id, + kind: "c".to_string(), + data: rmp_serde::to_vec(&data)?, + device_pub_id: device_pub_id.to_db(), + _params: vec![], + }; + + // NOTE(@fogodev): I wish I could do a create many here instead of creating separately each + // entry, but it's not supported by PCR + let model_sync_data = ModelSyncData::from_op(CRDTOperation { + device_pub_id: Uuid::from(device_pub_id), + model_id, + record_id, + timestamp, + data, + })? + .exec(&db); + + Ok::<_, Error>((crdt_create, model_sync_data)) + }, + ) + .collect::, _>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + model_sync_data.try_join().await?; + + db.crdt_operation().create_many(crdt_creates).exec().await?; + + Ok::<_, Error>(()) + } + }) + .await?; + + drop(lock_guard); + + update_timestamp_per_device(timestamp_per_device, device_pub_id, latest_timestamp).await; + + Ok(()) +} + +#[instrument(skip_all, err)] +async fn handle_crdt_updates( + db: &PrismaClient, + sync_lock: &Mutex<()>, + device_pub_id: &DevicePubId, + model_id: ModelId, + record_id: rmpv::Value, + data: BTreeMap, +) -> Result<(), Error> { + let device_pub_id = sd_sync::DevicePubId::from(device_pub_id); + + let _lock_guard = sync_lock.lock().await; + + db._transaction() + .with_timeout(30 * 10000) + .with_max_wait(30 * 10000) + .run(|db| async move { + // fake operation to batch them all at once, inserting the latest data on appropriate table + ModelSyncData::from_op(CRDTOperation { + device_pub_id, + model_id, + record_id: record_id.clone(), + timestamp: NTP64(0), + data: CRDTOperationData::Create( + data.iter() + .map(|(k, (data, _))| (k.clone(), data.clone())) + .collect(), + ), + })? + .exec(&db) + .await?; + + let (fields_and_values, latest_timestamp) = data.into_iter().fold( + (BTreeMap::new(), NTP64::default()), + |(mut fields_and_values, mut latest_time_stamp), (field, (value, timestamp))| { + fields_and_values.insert(field, value); + if timestamp > latest_time_stamp { + latest_time_stamp = timestamp; + } + (fields_and_values, latest_time_stamp) + }, + ); + + write_crdt_op_to_db( + &CRDTOperation { + device_pub_id, + model_id, + record_id, + timestamp: latest_timestamp, + data: CRDTOperationData::Update(fields_and_values), + }, + &db, + ) + .await + }) + .await +} + +#[instrument(skip_all, err)] +async fn handle_crdt_create_and_updates( + db: &PrismaClient, + sync_lock: &Mutex<()>, + device_pub_id: &DevicePubId, + model_id: ModelId, + record_id: rmpv::Value, + ops: Vec, + timestamp: NTP64, +) -> Result<(), Error> { + let mut data = BTreeMap::new(); + let device_pub_id = sd_sync::DevicePubId::from(device_pub_id); + + let mut applied_ops = vec![]; + + // search for all Updates until a Create is found + for op in ops.into_iter().rev() { + match &op.data { + CRDTOperationData::Delete => unreachable!("Delete can't exist here!"), + CRDTOperationData::Create(create_data) => { + for (k, v) in create_data { + data.entry(k.clone()).or_insert_with(|| v.clone()); + } + + applied_ops.push(op); + + break; + } + CRDTOperationData::Update(fields_and_values) => { + for (field, value) in fields_and_values { + data.insert(field.clone(), value.clone()); + } + + applied_ops.push(op); + } + } + } + + let _lock_guard = sync_lock.lock().await; + + db._transaction() + .with_timeout(30 * 10000) + .with_max_wait(30 * 10000) + .run(|db| async move { + // fake a create with a bunch of data rather than individual insert + ModelSyncData::from_op(CRDTOperation { + device_pub_id, + model_id, + record_id: record_id.clone(), + timestamp, + data: CRDTOperationData::Create(data), + })? + .exec(&db) + .await?; + + applied_ops + .into_iter() + .map(|CompressedCRDTOperation { timestamp, data }| { + let record_id = record_id.clone(); + let db = &db; + async move { + write_crdt_op_to_db( + &CRDTOperation { + device_pub_id, + timestamp, + model_id, + record_id, + data, + }, + db, + ) + .await + } + }) + .collect::>() + .try_join() + .await + .map(|_| ()) + }) + .await +} + +#[instrument(skip_all, err)] +async fn handle_crdt_deletion( + db: &PrismaClient, + sync_lock: &Mutex<()>, + device_pub_id: &DevicePubId, + model: u16, + record_id: rmpv::Value, + delete_op: &CompressedCRDTOperation, +) -> Result<(), Error> { + // deletes are the be all and end all, except if we never created the object to begin with + // in this case we don't need to delete anything + + if db + .crdt_operation() + .count(vec![ + crdt_operation::model::equals(i32::from(model)), + crdt_operation::record_id::equals(rmp_serde::to_vec(&record_id)?), + ]) + .exec() + .await? + == 0 + { + // This means that in the other device this entry was created and deleted, before this + // device here could even take notice of it. So we don't need to do anything here. + return Ok(()); + } + + let op = CRDTOperation { + device_pub_id: device_pub_id.into(), + model_id: model, + record_id, + timestamp: delete_op.timestamp, + data: CRDTOperationData::Delete, + }; + + let _lock_guard = sync_lock.lock().await; + + db._transaction() + .with_timeout(30 * 10000) + .with_max_wait(30 * 10000) + .run(|db| async move { + ModelSyncData::from_op(op.clone())?.exec(&db).await?; + + write_crdt_op_to_db(&op, &db).await + }) + .await +} + +fn update_clock(clock: &HLC, latest_timestamp: NTP64, device_pub_id: &DevicePubId) { + // first, we update the HLC's timestamp with the incoming one. + // this involves a drift check + sets the last time of the clock + clock + .update_with_timestamp(&Timestamp::new( + latest_timestamp, + uhlc::ID::from( + NonZeroU128::new(Uuid::from(device_pub_id).to_u128_le()).expect("Non zero id"), + ), + )) + .expect("timestamp has too much drift!"); +} + +async fn update_timestamp_per_device( + timestamp_per_device: &TimestampPerDevice, + device_pub_id: DevicePubId, + latest_timestamp: NTP64, +) { + // read the timestamp for the operation's device, or insert one if it doesn't exist + let current_last_timestamp = timestamp_per_device + .read() + .await + .get(&device_pub_id) + .copied(); + + // update the stored timestamp for this device - will be derived from the crdt operations table on restart + let new_ts = NTP64::max(current_last_timestamp.unwrap_or_default(), latest_timestamp); + + timestamp_per_device + .write() + .await + .insert(device_pub_id, new_ts); +} diff --git a/core/crates/sync/src/lib.rs b/core/crates/sync/src/lib.rs index d5c208668..5b8d90efe 100644 --- a/core/crates/sync/src/lib.rs +++ b/core/crates/sync/src/lib.rs @@ -27,45 +27,39 @@ #![forbid(deprecated_in_future)] #![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)] -use sd_prisma::prisma::{crdt_operation, instance, PrismaClient}; -use sd_sync::CRDTOperation; - -use std::{ - collections::HashMap, - sync::{atomic::AtomicBool, Arc}, +use sd_prisma::{ + prisma::{cloud_crdt_operation, crdt_operation}, + prisma_sync, }; +use sd_utils::uuid_to_bytes; -use tokio::sync::{Notify, RwLock}; -use uuid::Uuid; +use std::{collections::HashMap, sync::Arc}; + +use tokio::{sync::RwLock, task::JoinError}; -mod actor; pub mod backfill; mod db_operation; -pub mod ingest; +mod ingest_utils; mod manager; -pub use ingest::*; -pub use manager::*; +pub use db_operation::{from_cloud_crdt_ops, from_crdt_ops, write_crdt_op_to_db}; +pub use manager::Manager as SyncManager; pub use uhlc::NTP64; #[derive(Clone, Debug)] -pub enum SyncMessage { +pub enum SyncEvent { Ingested, Created, } -pub type Timestamps = Arc>>; +pub use sd_core_prisma_helpers::DevicePubId; +pub use sd_sync::{ + CRDTOperation, CompressedCRDTOperation, CompressedCRDTOperationsPerModel, + CompressedCRDTOperationsPerModelPerDevice, ModelId, OperationFactory, RecordId, RelationSyncId, + RelationSyncModel, SharedSyncModel, SyncId, SyncModel, +}; -pub struct SharedState { - pub db: Arc, - pub emit_messages_flag: Arc, - pub instance: Uuid, - pub timestamps: Timestamps, - pub clock: uhlc::HLC, - pub active: AtomicBool, - pub active_notify: Notify, - pub actors: Arc, -} +pub type TimestampPerDevice = Arc>>; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -75,8 +69,16 @@ pub enum Error { Deserialization(#[from] rmp_serde::decode::Error), #[error("database error: {0}")] Database(#[from] prisma_client_rust::QueryError), + #[error("PrismaSync error: {0}")] + PrismaSync(#[from] prisma_sync::Error), #[error("invalid model id: {0}")] - InvalidModelId(u16), + InvalidModelId(ModelId), + #[error("tried to write an empty operations list")] + EmptyOperations, + #[error("device not found: {0}")] + DeviceNotFound(DevicePubId), + #[error("processes crdt task panicked")] + ProcessCrdtPanic(JoinError), } impl From for rspc::Error { @@ -105,19 +107,16 @@ pub fn crdt_op_db(op: &CRDTOperation) -> Result { op.timestamp.as_u64() as i64 } }, - instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()), + device_pub_id: uuid_to_bytes(&op.device_pub_id), kind: op.kind().to_string(), data: rmp_serde::to_vec(&op.data)?, - model: i32::from(op.model), + model: i32::from(op.model_id), record_id: rmp_serde::to_vec(&op.record_id)?, _params: vec![], }) } -pub fn crdt_op_unchecked_db( - op: &CRDTOperation, - instance_id: i32, -) -> Result { +pub fn crdt_op_unchecked_db(op: &CRDTOperation) -> Result { Ok(crdt_operation::CreateUnchecked { timestamp: { #[allow(clippy::cast_possible_wrap)] @@ -126,10 +125,28 @@ pub fn crdt_op_unchecked_db( op.timestamp.as_u64() as i64 } }, - instance_id, + device_pub_id: uuid_to_bytes(&op.device_pub_id), kind: op.kind().to_string(), data: rmp_serde::to_vec(&op.data)?, - model: i32::from(op.model), + model: i32::from(op.model_id), + record_id: rmp_serde::to_vec(&op.record_id)?, + _params: vec![], + }) +} + +pub fn cloud_crdt_op_db(op: &CRDTOperation) -> Result { + Ok(cloud_crdt_operation::Create { + timestamp: { + #[allow(clippy::cast_possible_wrap)] + // SAFETY: we had to store using i64 due to SQLite limitations + { + op.timestamp.as_u64() as i64 + } + }, + device_pub_id: uuid_to_bytes(&op.device_pub_id), + kind: op.data.as_kind().to_string(), + data: rmp_serde::to_vec(&op.data)?, + model: i32::from(op.model_id), record_id: rmp_serde::to_vec(&op.record_id)?, _params: vec![], }) diff --git a/core/crates/sync/src/manager.rs b/core/crates/sync/src/manager.rs index d7f9562d9..47460afac 100644 --- a/core/crates/sync/src/manager.rs +++ b/core/crates/sync/src/manager.rs @@ -1,35 +1,60 @@ -use sd_prisma::prisma::{cloud_crdt_operation, crdt_operation, instance, PrismaClient, SortOrder}; -use sd_sync::{CRDTOperation, OperationFactory}; -use sd_utils::{from_bytes_to_uuid, uuid_to_bytes}; -use tracing::warn; +use sd_core_prisma_helpers::DevicePubId; + +use sd_prisma::{ + prisma::{cloud_crdt_operation, crdt_operation, device, PrismaClient, SortOrder}, + prisma_sync, +}; +use sd_sync::{ + CRDTOperation, CRDTOperationData, CompressedCRDTOperation, ModelId, OperationFactory, RecordId, +}; +use sd_utils::timestamp_to_datetime; use std::{ - cmp, fmt, + collections::{hash_map::Entry, BTreeMap, HashMap}, + fmt, mem, num::NonZeroU128, - ops::Deref, sync::{ atomic::{self, AtomicBool}, Arc, }, + time::{Duration, SystemTime}, }; -use prisma_client_rust::{and, operator::or}; -use tokio::sync::{broadcast, Mutex, Notify, RwLock}; +use async_stream::stream; +use futures::{stream::FuturesUnordered, Stream, TryStreamExt}; +use futures_concurrency::future::TryJoin; +use itertools::Itertools; +use tokio::{ + spawn, + sync::{broadcast, Mutex, Notify, RwLock}, + time::Instant, +}; +use tracing::{debug, instrument, warn}; use uhlc::{HLCBuilder, HLC}; use uuid::Uuid; use super::{ crdt_op_db, - db_operation::{cloud_crdt_with_instance, crdt_with_instance}, - ingest, Error, SharedState, SyncMessage, NTP64, + db_operation::{from_cloud_crdt_ops, from_crdt_ops}, + ingest_utils::{bulk_ingest_create_only_ops, process_crdt_operations}, + Error, SyncEvent, TimestampPerDevice, NTP64, }; +const INGESTION_BATCH_SIZE: i64 = 10_000; + /// Wrapper that spawns the ingest actor and provides utilities for reading and writing sync operations. +#[derive(Clone)] pub struct Manager { - pub tx: broadcast::Sender, - pub ingest: ingest::Handler, - pub shared: Arc, - pub timestamp_lock: Mutex<()>, + pub tx: broadcast::Sender, + pub db: Arc, + pub emit_messages_flag: Arc, + pub device_pub_id: DevicePubId, + pub timestamp_per_device: TimestampPerDevice, + pub clock: Arc, + pub active: Arc, + pub active_notify: Arc, + pub(crate) sync_lock: Arc>, + pub(crate) available_parallelism: usize, } impl fmt::Debug for Manager { @@ -38,29 +63,21 @@ impl fmt::Debug for Manager { } } -#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)] -pub struct GetOpsArgs { - pub clocks: Vec<(Uuid, NTP64)>, - pub count: u32, -} - impl Manager { /// Creates a new manager that can be used to read and write CRDT operations. /// Sync messages are received on the returned [`broadcast::Receiver`]. pub async fn new( db: Arc, - current_instance_uuid: Uuid, + current_device_pub_id: &DevicePubId, emit_messages_flag: Arc, - actors: Arc, - ) -> Result<(Self, broadcast::Receiver), Error> { - let existing_instances = db.instance().find_many(vec![]).exec().await?; + ) -> Result<(Self, broadcast::Receiver), Error> { + let existing_devices = db.device().find_many(vec![]).exec().await?; - Self::with_existing_instances( + Self::with_existing_devices( db, - current_instance_uuid, + current_device_pub_id, emit_messages_flag, - &existing_instances, - actors, + &existing_devices, ) .await } @@ -69,33 +86,34 @@ impl Manager { /// Sync messages are received on the returned [`broadcast::Receiver`]. /// /// # Panics - /// Panics if the `current_instance_id` UUID is zeroed. - pub async fn with_existing_instances( + /// Panics if the `current_device_pub_id` UUID is zeroed, which will never happen as we use `UUIDv7` for the + /// device pub id. As this version have a timestamp part, instead of being totally random. So the only + /// possible way to get zero from a `UUIDv7` is to go back in time to 1970 + pub async fn with_existing_devices( db: Arc, - current_instance_uuid: Uuid, + current_device_pub_id: &DevicePubId, emit_messages_flag: Arc, - existing_instances: &[instance::Data], - actors: Arc, - ) -> Result<(Self, broadcast::Receiver), Error> { - let timestamps = db + existing_devices: &[device::Data], + ) -> Result<(Self, broadcast::Receiver), Error> { + let latest_timestamp_per_device = db ._batch( - existing_instances + existing_devices .iter() - .map(|i| { + .map(|device| { db.crdt_operation() - .find_first(vec![crdt_operation::instance::is(vec![ - instance::id::equals(i.id), - ])]) + .find_first(vec![crdt_operation::device_pub_id::equals( + device.pub_id.clone(), + )]) .order_by(crdt_operation::timestamp::order(SortOrder::Desc)) }) .collect::>(), ) .await? .into_iter() - .zip(existing_instances) - .map(|(op, i)| { + .zip(existing_devices) + .map(|(op, device)| { ( - from_bytes_to_uuid(&i.pub_id), + DevicePubId::from(&device.pub_id), #[allow(clippy::cast_sign_loss)] // SAFETY: we had to store using i64 due to SQLite limitations NTP64(op.map(|o| o.timestamp).unwrap_or_default() as u64), @@ -105,54 +123,303 @@ impl Manager { let (tx, rx) = broadcast::channel(64); - let clock = HLCBuilder::new() - .with_id(uhlc::ID::from( - NonZeroU128::new(current_instance_uuid.to_u128_le()).expect("Non zero id"), - )) - .build(); - - let shared = Arc::new(SharedState { - db, - instance: current_instance_uuid, - clock, - timestamps: Arc::new(RwLock::new(timestamps)), - emit_messages_flag, - active: AtomicBool::default(), - active_notify: Notify::default(), - actors, - }); - - let ingest = ingest::Actor::declare(shared.clone()).await; - Ok(( Self { tx, - ingest, - shared, - timestamp_lock: Mutex::default(), + db, + device_pub_id: current_device_pub_id.clone(), + clock: Arc::new( + HLCBuilder::new() + .with_id(uhlc::ID::from( + NonZeroU128::new(Uuid::from(current_device_pub_id).to_u128_le()) + .expect("Non zero id"), + )) + .build(), + ), + timestamp_per_device: Arc::new(RwLock::new(latest_timestamp_per_device)), + emit_messages_flag, + active: Arc::default(), + active_notify: Arc::default(), + sync_lock: Arc::new(Mutex::default()), + available_parallelism: std::thread::available_parallelism() + .map_or(1, std::num::NonZero::get), }, rx, )) } - pub fn subscribe(&self) -> broadcast::Receiver { + async fn fetch_cloud_crdt_ops( + &self, + model_id: ModelId, + batch_size: i64, + ) -> Result<(Vec, Vec), Error> { + self.db + .cloud_crdt_operation() + .find_many(vec![cloud_crdt_operation::model::equals(i32::from( + model_id, + ))]) + .take(batch_size) + .order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc)) + .exec() + .await? + .into_iter() + .map(from_cloud_crdt_ops) + .collect::, Vec<_>), _>>() + } + + #[instrument(skip(self))] + async fn ingest_by_model(&self, model_id: ModelId) -> Result { + let mut total_count = 0; + + let mut buckets = (0..self.available_parallelism) + .map(|_| FuturesUnordered::new()) + .collect::>(); + + let mut total_fetch_time = Duration::ZERO; + let mut total_compression_time = Duration::ZERO; + let mut total_work_distribution_time = Duration::ZERO; + let mut total_process_time = Duration::ZERO; + + loop { + let fetching_start = Instant::now(); + + let (ops_ids, ops) = self + .fetch_cloud_crdt_ops(model_id, INGESTION_BATCH_SIZE) + .await?; + if ops_ids.is_empty() { + break; + } + + total_fetch_time += fetching_start.elapsed(); + + let messages_count = ops.len(); + + debug!( + messages_count, + first_message = ?ops + .first() + .map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)), + last_message = ?ops + .last() + .map_or_else(|| SystemTime::UNIX_EPOCH.into(), |op| timestamp_to_datetime(op.timestamp)), + "Messages by model to ingest", + ); + + let compression_start = Instant::now(); + + let mut compressed_map = + BTreeMap::, (RecordId, Vec)>>::new(); + + for CRDTOperation { + device_pub_id, + timestamp, + model_id: _, // Ignoring model_id as we know it already + record_id, + data, + } in ops + { + let records = compressed_map.entry(device_pub_id).or_default(); + + // Can't use RecordId as a key because rmpv::Value doesn't implement Hash + Eq. + // So we use it's serialized bytes as a key. + let record_id_bytes = + rmp_serde::to_vec_named(&record_id).expect("already serialized to Value"); + + match records.entry(record_id_bytes) { + Entry::Occupied(mut entry) => { + entry + .get_mut() + .1 + .push(CompressedCRDTOperation { timestamp, data }); + } + Entry::Vacant(entry) => { + entry + .insert((record_id, vec![CompressedCRDTOperation { timestamp, data }])); + } + } + } + + // Now that we separated all operations by their record_ids, we can do an optimization + // to process all records that only posses a single create operation, batching them together + let mut create_only_ops: BTreeMap> = + BTreeMap::new(); + for (device_pub_id, records) in &mut compressed_map { + for (record_id, ops) in records.values_mut() { + if ops.len() == 1 && matches!(ops[0].data, CRDTOperationData::Create(_)) { + create_only_ops + .entry(*device_pub_id) + .or_default() + .push((mem::replace(record_id, rmpv::Value::Nil), ops.remove(0))); + } + } + } + + total_count += bulk_process_of_create_only_ops( + self.available_parallelism, + Arc::clone(&self.clock), + Arc::clone(&self.timestamp_per_device), + Arc::clone(&self.db), + Arc::clone(&self.sync_lock), + model_id, + create_only_ops, + ) + .await?; + + total_compression_time += compression_start.elapsed(); + + let work_distribution_start = Instant::now(); + + compressed_map + .into_iter() + .flat_map(|(device_pub_id, records)| { + records.into_values().filter_map(move |(record_id, ops)| { + if record_id.is_nil() { + return None; + } + + // We can process each record in parallel as they are independent + + let clock = Arc::clone(&self.clock); + let timestamp_per_device = Arc::clone(&self.timestamp_per_device); + let db = Arc::clone(&self.db); + let device_pub_id = device_pub_id.into(); + let sync_lock = Arc::clone(&self.sync_lock); + + Some(async move { + let count = ops.len(); + + process_crdt_operations( + &clock, + ×tamp_per_device, + sync_lock, + &db, + device_pub_id, + model_id, + (record_id, ops), + ) + .await + .map(|()| count) + }) + }) + }) + .enumerate() + .for_each(|(idx, fut)| buckets[idx % self.available_parallelism].push(fut)); + + total_work_distribution_time += work_distribution_start.elapsed(); + + let processing_start = Instant::now(); + + let handles = buckets + .iter_mut() + .enumerate() + .filter(|(_idx, bucket)| !bucket.is_empty()) + .map(|(idx, bucket)| { + let mut bucket = mem::take(bucket); + + spawn(async move { + let mut ops_count = 0; + let processing_start = Instant::now(); + while let Some(count) = bucket.try_next().await? { + ops_count += count; + } + + debug!( + "Ingested {ops_count} operations in {:?}", + processing_start.elapsed() + ); + + Ok::<_, Error>((ops_count, idx, bucket)) + }) + }) + .collect::>(); + + let results = handles.try_join().await.map_err(Error::ProcessCrdtPanic)?; + + total_process_time += processing_start.elapsed(); + + for res in results { + let (count, idx, bucket) = res?; + + buckets[idx] = bucket; + + total_count += count; + } + + self.db + .cloud_crdt_operation() + .delete_many(vec![cloud_crdt_operation::id::in_vec(ops_ids)]) + .exec() + .await?; + } + + debug!( + total_count, + ?total_fetch_time, + ?total_compression_time, + ?total_work_distribution_time, + ?total_process_time, + "Ingested all operations of this model" + ); + + Ok(total_count) + } + + pub async fn ingest_ops(&self) -> Result { + let mut total_count = 0; + + // WARN: this order here exists because sync messages MUST be processed in this exact order + // due to relationship dependencies between these tables. + total_count += self.ingest_by_model(prisma_sync::device::MODEL_ID).await?; + + total_count += [ + self.ingest_by_model(prisma_sync::storage_statistics::MODEL_ID), + self.ingest_by_model(prisma_sync::tag::MODEL_ID), + self.ingest_by_model(prisma_sync::location::MODEL_ID), + self.ingest_by_model(prisma_sync::object::MODEL_ID), + self.ingest_by_model(prisma_sync::label::MODEL_ID), + ] + .try_join() + .await? + .into_iter() + .sum::(); + + total_count += [ + self.ingest_by_model(prisma_sync::exif_data::MODEL_ID), + self.ingest_by_model(prisma_sync::file_path::MODEL_ID), + self.ingest_by_model(prisma_sync::tag_on_object::MODEL_ID), + self.ingest_by_model(prisma_sync::label_on_object::MODEL_ID), + ] + .try_join() + .await? + .into_iter() + .sum::(); + + if self.tx.send(SyncEvent::Ingested).is_err() { + warn!("failed to send ingested message on `ingest_ops`"); + } + + Ok(total_count) + } + + #[must_use] + pub fn subscribe(&self) -> broadcast::Receiver { self.tx.subscribe() } pub async fn write_ops<'item, Q>( &self, tx: &PrismaClient, - (mut ops, queries): (Vec, Q), + (ops, queries): (Vec, Q), ) -> Result where Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send, { - let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) { - let lock = self.timestamp_lock.lock().await; + if ops.is_empty() { + return Err(Error::EmptyOperations); + } - for op in &mut ops { - op.timestamp = *self.get_clock().new_timestamp().get_time(); - } + let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) { + let lock_guard = self.sync_lock.lock().await; let (res, _) = tx ._batch(( @@ -164,18 +431,17 @@ impl Manager { .await?; if let Some(last) = ops.last() { - self.shared - .timestamps + self.timestamp_per_device .write() .await - .insert(self.instance, last.timestamp); + .insert(self.device_pub_id.clone(), last.timestamp); } - if self.tx.send(SyncMessage::Created).is_err() { + if self.tx.send(SyncEvent::Created).is_err() { warn!("failed to send created message on `write_ops`"); } - drop(lock); + drop(lock_guard); res } else { @@ -188,160 +454,289 @@ impl Manager { pub async fn write_op<'item, Q>( &self, tx: &PrismaClient, - mut op: CRDTOperation, + op: CRDTOperation, query: Q, ) -> Result where Q: prisma_client_rust::BatchItem<'item, ReturnValue: Send> + Send, { let ret = if self.emit_messages_flag.load(atomic::Ordering::Relaxed) { - let lock = self.timestamp_lock.lock().await; - - op.timestamp = *self.get_clock().new_timestamp().get_time(); + let lock_guard = self.sync_lock.lock().await; let ret = tx._batch((crdt_op_db(&op)?.to_query(tx), query)).await?.1; - if self.tx.send(SyncMessage::Created).is_err() { + if self.tx.send(SyncEvent::Created).is_err() { warn!("failed to send created message on `write_op`"); } - drop(lock); + drop(lock_guard); ret } else { tx._batch(vec![query]).await?.remove(0) }; - self.shared - .timestamps + self.timestamp_per_device .write() .await - .insert(self.instance, op.timestamp); + .insert(self.device_pub_id.clone(), op.timestamp); Ok(ret) } - pub async fn get_instance_ops( - &self, - count: u32, - instance_uuid: Uuid, - timestamp: NTP64, - ) -> Result, Error> { - self.db - .crdt_operation() - .find_many(vec![ - crdt_operation::instance::is(vec![instance::pub_id::equals(uuid_to_bytes( - &instance_uuid, - ))]), - #[allow(clippy::cast_possible_wrap)] - crdt_operation::timestamp::gt(timestamp.as_u64() as i64), - ]) - .take(i64::from(count)) - .order_by(crdt_operation::timestamp::order(SortOrder::Asc)) - .include(crdt_with_instance::include()) - .exec() - .await? + // pub async fn get_device_ops( + // &self, + // count: u32, + // device_pub_id: DevicePubId, + // timestamp: NTP64, + // ) -> Result, Error> { + // self.db + // .crdt_operation() + // .find_many(vec![ + // crdt_operation::device_pub_id::equals(device_pub_id.into()), + // #[allow(clippy::cast_possible_wrap)] + // crdt_operation::timestamp::gt(timestamp.as_u64() as i64), + // ]) + // .take(i64::from(count)) + // .order_by(crdt_operation::timestamp::order(SortOrder::Asc)) + // .exec() + // .await? + // .into_iter() + // .map(from_crdt_ops) + // .collect() + // } + + pub fn stream_device_ops<'a>( + &'a self, + device_pub_id: &'a DevicePubId, + chunk_size: u32, + initial_timestamp: NTP64, + ) -> impl Stream, Error>> + Send + 'a { + stream! { + let mut current_initial_timestamp = initial_timestamp; + + loop { + match self.db.crdt_operation() + .find_many(vec![ + crdt_operation::device_pub_id::equals(device_pub_id.to_db()), + #[allow(clippy::cast_possible_wrap)] + crdt_operation::timestamp::gt(current_initial_timestamp.as_u64() as i64), + ]) + .take(i64::from(chunk_size)) + .order_by(crdt_operation::timestamp::order(SortOrder::Asc)) + .exec() + .await + { + Ok(ops) if ops.is_empty() => break, + + Ok(ops) => match ops + .into_iter() + .map(from_crdt_ops) + .collect::, _>>() + { + Ok(ops) => { + debug!( + start_datetime = ?ops + .first() + .map(|op| timestamp_to_datetime(op.timestamp)), + end_datetime = ?ops + .last() + .map(|op| timestamp_to_datetime(op.timestamp)), + count = ops.len(), + "Streaming crdt ops", + ); + + if let Some(last_op) = ops.last() { + current_initial_timestamp = last_op.timestamp; + } + + yield Ok(ops); + } + + Err(e) => return yield Err(e), + } + + Err(e) => return yield Err(e.into()) + } + } + } + } + + // pub async fn get_ops( + // &self, + // count: u32, + // timestamp_per_device: Vec<(DevicePubId, NTP64)>, + // ) -> Result, Error> { + // let mut ops = self + // .db + // .crdt_operation() + // .find_many(vec![or(timestamp_per_device + // .iter() + // .map(|(device_pub_id, timestamp)| { + // and![ + // crdt_operation::device_pub_id::equals(device_pub_id.to_db()), + // crdt_operation::timestamp::gt({ + // #[allow(clippy::cast_possible_wrap)] + // // SAFETY: we had to store using i64 due to SQLite limitations + // { + // timestamp.as_u64() as i64 + // } + // }) + // ] + // }) + // .chain([crdt_operation::device_pub_id::not_in_vec( + // timestamp_per_device + // .iter() + // .map(|(device_pub_id, _)| device_pub_id.to_db()) + // .collect(), + // )]) + // .collect())]) + // .take(i64::from(count)) + // .order_by(crdt_operation::timestamp::order(SortOrder::Asc)) + // .exec() + // .await?; + + // ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) { + // cmp::Ordering::Equal => { + // from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id)) + // } + // o => o, + // }); + + // ops.into_iter() + // .take(count as usize) + // .map(from_crdt_ops) + // .collect() + // } + + // pub async fn get_cloud_ops( + // &self, + // count: u32, + // timestamp_per_device: Vec<(DevicePubId, NTP64)>, + // ) -> Result, Error> { + // let mut ops = self + // .db + // .cloud_crdt_operation() + // .find_many(vec![or(timestamp_per_device + // .iter() + // .map(|(device_pub_id, timestamp)| { + // and![ + // cloud_crdt_operation::device_pub_id::equals(device_pub_id.to_db()), + // cloud_crdt_operation::timestamp::gt({ + // #[allow(clippy::cast_possible_wrap)] + // // SAFETY: we had to store using i64 due to SQLite limitations + // { + // timestamp.as_u64() as i64 + // } + // }) + // ] + // }) + // .chain([cloud_crdt_operation::device_pub_id::not_in_vec( + // timestamp_per_device + // .iter() + // .map(|(device_pub_id, _)| device_pub_id.to_db()) + // .collect(), + // )]) + // .collect())]) + // .take(i64::from(count)) + // .order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc)) + // .exec() + // .await?; + + // ops.sort_by(|a, b| match a.timestamp.cmp(&b.timestamp) { + // cmp::Ordering::Equal => { + // from_bytes_to_uuid(&a.device_pub_id).cmp(&from_bytes_to_uuid(&b.device_pub_id)) + // } + // o => o, + // }); + + // ops.into_iter() + // .take(count as usize) + // .map(from_cloud_crdt_ops) + // .collect() + // } +} + +async fn bulk_process_of_create_only_ops( + available_parallelism: usize, + clock: Arc, + timestamp_per_device: TimestampPerDevice, + db: Arc, + sync_lock: Arc>, + model_id: ModelId, + create_only_ops: BTreeMap>, +) -> Result { + let buckets = (0..available_parallelism) + .map(|_| FuturesUnordered::new()) + .collect::>(); + + let mut bucket_idx = 0; + + for (device_pub_id, records) in create_only_ops { + records .into_iter() - .map(crdt_with_instance::Data::into_operation) - .collect() + .chunks(100) + .into_iter() + .for_each(|chunk| { + let ops = chunk.collect::>(); + + buckets[bucket_idx % available_parallelism].push({ + let clock = Arc::clone(&clock); + let timestamp_per_device = Arc::clone(×tamp_per_device); + let db = Arc::clone(&db); + let device_pub_id = device_pub_id.into(); + let sync_lock = Arc::clone(&sync_lock); + + async move { + let count = ops.len(); + bulk_ingest_create_only_ops( + &clock, + ×tamp_per_device, + &db, + device_pub_id, + model_id, + ops, + sync_lock, + ) + .await + .map(|()| count) + } + }); + + bucket_idx += 1; + }); } - pub async fn get_ops(&self, args: GetOpsArgs) -> Result, Error> { - let mut ops = self - .db - .crdt_operation() - .find_many(vec![or(args - .clocks - .iter() - .map(|(instance_id, timestamp)| { - and![ - crdt_operation::instance::is(vec![instance::pub_id::equals( - uuid_to_bytes(instance_id) - )]), - crdt_operation::timestamp::gt({ - #[allow(clippy::cast_possible_wrap)] - // SAFETY: we had to store using i64 due to SQLite limitations - { - timestamp.as_u64() as i64 - } - }) - ] - }) - .chain([crdt_operation::instance::is_not(vec![ - instance::pub_id::in_vec( - args.clocks - .iter() - .map(|(instance_id, _)| uuid_to_bytes(instance_id)) - .collect(), - ), - ])]) - .collect())]) - .take(i64::from(args.count)) - .order_by(crdt_operation::timestamp::order(SortOrder::Asc)) - .include(crdt_with_instance::include()) - .exec() - .await?; + let handles = buckets + .into_iter() + .map(|mut bucket| { + spawn(async move { + let mut total_count = 0; - ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) { - cmp::Ordering::Equal => a.instance().cmp(&b.instance()), - o => o, - }); + let process_creates_batch_start = Instant::now(); - ops.into_iter() - .take(args.count as usize) - .map(crdt_with_instance::Data::into_operation) - .collect() - } + while let Some(count) = bucket.try_next().await? { + total_count += count; + } - pub async fn get_cloud_ops( - &self, - args: GetOpsArgs, - ) -> Result, Error> { - let mut ops = self - .db - .cloud_crdt_operation() - .find_many(vec![or(args - .clocks - .iter() - .map(|(instance_id, timestamp)| { - and![ - cloud_crdt_operation::instance::is(vec![instance::pub_id::equals( - uuid_to_bytes(instance_id) - )]), - cloud_crdt_operation::timestamp::gt({ - #[allow(clippy::cast_possible_wrap)] - // SAFETY: we had to store using i64 due to SQLite limitations - { - timestamp.as_u64() as i64 - } - }) - ] - }) - .chain([cloud_crdt_operation::instance::is_not(vec![ - instance::pub_id::in_vec( - args.clocks - .iter() - .map(|(instance_id, _)| uuid_to_bytes(instance_id)) - .collect(), - ), - ])]) - .collect())]) - .take(i64::from(args.count)) - .order_by(cloud_crdt_operation::timestamp::order(SortOrder::Asc)) - .include(cloud_crdt_with_instance::include()) - .exec() - .await?; + debug!( + "Processed {total_count} creates in {:?}", + process_creates_batch_start.elapsed() + ); - ops.sort_by(|a, b| match a.timestamp().cmp(&b.timestamp()) { - cmp::Ordering::Equal => a.instance().cmp(&b.instance()), - o => o, - }); + Ok::<_, Error>(total_count) + }) + }) + .collect::>(); - ops.into_iter() - .take(args.count as usize) - .map(cloud_crdt_with_instance::Data::into_operation) - .collect() - } + Ok(handles + .try_join() + .await + .map_err(Error::ProcessCrdtPanic)? + .into_iter() + .collect::, _>>()? + .into_iter() + .sum()) } impl OperationFactory for Manager { @@ -349,15 +744,7 @@ impl OperationFactory for Manager { &self.clock } - fn get_instance(&self) -> Uuid { - self.instance - } -} - -impl Deref for Manager { - type Target = SharedState; - - fn deref(&self) -> &Self::Target { - &self.shared + fn get_device_pub_id(&self) -> sd_sync::DevicePubId { + sd_sync::DevicePubId::from(&self.device_pub_id) } } diff --git a/core/crates/sync/tests/lib.rs b/core/crates/sync/tests/lib.rs deleted file mode 100644 index 604739ac8..000000000 --- a/core/crates/sync/tests/lib.rs +++ /dev/null @@ -1,247 +0,0 @@ -mod mock_instance; - -use sd_core_sync::*; - -use sd_prisma::{prisma::location, prisma_sync}; -use sd_sync::*; -use sd_utils::{msgpack, uuid_to_bytes}; - -use mock_instance::Instance; -use tracing::info; -use tracing_test::traced_test; -use uuid::Uuid; - -const MOCK_LOCATION_NAME: &str = "Location 0"; -const MOCK_LOCATION_PATH: &str = "/User/Anon/Documents"; - -async fn write_test_location(instance: &Instance) -> location::Data { - let location_pub_id = Uuid::new_v4(); - - let location = instance - .sync - .write_ops(&instance.db, { - let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [ - sync_db_entry!(MOCK_LOCATION_NAME, location::name), - sync_db_entry!(MOCK_LOCATION_PATH, location::path), - ] - .into_iter() - .unzip(); - - ( - instance.sync.shared_create( - prisma_sync::location::SyncId { - pub_id: uuid_to_bytes(&location_pub_id), - }, - sync_ops, - ), - instance - .db - .location() - .create(uuid_to_bytes(&location_pub_id), db_ops), - ) - }) - .await - .expect("failed to create mock location"); - - instance - .sync - .write_ops(&instance.db, { - let (sync_ops, db_ops): (Vec<_>, Vec<_>) = [ - sync_db_entry!(1024, location::total_capacity), - sync_db_entry!(512, location::available_capacity), - ] - .into_iter() - .unzip(); - - ( - sync_ops - .into_iter() - .map(|(k, v)| { - instance.sync.shared_update( - prisma_sync::location::SyncId { - pub_id: uuid_to_bytes(&location_pub_id), - }, - k, - v, - ) - }) - .collect::>(), - instance - .db - .location() - .update(location::id::equals(location.id), db_ops), - ) - }) - .await - .expect("failed to create mock location"); - - location -} - -#[tokio::test] -#[traced_test] -async fn writes_operations_and_rows_together() -> Result<(), Box> { - let instance = Instance::new(Uuid::new_v4()).await; - - write_test_location(&instance).await; - - let operations = instance - .db - .crdt_operation() - .find_many(vec![]) - .exec() - .await?; - - // 1 create, 2 update - assert_eq!(operations.len(), 3); - assert_eq!(operations[0].model, prisma_sync::location::MODEL_ID as i32); - - let out = instance - .sync - .get_ops(GetOpsArgs { - clocks: vec![], - count: 100, - }) - .await?; - - assert_eq!(out.len(), 3); - - let locations = instance.db.location().find_many(vec![]).exec().await?; - - assert_eq!(locations.len(), 1); - let location = locations.first().unwrap(); - assert_eq!(location.name.as_deref(), Some(MOCK_LOCATION_NAME)); - assert_eq!(location.path.as_deref(), Some(MOCK_LOCATION_PATH)); - - Ok(()) -} - -#[tokio::test] -#[traced_test] -async fn operations_send_and_ingest() -> Result<(), Box> { - let instance1 = Instance::new(Uuid::new_v4()).await; - let instance2 = Instance::new(Uuid::new_v4()).await; - - let mut instance2_sync_rx = instance2.sync_rx.resubscribe(); - - info!("Created instances!"); - - Instance::pair(&instance1, &instance2).await; - - info!("Paired instances!"); - - write_test_location(&instance1).await; - - info!("Created mock location!"); - - assert!(matches!( - instance2_sync_rx.recv().await?, - SyncMessage::Ingested - )); - - let out = instance2 - .sync - .get_ops(GetOpsArgs { - clocks: vec![], - count: 100, - }) - .await?; - - assert_locations_equality( - &instance1.db.location().find_many(vec![]).exec().await?[0], - &instance2.db.location().find_many(vec![]).exec().await?[0], - ); - - assert_eq!(out.len(), 3); - - instance1.teardown().await; - instance2.teardown().await; - - Ok(()) -} - -#[tokio::test] -async fn no_update_after_delete() -> Result<(), Box> { - let instance1 = Instance::new(Uuid::new_v4()).await; - let instance2 = Instance::new(Uuid::new_v4()).await; - - let mut instance2_sync_rx = instance2.sync_rx.resubscribe(); - - Instance::pair(&instance1, &instance2).await; - - let location = write_test_location(&instance1).await; - - assert!(matches!( - instance2_sync_rx.recv().await?, - SyncMessage::Ingested - )); - - instance2 - .sync - .write_op( - &instance2.db, - instance2.sync.shared_delete(prisma_sync::location::SyncId { - pub_id: location.pub_id.clone(), - }), - instance2.db.location().delete_many(vec![]), - ) - .await?; - - assert!(matches!( - instance1.sync_rx.resubscribe().recv().await?, - SyncMessage::Ingested - )); - - instance1 - .sync - .write_op( - &instance1.db, - instance1.sync.shared_update( - prisma_sync::location::SyncId { - pub_id: location.pub_id.clone(), - }, - "name", - msgpack!("New Location"), - ), - instance1.db.location().find_many(vec![]), - ) - .await?; - - // one spare update operation that actually gets ignored by instance 2 - assert_eq!(instance1.db.crdt_operation().count(vec![]).exec().await?, 5); - assert_eq!(instance2.db.crdt_operation().count(vec![]).exec().await?, 4); - - assert_eq!(instance1.db.location().count(vec![]).exec().await?, 0); - // the whole point of the test - the update (which is ingested as an upsert) should be ignored - assert_eq!(instance2.db.location().count(vec![]).exec().await?, 0); - - instance1.teardown().await; - instance2.teardown().await; - - Ok(()) -} - -fn assert_locations_equality(l1: &location::Data, l2: &location::Data) { - assert_eq!(l1.pub_id, l2.pub_id, "pub id"); - assert_eq!(l1.name, l2.name, "name"); - assert_eq!(l1.path, l2.path, "path"); - assert_eq!(l1.total_capacity, l2.total_capacity, "total capacity"); - assert_eq!( - l1.available_capacity, l2.available_capacity, - "available capacity" - ); - assert_eq!(l1.size_in_bytes, l2.size_in_bytes, "size in bytes"); - assert_eq!(l1.is_archived, l2.is_archived, "is archived"); - assert_eq!( - l1.generate_preview_media, l2.generate_preview_media, - "generate preview media" - ); - assert_eq!( - l1.sync_preview_media, l2.sync_preview_media, - "sync preview media" - ); - assert_eq!(l1.hidden, l2.hidden, "hidden"); - assert_eq!(l1.date_created, l2.date_created, "date created"); - assert_eq!(l1.scan_state, l2.scan_state, "scan state"); - assert_eq!(l1.instance_id, l2.instance_id, "instance id"); -} diff --git a/core/crates/sync/tests/mock_instance.rs b/core/crates/sync/tests/mock_instance.rs deleted file mode 100644 index 807ccd4f6..000000000 --- a/core/crates/sync/tests/mock_instance.rs +++ /dev/null @@ -1,161 +0,0 @@ -use sd_core_sync::*; - -use sd_prisma::prisma; -use sd_sync::CompressedCRDTOperations; -use sd_utils::uuid_to_bytes; - -use std::sync::{atomic::AtomicBool, Arc}; - -use prisma_client_rust::chrono::Utc; -use tokio::{fs, spawn, sync::broadcast}; -use tracing::{info, instrument, warn, Instrument}; -use uuid::Uuid; - -fn db_path(id: Uuid) -> String { - format!("/tmp/test-{id}.db") -} - -#[derive(Clone)] -pub struct Instance { - pub id: Uuid, - pub db: Arc, - pub sync: Arc, - pub sync_rx: Arc>, -} - -impl Instance { - pub async fn new(id: Uuid) -> Arc { - let url = format!("file:{}", db_path(id)); - - let db = Arc::new( - prisma::PrismaClient::_builder() - .with_url(url.to_string()) - .build() - .await - .unwrap(), - ); - - db._db_push().await.unwrap(); - - db.instance() - .create( - uuid_to_bytes(&id), - vec![], - vec![], - Utc::now().into(), - Utc::now().into(), - vec![], - ) - .exec() - .await - .unwrap(); - - let (sync, sync_rx) = sd_core_sync::Manager::new( - Arc::clone(&db), - id, - Arc::new(AtomicBool::new(true)), - Default::default(), - ) - .await - .expect("failed to create sync manager"); - - Arc::new(Self { - id, - db, - sync: Arc::new(sync), - sync_rx: Arc::new(sync_rx), - }) - } - - pub async fn teardown(&self) { - fs::remove_file(db_path(self.id)).await.unwrap(); - } - - pub async fn pair(instance1: &Arc, instance2: &Arc) { - #[instrument(skip(left, right))] - async fn half(left: &Arc, right: &Arc, context: &'static str) { - left.db - .instance() - .create( - uuid_to_bytes(&right.id), - vec![], - vec![], - Utc::now().into(), - Utc::now().into(), - vec![], - ) - .exec() - .await - .unwrap(); - - spawn({ - let mut sync_rx_left = left.sync_rx.resubscribe(); - let right = Arc::clone(right); - - async move { - while let Ok(msg) = sync_rx_left.recv().await { - info!(?msg, "sync_rx_left received message"); - if matches!(msg, SyncMessage::Created) { - right - .sync - .ingest - .event_tx - .send(ingest::Event::Notification) - .await - .unwrap(); - info!("sent notification to instance 2"); - } - } - } - .in_current_span() - }); - - spawn({ - let left = Arc::clone(left); - let right = Arc::clone(right); - - async move { - while let Ok(msg) = right.sync.ingest.req_rx.recv().await { - info!(?msg, "right instance received request"); - match msg { - ingest::Request::Messages { timestamps, tx } => { - let messages = left - .sync - .get_ops(GetOpsArgs { - clocks: timestamps, - count: 100, - }) - .await - .unwrap(); - - let ingest = &right.sync.ingest; - - ingest - .event_tx - .send(ingest::Event::Messages(ingest::MessagesEvent { - messages: CompressedCRDTOperations::new(messages), - has_more: false, - instance_id: left.id, - wait_tx: None, - })) - .await - .unwrap(); - - if tx.send(()).is_err() { - warn!("failed to send ack to instance 1"); - } - } - ingest::Request::FinishedIngesting => { - right.sync.tx.send(SyncMessage::Ingested).unwrap(); - } - } - } - } - .in_current_span() - }); - } - - half(instance1, instance2, "instance1 -> instance2").await; - half(instance2, instance1, "instance2 -> instance1").await; - } -} diff --git a/core/prisma/migrations/20240920032950_adding_devices/migration.sql b/core/prisma/migrations/20240920032950_adding_devices/migration.sql new file mode 100644 index 000000000..ca9765728 --- /dev/null +++ b/core/prisma/migrations/20240920032950_adding_devices/migration.sql @@ -0,0 +1,201 @@ +/* + Warnings: + + - You are about to drop the `node` table. If the table is not empty, all the data it contains will be lost. + - You are about to drop the column `instance_id` on the `cloud_crdt_operation` table. All the data in the column will be lost. + - You are about to drop the column `instance_id` on the `crdt_operation` table. All the data in the column will be lost. + - You are about to drop the column `instance_pub_id` on the `storage_statistics` table. All the data in the column will be lost. + - Added the required column `device_pub_id` to the `cloud_crdt_operation` table without a default value. This is not possible if the table is not empty. + - Added the required column `device_pub_id` to the `crdt_operation` table without a default value. This is not possible if the table is not empty. + +*/ +-- DropIndex +DROP INDEX "node_pub_id_key"; + +-- DropTable +PRAGMA foreign_keys=off; +DROP TABLE "node"; +PRAGMA foreign_keys=on; + +-- CreateTable +CREATE TABLE "device" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "pub_id" BLOB NOT NULL, + "name" TEXT, + "os" INTEGER, + "hardware_model" INTEGER, + "timestamp" BIGINT, + "date_created" DATETIME, + "date_deleted" DATETIME +); + +-- RedefineTables +PRAGMA defer_foreign_keys=ON; +PRAGMA foreign_keys=OFF; +CREATE TABLE "new_cloud_crdt_operation" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "timestamp" BIGINT NOT NULL, + "model" INTEGER NOT NULL, + "record_id" BLOB NOT NULL, + "kind" TEXT NOT NULL, + "data" BLOB NOT NULL, + "device_pub_id" BLOB NOT NULL, + CONSTRAINT "cloud_crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE +); +INSERT INTO "new_cloud_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "cloud_crdt_operation"; +DROP TABLE "cloud_crdt_operation"; +ALTER TABLE "new_cloud_crdt_operation" RENAME TO "cloud_crdt_operation"; +CREATE INDEX "cloud_crdt_operation_timestamp_idx" ON "cloud_crdt_operation"("timestamp"); +CREATE TABLE "new_crdt_operation" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "timestamp" BIGINT NOT NULL, + "model" INTEGER NOT NULL, + "record_id" BLOB NOT NULL, + "kind" TEXT NOT NULL, + "data" BLOB NOT NULL, + "device_pub_id" BLOB NOT NULL, + CONSTRAINT "crdt_operation_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE RESTRICT ON UPDATE CASCADE +); +INSERT INTO "new_crdt_operation" ("data", "id", "kind", "model", "record_id", "timestamp") SELECT "data", "id", "kind", "model", "record_id", "timestamp" FROM "crdt_operation"; +DROP TABLE "crdt_operation"; +ALTER TABLE "new_crdt_operation" RENAME TO "crdt_operation"; +CREATE INDEX "crdt_operation_timestamp_idx" ON "crdt_operation"("timestamp"); +CREATE TABLE "new_exif_data" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "resolution" BLOB, + "media_date" BLOB, + "media_location" BLOB, + "camera_data" BLOB, + "artist" TEXT, + "description" TEXT, + "copyright" TEXT, + "exif_version" TEXT, + "epoch_time" BIGINT, + "object_id" INTEGER NOT NULL, + "device_pub_id" BLOB, + CONSTRAINT "exif_data_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT "exif_data_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_exif_data" ("artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution") SELECT "artist", "camera_data", "copyright", "description", "epoch_time", "exif_version", "id", "media_date", "media_location", "object_id", "resolution" FROM "exif_data"; +DROP TABLE "exif_data"; +ALTER TABLE "new_exif_data" RENAME TO "exif_data"; +CREATE UNIQUE INDEX "exif_data_object_id_key" ON "exif_data"("object_id"); +CREATE TABLE "new_file_path" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "pub_id" BLOB NOT NULL, + "is_dir" BOOLEAN, + "cas_id" TEXT, + "integrity_checksum" TEXT, + "location_id" INTEGER, + "materialized_path" TEXT, + "name" TEXT, + "extension" TEXT, + "hidden" BOOLEAN, + "size_in_bytes" TEXT, + "size_in_bytes_bytes" BLOB, + "inode" BLOB, + "object_id" INTEGER, + "key_id" INTEGER, + "date_created" DATETIME, + "date_modified" DATETIME, + "date_indexed" DATETIME, + "device_pub_id" BLOB, + CONSTRAINT "file_path_location_id_fkey" FOREIGN KEY ("location_id") REFERENCES "location" ("id") ON DELETE SET NULL ON UPDATE CASCADE, + CONSTRAINT "file_path_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE SET NULL ON UPDATE CASCADE, + CONSTRAINT "file_path_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_file_path" ("cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes") SELECT "cas_id", "date_created", "date_indexed", "date_modified", "extension", "hidden", "id", "inode", "integrity_checksum", "is_dir", "key_id", "location_id", "materialized_path", "name", "object_id", "pub_id", "size_in_bytes", "size_in_bytes_bytes" FROM "file_path"; +DROP TABLE "file_path"; +ALTER TABLE "new_file_path" RENAME TO "file_path"; +CREATE UNIQUE INDEX "file_path_pub_id_key" ON "file_path"("pub_id"); +CREATE INDEX "file_path_location_id_idx" ON "file_path"("location_id"); +CREATE INDEX "file_path_location_id_materialized_path_idx" ON "file_path"("location_id", "materialized_path"); +CREATE UNIQUE INDEX "file_path_location_id_materialized_path_name_extension_key" ON "file_path"("location_id", "materialized_path", "name", "extension"); +CREATE UNIQUE INDEX "file_path_location_id_inode_key" ON "file_path"("location_id", "inode"); +CREATE TABLE "new_label_on_object" ( + "date_created" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "object_id" INTEGER NOT NULL, + "label_id" INTEGER NOT NULL, + "device_pub_id" BLOB, + + PRIMARY KEY ("label_id", "object_id"), + CONSTRAINT "label_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE, + CONSTRAINT "label_on_object_label_id_fkey" FOREIGN KEY ("label_id") REFERENCES "label" ("id") ON DELETE RESTRICT ON UPDATE CASCADE, + CONSTRAINT "label_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_label_on_object" ("date_created", "label_id", "object_id") SELECT "date_created", "label_id", "object_id" FROM "label_on_object"; +DROP TABLE "label_on_object"; +ALTER TABLE "new_label_on_object" RENAME TO "label_on_object"; +CREATE TABLE "new_location" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "pub_id" BLOB NOT NULL, + "name" TEXT, + "path" TEXT, + "total_capacity" INTEGER, + "available_capacity" INTEGER, + "size_in_bytes" BLOB, + "is_archived" BOOLEAN, + "generate_preview_media" BOOLEAN, + "sync_preview_media" BOOLEAN, + "hidden" BOOLEAN, + "date_created" DATETIME, + "scan_state" INTEGER NOT NULL DEFAULT 0, + "device_pub_id" BLOB, + "instance_id" INTEGER, + CONSTRAINT "location_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT "location_instance_id_fkey" FOREIGN KEY ("instance_id") REFERENCES "instance" ("id") ON DELETE SET NULL ON UPDATE CASCADE +); +INSERT INTO "new_location" ("available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity") SELECT "available_capacity", "date_created", "generate_preview_media", "hidden", "id", "instance_id", "is_archived", "name", "path", "pub_id", "scan_state", "size_in_bytes", "sync_preview_media", "total_capacity" FROM "location"; +DROP TABLE "location"; +ALTER TABLE "new_location" RENAME TO "location"; +CREATE UNIQUE INDEX "location_pub_id_key" ON "location"("pub_id"); +CREATE TABLE "new_object" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "pub_id" BLOB NOT NULL, + "kind" INTEGER, + "key_id" INTEGER, + "hidden" BOOLEAN, + "favorite" BOOLEAN, + "important" BOOLEAN, + "note" TEXT, + "date_created" DATETIME, + "date_accessed" DATETIME, + "device_pub_id" BLOB, + CONSTRAINT "object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_object" ("date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id") SELECT "date_accessed", "date_created", "favorite", "hidden", "id", "important", "key_id", "kind", "note", "pub_id" FROM "object"; +DROP TABLE "object"; +ALTER TABLE "new_object" RENAME TO "object"; +CREATE UNIQUE INDEX "object_pub_id_key" ON "object"("pub_id"); +CREATE TABLE "new_storage_statistics" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "pub_id" BLOB NOT NULL, + "total_capacity" BIGINT NOT NULL DEFAULT 0, + "available_capacity" BIGINT NOT NULL DEFAULT 0, + "device_pub_id" BLOB, + CONSTRAINT "storage_statistics_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_storage_statistics" ("available_capacity", "id", "pub_id", "total_capacity") SELECT "available_capacity", "id", "pub_id", "total_capacity" FROM "storage_statistics"; +DROP TABLE "storage_statistics"; +ALTER TABLE "new_storage_statistics" RENAME TO "storage_statistics"; +CREATE UNIQUE INDEX "storage_statistics_pub_id_key" ON "storage_statistics"("pub_id"); +CREATE UNIQUE INDEX "storage_statistics_device_pub_id_key" ON "storage_statistics"("device_pub_id"); +CREATE TABLE "new_tag_on_object" ( + "object_id" INTEGER NOT NULL, + "tag_id" INTEGER NOT NULL, + "date_created" DATETIME, + "device_pub_id" BLOB, + + PRIMARY KEY ("tag_id", "object_id"), + CONSTRAINT "tag_on_object_object_id_fkey" FOREIGN KEY ("object_id") REFERENCES "object" ("id") ON DELETE RESTRICT ON UPDATE CASCADE, + CONSTRAINT "tag_on_object_tag_id_fkey" FOREIGN KEY ("tag_id") REFERENCES "tag" ("id") ON DELETE RESTRICT ON UPDATE CASCADE, + CONSTRAINT "tag_on_object_device_pub_id_fkey" FOREIGN KEY ("device_pub_id") REFERENCES "device" ("pub_id") ON DELETE CASCADE ON UPDATE CASCADE +); +INSERT INTO "new_tag_on_object" ("date_created", "object_id", "tag_id") SELECT "date_created", "object_id", "tag_id" FROM "tag_on_object"; +DROP TABLE "tag_on_object"; +ALTER TABLE "new_tag_on_object" RENAME TO "tag_on_object"; +PRAGMA foreign_keys=ON; +PRAGMA defer_foreign_keys=OFF; + +-- CreateIndex +CREATE UNIQUE INDEX "device_pub_id_key" ON "device"("pub_id"); diff --git a/core/prisma/schema.prisma b/core/prisma/schema.prisma index 62da3559d..ccd012859 100644 --- a/core/prisma/schema.prisma +++ b/core/prisma/schema.prisma @@ -28,9 +28,10 @@ model CRDTOperation { kind String data Bytes - instance_id Int - instance Instance @relation(fields: [instance_id], references: [id]) + // We just need the actual device_pub_id here, but we don't need as an actual relation + device_pub_id Bytes + @@index([timestamp]) @@map("crdt_operation") } @@ -46,24 +47,41 @@ model CloudCRDTOperation { kind String data Bytes - instance_id Int - instance Instance @relation(fields: [instance_id], references: [id]) + // We just need the actual device_pub_id here, but we don't need as an actual relation + device_pub_id Bytes + @@index([timestamp]) @@map("cloud_crdt_operation") } -/// @deprecated: This model has to exist solely for backwards compatibility. -/// @local -model Node { - id Int @id @default(autoincrement()) - pub_id Bytes @unique - name String - // Enum: sd_core::node::Platform - platform Int - date_created DateTime - identity Bytes? // TODO: Change to required field in future +/// Devices are the owner machines connected to this library +/// @shared(id: pub_id, modelId: 12) +model Device { + id Int @id @default(autoincrement()) + // uuid v7 + pub_id Bytes @unique + name String? // Not actually NULLABLE, but we have to comply with current sync implementation BS - @@map("node") + // Enum: sd_cloud_schema::device::DeviceOS + os Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS + // Enum: sd_cloud_schema::device::HardwareModel + hardware_model Int? // Not actually NULLABLE, but we have to comply with current sync implementation BS + + // clock timestamp for sync + timestamp BigInt? + + date_created DateTime? // Not actually NULLABLE, but we have to comply with current sync implementation BS + date_deleted DateTime? + + StorageStatistics StorageStatistics? + Location Location[] + FilePath FilePath[] + Object Object[] + ExifData ExifData[] + TagOnObject TagOnObject[] + LabelOnObject LabelOnObject[] + + @@map("device") } // represents a single `.db` file (SQLite DB) that is paired to the current library. @@ -88,11 +106,7 @@ model Instance { // clock timestamp for sync timestamp BigInt? - - locations Location[] - CRDTOperation CRDTOperation[] - CloudCRDTOperation CloudCRDTOperation[] - storage_statistics StorageStatistics? + Location Location[] @@map("instance") } @@ -158,6 +172,9 @@ model Location { scan_state Int @default(0) // Enum: sd_core::location::ScanState + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + // this should just be a local-only cache but it's too much effort to broadcast online locations rn (@brendan) instance_id Int? instance Instance? @relation(fields: [instance_id], references: [id], onDelete: SetNull) @@ -208,6 +225,9 @@ model FilePath { date_modified DateTime? date_indexed DateTime? + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + // key Key? @relation(fields: [key_id], references: [id]) @@unique([location_id, materialized_path, name, extension]) @@ -255,6 +275,9 @@ model Object { exif_data ExifData? ffmpeg_data FfmpegData? + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + // key Key? @relation(fields: [key_id], references: [id]) @@map("object") @@ -323,6 +346,9 @@ model ExifData { object_id Int @unique object Object @relation(fields: [object_id], references: [id], onDelete: Cascade) + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + @@map("exif_data") } @@ -509,6 +535,9 @@ model TagOnObject { date_created DateTime? + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + @@id([tag_id, object_id]) @@map("tag_on_object") } @@ -537,6 +566,9 @@ model LabelOnObject { label_id Int label Label @relation(fields: [label_id], references: [id], onDelete: Restrict) + device_id Int? + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) + @@id([label_id, object_id]) @@map("label_on_object") } @@ -576,8 +608,8 @@ model StorageStatistics { total_capacity BigInt @default(0) available_capacity BigInt @default(0) - instance_pub_id Bytes? @unique - instance Instance? @relation(fields: [instance_pub_id], references: [pub_id], onDelete: Cascade) + device_id Int? @unique + device Device? @relation(fields: [device_id], references: [id], onDelete: Cascade) @@map("storage_statistics") } diff --git a/core/src/api/auth.rs b/core/src/api/auth.rs deleted file mode 100644 index 994dc7474..000000000 --- a/core/src/api/auth.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::time::Duration; - -use reqwest::StatusCode; -use rspc::alpha::AlphaRouter; -use serde::{Deserialize, Serialize}; -use specta::Type; - -use super::{Ctx, R}; - -pub(crate) fn mount() -> AlphaRouter { - R.router() - .procedure("loginSession", { - #[derive(Serialize, Type)] - #[specta(inline)] - enum Response { - Start { - user_code: String, - verification_url: String, - verification_url_complete: String, - }, - Complete, - Error(String), - } - - R.subscription(|node, _: ()| async move { - #[derive(Deserialize, Type)] - struct DeviceAuthorizationResponse { - device_code: String, - user_code: String, - verification_url: String, - verification_uri_complete: String, - } - - async_stream::stream! { - let device_type = if cfg!(target_arch = "wasm32") { - "web".to_string() - } else if cfg!(target_os = "ios") || cfg!(target_os = "android") { - "mobile".to_string() - } else { - "desktop".to_string() - }; - - let auth_response = match match node - .http - .post(format!( - "{}/login/device/code", - &node.env.api_url.lock().await - )) - .form(&[("client_id", &node.env.client_id), ("device", &device_type)]) - .send() - .await - .map_err(|e| e.to_string()) - { - Ok(r) => r.json::().await.map_err(|e| e.to_string()), - Err(e) => { - yield Response::Error(e.to_string()); - return - }, - } { - Ok(v) => v, - Err(e) => { - yield Response::Error(e.to_string()); - return - }, - }; - - yield Response::Start { - user_code: auth_response.user_code.clone(), - verification_url: auth_response.verification_url.clone(), - verification_url_complete: auth_response.verification_uri_complete.clone(), - }; - - yield loop { - tokio::time::sleep(Duration::from_secs(5)).await; - - let token_resp = match node.http - .post(format!("{}/login/oauth/access_token", &node.env.api_url.lock().await)) - .form(&[ - ("grant_type", sd_cloud_api::auth::DEVICE_CODE_URN), - ("device_code", &auth_response.device_code), - ("client_id", &node.env.client_id) - ]) - .send() - .await { - Ok(v) => v, - Err(e) => break Response::Error(e.to_string()) - }; - - match token_resp.status() { - StatusCode::OK => { - let token = match token_resp.json().await { - Ok(v) => v, - Err(e) => break Response::Error(e.to_string()) - }; - - if let Err(e) = node.config - .write(|c| c.auth_token = Some(token)) - .await { - break Response::Error(e.to_string()); - }; - - - break Response::Complete; - }, - StatusCode::BAD_REQUEST => { - #[derive(Debug, Deserialize)] - struct OAuth400 { - error: String - } - - let resp = match token_resp.json::().await { - Ok(v) => v, - Err(e) => break Response::Error(e.to_string()) - }; - - match resp.error.as_str() { - "authorization_pending" => continue, - e => { - break Response::Error(e.to_string()) - } - } - }, - s => { - break Response::Error(s.to_string()); - } - } - } - } - }) - }) - .procedure( - "logout", - R.mutation(|node, _: ()| async move { - node.config - .write(|c| c.auth_token = None) - .await - .map(|_| ()) - .map_err(|_| { - rspc::Error::new( - rspc::ErrorCode::InternalServerError, - "Failed to write config".to_string(), - ) - }) - }), - ) - .procedure("me", { - R.query(|node, _: ()| async move { - let resp = sd_cloud_api::user::me(node.cloud_api_config().await).await?; - - Ok(resp) - }) - }) -} diff --git a/core/src/api/backups.rs b/core/src/api/backups.rs index de3bb1deb..01e49ced0 100644 --- a/core/src/api/backups.rs +++ b/core/src/api/backups.rs @@ -381,6 +381,7 @@ async fn restore_backup(node: &Arc, path: impl AsRef) -> Result(response: Response) -> Result { - response.json().await.map_err(|_| { - rspc::Error::new( - rspc::ErrorCode::InternalServerError, - "JSON conversion failed".to_string(), - ) - }) -} - -pub(crate) fn mount() -> AlphaRouter { - R.router() - .merge("library.", library::mount()) - .merge("locations.", locations::mount()) - .procedure("getApiOrigin", { - R.query(|node, _: ()| async move { Ok(node.env.api_url.lock().await.to_string()) }) - }) - .procedure("setApiOrigin", { - R.mutation(|node, origin: String| async move { - let mut origin_env = node.env.api_url.lock().await; - origin_env.clone_from(&origin); - - node.config - .write(|c| { - c.auth_token = None; - c.sd_api_origin = Some(origin); - }) - .await - .ok(); - - Ok(()) - }) - }) -} - -mod library { - use std::str::FromStr; - - use sd_p2p::RemoteIdentity; - - use crate::util::MaybeUndefined; - - use super::*; - - pub fn mount() -> AlphaRouter { - R.router() - .procedure("get", { - R.with2(library()) - .query(|(node, library), _: ()| async move { - Ok( - sd_cloud_api::library::get(node.cloud_api_config().await, library.id) - .await?, - ) - }) - }) - .procedure("list", { - R.query(|node, _: ()| async move { - Ok(sd_cloud_api::library::list(node.cloud_api_config().await).await?) - }) - }) - .procedure("create", { - R.with2(library()) - .mutation(|(node, library), _: ()| async move { - let node_config = node.config.get().await; - let cloud_library = sd_cloud_api::library::create( - node.cloud_api_config().await, - library.id, - &library.config().await.name, - library.instance_uuid, - library.identity.to_remote_identity(), - node_config.id, - node_config.identity.to_remote_identity(), - &node.p2p.peer_metadata(), - ) - .await?; - node.libraries - .edit( - library.id, - None, - MaybeUndefined::Undefined, - MaybeUndefined::Value(cloud_library.id), - None, - ) - .await?; - - invalidate_query!(library, "cloud.library.get"); - - Ok(()) - }) - }) - .procedure("join", { - R.mutation(|node, library_id: Uuid| async move { - let Some(cloud_library) = - sd_cloud_api::library::get(node.cloud_api_config().await, library_id) - .await? - else { - return Err(rspc::Error::new( - rspc::ErrorCode::NotFound, - "Library not found".to_string(), - )); - }; - - let library = node - .libraries - .create_with_uuid( - library_id, - LibraryName::new(cloud_library.name).map_err(|e| { - rspc::Error::new( - rspc::ErrorCode::InternalServerError, - e.to_string(), - ) - })?, - None, - false, - None, - &node, - true, - ) - .await?; - node.libraries - .edit( - library.id, - None, - MaybeUndefined::Undefined, - MaybeUndefined::Value(cloud_library.id), - None, - ) - .await?; - - let node_config = node.config.get().await; - let instances = sd_cloud_api::library::join( - node.cloud_api_config().await, - library_id, - library.instance_uuid, - library.identity.to_remote_identity(), - node_config.id, - node_config.identity.to_remote_identity(), - node.p2p.peer_metadata(), - ) - .await?; - - for instance in instances { - crate::cloud::sync::receive::upsert_instance( - library.id, - &library.db, - &library.sync, - &node.libraries, - &instance.uuid, - instance.identity, - &instance.node_id, - RemoteIdentity::from_str(&instance.node_remote_identity) - .expect("malformed remote identity in the DB"), - instance.metadata, - ) - .await?; - } - - invalidate_query!(library, "cloud.library.get"); - invalidate_query!(library, "cloud.library.list"); - - Ok(LibraryConfigWrapped::from_library(&library).await) - }) - }) - .procedure("sync", { - R.with2(library()) - .mutation(|(_, library), _: ()| async move { - library.do_cloud_sync(); - Ok(()) - }) - }) - } -} -mod locations { - use super::*; - - use serde::{Deserialize, Serialize}; - use specta::Type; - #[derive(Type, Serialize, Deserialize)] - pub struct CloudLocation { - id: String, - name: String, - } - - pub fn mount() -> AlphaRouter { - R.router() - .procedure("list", { - R.query(|node, _: ()| async move { - sd_cloud_api::locations::list(node.cloud_api_config().await) - .await - .map_err(Into::into) - }) - }) - .procedure("create", { - R.mutation(|node, name: String| async move { - sd_cloud_api::locations::create(node.cloud_api_config().await, name) - .await - .map_err(Into::into) - }) - }) - .procedure("remove", { - R.mutation(|node, id: String| async move { - sd_cloud_api::locations::create(node.cloud_api_config().await, id) - .await - .map_err(Into::into) - }) - }) - } -} diff --git a/core/src/api/cloud/devices.rs b/core/src/api/cloud/devices.rs new file mode 100644 index 000000000..ca1a81ad6 --- /dev/null +++ b/core/src/api/cloud/devices.rs @@ -0,0 +1,397 @@ +use crate::api::{Ctx, R}; + +use sd_core_cloud_services::QuinnConnection; + +use sd_cloud_schema::{ + auth::AccessToken, + devices::{self, DeviceOS, HardwareModel, PubId}, + opaque_ke::{ + ClientLogin, ClientLoginFinishParameters, ClientLoginFinishResult, ClientLoginStartResult, + ClientRegistration, ClientRegistrationFinishParameters, ClientRegistrationFinishResult, + ClientRegistrationStartResult, + }, + Client, NodeId, Service, SpacedriveCipherSuite, +}; +use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng}; + +use blake3::Hash; +use futures::{FutureExt, SinkExt, StreamExt}; +use futures_concurrency::future::TryJoin; +use rspc::alpha::AlphaRouter; +use serde::Deserialize; +use tracing::{debug, error}; + +pub fn mount() -> AlphaRouter { + R.router() + .procedure("get", { + R.query(|node, pub_id: devices::PubId| async move { + use devices::get::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + let Response(device) = super::handle_comm_error( + client + .devices() + .get(Request { + pub_id, + access_token, + }) + .await, + "Failed to get device;", + )??; + + debug!(?device, "Got device"); + + Ok(device) + }) + }) + .procedure("list", { + R.query(|node, _: ()| async move { + use devices::list::{Request, Response}; + + let ((client, access_token), pub_id) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + ) + .try_join() + .await?; + + let Response(mut devices) = super::handle_comm_error( + client.devices().list(Request { access_token }).await, + "Failed to list devices;", + )??; + + // Filter out the local device by matching pub_id + devices.retain(|device| device.pub_id != pub_id); + + debug!(?devices, "Listed devices"); + + Ok(devices) + }) + }) + .procedure("get_current_device", { + R.query(|node, _: ()| async move { + use devices::get::{Request, Response}; + + let ((client, access_token), pub_id) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + ) + .try_join() + .await?; + + let Response(device) = super::handle_comm_error( + client + .devices() + .get(Request { + pub_id, + access_token, + }) + .await, + "Failed to get current device;", + )??; + Ok(device) + }) + }) + .procedure("delete", { + R.mutation(|node, pub_id: devices::PubId| async move { + use devices::delete::Request; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .devices() + .delete(Request { + pub_id, + access_token, + }) + .await, + "Failed to delete device;", + )??; + + debug!("Deleted device"); + + Ok(()) + }) + }) + .procedure("update", { + #[derive(Deserialize, specta::Type)] + struct CloudUpdateDeviceArgs { + pub_id: devices::PubId, + name: String, + } + + R.mutation( + |node, CloudUpdateDeviceArgs { pub_id, name }: CloudUpdateDeviceArgs| async move { + use devices::update::Request; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .devices() + .update(Request { + access_token, + pub_id, + name, + }) + .await, + "Failed to update device;", + )??; + + debug!("Updated device"); + + Ok(()) + }, + ) + }) +} + +pub async fn hello( + client: &Client, Service>, + access_token: AccessToken, + device_pub_id: PubId, + hashed_pub_id: Hash, + rng: &mut CryptoRng, +) -> Result { + use devices::hello::{Request, RequestUpdate, Response, State}; + + let ClientLoginStartResult { message, state } = + ClientLogin::::start(rng, hashed_pub_id.as_bytes().as_slice()) + .map_err(|e| { + error!(?e, "OPAQUE error initializing device hello request;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to initialize device login".into(), + ) + })?; + + let (mut hello_continuation, mut res_stream) = super::handle_comm_error( + client + .devices() + .hello(Request { + access_token, + pub_id: device_pub_id, + opaque_login_message: Box::new(message), + }) + .await, + "Failed to send device hello request;", + )?; + + let Some(res) = res_stream.next().await else { + let message = "Server did not send a device hello response;"; + error!("{message}"); + return Err(rspc::Error::new( + rspc::ErrorCode::InternalServerError, + message.to_string(), + )); + }; + + let credential_response = match super::handle_comm_error( + res, + "Communication error on device hello response;", + )? { + Ok(Response(State::LoginResponse(credential_response))) => credential_response, + + Ok(Response(State::End)) => { + unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;"); + } + + Err(e) => { + error!(?e, "Device hello response error;"); + return Err(e.into()); + } + }; + + let ClientLoginFinishResult { + message, + export_key, + .. + } = state + .finish( + hashed_pub_id.as_bytes().as_slice(), + *credential_response, + ClientLoginFinishParameters::default(), + ) + .map_err(|e| { + error!(?e, "Device hello finish error;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to finish device login".into(), + ) + })?; + + hello_continuation + .send(RequestUpdate { + opaque_login_finish: Box::new(message), + }) + .await + .map_err(|e| { + error!(?e, "Failed to send device hello request continuation;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to finish device login procedure;".into(), + ) + })?; + + let Some(res) = res_stream.next().await else { + let message = "Server did not send a device hello END response;"; + error!("{message}"); + return Err(rspc::Error::new( + rspc::ErrorCode::InternalServerError, + message.to_string(), + )); + }; + + match super::handle_comm_error(res, "Communication error on device hello response;")? { + Ok(Response(State::LoginResponse(_))) => { + unreachable!("Device hello final response MUST be End here, this is a serious bug and should crash;"); + } + + Ok(Response(State::End)) => { + // Protocol completed successfully + Ok(SecretKey::from(export_key)) + } + + Err(e) => { + error!(?e, "Device hello final response error;"); + Err(e.into()) + } + } +} + +pub struct DeviceRegisterData { + pub pub_id: PubId, + pub name: String, + pub os: DeviceOS, + pub hardware_model: HardwareModel, + pub connection_id: NodeId, +} + +pub async fn register( + client: &Client, Service>, + access_token: AccessToken, + DeviceRegisterData { + pub_id, + name, + os, + hardware_model, + connection_id, + }: DeviceRegisterData, + hashed_pub_id: Hash, + rng: &mut CryptoRng, +) -> Result { + use devices::register::{Request, RequestUpdate, Response, State}; + + let ClientRegistrationStartResult { message, state } = + ClientRegistration::::start( + rng, + hashed_pub_id.as_bytes().as_slice(), + ) + .map_err(|e| { + error!(?e, "OPAQUE error initializing device register request;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to initialize device register".into(), + ) + })?; + + let (mut register_continuation, mut res_stream) = super::handle_comm_error( + client + .devices() + .register(Request { + access_token, + pub_id, + name, + os, + hardware_model, + connection_id, + opaque_register_message: Box::new(message), + }) + .await, + "Failed to send device register request;", + )?; + + let Some(res) = res_stream.next().await else { + let message = "Server did not send a device register response;"; + error!("{message}"); + return Err(rspc::Error::new( + rspc::ErrorCode::InternalServerError, + message.to_string(), + )); + }; + + let registration_response = match super::handle_comm_error( + res, + "Communication error on device register response;", + )? { + Ok(Response(State::RegistrationResponse(res))) => res, + + Ok(Response(State::End)) => { + unreachable!("Device hello response MUST not be End here, this is a serious bug and should crash;"); + } + + Err(e) => { + error!(?e, "Device hello response error;"); + return Err(e.into()); + } + }; + + let ClientRegistrationFinishResult { + message, + export_key, + .. + } = state + .finish( + rng, + hashed_pub_id.as_bytes().as_slice(), + *registration_response, + ClientRegistrationFinishParameters::default(), + ) + .map_err(|e| { + error!(?e, "Device register finish error;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to finish device register".into(), + ) + })?; + + register_continuation + .send(RequestUpdate { + opaque_registration_finish: Box::new(message), + }) + .await + .map_err(|e| { + error!(?e, "Failed to send device register request continuation;"); + rspc::Error::new( + rspc::ErrorCode::InternalServerError, + "Failed to finish device register procedure;".into(), + ) + })?; + + let Some(res) = res_stream.next().await else { + let message = "Server did not send a device register END response;"; + error!("{message}"); + return Err(rspc::Error::new( + rspc::ErrorCode::InternalServerError, + message.to_string(), + )); + }; + + match super::handle_comm_error(res, "Communication error on device register response;")? { + Ok(Response(State::RegistrationResponse(_))) => { + unreachable!("Device register final response MUST be End here, this is a serious bug and should crash;"); + } + + Ok(Response(State::End)) => { + // Protocol completed successfully + Ok(SecretKey::from(export_key)) + } + + Err(e) => { + error!(?e, "Device register final response error;"); + Err(e.into()) + } + } +} diff --git a/core/src/api/cloud/libraries.rs b/core/src/api/cloud/libraries.rs new file mode 100644 index 000000000..884e5e21b --- /dev/null +++ b/core/src/api/cloud/libraries.rs @@ -0,0 +1,140 @@ +use crate::api::{utils::library, Ctx, R}; + +use sd_cloud_schema::libraries; + +use futures::FutureExt; +use futures_concurrency::future::TryJoin; +use rspc::alpha::AlphaRouter; +use serde::Deserialize; +use tracing::debug; + +pub fn mount() -> AlphaRouter { + R.router() + .procedure("get", { + #[derive(Deserialize, specta::Type)] + struct CloudGetLibraryArgs { + pub_id: libraries::PubId, + with_device: bool, + } + + R.query( + |node, + CloudGetLibraryArgs { + pub_id, + with_device, + }: CloudGetLibraryArgs| async move { + use libraries::get::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + let Response(library) = super::handle_comm_error( + client + .libraries() + .get(Request { + access_token, + pub_id, + with_device, + }) + .await, + "Failed to get library;", + )??; + + debug!(?library, "Got library"); + + Ok(library) + }, + ) + }) + .procedure("list", { + R.query(|node, with_device: bool| async move { + use libraries::list::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + let Response(libraries) = super::handle_comm_error( + client + .libraries() + .list(Request { + access_token, + with_device, + }) + .await, + "Failed to list libraries;", + )??; + + debug!(?libraries, "Listed libraries"); + + Ok(libraries) + }) + }) + .procedure("create", { + R.with2(library()) + .mutation(|(node, library), _: ()| async move { + let ((client, access_token), name, device_pub_id) = ( + super::get_client_and_access_token(&node), + library.config().map(|config| Ok(config.name.to_string())), + node.config.get().map(|config| Ok(config.id.into())), + ) + .try_join() + .await?; + + super::handle_comm_error( + client + .libraries() + .create(libraries::create::Request { + name, + access_token, + pub_id: libraries::PubId(library.id), + device_pub_id, + }) + .await, + "Failed to create library;", + )??; + + Ok(()) + }) + }) + .procedure("delete", { + R.with2(library()) + .mutation(|(node, library), _: ()| async move { + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .libraries() + .delete(libraries::delete::Request { + access_token, + pub_id: libraries::PubId(library.id), + }) + .await, + "Failed to delete library;", + )??; + + debug!("Deleted library"); + + Ok(()) + }) + }) + .procedure("update", { + R.with2(library()) + .mutation(|(node, library), name: String| async move { + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .libraries() + .update(libraries::update::Request { + access_token, + pub_id: libraries::PubId(library.id), + name, + }) + .await, + "Failed to update library;", + )??; + + debug!("Updated library"); + + Ok(()) + }) + }) +} diff --git a/core/src/api/cloud/locations.rs b/core/src/api/cloud/locations.rs new file mode 100644 index 000000000..e41e3a865 --- /dev/null +++ b/core/src/api/cloud/locations.rs @@ -0,0 +1,112 @@ +use crate::api::{Ctx, R}; + +use sd_cloud_schema::{devices, libraries, locations}; + +use rspc::alpha::AlphaRouter; +use serde::Deserialize; +use tracing::debug; + +pub fn mount() -> AlphaRouter { + R.router() + .procedure("list", { + #[derive(Deserialize, specta::Type)] + struct CloudListLocationsArgs { + pub library_pub_id: libraries::PubId, + pub with_library: bool, + pub with_device: bool, + } + + R.query( + |node, + CloudListLocationsArgs { + library_pub_id, + with_library, + with_device, + }: CloudListLocationsArgs| async move { + use locations::list::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + let Response(locations) = super::handle_comm_error( + client + .locations() + .list(Request { + access_token, + library_pub_id, + with_library, + with_device, + }) + .await, + "Failed to list locations;", + )??; + + debug!(?locations, "Got locations"); + + Ok(locations) + }, + ) + }) + .procedure("create", { + #[derive(Deserialize, specta::Type)] + struct CloudCreateLocationArgs { + pub pub_id: locations::PubId, + pub name: String, + pub library_pub_id: libraries::PubId, + pub device_pub_id: devices::PubId, + } + + R.mutation( + |node, + CloudCreateLocationArgs { + pub_id, + name, + library_pub_id, + device_pub_id, + }: CloudCreateLocationArgs| async move { + use locations::create::Request; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .locations() + .create(Request { + access_token, + pub_id, + name, + library_pub_id, + device_pub_id, + }) + .await, + "Failed to list locations;", + )??; + + debug!("Created cloud location"); + + Ok(()) + }, + ) + }) + .procedure("delete", { + R.mutation(|node, pub_id: locations::PubId| async move { + use locations::delete::Request; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .locations() + .delete(Request { + access_token, + pub_id, + }) + .await, + "Failed to list locations;", + )??; + + debug!("Created cloud location"); + + Ok(()) + }) + }) +} diff --git a/core/src/api/cloud/mod.rs b/core/src/api/cloud/mod.rs new file mode 100644 index 000000000..611fb0c31 --- /dev/null +++ b/core/src/api/cloud/mod.rs @@ -0,0 +1,316 @@ +use crate::{ + library::LibraryManagerError, + node::{config::NodeConfig, HardwareModel}, + Node, +}; + +use sd_core_cloud_services::{CloudP2P, KeyManager, QuinnConnection, UserResponse}; + +use sd_cloud_schema::{ + auth, + error::{ClientSideError, Error}, + sync::groups, + users, Client, SecretKey as IrohSecretKey, Service, +}; +use sd_crypto::{CryptoRng, SeedableRng}; +use sd_utils::error::report_error; + +use std::pin::pin; + +use async_stream::stream; +use futures::{FutureExt, StreamExt}; +use futures_concurrency::future::TryJoin; +use rspc::alpha::AlphaRouter; +use tracing::{debug, error, instrument}; + +use super::{Ctx, R}; + +mod devices; +mod libraries; +mod locations; +mod sync_groups; + +async fn try_get_cloud_services_client( + node: &Node, +) -> Result, Service>, sd_core_cloud_services::Error> { + node.cloud_services + .client() + .await + .map_err(report_error("Failed to get cloud services client")) +} + +pub(crate) fn mount() -> AlphaRouter { + R.router() + .merge("libraries.", libraries::mount()) + .merge("locations.", locations::mount()) + .merge("devices.", devices::mount()) + .merge("syncGroups.", sync_groups::mount()) + .procedure("bootstrap", { + R.mutation( + |node, (access_token, refresh_token): (auth::AccessToken, auth::RefreshToken)| async move { + use sd_cloud_schema::devices; + + // Only allow a single bootstrap request in flight at a time + let mut has_bootstrapped_lock = node + .cloud_services + .has_bootstrapped + .try_lock() + .map_err(|_| { + rspc::Error::new( + rspc::ErrorCode::Conflict, + String::from("Bootstrap in progress"), + ) + })?; + + if *has_bootstrapped_lock { + return Err(rspc::Error::new( + rspc::ErrorCode::Conflict, + String::from("Already bootstrapped"), + )); + } + + node.cloud_services + .token_refresher + .init(access_token, refresh_token) + .await?; + + let client = try_get_cloud_services_client(&node).await?; + let data_directory = node.config.data_directory(); + + let mut rng = + CryptoRng::from_seed(node.master_rng.lock().await.generate_fixed()); + + // create user route is idempotent, so we can safely keep creating the same user over and over + handle_comm_error( + client + .users() + .create(users::create::Request { + access_token: node + .cloud_services + .token_refresher + .get_access_token() + .await?, + }) + .await, + "Failed to create user;", + )??; + + let (device_pub_id, name, os) = { + let NodeConfig { id, name, os, .. } = node.config.get().await; + (devices::PubId(id.into()), name, os) + }; + + let hashed_pub_id = blake3::hash(device_pub_id.0.as_bytes().as_slice()); + + let key_manager = match handle_comm_error( + client + .devices() + .get(devices::get::Request { + access_token: node + .cloud_services + .token_refresher + .get_access_token() + .await?, + pub_id: device_pub_id, + }) + .await, + "Failed to get device on cloud bootstrap;", + )? { + Ok(_) => { + // Device registered, we execute a device hello flow + let master_key = self::devices::hello( + &client, + node.cloud_services + .token_refresher + .get_access_token() + .await?, + device_pub_id, + hashed_pub_id, + &mut rng, + ) + .await?; + + debug!("Device hello successful"); + + KeyManager::load(master_key, data_directory).await? + } + Err(Error::Client(ClientSideError::NotFound(_))) => { + // Device not registered, we execute a device register flow + let iroh_secret_key = IrohSecretKey::generate_with_rng(&mut rng); + let hardware_model = Into::into( + HardwareModel::try_get().unwrap_or(HardwareModel::Other), + ); + + let master_key = self::devices::register( + &client, + node.cloud_services + .token_refresher + .get_access_token() + .await?, + self::devices::DeviceRegisterData { + pub_id: device_pub_id, + name, + os, + hardware_model, + connection_id: iroh_secret_key.public(), + }, + hashed_pub_id, + &mut rng, + ) + .await?; + + debug!("Device registered successfully"); + + KeyManager::new(master_key, iroh_secret_key, data_directory, &mut rng) + .await? + } + Err(e) => return Err(e.into()), + }; + + let iroh_secret_key = key_manager.iroh_secret_key().await; + + node.cloud_services.set_key_manager(key_manager).await; + + node.cloud_services + .set_cloud_p2p( + CloudP2P::new( + device_pub_id, + &node.cloud_services, + rng, + iroh_secret_key, + node.cloud_services.cloud_p2p_dns_origin_name.clone(), + node.cloud_services.cloud_p2p_dns_pkarr_url.clone(), + node.cloud_services.cloud_p2p_relay_url.clone(), + ) + .await?, + ) + .await; + + let groups::list::Response(groups) = handle_comm_error( + client + .sync() + .groups() + .list(groups::list::Request { + access_token: node + .cloud_services + .token_refresher + .get_access_token() + .await?, + }) + .await, + "Failed to list sync groups on bootstrap", + )??; + + groups + .into_iter() + .map( + |groups::GroupBaseData { + pub_id, + library, + // TODO(@fogodev): We can use this latest key hash to check if we + // already have the latest key hash for this group locally + // issuing a ask for key hash request for other devices if we don't + latest_key_hash: _latest_key_hash, + .. + }| { + let node = &node; + + async move { + match initialize_cloud_sync(pub_id, library, node).await { + // If we don't have this library locally, we didn't joined this group yet + Ok(()) | Err(LibraryManagerError::LibraryNotFound) => { + Ok(()) + } + Err(e) => Err(e), + } + } + }, + ) + .collect::>() + .try_join() + .await?; + + *has_bootstrapped_lock = true; + + Ok(()) + }, + ) + }) + .procedure( + "listenCloudServicesNotifications", + R.subscription(|node, _: ()| async move { + stream! { + let mut notifications_stream = + pin!(node.cloud_services.stream_user_notifications()); + + while let Some(notification) = notifications_stream.next().await { + yield notification; + } + } + }), + ) + .procedure( + "userResponse", + R.mutation(|node, response: UserResponse| async move { + node.cloud_services.send_user_response(response).await; + + Ok(()) + }), + ) + .procedure( + "hasBootstrapped", + R.query(|node, _: ()| async move { + // If we can't lock immediately, it means that there is a bootstrap in progress + // so we didn't bootstrapped yet + Ok(node + .cloud_services + .has_bootstrapped + .try_lock() + .map(|lock| *lock) + .unwrap_or(false)) + }), + ) +} + +fn handle_comm_error( + res: Result, + message: &'static str, +) -> Result { + res.map_err(|e| { + error!(?e, "Communication with cloud services error: {message}"); + rspc::Error::with_cause(rspc::ErrorCode::InternalServerError, message.into(), e) + }) +} + +#[instrument(skip_all, fields(%group_pub_id, %library_pub_id), err)] +async fn initialize_cloud_sync( + group_pub_id: groups::PubId, + sd_cloud_schema::libraries::Library { + pub_id: sd_cloud_schema::libraries::PubId(library_pub_id), + .. + }: sd_cloud_schema::libraries::Library, + node: &Node, +) -> Result<(), LibraryManagerError> { + let library = node + .libraries + .get_library(&library_pub_id) + .await + .ok_or(LibraryManagerError::LibraryNotFound)?; + + library.init_cloud_sync(node, group_pub_id).await +} + +async fn get_client_and_access_token( + node: &Node, +) -> Result<(Client, Service>, auth::AccessToken), rspc::Error> { + ( + try_get_cloud_services_client(node), + node.cloud_services + .token_refresher + .get_access_token() + .map(|res| res.map_err(Into::into)), + ) + .try_join() + .await + .map_err(Into::into) +} diff --git a/core/src/api/cloud/sync_groups.rs b/core/src/api/cloud/sync_groups.rs new file mode 100644 index 000000000..6095b01bb --- /dev/null +++ b/core/src/api/cloud/sync_groups.rs @@ -0,0 +1,404 @@ +use crate::{ + api::{utils::library, Ctx, R}, + library::LibraryName, + Node, +}; + +use sd_core_cloud_services::JoinedLibraryCreateArgs; + +use sd_cloud_schema::{ + cloud_p2p, devices, libraries, + sync::{groups, KeyHash}, +}; +use sd_crypto::{cloud::secret_key::SecretKey, CryptoRng, SeedableRng}; + +use std::sync::Arc; + +use futures::FutureExt; +use futures_concurrency::future::TryJoin; +use rspc::alpha::AlphaRouter; +use serde::{Deserialize, Serialize}; +use tokio::{spawn, sync::oneshot}; +use tracing::{debug, error}; + +pub fn mount() -> AlphaRouter { + R.router() + .procedure("create", { + R.with2(library()) + .mutation(|(node, library), _: ()| async move { + use groups::create::{Request, Response}; + + let ((client, access_token), device_pub_id, mut rng, key_manager) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + node.master_rng + .lock() + .map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))), + node.cloud_services + .key_manager() + .map(|res| res.map_err(Into::into)), + ) + .try_join() + .await?; + + let new_key = SecretKey::generate(&mut rng); + let key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string()); + + let Response(group_pub_id) = super::handle_comm_error( + client + .sync() + .groups() + .create(Request { + access_token: access_token.clone(), + key_hash: key_hash.clone(), + library_pub_id: libraries::PubId(library.id), + device_pub_id, + }) + .await, + "Failed to create sync group;", + )??; + + if let Err(e) = key_manager + .add_key_with_hash(group_pub_id, new_key, key_hash.clone(), &mut rng) + .await + { + super::handle_comm_error( + client + .sync() + .groups() + .delete(groups::delete::Request { + access_token, + pub_id: group_pub_id, + }) + .await, + "Failed to delete sync group after we failed to store secret key in key manager;", + )??; + + return Err(e.into()); + } + + library.init_cloud_sync(&node, group_pub_id).await?; + + debug!(%group_pub_id, ?key_hash, "Created sync group"); + + Ok(()) + }) + }) + .procedure("delete", { + R.mutation(|node, pub_id: groups::PubId| async move { + use groups::delete::Request; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + super::handle_comm_error( + client + .sync() + .groups() + .delete(Request { + access_token, + pub_id, + }) + .await, + "Failed to delete sync group;", + )??; + + debug!(%pub_id, "Deleted sync group"); + + Ok(()) + }) + }) + .procedure("get", { + #[derive(Deserialize, specta::Type)] + struct CloudGetSyncGroupArgs { + pub pub_id: groups::PubId, + pub kind: groups::get::RequestKind, + } + + // This is a compatibility layer because quic-rpc uses bincode for serialization + // and bincode doesn't support serde's tagged enums, and we need them for serializing + // to frontend + #[derive(Debug, Serialize, specta::Type)] + #[serde(tag = "kind", content = "data")] + pub enum CloudSyncGroupGetResponseKind { + WithDevices(groups::GroupWithDevices), + FullData(groups::Group), + } + + impl From for CloudSyncGroupGetResponseKind { + fn from(kind: groups::get::ResponseKind) -> Self { + match kind { + groups::get::ResponseKind::WithDevices(data) => { + CloudSyncGroupGetResponseKind::WithDevices(data) + } + + groups::get::ResponseKind::FullData(data) => { + CloudSyncGroupGetResponseKind::FullData(data) + } + groups::get::ResponseKind::DevicesConnectionIds(_) => { + unreachable!( + "DevicesConnectionIds response is not expected, as we requested it" + ); + } + } + } + } + + R.query( + |node, CloudGetSyncGroupArgs { pub_id, kind }: CloudGetSyncGroupArgs| async move { + use groups::get::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + if matches!(kind, groups::get::RequestKind::DevicesConnectionIds) { + return Err(rspc::Error::new( + rspc::ErrorCode::PreconditionFailed, + "This request isn't allowed here".into(), + )); + } + + let Response(response_kind) = super::handle_comm_error( + client + .sync() + .groups() + .get(Request { + access_token, + pub_id, + kind, + }) + .await, + "Failed to get sync group;", + )??; + + debug!(?response_kind, "Got sync group"); + + Ok(CloudSyncGroupGetResponseKind::from(response_kind)) + }, + ) + }) + .procedure("leave", { + R.query(|node, pub_id: groups::PubId| async move { + let ((client, access_token), current_device_pub_id, mut rng, key_manager) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + node.master_rng + .lock() + .map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))), + node.cloud_services + .key_manager() + .map(|res| res.map_err(Into::into)), + ) + .try_join() + .await?; + + super::handle_comm_error( + client + .sync() + .groups() + .leave(groups::leave::Request { + access_token, + pub_id, + current_device_pub_id, + }) + .await, + "Failed to leave sync group;", + )??; + + key_manager.remove_group(pub_id, &mut rng).await?; + + debug!(%pub_id, "Left sync group"); + + Ok(()) + }) + }) + .procedure("list", { + R.query(|node, _: ()| async move { + use groups::list::{Request, Response}; + + let (client, access_token) = super::get_client_and_access_token(&node).await?; + + let Response(groups) = super::handle_comm_error( + client.sync().groups().list(Request { access_token }).await, + "Failed to list groups;", + )??; + + debug!(?groups, "Listed sync groups"); + + Ok(groups) + }) + }) + .procedure("remove_device", { + #[derive(Deserialize, specta::Type)] + struct CloudSyncGroupsRemoveDeviceArgs { + group_pub_id: groups::PubId, + to_remove_device_pub_id: devices::PubId, + } + R.query( + |node, + CloudSyncGroupsRemoveDeviceArgs { + group_pub_id, + to_remove_device_pub_id, + }: CloudSyncGroupsRemoveDeviceArgs| async move { + use groups::remove_device::Request; + + let ((client, access_token), current_device_pub_id, mut rng, key_manager) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + node.master_rng + .lock() + .map(|mut rng| Ok(CryptoRng::from_seed(rng.generate_fixed()))), + node.cloud_services + .key_manager() + .map(|res| res.map_err(Into::into)), + ) + .try_join() + .await?; + + let new_key = SecretKey::generate(&mut rng); + let new_key_hash = KeyHash(blake3::hash(new_key.as_ref()).to_hex().to_string()); + + key_manager + .add_key_with_hash(group_pub_id, new_key, new_key_hash.clone(), &mut rng) + .await?; + + super::handle_comm_error( + client + .sync() + .groups() + .remove_device(Request { + access_token, + group_pub_id, + new_key_hash, + current_device_pub_id, + to_remove_device_pub_id, + }) + .await, + "Failed to remove device from sync group;", + )??; + + debug!(%to_remove_device_pub_id, %group_pub_id, "Removed device"); + + Ok(()) + }, + ) + }) + .procedure("request_join", { + #[derive(Deserialize, specta::Type)] + struct SyncGroupsRequestJoinArgs { + sync_group: groups::GroupWithDevices, + asking_device: devices::Device, + } + + R.mutation( + |node, + SyncGroupsRequestJoinArgs { + sync_group, + asking_device, + }: SyncGroupsRequestJoinArgs| async move { + let ((client, access_token), current_device_pub_id, cloud_p2p) = ( + super::get_client_and_access_token(&node), + node.config.get().map(|config| Ok(config.id.into())), + node.cloud_services + .cloud_p2p() + .map(|res| res.map_err(Into::into)), + ) + .try_join() + .await?; + + let group_pub_id = sync_group.pub_id; + + debug!("My pub id: {:?}", current_device_pub_id); + debug!("Asking device pub id: {:?}", asking_device.pub_id); + if asking_device.pub_id != current_device_pub_id { + return Err(rspc::Error::new( + rspc::ErrorCode::BadRequest, + String::from("Asking device must be the current device"), + )); + } + + let groups::request_join::Response(existing_devices) = + super::handle_comm_error( + client + .sync() + .groups() + .request_join(groups::request_join::Request { + access_token, + group_pub_id, + current_device_pub_id, + }) + .await, + "Failed to update library;", + )??; + + let (tx, rx) = oneshot::channel(); + + cloud_p2p + .request_join_sync_group( + existing_devices, + cloud_p2p::authorize_new_device_in_sync_group::Request { + sync_group, + asking_device, + }, + tx, + ) + .await; + + JoinedSyncGroupReceiver { + node, + group_pub_id, + rx, + } + .dispatch(); + + debug!(%group_pub_id, "Requested to join sync group"); + + Ok(()) + }, + ) + }) +} + +struct JoinedSyncGroupReceiver { + node: Arc, + group_pub_id: groups::PubId, + rx: oneshot::Receiver, +} + +impl JoinedSyncGroupReceiver { + fn dispatch(self) { + spawn(async move { + let Self { + node, + group_pub_id, + rx, + } = self; + + if let Ok(JoinedLibraryCreateArgs { + pub_id: libraries::PubId(pub_id), + name, + description, + }) = rx.await + { + let Ok(name) = + LibraryName::new(name).map_err(|e| error!(?e, "Invalid library name")) + else { + return; + }; + + let Ok(library) = node + .libraries + .create_with_uuid(pub_id, name, description, true, None, &node) + .await + .map_err(|e| { + error!(?e, "Failed to create library from sync group join response") + }) + else { + return; + }; + + if let Err(e) = library.init_cloud_sync(&node, group_pub_id).await { + error!(?e, "Failed to initialize cloud sync for library"); + } + } + }); + } +} diff --git a/core/src/api/ephemeral_files.rs b/core/src/api/ephemeral_files.rs index c2cc85a52..3acd7c573 100644 --- a/core/src/api/ephemeral_files.rs +++ b/core/src/api/ephemeral_files.rs @@ -23,7 +23,6 @@ use sd_utils::error::FileIOError; use std::{ffi::OsStr, path::PathBuf, str::FromStr}; -use async_recursion::async_recursion; use futures_concurrency::future::TryJoin; use regex::Regex; use rspc::{alpha::AlphaRouter, ErrorCode}; @@ -481,7 +480,6 @@ impl EphemeralFileSystemOps { Ok(()) } - #[async_recursion] async fn copy(self, library: &Library) -> Result<(), rspc::Error> { self.check().await?; @@ -584,11 +582,13 @@ impl EphemeralFileSystemOps { .await?; if !more_files.is_empty() { - Self { - sources: more_files, - target_dir: target, - } - .copy(library) + Box::pin( + Self { + sources: more_files, + target_dir: target, + } + .copy(library), + ) .await } else { Ok(()) diff --git a/core/src/api/files.rs b/core/src/api/files.rs index 9a512fddc..155fd2884 100644 --- a/core/src/api/files.rs +++ b/core/src/api/files.rs @@ -28,8 +28,8 @@ use sd_prisma::{ prisma::{file_path, location, object}, prisma_sync, }; -use sd_sync::OperationFactory; -use sd_utils::{db::maybe_missing, error::FileIOError, msgpack}; +use sd_sync::{sync_db_entry, sync_db_nullable_entry, sync_entry, OperationFactory}; +use sd_utils::{db::maybe_missing, error::FileIOError}; use std::{ ffi::OsString, @@ -195,19 +195,19 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let (sync_param, db_param) = sync_db_nullable_entry!(args.note, object::note); + sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id, }, - object::note::NAME, - msgpack!(&args.note), - ), - db.object().update( - object::id::equals(args.id), - vec![object::note::set(args.note)], + [sync_param], ), + db.object() + .update(object::id::equals(args.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; @@ -241,19 +241,19 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let (sync_param, db_param) = sync_db_entry!(args.favorite, object::favorite); + sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id, }, - object::favorite::NAME, - msgpack!(&args.favorite), - ), - db.object().update( - object::id::equals(args.id), - vec![object::favorite::set(Some(args.favorite))], + [sync_param], ), + db.object() + .update(object::id::equals(args.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; @@ -346,34 +346,38 @@ pub(crate) fn mount() -> AlphaRouter { let date_accessed = Utc::now().into(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = objects + let (ops, object_ids) = objects .into_iter() - .map(|d| { + .map(|object| { ( sync.shared_update( - prisma_sync::object::SyncId { pub_id: d.pub_id }, - object::date_accessed::NAME, - msgpack!(date_accessed), + prisma_sync::object::SyncId { + pub_id: object.pub_id, + }, + [sync_entry!(date_accessed, object::date_accessed)], ), - d.id, + object.id, ) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( - db, - ( - sync_params, - db.object().update_many( - vec![object::id::in_vec(db_params)], - vec![object::date_accessed::set(Some(date_accessed))], + if !ops.is_empty() && !object_ids.is_empty() { + sync.write_ops( + db, + ( + ops, + db.object().update_many( + vec![object::id::in_vec(object_ids)], + vec![object::date_accessed::set(Some(date_accessed))], + ), ), - ), - ) - .await?; + ) + .await?; + + invalidate_query!(library, "search.paths"); + invalidate_query!(library, "search.objects"); + } - invalidate_query!(library, "search.paths"); - invalidate_query!(library, "search.objects"); Ok(()) }) }) @@ -389,33 +393,38 @@ pub(crate) fn mount() -> AlphaRouter { .exec() .await?; - let (sync_params, db_params): (Vec<_>, Vec<_>) = objects + let (ops, object_ids) = objects .into_iter() - .map(|d| { + .map(|object| { ( sync.shared_update( - prisma_sync::object::SyncId { pub_id: d.pub_id }, - object::date_accessed::NAME, - msgpack!(nil), + prisma_sync::object::SyncId { + pub_id: object.pub_id, + }, + [sync_entry!(nil, object::date_accessed)], ), - d.id, + object.id, ) }) - .unzip(); - sync.write_ops( - db, - ( - sync_params, - db.object().update_many( - vec![object::id::in_vec(db_params)], - vec![object::date_accessed::set(None)], - ), - ), - ) - .await?; + .unzip::<_, _, Vec<_>, Vec<_>>(); + + if !ops.is_empty() && !object_ids.is_empty() { + sync.write_ops( + db, + ( + ops, + db.object().update_many( + vec![object::id::in_vec(object_ids)], + vec![object::date_accessed::set(None)], + ), + ), + ) + .await?; + + invalidate_query!(library, "search.objects"); + invalidate_query!(library, "search.paths"); + } - invalidate_query!(library, "search.objects"); - invalidate_query!(library, "search.paths"); Ok(()) }) }) @@ -480,11 +489,32 @@ pub(crate) fn mount() -> AlphaRouter { path = %full_path.display(), "File not found in the file system, will remove from database;", ); - library + + let file_path_pub_id = library .db .file_path() - .delete(file_path::id::equals(args.file_path_ids[0])) + .find_unique(file_path::id::equals(args.file_path_ids[0])) + .select(file_path::select!({ pub_id })) .exec() + .await? + .ok_or(LocationError::FilePath(FilePathError::IdNotFound( + args.file_path_ids[0], + )))? + .pub_id; + + library + .sync + .write_op( + &library.db, + library.sync.shared_delete( + prisma_sync::file_path::SyncId { + pub_id: file_path_pub_id, + }, + ), + library.db.file_path().delete(file_path::id::equals( + args.file_path_ids[0], + )), + ) .await .map_err(LocationError::from)?; diff --git a/core/src/api/labels.rs b/core/src/api/labels.rs index 9aaaf30e3..eed08d8d3 100644 --- a/core/src/api/labels.rs +++ b/core/src/api/labels.rs @@ -116,7 +116,7 @@ pub(crate) fn mount() -> AlphaRouter { .procedure( "delete", R.with2(library()) - .mutation(|(_, library), label_id: i32| async move { + .mutation(|(_, library), label_id: label::id::Type| async move { let Library { db, sync, .. } = library.as_ref(); let label = db @@ -131,6 +131,35 @@ pub(crate) fn mount() -> AlphaRouter { ) })?; + let delete_ops = db + .label_on_object() + .find_many(vec![label_on_object::label_id::equals(label_id)]) + .select(label_on_object::select!({ object: select { pub_id } })) + .exec() + .await? + .into_iter() + .map(|label_on_object| { + sync.relation_delete(prisma_sync::label_on_object::SyncId { + label: prisma_sync::label::SyncId { + name: label.name.clone(), + }, + object: prisma_sync::object::SyncId { + pub_id: label_on_object.object.pub_id, + }, + }) + }) + .collect::>(); + + sync.write_ops( + db, + ( + delete_ops, + db.label_on_object() + .delete_many(vec![label_on_object::label_id::equals(label_id)]), + ), + ) + .await?; + sync.write_op( db, sync.shared_delete(prisma_sync::label::SyncId { name: label.name }), diff --git a/core/src/api/libraries.rs b/core/src/api/libraries.rs index 32e5d36da..d7e5cfe6c 100644 --- a/core/src/api/libraries.rs +++ b/core/src/api/libraries.rs @@ -471,37 +471,19 @@ pub(crate) fn mount() -> AlphaRouter { .procedure( "actors", R.with2(library()).subscription(|(_, library), _: ()| { - let mut rx = library.actors.invalidate_rx.resubscribe(); + let mut rx = library.cloud_sync_actors.invalidate_rx.resubscribe(); async_stream::stream! { - let actors = library.actors.get_state().await; + let actors = library.cloud_sync_actors.get_state().await; yield actors; while let Ok(()) = rx.recv().await { - let actors = library.actors.get_state().await; + let actors = library.cloud_sync_actors.get_state().await; yield actors; } } }), ) - .procedure( - "startActor", - R.with2(library()) - .mutation(|(_, library), name: String| async move { - library.actors.start(&name).await; - - Ok(()) - }), - ) - .procedure( - "stopActor", - R.with2(library()) - .mutation(|(_, library), name: String| async move { - library.actors.stop(&name).await; - - Ok(()) - }), - ) .procedure( "vacuumDb", R.with2(library()) diff --git a/core/src/api/mod.rs b/core/src/api/mod.rs index 8237afe16..7a1dd1597 100644 --- a/core/src/api/mod.rs +++ b/core/src/api/mod.rs @@ -3,25 +3,27 @@ use crate::{ library::LibraryId, node::{ config::{is_in_docker, NodeConfig, NodeConfigP2P, NodePreferences}, - get_hardware_model_name, HardwareModel, + HardwareModel, }, old_job::JobProgressEvent, Node, }; use sd_core_heavy_lifting::media_processor::ThumbKey; +use sd_core_sync::DevicePubId; + +use sd_cloud_schema::devices::DeviceOS; use sd_p2p::RemoteIdentity; use sd_prisma::prisma::file_path; -use std::sync::{atomic::Ordering, Arc}; +use std::sync::Arc; use itertools::Itertools; use rspc::{alpha::Rspc, Config, ErrorCode}; use serde::{Deserialize, Serialize}; use specta::Type; -use uuid::Uuid; +use tracing::warn; -mod auth; mod backups; mod cloud; mod ephemeral_files; @@ -70,35 +72,34 @@ pub enum CoreEvent { /// If you want a variant of this to show up on the frontend it must be added to `backendFeatures` in `useFeatureFlag.tsx` #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, Type)] #[serde(rename_all = "camelCase")] -pub enum BackendFeature { - CloudSync, -} +pub enum BackendFeature {} -impl BackendFeature { - pub fn restore(&self, node: &Node) { - match self { - BackendFeature::CloudSync => { - node.cloud_sync_flag.store(true, Ordering::Relaxed); - } - } - } -} +// impl BackendFeature { +// pub fn restore(&self, node: &Node) { +// match self { +// BackendFeature::CloudSync => { +// node.cloud_sync_flag.store(true, Ordering::Relaxed); +// } +// } +// } +// } -// A version of [NodeConfig] that is safe to share with the frontend +/// A version of [`NodeConfig`] that is safe to share with the frontend #[derive(Debug, Serialize, Deserialize, Clone, Type)] -pub struct SanitisedNodeConfig { +pub struct SanitizedNodeConfig { /// id is a unique identifier for the current node. Each node has a public identifier (this one) and is given a local id for each library (done within the library code). - pub id: Uuid, + pub id: DevicePubId, /// name is the display name of the current node. This is set by the user and is shown in the UI. // TODO: Length validation so it can fit in DNS record pub name: String, pub identity: RemoteIdentity, pub p2p: NodeConfigP2P, pub features: Vec, pub preferences: NodePreferences, - pub image_labeler_version: Option, + pub os: DeviceOS, + pub hardware_model: HardwareModel, } -impl From for SanitisedNodeConfig { +impl From for SanitizedNodeConfig { fn from(value: NodeConfig) -> Self { Self { id: value.id, @@ -107,7 +108,8 @@ impl From for SanitisedNodeConfig { p2p: value.p2p, features: value.features, preferences: value.preferences, - image_labeler_version: value.image_labeler_version, + os: value.os, + hardware_model: value.hardware_model, } } } @@ -115,7 +117,7 @@ impl From for SanitisedNodeConfig { #[derive(Serialize, Debug, Type)] struct NodeState { #[serde(flatten)] - config: SanitisedNodeConfig, + config: SanitizedNodeConfig, data_path: String, device_model: Option, is_in_docker: bool, @@ -140,12 +142,11 @@ pub(crate) fn mount() -> Arc { }) .procedure("nodeState", { R.query(|node, _: ()| async move { - let device_model = get_hardware_model_name() - .unwrap_or(HardwareModel::Other) - .to_string(); + let config = SanitizedNodeConfig::from(node.config.get().await); Ok(NodeState { - config: node.config.get().await.into(), + device_model: Some(config.hardware_model.to_string()), + config, // We are taking the assumption here that this value is only used on the frontend for display purposes data_path: node .config @@ -153,7 +154,6 @@ pub(crate) fn mount() -> Arc { .to_str() .expect("Found non-UTF-8 path") .to_string(), - device_model: Some(device_model), is_in_docker: is_in_docker(), }) }) @@ -179,11 +179,13 @@ pub(crate) fn mount() -> Arc { } .map_err(|e| rspc::Error::new(ErrorCode::InternalServerError, e.to_string()))?; - match feature { - BackendFeature::CloudSync => { - node.cloud_sync_flag.store(enabled, Ordering::Relaxed); - } - } + warn!("Feature {:?} is now {}", feature, enabled); + + // match feature { + // BackendFeature::CloudSync => { + // node.cloud_sync_flag.store(enabled, Ordering::Relaxed); + // } + // } invalidate_query!(node; node, "nodeState"); @@ -191,7 +193,6 @@ pub(crate) fn mount() -> Arc { }) }) .merge("api.", web_api::mount()) - .merge("auth.", auth::mount()) .merge("cloud.", cloud::mount()) .merge("search.", search::mount()) .merge("library.", libraries::mount()) diff --git a/core/src/api/nodes.rs b/core/src/api/nodes.rs index 0f422b593..9083d4e70 100644 --- a/core/src/api/nodes.rs +++ b/core/src/api/nodes.rs @@ -5,9 +5,10 @@ use crate::{ node::config::{P2PDiscoveryState, Port}, }; -use sd_prisma::prisma::{instance, location}; +use sd_prisma::prisma::{device, location}; use rspc::{alpha::AlphaRouter, ErrorCode}; +use sd_utils::uuid_to_bytes; use serde::Deserialize; use specta::Type; use tracing::error; @@ -28,8 +29,6 @@ pub(crate) fn mount() -> AlphaRouter { pub p2p_discovery: Option, pub p2p_remote_access: Option, pub p2p_manual_peers: Option>, - #[cfg(feature = "ai")] - pub image_labeler_version: Option, } R.mutation(|node, args: ChangeNodeNameArgs| async move { if let Some(name) = &args.name { @@ -41,9 +40,6 @@ pub(crate) fn mount() -> AlphaRouter { } } - #[cfg(feature = "ai")] - let mut new_model = None; - node.config .write(|config| { if let Some(name) = args.name { @@ -71,29 +67,6 @@ pub(crate) fn mount() -> AlphaRouter { if let Some(manual_peers) = args.p2p_manual_peers { config.p2p.manual_peers = manual_peers; }; - - #[cfg(feature = "ai")] - if let Some(version) = args.image_labeler_version { - if config - .image_labeler_version - .as_ref() - .map(|node_version| version != *node_version) - .unwrap_or(true) - { - new_model = sd_ai::old_image_labeler::YoloV8::model(Some(&version)) - .map_err(|e| { - error!( - %version, - ?e, - "Failed to crate image_detection model;", - ); - }) - .ok(); - if new_model.is_some() { - config.image_labeler_version = Some(version); - } - } - } }) .await .map_err(|e| { @@ -109,44 +82,6 @@ pub(crate) fn mount() -> AlphaRouter { invalidate_query!(node; node, "nodeState"); - #[cfg(feature = "ai")] - { - use super::notifications::{NotificationData, NotificationKind}; - - if let Some(model) = new_model { - let version = model.version().to_string(); - tokio::spawn(async move { - let notification = if let Some(image_labeller) = - node.old_image_labeller.as_ref() - { - if let Err(e) = image_labeller.change_model(model).await { - NotificationData { - title: String::from( - "Failed to change image detection model", - ), - content: format!("Error: {e}"), - kind: NotificationKind::Error, - } - } else { - NotificationData { - title: String::from("Model download completed"), - content: format!("Successfully loaded model: {version}"), - kind: NotificationKind::Success, - } - } - } else { - NotificationData { - title: String::from("Failed to change image detection model"), - content: "The AI system is disabled due to a previous error. Contact support for help.".to_string(), - kind: NotificationKind::Success, - } - }; - - node.emit_notification(notification, None).await; - }); - } - } - Ok(()) }) }) @@ -154,27 +89,18 @@ pub(crate) fn mount() -> AlphaRouter { .procedure("listLocations", { R.with2(library()) // TODO: I don't like this. `node_id` should probs be a machine hash or something cause `node_id` is dynamic in the context of P2P and what does it mean for removable media to be owned by a node? - .query(|(_, library), node_id: Option| async move { - // Be aware multiple instances can exist on a single node. This is generally an edge case but it's possible. - let instances = library - .db - .instance() - .find_many(vec![node_id - .map(|id| instance::node_id::equals(id.as_bytes().to_vec())) - .unwrap_or(instance::id::equals( - library.config().await.instance_id, - ))]) - .exec() - .await?; - + .query(|(_, library), device_pub_id: Option| async move { Ok(library .db .location() .find_many( - instances - .into_iter() - .map(|i| location::instance_id::equals(Some(i.id))) - .collect(), + device_pub_id + .map(|id| { + vec![location::device::is(vec![device::pub_id::equals( + uuid_to_bytes(&id), + )])] + }) + .unwrap_or_default(), ) .exec() .await? diff --git a/core/src/api/search/saved.rs b/core/src/api/search/saved.rs index e2e797765..957474c49 100644 --- a/core/src/api/search/saved.rs +++ b/core/src/api/search/saved.rs @@ -66,10 +66,10 @@ pub(crate) fn mount() -> AlphaRouter { |(_, library), args: Args| async move { let Library { db, sync, .. } = library.as_ref(); - let pub_id = Uuid::new_v4().as_bytes().to_vec(); + let pub_id = Uuid::now_v7().as_bytes().to_vec(); let date_created: DateTime = Utc::now().into(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = chain_optional_iter( + let (sync_params, db_params) = chain_optional_iter( [ sync_db_entry!(date_created, saved_search::date_created), sync_db_entry!(args.name, saved_search::name), @@ -96,19 +96,19 @@ pub(crate) fn mount() -> AlphaRouter { ], ) .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( + sync.write_op( db, - ( - sync.shared_create( - prisma_sync::saved_search::SyncId { - pub_id: pub_id.clone(), - }, - sync_params, - ), - db.saved_search().create(pub_id, db_params), + sync.shared_create( + prisma_sync::saved_search::SyncId { + pub_id: pub_id.clone(), + }, + sync_params, ), + db.saved_search() + .create(pub_id, db_params) + .select(saved_search::select!({ id })), ) .await?; @@ -164,7 +164,7 @@ pub(crate) fn mount() -> AlphaRouter { rspc::Error::new(rspc::ErrorCode::NotFound, "search not found".into()) })?; - let (sync_params, db_params): (Vec<_>, Vec<_>) = chain_optional_iter( + let (sync_params, db_params) = chain_optional_iter( [sync_db_entry!(updated_at, saved_search::date_modified)], [ option_sync_db_entry!(args.name.flatten(), saved_search::name), @@ -175,27 +175,18 @@ pub(crate) fn mount() -> AlphaRouter { ], ) .into_iter() - .map(|((k, v), p)| { - ( - sync.shared_update( - prisma_sync::saved_search::SyncId { - pub_id: search.pub_id.clone(), - }, - k, - v, - ), - p, - ) - }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( + sync.write_op( db, - ( + sync.shared_update( + prisma_sync::saved_search::SyncId { + pub_id: search.pub_id.clone(), + }, sync_params, - db.saved_search() - .update_unchecked(saved_search::id::equals(id), db_params), ), + db.saved_search() + .update_unchecked(saved_search::id::equals(id), db_params), ) .await?; diff --git a/core/src/api/sync.rs b/core/src/api/sync.rs index 50935a249..5880f8fa7 100644 --- a/core/src/api/sync.rs +++ b/core/src/api/sync.rs @@ -1,5 +1,4 @@ use rspc::alpha::AlphaRouter; -use sd_core_sync::GetOpsArgs; use std::sync::atomic::Ordering; use crate::util::MaybeUndefined; @@ -8,32 +7,6 @@ use super::{utils::library, Ctx, R}; pub(crate) fn mount() -> AlphaRouter { R.router() - .procedure("newMessage", { - R.with2(library()) - .subscription(|(_, library), _: ()| async move { - async_stream::stream! { - let mut rx = library.sync.subscribe(); - while let Ok(_msg) = rx.recv().await { - // let op = match msg { - // SyncMessage::Ingested => (), - // SyncMessage::Created => op - // }; - yield (); - } - } - }) - }) - .procedure("messages", { - R.with2(library()).query(|(_, library), _: ()| async move { - Ok(library - .sync - .get_ops(GetOpsArgs { - clocks: vec![], - count: 1000, - }) - .await?) - }) - }) .procedure("backfill", { R.with2(library()) .mutation(|(node, library), _: ()| async move { @@ -46,12 +19,7 @@ pub(crate) fn mount() -> AlphaRouter { return Ok(()); } - sd_core_sync::backfill::backfill_operations( - &library.db, - &library.sync, - library.config().await.instance_id, - ) - .await?; + sd_core_sync::backfill::backfill_operations(&library.sync).await?; node.libraries .edit( @@ -88,19 +56,19 @@ pub(crate) fn mount() -> AlphaRouter { } async_stream::stream! { - let cloud_sync = &library.cloud.sync; - let sync = &library.sync.shared; + let cloud_sync_state = &library.cloud_sync_state; + let sync = &library.sync; loop { yield Data { ingest: sync.active.load(Ordering::Relaxed), - cloud_send: cloud_sync.send_active.load(Ordering::Relaxed), - cloud_receive: cloud_sync.receive_active.load(Ordering::Relaxed), - cloud_ingest: cloud_sync.ingest_active.load(Ordering::Relaxed), + cloud_send: cloud_sync_state.send_active.load(Ordering::Relaxed), + cloud_receive: cloud_sync_state.receive_active.load(Ordering::Relaxed), + cloud_ingest: cloud_sync_state.ingest_active.load(Ordering::Relaxed), }; tokio::select! { - _ = cloud_sync.notifier.notified() => {}, + _ = cloud_sync_state.state_change_notifier.notified() => {}, _ = sync.active_notify.notified() => {} } } diff --git a/core/src/api/tags.rs b/core/src/api/tags.rs index b951368f2..0035ea592 100644 --- a/core/src/api/tags.rs +++ b/core/src/api/tags.rs @@ -1,11 +1,10 @@ use crate::{invalidate_query, library::Library, object::tag::TagCreateArgs}; use sd_prisma::{ - prisma::{file_path, object, tag, tag_on_object}, + prisma::{device, file_path, object, tag, tag_on_object}, prisma_sync, }; -use sd_sync::{option_sync_db_entry, OperationFactory}; -use sd_utils::{msgpack, uuid_to_bytes}; +use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, OperationFactory}; use std::collections::BTreeMap; @@ -14,7 +13,6 @@ use itertools::{Either, Itertools}; use rspc::{alpha::AlphaRouter, ErrorCode}; use serde::{Deserialize, Serialize}; use specta::Type; -use uuid::Uuid; use super::{utils::library, Ctx, R}; @@ -131,6 +129,21 @@ pub(crate) fn mount() -> AlphaRouter { .mutation(|(_, library), args: TagAssignArgs| async move { let Library { db, sync, .. } = library.as_ref(); + let device_id = library + .db + .device() + .find_unique(device::pub_id::equals(sync.device_pub_id.to_db())) + .select(device::select!({ id })) + .exec() + .await? + .ok_or_else(|| { + rspc::Error::new( + ErrorCode::NotFound, + "Local device not found".to_string(), + ) + })? + .id; + let tag = db .tag() .find_unique(tag::id::equals(args.tag_id)) @@ -170,17 +183,6 @@ pub(crate) fn mount() -> AlphaRouter { }) .await?; - macro_rules! sync_id { - ($pub_id:expr) => { - prisma_sync::tag_on_object::SyncId { - tag: prisma_sync::tag::SyncId { - pub_id: tag.pub_id.clone(), - }, - object: prisma_sync::object::SyncId { pub_id: $pub_id }, - } - }; - } - if args.unassign { let query = db.tag_on_object().delete_many(vec![ tag_on_object::tag_id::equals(args.tag_id), @@ -197,59 +199,28 @@ pub(crate) fn mount() -> AlphaRouter { ), ]); - sync.write_ops( - db, - ( - objects + let ops = objects + .into_iter() + .map(|o| o.pub_id) + .chain( + file_paths .into_iter() - .map(|o| o.pub_id) - .chain( - file_paths - .into_iter() - .filter_map(|fp| fp.object.map(|o| o.pub_id)), - ) - .map(|pub_id| sync.relation_delete(sync_id!(pub_id))) - .collect(), - query, - ), - ) - .await?; - } else { - let mut sync_params = vec![]; - - let db_params: (Vec<_>, Vec<_>) = file_paths - .iter() - .filter(|fp| fp.is_dir.unwrap_or_default() && fp.object.is_none()) - .map(|fp| { - let id = uuid_to_bytes(&Uuid::new_v4()); - - sync_params.extend(sync.shared_create( - prisma_sync::object::SyncId { pub_id: id.clone() }, - [], - )); - - sync_params.push(sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: fp.pub_id.clone(), + .filter_map(|fp| fp.object.map(|o| o.pub_id)), + ) + .map(|pub_id| { + sync.relation_delete(prisma_sync::tag_on_object::SyncId { + tag: prisma_sync::tag::SyncId { + pub_id: tag.pub_id.clone(), }, - file_path::object::NAME, - msgpack!(id), - )); - - ( - db.object().create(id.clone(), vec![]), - db.file_path().update( - file_path::id::equals(fp.id), - vec![file_path::object::connect(object::pub_id::equals( - id, - ))], - ), - ) + object: prisma_sync::object::SyncId { pub_id }, + }) }) - .unzip(); - - let (new_objects, _) = sync.write_ops(db, (sync_params, db_params)).await?; + .collect::>(); + if !ops.is_empty() { + sync.write_ops(db, (ops, query)).await?; + } + } else { let (sync_ops, db_creates) = objects .into_iter() .map(|o| (o.id, o.pub_id)) @@ -258,32 +229,46 @@ pub(crate) fn mount() -> AlphaRouter { .into_iter() .filter_map(|fp| fp.object.map(|o| (o.id, o.pub_id))), ) - .chain(new_objects.into_iter().map(|o| (o.id, o.pub_id))) - .fold( - (vec![], vec![]), - |(mut sync_ops, mut db_creates), (id, pub_id)| { - db_creates.push(tag_on_object::CreateUnchecked { + .map(|(id, pub_id)| { + ( + sync.relation_create( + prisma_sync::tag_on_object::SyncId { + tag: prisma_sync::tag::SyncId { + pub_id: tag.pub_id.clone(), + }, + object: prisma_sync::object::SyncId { pub_id }, + }, + [sync_entry!( + prisma_sync::device::SyncId { + pub_id: sync.device_pub_id.to_db(), + }, + tag_on_object::device + )], + ), + tag_on_object::CreateUnchecked { tag_id: args.tag_id, object_id: id, - _params: vec![tag_on_object::date_created::set(Some( - Utc::now().into(), - ))], - }); + _params: vec![ + tag_on_object::date_created::set(Some( + Utc::now().into(), + )), + tag_on_object::device_id::set(Some(device_id)), + ], + }, + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync_ops.extend(sync.relation_create(sync_id!(pub_id), [])); - - (sync_ops, db_creates) - }, - ); - - sync.write_ops( - db, - ( - sync_ops, - db.tag_on_object().create_many(db_creates).skip_duplicates(), - ), - ) - .await?; + if !sync_ops.is_empty() && !db_creates.is_empty() { + sync.write_ops( + db, + ( + sync_ops, + db.tag_on_object().create_many(db_creates).skip_duplicates(), + ), + ) + .await?; + } } invalidate_query!(library, "tags.getForObject"); @@ -301,13 +286,17 @@ pub(crate) fn mount() -> AlphaRouter { pub color: Option, } - R.with2(library()) - .mutation(|(_, library), args: TagUpdateArgs| async move { + R.with2(library()).mutation( + |(_, library), TagUpdateArgs { id, name, color }: TagUpdateArgs| async move { + if name.is_none() && color.is_none() { + return Ok(()); + } + let Library { sync, db, .. } = library.as_ref(); let tag = db .tag() - .find_unique(tag::id::equals(args.id)) + .find_unique(tag::id::equals(id)) .select(tag::select!({ pub_id })) .exec() .await? @@ -316,64 +305,88 @@ pub(crate) fn mount() -> AlphaRouter { "Error finding tag in db".into(), ))?; - db.tag() - .update( - tag::id::equals(args.id), - vec![tag::date_modified::set(Some(Utc::now().into()))], - ) - .exec() - .await?; - - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - option_sync_db_entry!(args.name, tag::name), - option_sync_db_entry!(args.color, tag::color), + let (sync_params, db_params) = [ + option_sync_db_entry!(name, tag::name), + option_sync_db_entry!(color, tag::color), + Some(sync_db_entry!(Utc::now(), tag::date_modified)), ] .into_iter() .flatten() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|(k, v)| { - sync.shared_update( - prisma_sync::tag::SyncId { - pub_id: tag.pub_id.clone(), - }, - k, - v, - ) - }) - .collect(), - db.tag().update(tag::id::equals(args.id), db_params), + sync.shared_update( + prisma_sync::tag::SyncId { + pub_id: tag.pub_id.clone(), + }, + sync_params, ), + db.tag() + .update(tag::id::equals(id), db_params) + .select(tag::select!({ id })), ) .await?; invalidate_query!(library, "tags.list"); Ok(()) - }) + }, + ) }) .procedure( "delete", R.with2(library()) - .mutation(|(_, library), tag_id: i32| async move { - library - .db - .tag_on_object() - .delete_many(vec![tag_on_object::tag_id::equals(tag_id)]) - .exec() - .await?; + .mutation(|(_, library), tag_id: tag::id::Type| async move { + let Library { sync, db, .. } = &*library; - library - .db + let tag_pub_id = db .tag() - .delete(tag::id::equals(tag_id)) + .find_unique(tag::id::equals(tag_id)) + .select(tag::select!({ pub_id })) .exec() - .await?; + .await? + .ok_or(rspc::Error::new( + rspc::ErrorCode::NotFound, + "Tag not found".to_string(), + ))? + .pub_id; + + let delete_ops = db + .tag_on_object() + .find_many(vec![tag_on_object::tag_id::equals(tag_id)]) + .select(tag_on_object::select!({ object: select { pub_id } })) + .exec() + .await? + .into_iter() + .map(|tag_on_object| { + sync.relation_delete(prisma_sync::tag_on_object::SyncId { + tag: prisma_sync::tag::SyncId { + pub_id: tag_pub_id.clone(), + }, + object: prisma_sync::object::SyncId { + pub_id: tag_on_object.object.pub_id, + }, + }) + }) + .collect::>(); + + sync.write_ops( + db, + ( + delete_ops, + db.tag_on_object() + .delete_many(vec![tag_on_object::tag_id::equals(tag_id)]), + ), + ) + .await?; + + sync.write_op( + db, + sync.shared_delete(prisma_sync::tag::SyncId { pub_id: tag_pub_id }), + db.tag().delete(tag::id::equals(tag_id)), + ) + .await?; invalidate_query!(library, "tags.list"); diff --git a/core/src/api/utils/invalidate.rs b/core/src/api/utils/invalidate.rs index 8df2eea6d..e888b08a2 100644 --- a/core/src/api/utils/invalidate.rs +++ b/core/src/api/utils/invalidate.rs @@ -121,6 +121,7 @@ impl InvalidRequests { } /// `invalidate_query` is a macro which stores a list of all of it's invocations so it can ensure all of the queries match the queries attached to the router. +/// /// This allows invalidate to be type-safe even when the router keys are stringly typed. /// ```ignore /// invalidate_query!( diff --git a/core/src/api/utils/library.rs b/core/src/api/utils/library.rs index effdb89ba..001943f55 100644 --- a/core/src/api/utils/library.rs +++ b/core/src/api/utils/library.rs @@ -22,7 +22,10 @@ pub(crate) struct LibraryArgs { pub(crate) struct LibraryArgsLike; impl MwArgMapper for LibraryArgsLike { - type Input = LibraryArgs where T: Type + DeserializeOwned + 'static; + type Input + = LibraryArgs + where + T: Type + DeserializeOwned + 'static; type State = Uuid; fn map( diff --git a/core/src/api/utils/mod.rs b/core/src/api/utils/mod.rs index 93b00c104..b12c1ec7e 100644 --- a/core/src/api/utils/mod.rs +++ b/core/src/api/utils/mod.rs @@ -1,5 +1,8 @@ use std::path::Path; +// #[cfg(not(any(target_os = "ios", target_os = "android")))] +// use keyring::Entry; + use tokio::{fs, io}; mod invalidate; @@ -35,3 +38,72 @@ pub async fn get_size(path: impl AsRef) -> Result { Ok(metadata.len()) } } + +// pub fn get_access_token() -> Result { +// // If target is ios or android, return an error as this function is not supported on those platforms +// if cfg!(any(target_os = "ios", target_os = "android")) { +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Function not supported on this platform".to_string(), +// )); +// } else { +// let username = whoami::username(); +// let entry = match Entry::new("spacedrive-auth-service", username.as_str()) { +// Ok(entry) => entry, +// Err(e) => { +// error!("Error creating entry: {}", e); +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Error creating entry".to_string(), +// )); +// } +// }; + +// let data = match entry.get_password() { +// Ok(key) => key, +// Err(e) => { +// error!("Error retrieving key: {}. Does the key exist yet?", e); +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Error retrieving key".to_string(), +// )); +// } +// }; + +// let re = match Regex::new(r#"st-access-token=([^;]+)"#) { +// Ok(re) => re, +// Err(e) => { +// error!("Error creating regex: {}", e); +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Error creating regex".to_string(), +// )); +// } +// }; + +// let token = match re.captures(&data) { +// Some(captures) => match captures.get(1) { +// Some(token) => token.as_str(), +// None => { +// error!("Error parsing Cookie String value: {}", "No token found"); +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Error parsing Cookie String value".to_string(), +// )); +// } +// }, +// None => { +// error!( +// "Error parsing Cookie String value: {}", +// "No token cookie string found" +// ); +// return Err(rspc::Error::new( +// rspc::ErrorCode::InternalServerError, +// "Error parsing Cookie String value".to_string(), +// )); +// } +// }; + +// Ok(token.to_string()) +// } +// } diff --git a/core/src/api/web_api.rs b/core/src/api/web_api.rs index 49802bb5b..81cb410d3 100644 --- a/core/src/api/web_api.rs +++ b/core/src/api/web_api.rs @@ -14,13 +14,13 @@ pub(crate) fn mount() -> AlphaRouter { emoji: u8, } - |node, args: Feedback| async move { - sd_cloud_api::feedback::send( - node.cloud_api_config().await, - args.message, - args.emoji, - ) - .await?; + |_node, _args: Feedback| async move { + // sd_cloud_api::feedback::send( + // node.cloud_api_config().await, + // args.message, + // args.emoji, + // ) + // .await?; Ok(()) } diff --git a/core/src/cloud/mod.rs b/core/src/cloud/mod.rs deleted file mode 100644 index 529c18cd1..000000000 --- a/core/src/cloud/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::sync::Arc; - -use uuid::Uuid; - -use crate::Node; - -pub mod sync; - -#[derive(Default)] -pub struct State { - pub sync: sync::State, -} - -pub async fn start( - node: &Arc, - actors: &Arc, - library_id: Uuid, - instance_uuid: Uuid, - sync: &Arc, - db: &Arc, -) -> State { - let sync = sync::declare_actors( - node, - actors, - library_id, - instance_uuid, - sync.clone(), - db.clone(), - ) - .await; - - State { sync } -} diff --git a/core/src/cloud/sync/ingest.rs b/core/src/cloud/sync/ingest.rs deleted file mode 100644 index d41331dff..000000000 --- a/core/src/cloud/sync/ingest.rs +++ /dev/null @@ -1,127 +0,0 @@ -use crate::cloud::sync::err_break; - -use sd_actors::Stopper; -use sd_prisma::prisma::cloud_crdt_operation; -use sd_sync::CompressedCRDTOperations; - -use std::{ - future::IntoFuture, - pin::pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; - -use futures::{FutureExt, StreamExt}; -use futures_concurrency::future::Race; -use tokio::sync::Notify; -use tracing::debug; - -// Responsible for taking sync operations received from the cloud, -// and applying them to the local database via the sync system's ingest actor. - -pub async fn run_actor( - sync: Arc, - notify: Arc, - state: Arc, - state_notify: Arc, - stop: Stopper, -) { - enum Race { - Notified, - Stopped, - } - - loop { - state.store(true, Ordering::Relaxed); - state_notify.notify_waiters(); - - { - let mut rx = pin!(sync.ingest.req_rx.clone()); - - if sync - .ingest - .event_tx - .send(sd_core_sync::Event::Notification) - .await - .is_ok() - { - while let Some(req) = rx.next().await { - const OPS_PER_REQUEST: u32 = 1000; - - // FIXME: If there are exactly a multiple of OPS_PER_REQUEST operations, - // then this will bug, as we sent `has_more` as true, but we don't have - // more operations to send. - - use sd_core_sync::*; - - let timestamps = match req { - Request::FinishedIngesting => { - break; - } - Request::Messages { timestamps, .. } => timestamps, - }; - - let (ops_ids, ops): (Vec<_>, Vec<_>) = err_break!( - sync.get_cloud_ops(GetOpsArgs { - clocks: timestamps, - count: OPS_PER_REQUEST, - }) - .await - ) - .into_iter() - .unzip(); - - if ops.is_empty() { - break; - } - - debug!( - messages_count = ops.len(), - first_message = ?ops.first().map(|operation| operation.timestamp.as_u64()), - last_message = ?ops.last().map(|operation| operation.timestamp.as_u64()), - "Sending messages to ingester", - ); - - let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>(); - - err_break!( - sync.ingest - .event_tx - .send(sd_core_sync::Event::Messages(MessagesEvent { - instance_id: sync.instance, - has_more: ops.len() == OPS_PER_REQUEST as usize, - messages: CompressedCRDTOperations::new(ops), - wait_tx: Some(wait_tx) - })) - .await - ); - - err_break!(wait_rx.await); - - err_break!( - sync.db - .cloud_crdt_operation() - .delete_many(vec![cloud_crdt_operation::id::in_vec(ops_ids)]) - .exec() - .await - ); - } - } - } - - state.store(false, Ordering::Relaxed); - state_notify.notify_waiters(); - - if let Race::Stopped = ( - notify.notified().map(|()| Race::Notified), - stop.into_future().map(|()| Race::Stopped), - ) - .race() - .await - { - break; - } - } -} diff --git a/core/src/cloud/sync/mod.rs b/core/src/cloud/sync/mod.rs deleted file mode 100644 index 8a52025bb..000000000 --- a/core/src/cloud/sync/mod.rs +++ /dev/null @@ -1,109 +0,0 @@ -use sd_sync::*; -use std::sync::{ - atomic::{self, AtomicBool}, - Arc, -}; -use tokio::sync::Notify; -use uuid::Uuid; - -use crate::Node; - -pub mod ingest; -pub mod receive; -pub mod send; - -#[derive(Default)] -pub struct State { - pub send_active: Arc, - pub receive_active: Arc, - pub ingest_active: Arc, - pub notifier: Arc, -} - -pub async fn declare_actors( - node: &Arc, - actors: &Arc, - library_id: Uuid, - instance_uuid: Uuid, - sync: Arc, - db: Arc, -) -> State { - let ingest_notify = Arc::new(Notify::new()); - let state = State::default(); - - let autorun = node.cloud_sync_flag.load(atomic::Ordering::Relaxed); - - actors - .declare( - "Cloud Sync Sender", - { - let sync = sync.clone(); - let node = node.clone(); - let active = state.send_active.clone(); - let active_notifier = state.notifier.clone(); - - move |stop| send::run_actor(library_id, sync, node, active, active_notifier, stop) - }, - autorun, - ) - .await; - - actors - .declare( - "Cloud Sync Receiver", - { - let sync = sync.clone(); - let node = node.clone(); - let ingest_notify = ingest_notify.clone(); - let active_notifier = state.notifier.clone(); - let active = state.receive_active.clone(); - - move |stop| { - receive::run_actor( - node.libraries.clone(), - db.clone(), - library_id, - instance_uuid, - sync, - ingest_notify, - node, - active, - active_notifier, - stop, - ) - } - }, - autorun, - ) - .await; - - actors - .declare( - "Cloud Sync Ingest", - { - let active = state.ingest_active.clone(); - let active_notifier = state.notifier.clone(); - - move |stop| { - ingest::run_actor(sync.clone(), ingest_notify, active, active_notifier, stop) - } - }, - autorun, - ) - .await; - - state -} - -macro_rules! err_break { - ($e:expr) => { - match $e { - Ok(d) => d, - Err(e) => { - tracing::error!(?e); - break; - } - } - }; -} -pub(crate) use err_break; diff --git a/core/src/cloud/sync/receive.rs b/core/src/cloud/sync/receive.rs deleted file mode 100644 index a0ec93abf..000000000 --- a/core/src/cloud/sync/receive.rs +++ /dev/null @@ -1,304 +0,0 @@ -use crate::{library::Libraries, Node}; - -use futures::FutureExt; -use futures_concurrency::future::Race; -use sd_actors::Stopper; -use sd_cloud_api::{library::message_collections::get::InstanceTimestamp, RequestConfigProvider}; -use sd_p2p::RemoteIdentity; -use sd_prisma::prisma::{cloud_crdt_operation, instance, PrismaClient, SortOrder}; -use sd_sync::CRDTOperation; -use sd_utils::uuid_to_bytes; - -use std::{ - collections::{hash_map::Entry, HashMap}, - future::IntoFuture, - str::FromStr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; - -use base64::prelude::*; -use chrono::Utc; -use serde_json::to_vec; -use tokio::{sync::Notify, time::sleep}; -use tracing::{debug, info}; -use uuid::Uuid; - -use super::{err_break, CompressedCRDTOperations}; - -// Responsible for downloading sync operations from the cloud to be processed by the ingester - -#[allow(clippy::too_many_arguments)] -pub async fn run_actor( - libraries: Arc, - db: Arc, - library_id: Uuid, - instance_uuid: Uuid, - sync: Arc, - ingest_notify: Arc, - node: Arc, - active: Arc, - active_notify: Arc, - stop: Stopper, -) { - enum Race { - Continue, - Stop, - } - - loop { - active.store(true, Ordering::Relaxed); - active_notify.notify_waiters(); - - loop { - // We need to know the latest operations we should be retrieving - let mut cloud_timestamps = { - let timestamps = sync.timestamps.read().await; - - // looks up the most recent operation we've received (not ingested!) for each instance - let db_timestamps = err_break!( - db._batch( - timestamps - .keys() - .map(|id| { - db.cloud_crdt_operation() - .find_first(vec![cloud_crdt_operation::instance::is(vec![ - instance::pub_id::equals(uuid_to_bytes(id)), - ])]) - .order_by(cloud_crdt_operation::timestamp::order( - SortOrder::Desc, - )) - }) - .collect::>() - ) - .await - ); - - // compares the latest ingested timestamp with the latest received timestamp - // and picks the highest one for each instance - let mut cloud_timestamps = db_timestamps - .into_iter() - .zip(timestamps.iter()) - .map(|(d, (id, sync_timestamp))| { - let cloud_timestamp = d.map(|d| d.timestamp).unwrap_or_default() as u64; - - debug!( - instance_id = %id, - sync_timestamp = sync_timestamp.as_u64(), - %cloud_timestamp, - "Comparing sync timestamps", - ); - - let max_timestamp = Ord::max(cloud_timestamp, sync_timestamp.as_u64()); - - (*id, max_timestamp) - }) - .collect::>(); - - cloud_timestamps.remove(&instance_uuid); - - cloud_timestamps - }; - - let instance_timestamps: Vec = sync - .timestamps - .read() - .await - .keys() - .map( - |uuid| sd_cloud_api::library::message_collections::get::InstanceTimestamp { - instance_uuid: *uuid, - from_time: cloud_timestamps - .get(uuid) - .copied() - .unwrap_or_default() - .to_string(), - }, - ) - .collect(); - - let collections = err_break!( - sd_cloud_api::library::message_collections::get( - node.get_request_config().await, - library_id, - instance_uuid, - instance_timestamps, - ) - .await - ); - - info!( - collections_count = collections.len(), - "Received collections;", - ); - - if collections.is_empty() { - break; - } - - let mut cloud_library_data: Option> = None; - - for collection in collections { - if let Entry::Vacant(e) = cloud_timestamps.entry(collection.instance_uuid) { - let fetched_library = match &cloud_library_data { - None => { - let Some(fetched_library) = err_break!( - sd_cloud_api::library::get( - node.get_request_config().await, - library_id - ) - .await - ) else { - break; - }; - - cloud_library_data - .insert(Some(fetched_library)) - .as_ref() - .expect("error inserting fetched library") - } - Some(None) => { - break; - } - Some(Some(fetched_library)) => fetched_library, - }; - - let Some(instance) = fetched_library - .instances - .iter() - .find(|i| i.uuid == collection.instance_uuid) - else { - break; - }; - - err_break!( - upsert_instance( - library_id, - &db, - &sync, - &libraries, - &collection.instance_uuid, - instance.identity, - &instance.node_id, - RemoteIdentity::from_str(&instance.node_remote_identity) - .expect("malformed remote identity in the DB"), - node.p2p.peer_metadata(), - ) - .await - ); - - e.insert(0); - } - - let compressed_operations: CompressedCRDTOperations = err_break!( - rmp_serde::from_slice(err_break!(&BASE64_STANDARD.decode(collection.contents))) - ); - - let operations = compressed_operations.into_ops(); - - debug!( - instance_id = %collection.instance_uuid, - start = ?operations.first().map(|operation| operation.timestamp.as_u64()), - end = ?operations.last().map(|operation| operation.timestamp.as_u64()), - "Processing collection", - ); - - err_break!(write_cloud_ops_to_db(operations, &db).await); - - let collection_timestamp: u64 = - collection.end_time.parse().expect("unable to parse time"); - - let timestamp = cloud_timestamps - .entry(collection.instance_uuid) - .or_insert(collection_timestamp); - - if *timestamp < collection_timestamp { - *timestamp = collection_timestamp; - } - } - - ingest_notify.notify_waiters(); - } - - active.store(false, Ordering::Relaxed); - active_notify.notify_waiters(); - - if let Race::Stop = ( - sleep(Duration::from_secs(60)).map(|()| Race::Continue), - stop.into_future().map(|()| Race::Stop), - ) - .race() - .await - { - break; - } - } -} - -async fn write_cloud_ops_to_db( - ops: Vec, - db: &PrismaClient, -) -> Result<(), prisma_client_rust::QueryError> { - db._batch(ops.into_iter().map(|op| crdt_op_db(&op).to_query(db))) - .await?; - - Ok(()) -} - -fn crdt_op_db(op: &CRDTOperation) -> cloud_crdt_operation::Create { - cloud_crdt_operation::Create { - timestamp: op.timestamp.0 as i64, - instance: instance::pub_id::equals(op.instance.as_bytes().to_vec()), - kind: op.data.as_kind().to_string(), - data: to_vec(&op.data).expect("unable to serialize data"), - model: op.model as i32, - record_id: rmp_serde::to_vec(&op.record_id).expect("unable to serialize record id"), - _params: vec![], - } -} - -#[allow(clippy::too_many_arguments)] -pub async fn upsert_instance( - library_id: Uuid, - db: &PrismaClient, - sync: &sd_core_sync::Manager, - libraries: &Libraries, - uuid: &Uuid, - identity: RemoteIdentity, - node_id: &Uuid, - node_remote_identity: RemoteIdentity, - metadata: HashMap, -) -> prisma_client_rust::Result<()> { - db.instance() - .upsert( - instance::pub_id::equals(uuid_to_bytes(uuid)), - instance::create( - uuid_to_bytes(uuid), - identity.get_bytes().to_vec(), - node_id.as_bytes().to_vec(), - Utc::now().into(), - Utc::now().into(), - vec![ - instance::node_remote_identity::set(Some( - node_remote_identity.get_bytes().to_vec(), - )), - instance::metadata::set(Some( - serde_json::to_vec(&metadata).expect("unable to serialize metadata"), - )), - ], - ), - vec![], - ) - .exec() - .await?; - - sync.timestamps.write().await.entry(*uuid).or_default(); - - // Called again so the new instances are picked up - libraries.update_instances_by_id(library_id).await; - - Ok(()) -} diff --git a/core/src/cloud/sync/send.rs b/core/src/cloud/sync/send.rs deleted file mode 100644 index 82a049f49..000000000 --- a/core/src/cloud/sync/send.rs +++ /dev/null @@ -1,157 +0,0 @@ -use futures::FutureExt; -use futures_concurrency::future::Race; -use sd_core_sync::{SyncMessage, NTP64}; - -use sd_actors::Stopper; -use sd_cloud_api::RequestConfigProvider; - -use std::{ - future::IntoFuture, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; - -use tokio::{ - sync::{broadcast, Notify}, - time::sleep, -}; -use tracing::debug; -use uuid::Uuid; - -use super::{err_break, CompressedCRDTOperations}; - -enum RaceNotifiedOrStopped { - Notified, - Stopped, -} - -pub async fn run_actor( - library_id: Uuid, - sync: Arc, - cloud_api_config_provider: Arc, - state: Arc, - state_notify: Arc, - stop: Stopper, -) { - loop { - state.store(true, Ordering::Relaxed); - state_notify.notify_waiters(); - - loop { - // all available instances will have a default timestamp from create_instance - let instances = sync - .timestamps - .read() - .await - .keys() - .cloned() - .collect::>(); - - // obtains a lock on the timestamp collections for the instances we have - let req_adds = err_break!( - sd_cloud_api::library::message_collections::request_add( - cloud_api_config_provider.get_request_config().await, - library_id, - instances, - ) - .await - ); - - let mut instances = vec![]; - - use sd_cloud_api::library::message_collections::do_add; - - debug!( - total_operations = req_adds.len(), - "Preparing to send instance's operations to cloud;" - ); - - // gets new operations for each instance to send to cloud - for req_add in req_adds { - let ops = err_break!( - sync.get_instance_ops( - 1000, - req_add.instance_uuid, - NTP64( - req_add - .from_time - .unwrap_or_else(|| "0".to_string()) - .parse() - .expect("couldn't parse ntp64 value"), - ) - ) - .await - ); - - if ops.is_empty() { - continue; - } - - let start_time = ops[0].timestamp.0.to_string(); - let end_time = ops[ops.len() - 1].timestamp.0.to_string(); - - let ops_len = ops.len(); - - use base64::prelude::*; - - debug!(instance_id = %req_add.instance_uuid, %start_time, %end_time); - - instances.push(do_add::Input { - uuid: req_add.instance_uuid, - key: req_add.key, - start_time, - end_time, - contents: BASE64_STANDARD.encode( - rmp_serde::to_vec_named(&CompressedCRDTOperations::new(ops)) - .expect("CompressedCRDTOperation should serialize!"), - ), - ops_count: ops_len, - }) - } - - if instances.is_empty() { - break; - } - - // uses lock we acquired earlier to send the operations to the cloud - err_break!( - do_add( - cloud_api_config_provider.get_request_config().await, - library_id, - instances, - ) - .await - ); - } - - state.store(false, Ordering::Relaxed); - state_notify.notify_waiters(); - - if let RaceNotifiedOrStopped::Stopped = ( - // recreate subscription each time so that existing messages are dropped - wait_notification(sync.subscribe()), - stop.into_future().map(|()| RaceNotifiedOrStopped::Stopped), - ) - .race() - .await - { - break; - } - - sleep(Duration::from_millis(1000)).await; - } -} - -async fn wait_notification(mut rx: broadcast::Receiver) -> RaceNotifiedOrStopped { - // wait until Created message comes in - loop { - if let Ok(SyncMessage::Created) = rx.recv().await { - break; - }; - } - - RaceNotifiedOrStopped::Notified -} diff --git a/core/src/context.rs b/core/src/context.rs index 217acd54b..519034d30 100644 --- a/core/src/context.rs +++ b/core/src/context.rs @@ -4,6 +4,7 @@ use sd_core_heavy_lifting::{ job_system::report::{Report, Status}, OuterContext, ProgressUpdate, UpdateEvent, }; +use sd_core_sync::SyncManager; use std::{ ops::{Deref, DerefMut}, @@ -49,7 +50,7 @@ impl OuterContext for NodeContext { &self.library.db } - fn sync(&self) -> &Arc { + fn sync(&self) -> &SyncManager { &self.library.sync } @@ -96,7 +97,7 @@ impl OuterContext for JobContext &Arc { + fn sync(&self) -> &SyncManager { self.outer_ctx.sync() } @@ -191,7 +192,7 @@ impl sd_core_heavy_lifting::JobContext< spawn({ let db = Arc::clone(&library.db); - let mut report = report.clone(); + let report = report.clone(); async move { if let Err(e) = report.update(&db).await { error!( diff --git a/core/src/env.rs b/core/src/env.rs deleted file mode 100644 index 426c8fca9..000000000 --- a/core/src/env.rs +++ /dev/null @@ -1,15 +0,0 @@ -use tokio::sync::Mutex; - -pub struct Env { - pub api_url: Mutex, - pub client_id: String, -} - -impl Env { - pub fn new(client_id: &str) -> Self { - Self { - api_url: Mutex::new("https://app.spacedrive.com".to_string()), - client_id: client_id.to_string(), - } - } -} diff --git a/core/src/lib.rs b/core/src/lib.rs index 5be82ede2..ee49f8bac 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -6,27 +6,27 @@ use crate::{ location::LocationManagerError, }; +use sd_core_cloud_services::CloudServices; use sd_core_heavy_lifting::{media_processor::ThumbnailKind, JobSystem}; use sd_core_prisma_helpers::CasId; -#[cfg(feature = "ai")] -use sd_ai::old_image_labeler::{DownloadModelError, OldImageLabeler, YoloV8}; - +use sd_crypto::CryptoRng; use sd_task_system::TaskSystem; use sd_utils::error::FileIOError; -use volume::save_storage_statistics; use std::{ fmt, path::{Path, PathBuf}, - sync::{atomic::AtomicBool, Arc}, + sync::Arc, }; use chrono::{DateTime, Utc}; use futures_concurrency::future::Join; -use reqwest::{RequestBuilder, Response}; use thiserror::Error; -use tokio::{fs, io, sync::broadcast}; +use tokio::{ + fs, io, + sync::{broadcast, Mutex}, +}; use tracing::{error, info, warn}; use tracing_appender::{ non_blocking::{NonBlocking, WorkerGuard}, @@ -37,10 +37,8 @@ use tracing_subscriber::{ }; pub mod api; -mod cloud; mod context; pub mod custom_uri; -mod env; pub mod library; pub(crate) mod location; pub(crate) mod node; @@ -53,14 +51,12 @@ pub(crate) mod preferences; pub mod util; pub(crate) mod volume; -pub use env::Env; - use api::notifications::{Notification, NotificationData, NotificationId}; use context::{JobContext, NodeContext}; use node::config; use notifications::Notifications; - -pub(crate) use sd_core_sync as sync; +use sd_core_cloud_services::AUTH_SERVER_URL; +use volume::save_storage_statistics; /// Represents a single running instance of the Spacedrive core. /// Holds references to all the services that make up the Spacedrive core. @@ -73,13 +69,12 @@ pub struct Node { pub p2p: Arc, pub event_bus: (broadcast::Sender, broadcast::Receiver), pub notifications: Notifications, - pub cloud_sync_flag: Arc, - pub env: Arc, - pub http: reqwest::Client, pub task_system: TaskSystem, pub job_system: JobSystem>, - #[cfg(feature = "ai")] - pub old_image_labeller: Option, + pub cloud_services: Arc, + /// This should only be used to generate the seed of local instances of [`CryptoRng`]. + /// Don't use this as a common RNG, it will fuck up Core's performance due to this Mutex. + pub master_rng: Arc>, } impl fmt::Debug for Node { @@ -91,16 +86,11 @@ impl fmt::Debug for Node { } impl Node { - pub async fn new( - data_dir: impl AsRef, - env: env::Env, - ) -> Result<(Arc, Arc), NodeError> { + pub async fn new(data_dir: impl AsRef) -> Result<(Arc, Arc), NodeError> { let data_dir = data_dir.as_ref(); info!(data_directory = %data_dir.display(), "Starting core;"); - let env = Arc::new(env); - #[cfg(debug_assertions)] let init_data = util::debug_initializer::InitConfig::load(data_dir).await?; @@ -112,20 +102,52 @@ impl Node { .await .map_err(NodeError::FailedToInitializeConfig)?; - if let Some(url) = config.get().await.sd_api_origin { - *env.api_url.lock().await = url; - } - - #[cfg(feature = "ai")] - let image_labeler_version = { - sd_ai::init()?; - config.get().await.image_labeler_version - }; - let (locations, locations_actor) = location::Locations::new(); let (old_jobs, jobs_actor) = old_job::OldJobs::new(); let libraries = library::Libraries::new(data_dir.join("libraries")).await?; + let ( + get_cloud_api_address, + cloud_p2p_relay_url, + cloud_p2p_dns_origin_name, + cloud_p2p_dns_pkarr_url, + cloud_services_domain_name, + ) = { + #[cfg(debug_assertions)] + { + ( + std::env::var("SD_CLOUD_API_ADDRESS_URL").unwrap_or_else(|_| { + format!("{AUTH_SERVER_URL}/cloud-api-address").to_string() + }), + std::env::var("SD_CLOUD_P2P_RELAY_URL") + // .unwrap_or_else(|_| "https://use1-1.relay.iroh.network/".to_string()), + // .unwrap_or_else(|_| "http://localhost:8081/".to_string()), + .unwrap_or_else(|_| "https://relay.spacedrive.com:4433/".to_string()), + std::env::var("SD_CLOUD_P2P_DNS_ORIGIN_NAME") + // .unwrap_or_else(|_| "dns.iroh.link/".to_string()), + // .unwrap_or_else(|_| "irohdns.localhost".to_string()), + .unwrap_or_else(|_| "irohdns.spacedrive.com".to_string()), + std::env::var("SD_CLOUD_P2P_DNS_PKARR_URL") + // .unwrap_or_else(|_| "https://dns.iroh.link/pkarr".to_string()), + // .unwrap_or_else(|_| "http://localhost:8080/pkarr".to_string()), + .unwrap_or_else(|_| "https://irohdns.spacedrive.com/pkarr".to_string()), + std::env::var("SD_CLOUD_API_DOMAIN_NAME") + // .unwrap_or_else(|_| "localhost".to_string()), + .unwrap_or_else(|_| "cloud.spacedrive.com".to_string()), + ) + } + #[cfg(not(debug_assertions))] + { + ( + "https://auth.spacedrive.com/cloud-api-address".to_string(), + "https://relay.spacedrive.com/".to_string(), + "irohdns.spacedrive.com".to_string(), + "irohdns.spacedrive.com/pkarr".to_string(), + "api.spacedrive.com".to_string(), + ) + } + }; + let task_system = TaskSystem::new(); let (p2p, start_p2p) = p2p::P2PManager::new(config.clone(), libraries.clone()) @@ -142,31 +164,19 @@ impl Node { config, event_bus, libraries, - cloud_sync_flag: Arc::new(AtomicBool::new( - cfg!(target_os = "ios") || cfg!(target_os = "android"), - )), - http: reqwest::Client::new(), - env, - #[cfg(feature = "ai")] - old_image_labeller: OldImageLabeler::new( - YoloV8::model(image_labeler_version)?, - data_dir, - ) - .await - .map_err(|e| { - error!( - ?e, - "Failed to initialize image labeller. AI features will be disabled;" - ); - }) - .ok(), + cloud_services: Arc::new( + CloudServices::new( + &get_cloud_api_address, + cloud_p2p_relay_url, + cloud_p2p_dns_pkarr_url, + cloud_p2p_dns_origin_name, + cloud_services_domain_name, + ) + .await?, + ), + master_rng: Arc::new(Mutex::new(CryptoRng::new()?)), }); - // Restore backend feature flags - for feature in node.config.get().await.features { - feature.restore(&node); - } - // Setup start actors that depend on the `Node` #[cfg(debug_assertions)] if let Some(init_data) = init_data { @@ -248,6 +258,7 @@ impl Node { "RUST_LOG", format!( "info,\ + iroh_net=info,\ sd_core={level},\ sd_p2p={level},\ sd_core_heavy_lifting={level},\ @@ -320,10 +331,6 @@ impl Node { .join() .await; - #[cfg(feature = "ai")] - if let Some(image_labeller) = &self.old_image_labeller { - image_labeller.shutdown().await; - } info!("Spacedrive Core shutdown successful!"); } @@ -368,55 +375,6 @@ impl Node { } } } - - pub async fn add_auth_header(&self, mut req: RequestBuilder) -> RequestBuilder { - if let Some(auth_token) = self.config.get().await.auth_token { - req = req.header("authorization", auth_token.to_header()); - }; - - req - } - - pub async fn authed_api_request(&self, req: RequestBuilder) -> Result { - let Some(auth_token) = self.config.get().await.auth_token else { - return Err(rspc::Error::new( - rspc::ErrorCode::Unauthorized, - "No auth token".to_string(), - )); - }; - - let req = req.header("authorization", auth_token.to_header()); - - req.send().await.map_err(|_| { - rspc::Error::new( - rspc::ErrorCode::InternalServerError, - "Request failed".to_string(), - ) - }) - } - - pub async fn api_request(&self, req: RequestBuilder) -> Result { - req.send().await.map_err(|_| { - rspc::Error::new( - rspc::ErrorCode::InternalServerError, - "Request failed".to_string(), - ) - }) - } - - pub async fn cloud_api_config(&self) -> sd_cloud_api::RequestConfig { - sd_cloud_api::RequestConfig { - client: self.http.clone(), - api_url: self.env.api_url.lock().await.clone(), - auth_token: self.config.get().await.auth_token, - } - } -} - -impl sd_cloud_api::RequestConfigProvider for Node { - async fn get_request_config(self: &Arc) -> sd_cloud_api::RequestConfig { - Node::cloud_api_config(self).await - } } /// Error type for Node related errors. @@ -439,11 +397,8 @@ pub enum NodeError { Logger(#[from] FromEnvError), #[error(transparent)] JobSystem(#[from] sd_core_heavy_lifting::JobSystemError), - - #[cfg(feature = "ai")] - #[error("ai error: {0}")] - AI(#[from] sd_ai::Error), - #[cfg(feature = "ai")] - #[error("Failed to download model: {0}")] - DownloadModel(#[from] DownloadModelError), + #[error(transparent)] + CloudServices(#[from] sd_core_cloud_services::Error), + #[error(transparent)] + Crypto(#[from] sd_crypto::Error), } diff --git a/core/src/library/config.rs b/core/src/library/config.rs index 863f744f9..20c245d10 100644 --- a/core/src/library/config.rs +++ b/core/src/library/config.rs @@ -4,15 +4,14 @@ use crate::{ }; use sd_p2p::{Identity, RemoteIdentity}; -use sd_prisma::prisma::{file_path, indexer_rule, instance, location, node, PrismaClient}; +use sd_prisma::prisma::{file_path, indexer_rule, instance, location, PrismaClient}; use sd_utils::{db::maybe_missing, error::FileIOError}; use std::{ - path::Path, + path::{Path, PathBuf}, sync::{atomic::AtomicBool, Arc}, }; -use chrono::Utc; use int_enum::IntEnum; use prisma_client_rust::not; use serde::{Deserialize, Serialize}; @@ -44,6 +43,9 @@ pub struct LibraryConfig { #[serde(default)] pub generate_sync_operations: Arc, version: LibraryConfigVersion, + + #[serde(skip, default)] + pub config_path: PathBuf, } #[derive( @@ -88,7 +90,6 @@ impl LibraryConfig { description: Option, instance_id: i32, path: impl AsRef, - generate_sync_operations: bool, ) -> Result { let this = Self { name, @@ -96,8 +97,8 @@ impl LibraryConfig { instance_id, version: Self::LATEST_VERSION, cloud_id: None, - // will always be `true` eventually - generate_sync_operations: Arc::new(AtomicBool::new(generate_sync_operations)), + generate_sync_operations: Arc::new(AtomicBool::new(false)), + config_path: path.as_ref().to_path_buf(), }; this.save(path).await.map(|()| this) @@ -105,12 +106,12 @@ impl LibraryConfig { pub(crate) async fn load( path: impl AsRef, - node_config: &NodeConfig, + _node_config: &NodeConfig, db: &PrismaClient, ) -> Result { let path = path.as_ref(); - VersionManager::::migrate_and_load( + let mut loaded_config = VersionManager::::migrate_and_load( path, |current, next| async move { match (current, next) { @@ -167,34 +168,8 @@ impl LibraryConfig { } (LibraryConfigVersion::V2, LibraryConfigVersion::V3) => { - // The fact I have to migrate this hurts my soul - if db.node().count(vec![]).exec().await? != 1 { - return Err(LibraryConfigError::TooManyNodes); - } - - db.node() - .update_many( - vec![], - vec![node::pub_id::set(node_config.id.as_bytes().to_vec())], - ) - .exec() - .await?; - - let mut config = serde_json::from_slice::>( - &fs::read(path).await.map_err(|e| { - VersionManagerError::FileIO(FileIOError::from((path, e))) - })?, - ) - .map_err(VersionManagerError::SerdeJson)?; - - config.insert(String::from("node_id"), json!(node_config.id.to_string())); - - fs::write( - path, - &serde_json::to_vec(&config).map_err(VersionManagerError::SerdeJson)?, - ) - .await - .map_err(|e| VersionManagerError::FileIO(FileIOError::from((path, e))))?; + // Removed, can't be automatically updated + return Err(LibraryConfigError::CriticalUpdateError); } (LibraryConfigVersion::V3, LibraryConfigVersion::V4) => { @@ -255,51 +230,8 @@ impl LibraryConfig { }, (LibraryConfigVersion::V5, LibraryConfigVersion::V6) => { - let nodes = db.node().find_many(vec![]).exec().await?; - if nodes.is_empty() { - error!("6 - No nodes found... How did you even get this far? but this is fine we can fix it."); - } else if nodes.len() > 1 { - error!("6 - More than one node found in the DB... This can't be automatically reconciled!"); - return Err(LibraryConfigError::TooManyNodes); - } - - let node = nodes.first(); - let now = Utc::now().fixed_offset(); - let instance_id = Uuid::new_v4(); - - instance::Create { - pub_id: instance_id.as_bytes().to_vec(), - // WARNING: At this stage in the migration this field *should* be an `Identity` not a `RemoteIdentityOrIdentity` (as that was introduced later on). - remote_identity: node - .and_then(|n| n.identity.clone()) - .unwrap_or_else(|| Identity::new().to_bytes()), - node_id: node_config.id.as_bytes().to_vec(), - last_seen: now, - date_created: node.map(|n| n.date_created).unwrap_or_else(|| now), - _params: vec![], - } - .to_query(db) - .exec() - .await?; - - let mut config = serde_json::from_slice::>( - &fs::read(path).await.map_err(|e| { - VersionManagerError::FileIO(FileIOError::from((path, e))) - })?, - ) - .map_err(VersionManagerError::SerdeJson)?; - - config.remove("node_id"); - config.remove("identity"); - - config.insert(String::from("instance_id"), json!(instance_id.to_string())); - - fs::write( - path, - &serde_json::to_vec(&config).map_err(VersionManagerError::SerdeJson)?, - ) - .await - .map_err(|e| VersionManagerError::FileIO(FileIOError::from((path, e))))?; + // Removed, can't be automatically updated + return Err(LibraryConfigError::CriticalUpdateError); } (LibraryConfigVersion::V6, LibraryConfigVersion::V7) => { @@ -344,7 +276,7 @@ impl LibraryConfig { } (LibraryConfigVersion::V7, LibraryConfigVersion::V8) => { - let instances = db.instance().find_many(vec![]).exec().await?; + let instances = db.device().find_many(vec![]).exec().await?; let Some(instance) = instances.first() else { error!("8 - No nodes found... How did you even get this far?!"); return Err(LibraryConfigError::MissingInstance); @@ -477,7 +409,11 @@ impl LibraryConfig { Ok(()) }, ) - .await + .await?; + + loaded_config.config_path = path.to_path_buf(); + + Ok(loaded_config) } pub(crate) async fn save(&self, path: impl AsRef) -> Result<(), LibraryConfigError> { @@ -498,6 +434,8 @@ pub enum LibraryConfigError { TooManyInstances, #[error("missing instances")] MissingInstance, + #[error("your library version can't be automatically updated, please recreate your library")] + CriticalUpdateError, #[error(transparent)] SerdeJson(#[from] serde_json::Error), diff --git a/core/src/library/library.rs b/core/src/library/library.rs index 795714ab8..f4e284f8a 100644 --- a/core/src/library/library.rs +++ b/core/src/library/library.rs @@ -1,9 +1,14 @@ -use crate::{api::CoreEvent, cloud, sync, Node}; +use crate::{api::CoreEvent, Node}; +use sd_core_cloud_services::{declare_cloud_sync, CloudSyncActors, CloudSyncActorsState}; use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_heavy_lifting::media_processor::ThumbnailKind; use sd_core_prisma_helpers::{file_path_to_full_path, CasId}; +use sd_core_sync::{backfill::backfill_operations, SyncManager}; +use sd_actors::ActorsCollection; +use sd_cloud_schema::sync::groups; +use sd_crypto::{CryptoRng, SeedableRng}; use sd_p2p::Identity; use sd_prisma::prisma::{file_path, location, PrismaClient}; use sd_utils::{db::maybe_missing, error::FileIOError}; @@ -12,23 +17,16 @@ use std::{ collections::HashMap, fmt::{Debug, Formatter}, path::{Path, PathBuf}, - sync::Arc, + sync::{atomic::Ordering, Arc}, }; +use futures_concurrency::future::Join; use tokio::{fs, io, sync::broadcast, sync::RwLock}; -use tracing::warn; +use tracing::{debug, warn}; use uuid::Uuid; use super::{LibraryConfig, LibraryManagerError}; -// TODO: Finish this -// pub enum LibraryNew { -// InitialSync, -// Encrypted, -// Loaded(LoadedLibrary), -// Deleting, -// } - pub struct Library { /// id holds the ID of the current library. pub id: Uuid, @@ -37,8 +35,8 @@ pub struct Library { config: RwLock, /// db holds the database client for the current library. pub db: Arc, - pub sync: Arc, - pub cloud: cloud::State, + pub sync: SyncManager, + /// key manager that provides encryption keys to functions that require them // pub key_manager: Arc, /// p2p identity @@ -47,14 +45,12 @@ pub struct Library { // The UUID which matches `config.instance_id`'s primary key. pub instance_uuid: Uuid, - do_cloud_sync: broadcast::Sender<()>, - pub env: Arc, - // Look, I think this shouldn't be here but our current invalidation system needs it. // TODO(@Oscar): Get rid of this with the new invalidation system. event_bus_tx: broadcast::Sender, - pub actors: Arc, + pub cloud_sync_state: CloudSyncActorsState, + pub cloud_sync_actors: ActorsCollection, } impl Debug for Library { @@ -71,7 +67,6 @@ impl Debug for Library { } impl Library { - #[allow(clippy::too_many_arguments)] pub async fn new( id: Uuid, config: LibraryConfig, @@ -79,28 +74,66 @@ impl Library { identity: Arc, db: Arc, node: &Arc, - sync: Arc, - cloud: cloud::State, - do_cloud_sync: broadcast::Sender<()>, - actors: Arc, + sync: SyncManager, ) -> Arc { Arc::new(Self { id, config: RwLock::new(config), sync, - cloud, db: db.clone(), - // key_manager, identity, // orphan_remover: OrphanRemoverActor::spawn(db), instance_uuid, - do_cloud_sync, - env: node.env.clone(), event_bus_tx: node.event_bus.0.clone(), - actors, + cloud_sync_state: CloudSyncActorsState::default(), + cloud_sync_actors: ActorsCollection::default(), }) } + pub async fn init_cloud_sync( + &self, + node: &Node, + sync_group_pub_id: groups::PubId, + ) -> Result<(), LibraryManagerError> { + let rng = CryptoRng::from_seed(node.master_rng.lock().await.generate_fixed()); + + self.update_config(|config| { + config + .generate_sync_operations + .store(true, Ordering::Relaxed) + }) + .await?; + + // If this library doesn't have any sync operations, it means that it had sync activated + // for the first time, so we need to backfill the operations from existing db data + if self.db.crdt_operation().count(vec![]).exec().await? == 0 { + backfill_operations(&self.sync).await?; + } + + declare_cloud_sync( + node.data_dir.clone().into_boxed_path(), + node.cloud_services.clone(), + &self.cloud_sync_actors, + &self.cloud_sync_state, + sync_group_pub_id, + self.sync.clone(), + rng, + ) + .await?; + + ( + self.cloud_sync_actors.start(CloudSyncActors::Sender), + self.cloud_sync_actors.start(CloudSyncActors::Receiver), + self.cloud_sync_actors.start(CloudSyncActors::Ingester), + ) + .join() + .await; + + debug!(library_id = %self.id, "Started cloud sync actors"); + + Ok(()) + } + pub async fn config(&self) -> LibraryConfig { self.config.read().await.clone() } @@ -108,13 +141,12 @@ impl Library { pub async fn update_config( &self, update_fn: impl FnOnce(&mut LibraryConfig), - config_path: impl AsRef, ) -> Result<(), LibraryManagerError> { let mut config = self.config.write().await; update_fn(&mut config); - config.save(config_path).await.map_err(Into::into) + config.save(&config.config_path).await.map_err(Into::into) } // TODO: Remove this once we replace the old invalidation system @@ -183,10 +215,4 @@ impl Library { Ok(out) } - - pub fn do_cloud_sync(&self) { - if let Err(e) = self.do_cloud_sync.send(()) { - warn!(?e, "Error sending cloud resync message;"); - } - } } diff --git a/core/src/library/manager/error.rs b/core/src/library/manager/error.rs index 3541eabfd..4fc01dd4e 100644 --- a/core/src/library/manager/error.rs +++ b/core/src/library/manager/error.rs @@ -1,6 +1,7 @@ use crate::{library::LibraryConfigError, location::LocationManagerError}; use sd_core_indexer_rules::seed::SeederError; +use sd_core_sync::DevicePubId; use sd_p2p::IdentityErr; use sd_utils::{ @@ -8,10 +9,9 @@ use sd_utils::{ error::{FileIOError, NonUtf8PathError}, }; -use thiserror::Error; use tracing::error; -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum LibraryManagerError { #[error("error serializing or deserializing the JSON in the config file: {0}")] Json(#[from] serde_json::Error), @@ -23,8 +23,6 @@ pub enum LibraryManagerError { Uuid(#[from] uuid::Error), #[error("failed to run indexer rules seeder: {0}")] IndexerRulesSeeder(#[from] SeederError), - // #[error("failed to initialize the key manager: {0}")] - // KeyManager(#[from] sd_crypto::Error), #[error("error migrating the library: {0}")] MigrationError(#[from] db::MigrationError), #[error("invalid library configuration: {0}")] @@ -39,6 +37,8 @@ pub enum LibraryManagerError { InvalidIdentity, #[error("current instance with id '{0}' was not found in the database")] CurrentInstanceNotFound(String), + #[error("current device with pub id '{0}' was not found in the database")] + CurrentDeviceNotFound(DevicePubId), #[error("missing-field: {0}")] MissingField(#[from] MissingFieldError), @@ -47,6 +47,8 @@ pub enum LibraryManagerError { #[error(transparent)] LibraryConfig(#[from] LibraryConfigError), #[error(transparent)] + CloudServices(#[from] sd_core_cloud_services::Error), + #[error(transparent)] Sync(#[from] sd_core_sync::Error), } diff --git a/core/src/library/manager/mod.rs b/core/src/library/manager/mod.rs index 62f786bd6..2cd6652c3 100644 --- a/core/src/library/manager/mod.rs +++ b/core/src/library/manager/mod.rs @@ -1,16 +1,21 @@ use crate::{ api::{utils::InvalidateOperationEvent, CoreEvent}, - cloud, invalidate_query, + invalidate_query, location::metadata::{LocationMetadataError, SpacedriveLocationMetadataFile}, object::tag, - p2p, sync, + p2p, util::{mpscrr, MaybeUndefined}, Node, }; -use sd_core_sync::SyncMessage; +use sd_core_sync::{SyncEvent, SyncManager}; + use sd_p2p::{Identity, RemoteIdentity}; -use sd_prisma::prisma::{instance, location}; +use sd_prisma::{ + prisma::{self, device, instance, location, PrismaClient}, + prisma_sync, +}; +use sd_sync::ModelId; use sd_utils::{ db, error::{FileIOError, NonUtf8PathError}, @@ -24,26 +29,26 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, - time::Duration, }; use chrono::Utc; use futures_concurrency::future::{Join, TryJoin}; +use prisma_client_rust::Raw; use tokio::{ fs, io, spawn, sync::{broadcast, RwLock}, - time::sleep, }; use tracing::{debug, error, info, instrument, warn}; use uuid::Uuid; -use super::pragmas::configure_pragmas; use super::{Library, LibraryConfig, LibraryName}; mod error; pub mod pragmas; +use pragmas::configure_pragmas; + pub use error::*; /// Event that is emitted to subscribers of the library manager. @@ -136,7 +141,7 @@ impl Libraries { } let _library_arc = self - .load(library_id, &db_path, config_path, None, true, node) + .load(library_id, &db_path, config_path, None, None, true, node) .await?; // FIX-ME: Linux releases crashes with *** stack smashing detected *** if spawn_volume_watcher is enabled @@ -159,12 +164,11 @@ impl Libraries { description: Option, node: &Arc, ) -> Result, LibraryManagerError> { - self.create_with_uuid(Uuid::new_v4(), name, description, true, None, node, false) + self.create_with_uuid(Uuid::now_v7(), name, description, true, None, node) .await } #[instrument(skip(self, instance, node), err)] - #[allow(clippy::too_many_arguments)] pub(crate) async fn create_with_uuid( self: &Arc, id: Uuid, @@ -174,7 +178,6 @@ impl Libraries { // `None` will fallback to default as library must be created with at least one instance instance: Option, node: &Arc, - generate_sync_operations: bool, ) -> Result, LibraryManagerError> { if name.as_ref().is_empty() || name.as_ref().chars().all(|x| x.is_whitespace()) { return Err(LibraryManagerError::InvalidConfig( @@ -190,7 +193,6 @@ impl Libraries { // First instance will be zero 0, &config_path, - generate_sync_operations, ) .await?; @@ -206,12 +208,21 @@ impl Libraries { id, self.libraries_dir.join(format!("{id}.db")), config_path, + Some(device::Create { + pub_id: node_cfg.id.to_db(), + _params: vec![ + device::name::set(Some(node_cfg.name.clone())), + device::os::set(Some(node_cfg.os as i32)), + device::hardware_model::set(Some(node_cfg.hardware_model as i32)), + device::date_created::set(Some(now)), + ], + }), Some({ let identity = Identity::new(); let mut create = instance.unwrap_or_else(|| instance::Create { - pub_id: Uuid::new_v4().as_bytes().to_vec(), + pub_id: Uuid::now_v7().as_bytes().to_vec(), remote_identity: identity.to_remote_identity().get_bytes().to_vec(), - node_id: node_cfg.id.as_bytes().to_vec(), + node_id: node_cfg.id.to_db(), last_seen: now, date_created: now, _params: vec![ @@ -270,33 +281,28 @@ impl Libraries { ); library - .update_config( - |config| { - // update the library - if let Some(name) = name { - config.name = name; - } - match description { - MaybeUndefined::Undefined => {} - MaybeUndefined::Null => config.description = None, - MaybeUndefined::Value(description) => { - config.description = Some(description) - } - } - match cloud_id { - MaybeUndefined::Undefined => {} - MaybeUndefined::Null => config.cloud_id = None, - MaybeUndefined::Value(cloud_id) => config.cloud_id = Some(cloud_id), - } - match enable_sync { - None => {} - Some(value) => config - .generate_sync_operations - .store(value, Ordering::SeqCst), - } - }, - self.libraries_dir.join(format!("{id}.sdlibrary")), - ) + .update_config(|config| { + // update the library + if let Some(name) = name { + config.name = name; + } + match description { + MaybeUndefined::Undefined => {} + MaybeUndefined::Null => config.description = None, + MaybeUndefined::Value(description) => config.description = Some(description), + } + match cloud_id { + MaybeUndefined::Undefined => {} + MaybeUndefined::Null => config.cloud_id = None, + MaybeUndefined::Value(cloud_id) => config.cloud_id = Some(cloud_id), + } + match enable_sync { + None => {} + Some(value) => config + .generate_sync_operations + .store(value, Ordering::SeqCst), + } + }) .await?; self.tx @@ -425,6 +431,7 @@ impl Libraries { self.libraries.read().await.get(library_id).is_some() } + #[allow(clippy::too_many_arguments)] // TODO: remove this when we remove instance stuff #[instrument( skip_all, fields( @@ -441,7 +448,8 @@ impl Libraries { id: Uuid, db_path: impl AsRef, config_path: impl AsRef, - create: Option, + maybe_create_device: Option, + maybe_create_instance: Option, // Deprecated should_seed: bool, node: &Arc, ) -> Result, LibraryManagerError> { @@ -456,11 +464,21 @@ impl Libraries { ); let db = Arc::new(db::load_and_migrate(&db_url).await?); - if let Some(create) = create { + // Configure database + configure_pragmas(&db).await?; + special_sync_indexes(&db).await?; + + if let Some(create) = maybe_create_device { + create.to_query(&db).exec().await?; + } + + // TODO: remove instances from locations + if let Some(create) = maybe_create_instance { create.to_query(&db).exec().await?; } let node_config = node.config.get().await; + let device_pub_id = node_config.id.clone(); let config = LibraryConfig::load(config_path, &node_config, &db).await?; let instances = db.instance().find_many(vec![]).exec().await?; @@ -473,6 +491,16 @@ impl Libraries { })? .clone(); + let devices = db.device().find_many(vec![]).exec().await?; + + let device_pub_id_to_db = device_pub_id.to_db(); + if !devices + .iter() + .any(|device| device.pub_id == device_pub_id_to_db) + { + return Err(LibraryManagerError::CurrentDeviceNotFound(device_pub_id)); + } + let identity = match instance.identity.as_ref() { Some(b) => Arc::new(Identity::from_bytes(b)?), // We are not this instance, so we don't have the private key. @@ -489,7 +517,7 @@ impl Libraries { .node_remote_identity .as_ref() .and_then(|v| RemoteIdentity::from_bytes(v).ok()); - if instance_node_id != node_config.id + if instance_node_id != Uuid::from(&node_config.id) || instance_node_remote_identity != Some(node_config.identity.to_remote_identity()) || curr_metadata != Some(node.p2p.peer_metadata()) { @@ -505,7 +533,7 @@ impl Libraries { .update( instance::id::equals(instance.id), vec![ - instance::node_id::set(node_config.id.as_bytes().to_vec()), + instance::node_id::set(node_config.id.to_db()), instance::node_remote_identity::set(Some( node_config .identity @@ -519,47 +547,22 @@ impl Libraries { )), ], ) + .select(instance::select!({ id })) .exec() .await?; } // TODO: Move this reconciliation into P2P and do reconciliation of both local and remote nodes. - // let key_manager = Arc::new(KeyManager::new(vec![]).await?); - // seed_keymanager(&db, &key_manager).await?; - - let actors = Default::default(); - - let (sync, sync_rx) = sync::Manager::with_existing_instances( + let (sync, sync_rx) = SyncManager::with_existing_devices( Arc::clone(&db), - instance_id, + &device_pub_id, Arc::clone(&config.generate_sync_operations), - &instances, - Arc::clone(&actors), + &devices, ) .await?; - let sync_manager = Arc::new(sync); - // Configure database - configure_pragmas(&db).await?; - - let cloud = crate::cloud::start(node, &actors, id, instance_id, &sync_manager, &db).await; - - let (tx, mut rx) = broadcast::channel(10); - let library = Library::new( - id, - config, - instance_id, - identity, - // key_manager, - db, - node, - sync_manager, - cloud, - tx, - actors, - ) - .await; + let library = Library::new(id, config, instance_id, identity, db, node, sync).await; // This is an exception. Generally subscribe to this by `self.tx.subscribe`. spawn(sync_rx_actor(library.clone(), node.clone(), sync_rx)); @@ -597,127 +600,6 @@ impl Libraries { error!(?e, "Failed to resume jobs for library;"); } - spawn({ - let this = self.clone(); - let node = node.clone(); - let library = library.clone(); - async move { - loop { - debug!("Syncing library with cloud!"); - - if library.config().await.cloud_id.is_some() { - if let Ok(lib) = - sd_cloud_api::library::get(node.cloud_api_config().await, library.id) - .await - { - match lib { - Some(lib) => { - if let Some(this_instance) = lib - .instances - .iter() - .find(|i| i.uuid == library.instance_uuid) - { - let node_config = node.config.get().await; - let curr_metadata: Option> = - instance.metadata.as_ref().map(|metadata| { - serde_json::from_slice(metadata) - .expect("invalid metadata") - }); - let should_update = this_instance.node_id != node_config.id - || RemoteIdentity::from_str( - &this_instance.node_remote_identity, - ) - .ok() != Some( - node_config.identity.to_remote_identity(), - ) || curr_metadata - != Some(node.p2p.peer_metadata()); - - if should_update { - warn!("Library instance on cloud is outdated. Updating..."); - - if let Err(e) = sd_cloud_api::library::update_instance( - node.cloud_api_config().await, - library.id, - this_instance.uuid, - Some(node_config.id), - Some(node_config.identity.to_remote_identity()), - Some(node.p2p.peer_metadata()), - ) - .await - { - error!( - instance_uuid = %this_instance.uuid, - ?e, - "Failed to updating instance on cloud;", - ); - } - } - } - - if lib.name != *library.config().await.name { - warn!("Library name on cloud is outdated. Updating..."); - - if let Err(e) = sd_cloud_api::library::update( - node.cloud_api_config().await, - library.id, - Some(lib.name), - ) - .await - { - error!(?e, "Failed to update library name on cloud;"); - } - } - - for instance in lib.instances { - if let Err(e) = cloud::sync::receive::upsert_instance( - library.id, - &library.db, - &library.sync, - &node.libraries, - &instance.uuid, - instance.identity, - &instance.node_id, - RemoteIdentity::from_str( - &instance.node_remote_identity, - ) - .expect("malformed remote identity from API"), - instance.metadata, - ) - .await - { - error!(?e, "Failed to create instance on cloud;"); - } - } - } - None => { - warn!( - "Library not found on cloud. Removing from local node..." - ); - - let _ = this - .edit( - library.id, - None, - MaybeUndefined::Undefined, - MaybeUndefined::Null, - None, - ) - .await; - } - } - } - } - - tokio::select! { - // Update instances every 2 minutes - _ = sleep(Duration::from_secs(120)) => {} - // Or when asked by user - Ok(_) = rx.recv() => {} - }; - } - } - }); - Ok(library) } @@ -742,7 +624,7 @@ impl Libraries { async fn sync_rx_actor( library: Arc, node: Arc, - mut sync_rx: broadcast::Receiver, + mut sync_rx: broadcast::Receiver, ) { loop { let Ok(msg) = sync_rx.recv().await else { @@ -751,12 +633,63 @@ async fn sync_rx_actor( match msg { // TODO: Any sync event invalidates the entire React Query cache this is a hacky workaround until the new invalidation system. - SyncMessage::Ingested => node.emit(CoreEvent::InvalidateOperation( + SyncEvent::Ingested => node.emit(CoreEvent::InvalidateOperation( InvalidateOperationEvent::all(), )), - SyncMessage::Created => { + SyncEvent::Created => { p2p::sync::originator(library.clone(), &library.sync, &node.p2p).await } } } } + +async fn special_sync_indexes(db: &PrismaClient) -> Result<(), LibraryManagerError> { + async fn create_index( + db: &PrismaClient, + model_id: ModelId, + model_name: &str, + ) -> Result<(), LibraryManagerError> { + db._execute_raw(Raw::new( + &format!( + "CREATE INDEX IF NOT EXISTS partial_index_model_{model_name} \ + ON crdt_operation(model,record_id,kind,timestamp) \ + WHERE model = {model_id} + " + ), + vec![], + )) + .exec() + .await?; + + debug!(model_name, "Created sync partial index"); + + Ok(()) + } + + for (model_id, model_name) in [ + (prisma_sync::device::MODEL_ID, prisma::device::NAME), + ( + prisma_sync::storage_statistics::MODEL_ID, + prisma::storage_statistics::NAME, + ), + (prisma_sync::tag::MODEL_ID, prisma::tag::NAME), + (prisma_sync::location::MODEL_ID, prisma::location::NAME), + (prisma_sync::object::MODEL_ID, prisma::object::NAME), + (prisma_sync::label::MODEL_ID, prisma::label::NAME), + (prisma_sync::exif_data::MODEL_ID, prisma::exif_data::NAME), + (prisma_sync::file_path::MODEL_ID, prisma::file_path::NAME), + ( + prisma_sync::tag_on_object::MODEL_ID, + prisma::tag_on_object::NAME, + ), + ( + prisma_sync::label_on_object::MODEL_ID, + prisma::label_on_object::NAME, + ), + ] { + // Creating indexes sequentially just in case + create_index(db, model_id, model_name).await?; + } + + Ok(()) +} diff --git a/core/src/location/manager/runner.rs b/core/src/location/manager/runner.rs index 1daa383ce..bf769c191 100644 --- a/core/src/location/manager/runner.rs +++ b/core/src/location/manager/runner.rs @@ -38,14 +38,16 @@ type LocationIdAndLibraryId = (location::id::Type, LibraryId); struct Runner { node: Arc, + device_pub_id_to_db: Vec, locations_to_check: HashMap>, locations_watched: HashMap, locations_unwatched: HashMap, forced_unwatch: HashSet, } impl Runner { - fn new(node: Arc) -> Self { + async fn new(node: Arc) -> Self { Self { + device_pub_id_to_db: node.config.get().await.id.to_db(), node, locations_to_check: HashMap::new(), locations_watched: HashMap::new(), @@ -54,13 +56,20 @@ impl Runner { } } + fn check_same_device(&self, location: &location_ids_and_path::Data) -> bool { + location + .device + .as_ref() + .is_some_and(|device| device.pub_id == self.device_pub_id_to_db) + } + async fn add_location( &mut self, location_id: i32, library: Arc, ) -> Result<(), LocationManagerError> { if let Some(location) = get_location(location_id, &library).await? { - check_online(&location, &self.node, &library) + check_online(&location, &self.node, &library, &self.device_pub_id_to_db) .await .and_then(|is_online| { LocationWatcher::new(location, Arc::clone(&library), Arc::clone(&self.node)) @@ -92,8 +101,7 @@ impl Runner { let key = (location_id, library.id); if let Some(location) = get_location(location_id, &library).await? { - // TODO(N): This isn't gonna work with removable media and this will likely permanently break if the DB is restored from a backup. - if location.instance_id == Some(library.config().await.instance_id) { + if self.check_same_device(&location) { self.unwatch_location(location, library.id); self.locations_unwatched.remove(&key); self.forced_unwatch.remove(&key); @@ -101,7 +109,7 @@ impl Runner { self.drop_location( location_id, library.id, - "Dropping location from location manager, because we don't have a `local_path` anymore", + "Dropping location from location manager, because it isn't from this device", ); } } else { @@ -298,9 +306,8 @@ impl Runner { let key = (location_id, library.id); if let Some(location) = get_location(location_id, &library).await? { - // TODO(N): This isn't gonna work with removable media and this will likely permanently break if the DB is restored from a backup. - if location.instance_id == Some(library.config().await.instance_id) { - if check_online(&location, &self.node, &library).await? + if self.check_same_device(&location) { + if check_online(&location, &self.node, &library, &self.device_pub_id_to_db).await? && !self.forced_unwatch.contains(&key) { self.watch_location(location, library.id); @@ -314,7 +321,7 @@ impl Runner { location_id, library.id, "Dropping location from location manager, because \ - it isn't a location in the current node", + it isn't a location in the current device", ); self.forced_unwatch.remove(&key); } @@ -344,7 +351,7 @@ pub(super) async fn run( let mut check_locations_interval = interval(Duration::from_secs(2)); check_locations_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); - let mut runner = Runner::new(node); + let mut runner = Runner::new(node).await; let mut msg_stream = pin!(( location_management_rx.map(StreamMessage::LocationManagementMessage), @@ -410,20 +417,23 @@ async fn get_location( fields(%location_id, library_id = %library.id), err, )] -pub(super) async fn check_online( +async fn check_online( location_ids_and_path::Data { id: location_id, pub_id, - instance_id, + device, path, }: &location_ids_and_path::Data, node: &Node, library: &Library, + device_pub_id_to_db: &[u8], ) -> Result { let pub_id = Uuid::from_slice(pub_id)?; - // TODO(N): This isn't gonna work with removable media and this will likely permanently break if the DB is restored from a backup. - if *instance_id == Some(library.config().await.instance_id) { + if device + .as_ref() + .is_some_and(|device| device.pub_id == device_pub_id_to_db) + { match fs::metadata(maybe_missing(path, "location.path")?).await { Ok(_) => { node.locations.add_online(pub_id).await; diff --git a/core/src/location/manager/watcher/android.rs b/core/src/location/manager/watcher/android.rs index 01bd8a2a1..723f2e076 100644 --- a/core/src/location/manager/watcher/android.rs +++ b/core/src/location/manager/watcher/android.rs @@ -27,6 +27,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -40,9 +41,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self { + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self + where + Self: Sized, + { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -182,6 +192,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/ios.rs b/core/src/location/manager/watcher/ios.rs index 3a9c91500..25f0a49fd 100644 --- a/core/src/location/manager/watcher/ios.rs +++ b/core/src/location/manager/watcher/ios.rs @@ -33,6 +33,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -48,12 +49,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -183,6 +190,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/linux.rs b/core/src/location/manager/watcher/linux.rs index 0ec459a3c..34d37ed15 100644 --- a/core/src/location/manager/watcher/linux.rs +++ b/core/src/location/manager/watcher/linux.rs @@ -32,6 +32,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -45,9 +46,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self { + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self + where + Self: Sized, + { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -187,6 +197,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/macos.rs b/core/src/location/manager/watcher/macos.rs index 11486cd20..4d3b1ffec 100644 --- a/core/src/location/manager/watcher/macos.rs +++ b/core/src/location/manager/watcher/macos.rs @@ -42,6 +42,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -57,12 +58,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -206,6 +213,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/manager/watcher/mod.rs b/core/src/location/manager/watcher/mod.rs index 48935f4d3..d63709740 100644 --- a/core/src/location/manager/watcher/mod.rs +++ b/core/src/location/manager/watcher/mod.rs @@ -4,7 +4,7 @@ use sd_core_indexer_rules::{IndexerRule, IndexerRuler}; use sd_core_prisma_helpers::{location_ids_and_path, location_with_indexer_rules}; use sd_prisma::prisma::{location, PrismaClient}; -use sd_utils::db::maybe_missing; +use sd_utils::{db::maybe_missing, uuid_to_bytes}; use std::{ collections::HashSet, @@ -67,6 +67,8 @@ type Handler = ios::EventHandler; pub(super) type IgnorePath = (PathBuf, bool); type INode = u64; + +#[cfg(any(target_os = "ios", target_os = "macos", target_os = "windows"))] type InstantAndPath = (Instant, PathBuf); const ONE_SECOND: Duration = Duration::from_secs(1); @@ -74,7 +76,12 @@ const THIRTY_SECONDS: Duration = Duration::from_secs(30); const HUNDRED_MILLIS: Duration = Duration::from_millis(100); trait EventHandler: 'static { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized; @@ -198,7 +205,12 @@ impl LocationWatcher { Stop, } - let mut event_handler = Handler::new(location_id, Arc::clone(&library), Arc::clone(&node)); + let mut event_handler = Handler::new( + location_id, + uuid_to_bytes(&location_pub_id), + Arc::clone(&library), + Arc::clone(&node), + ); let mut last_event_at = Instant::now(); diff --git a/core/src/location/manager/watcher/utils.rs b/core/src/location/manager/watcher/utils.rs index e6380dd5e..88b065810 100644 --- a/core/src/location/manager/watcher/utils.rs +++ b/core/src/location/manager/watcher/utils.rs @@ -27,21 +27,23 @@ use sd_core_indexer_rules::{ seed::{GitIgnoreRules, GITIGNORE}, IndexerRuler, RulerDecision, }; -use sd_core_prisma_helpers::{file_path_with_object, object_ids, CasId, ObjectPubId}; +use sd_core_prisma_helpers::{ + file_path_watcher_remove, file_path_with_object, object_ids, CasId, ObjectPubId, +}; use sd_file_ext::{ extensions::{AudioExtension, ImageExtension, VideoExtension}, kind::ObjectKind, }; use sd_prisma::{ - prisma::{file_path, location, object}, + prisma::{device, file_path, location, object}, prisma_sync, }; -use sd_sync::OperationFactory; +use sd_sync::{option_sync_db_entry, sync_db_entry, sync_entry, OperationFactory}; use sd_utils::{ - db::{inode_from_db, inode_to_db, maybe_missing}, + chain_optional_iter, + db::{inode_from_db, inode_to_db, maybe_missing, size_in_bytes_to_db}, error::FileIOError, - msgpack, }; #[cfg(target_family = "unix")] @@ -352,28 +354,35 @@ async fn inner_create_file( DateTime::::from(fs_metadata.created_or_now()).into(); let int_kind = kind as i32; - sync.write_ops( - db, + let device_pub_id = sync.device_pub_id.to_db(); + + let (sync_params, db_params) = [ + sync_db_entry!(date_created, object::date_created), + sync_db_entry!(int_kind, object::kind), ( - sync.shared_create( - prisma_sync::object::SyncId { - pub_id: pub_id.to_db(), + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() }, - [ - (object::date_created::NAME, msgpack!(date_created)), - (object::kind::NAME, msgpack!(int_kind)), - ], + object::device ), - db.object() - .create( - pub_id.into(), - vec![ - object::date_created::set(Some(date_created)), - object::kind::set(Some(int_kind)), - ], - ) - .select(object_ids::select()), + object::device::connect(device::pub_id::equals(device_pub_id)), ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + sync.write_op( + db, + sync.shared_create( + prisma_sync::object::SyncId { + pub_id: pub_id.to_db(), + }, + sync_params, + ), + db.object() + .create(pub_id.into(), db_params) + .select(object_ids::select()), ) .await? }; @@ -384,17 +393,21 @@ async fn inner_create_file( prisma_sync::location::SyncId { pub_id: created_file.pub_id.clone(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: object_pub_id.clone() - }), - ), - db.file_path().update( - file_path::pub_id::equals(created_file.pub_id.clone()), - vec![file_path::object::connect(object::pub_id::equals( - object_pub_id.clone(), - ))], + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: object_pub_id.clone() + }, + file_path::object + )], ), + db.file_path() + .update( + file_path::pub_id::equals(created_file.pub_id.clone()), + vec![file_path::object::connect(object::pub_id::equals( + object_pub_id.clone(), + ))], + ) + .select(file_path::select!({ id })), ) .await?; @@ -583,34 +596,22 @@ async fn inner_update_file( let is_hidden = path_is_hidden(full_path, &fs_metadata); if file_path.cas_id.as_deref() != cas_id.as_ref().map(CasId::as_str) { - let (sync_params, db_params): (Vec<_>, Vec<_>) = { - use file_path::*; - + let (sync_params, db_params) = chain_optional_iter( [ - ( - (cas_id::NAME, msgpack!(file_path.cas_id)), - Some(cas_id::set(file_path.cas_id.clone())), + sync_db_entry!( + size_in_bytes_to_db(fs_metadata.len()), + file_path::size_in_bytes_bytes ), - ( - ( - size_in_bytes_bytes::NAME, - msgpack!(fs_metadata.len().to_be_bytes().to_vec()), - ), - Some(size_in_bytes_bytes::set(Some( - fs_metadata.len().to_be_bytes().to_vec(), - ))), + sync_db_entry!( + DateTime::::from(fs_metadata.modified_or_now()), + file_path::date_modified ), - { - let date = DateTime::::from(fs_metadata.modified_or_now()).into(); - - ( - (date_modified::NAME, msgpack!(date)), - Some(date_modified::set(Some(date))), - ) - }, - { - // TODO: Should this be a skip rather than a null-set? - let checksum = if file_path.integrity_checksum.is_some() { + ], + [ + option_sync_db_entry!(file_path.cas_id.clone(), file_path::cas_id), + option_sync_db_entry!( + if file_path.integrity_checksum.is_some() { + // TODO: Should this be a skip rather than a null-set? // If a checksum was already computed, we need to recompute it Some( file_checksum(full_path) @@ -619,62 +620,37 @@ async fn inner_update_file( ) } else { None - }; - - ( - (integrity_checksum::NAME, msgpack!(checksum)), - Some(integrity_checksum::set(checksum)), - ) - }, - { - if current_inode != inode { - ( - (inode::NAME, msgpack!(inode)), - Some(inode::set(Some(inode_to_db(inode)))), - ) - } else { - ((inode::NAME, msgpack!(nil)), None) - } - }, - { - if is_hidden != file_path.hidden.unwrap_or_default() { - ( - (hidden::NAME, msgpack!(inode)), - Some(hidden::set(Some(is_hidden))), - ) - } else { - ((hidden::NAME, msgpack!(nil)), None) - } - }, - ] - .into_iter() - .filter_map(|(sync_param, maybe_db_param)| { - maybe_db_param.map(|db_param| (sync_param, db_param)) - }) - .unzip() - }; + }, + file_path::integrity_checksum + ), + option_sync_db_entry!( + (current_inode != inode).then(|| inode_to_db(inode)), + file_path::inode + ), + option_sync_db_entry!( + (is_hidden != file_path.hidden.unwrap_or_default()).then_some(is_hidden), + file_path::hidden + ), + ], + ) + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); // file content changed - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), - }, - field, - value, - ) - }) - .collect(), - db.file_path().update( + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + sync_params, + ), + db.file_path() + .update( file_path::pub_id::equals(file_path.pub_id.clone()), db_params, - ), - ), + ) + .select(file_path::select!({ id })), ) .await?; @@ -688,19 +664,18 @@ async fn inner_update_file( .await? == 1 { if object.kind.map(|k| k != int_kind).unwrap_or_default() { + let (sync_param, db_param) = sync_db_entry!(int_kind, object::kind); sync.write_op( db, sync.shared_update( prisma_sync::object::SyncId { pub_id: object.pub_id.clone(), }, - object::kind::NAME, - msgpack!(int_kind), - ), - db.object().update( - object::id::equals(object.id), - vec![object::kind::set(Some(int_kind))], + [sync_param], ), + db.object() + .update(object::id::equals(object.id), vec![db_param]) + .select(object::select!({ id })), ) .await?; } @@ -709,26 +684,33 @@ async fn inner_update_file( let date_created: DateTime = DateTime::::from(fs_metadata.created_or_now()).into(); - sync.write_ops( - db, + let device_pub_id = sync.device_pub_id.to_db(); + + let (sync_params, db_params) = [ + sync_db_entry!(date_created, object::date_created), + sync_db_entry!(int_kind, object::kind), ( - sync.shared_create( - prisma_sync::object::SyncId { - pub_id: pub_id.to_db(), + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() }, - [ - (object::date_created::NAME, msgpack!(date_created)), - (object::kind::NAME, msgpack!(int_kind)), - ], - ), - db.object().create( - pub_id.to_db(), - vec![ - object::date_created::set(Some(date_created)), - object::kind::set(Some(int_kind)), - ], + object::device ), + object::device::connect(device::pub_id::equals(device_pub_id)), ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + sync.write_op( + db, + sync.shared_create( + prisma_sync::object::SyncId { + pub_id: pub_id.to_db(), + }, + sync_params, + ), + db.object().create(pub_id.to_db(), db_params), ) .await?; @@ -738,17 +720,21 @@ async fn inner_update_file( prisma_sync::location::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::object::NAME, - msgpack!(prisma_sync::object::SyncId { - pub_id: pub_id.to_db() - }), - ), - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::object::connect(object::pub_id::equals( - pub_id.into(), - ))], + [sync_entry!( + prisma_sync::object::SyncId { + pub_id: pub_id.to_db() + }, + file_path::object + )], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![file_path::object::connect(object::pub_id::equals( + pub_id.into(), + ))], + ) + .select(file_path::select!({ id })), ) .await?; } @@ -856,21 +842,22 @@ async fn inner_update_file( invalidate_query!(library, "search.paths"); invalidate_query!(library, "search.objects"); } else if is_hidden != file_path.hidden.unwrap_or_default() { - sync.write_ops( + let (sync_param, db_param) = sync_db_entry!(is_hidden, file_path::hidden); + + sync.write_op( db, - ( - vec![sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), - }, - file_path::hidden::NAME, - msgpack!(is_hidden), - )], - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::hidden::set(Some(is_hidden))], - ), + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + [sync_param], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![db_param], + ) + .select(file_path::select!({ id })), ) .await?; @@ -954,7 +941,7 @@ pub(super) async fn rename( .await?; let total_paths_count = paths.len(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = paths + let (sync_params, db_params) = paths .into_iter() .filter_map(|path| path.materialized_path.map(|mp| (path.id, path.pub_id, mp))) .map(|(id, pub_id, mp)| { @@ -963,75 +950,55 @@ pub(super) async fn rename( &format!("{}/{}/", new_parts.materialized_path, new_parts.name), ); + let (sync_param, db_param) = + sync_db_entry!(new_path, file_path::materialized_path); + ( sync.shared_update( sd_prisma::prisma_sync::file_path::SyncId { pub_id }, - file_path::materialized_path::NAME, - msgpack!(&new_path), - ), - db.file_path().update( - file_path::id::equals(id), - vec![file_path::materialized_path::set(Some(new_path))], + [sync_param], ), + db.file_path() + .update(file_path::id::equals(id), vec![db_param]) + .select(file_path::select!({ id })), ) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops(db, (sync_params, db_params)).await?; + if !sync_params.is_empty() && !db_params.is_empty() { + sync.write_ops(db, (sync_params, db_params)).await?; + } trace!(%total_paths_count, "Updated file_paths;"); } - let is_hidden = path_is_hidden(new_path, &new_path_metadata); - - let date_modified = DateTime::::from(new_path_metadata.modified_or_now()).into(); - - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - ( - ( - file_path::materialized_path::NAME, - msgpack!(new_path_materialized_str), - ), - file_path::materialized_path::set(Some(new_path_materialized_str)), + let (sync_params, db_params) = [ + sync_db_entry!(new_path_materialized_str, file_path::materialized_path), + sync_db_entry!(new_parts.name.to_string(), file_path::name), + sync_db_entry!(new_parts.extension.to_string(), file_path::extension), + sync_db_entry!( + DateTime::::from(new_path_metadata.modified_or_now()), + file_path::date_modified ), - ( - (file_path::name::NAME, msgpack!(new_parts.name)), - file_path::name::set(Some(new_parts.name.to_string())), - ), - ( - (file_path::extension::NAME, msgpack!(new_parts.extension)), - file_path::extension::set(Some(new_parts.extension.to_string())), - ), - ( - (file_path::date_modified::NAME, msgpack!(&date_modified)), - file_path::date_modified::set(Some(date_modified)), - ), - ( - (file_path::hidden::NAME, msgpack!(is_hidden)), - file_path::hidden::set(Some(is_hidden)), + sync_db_entry!( + path_is_hidden(new_path, &new_path_metadata), + file_path::hidden ), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|(k, v)| { - sync.shared_update( - prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), - }, - k, - v, - ) - }) - .collect(), - db.file_path() - .update(file_path::pub_id::equals(file_path.pub_id), db_params), + sync.shared_update( + prisma_sync::file_path::SyncId { + pub_id: file_path.pub_id.clone(), + }, + sync_params, ), + db.file_path() + .update(file_path::pub_id::equals(file_path.pub_id), db_params) + .select(file_path::select!({ id })), ) .await?; @@ -1060,19 +1027,20 @@ pub(super) async fn remove( &location_path, full_path, )?) + .select(file_path_watcher_remove::select()) .exec() .await? else { return Ok(()); }; - remove_by_file_path(location_id, full_path, &file_path, library).await + remove_by_file_path(location_id, full_path, file_path, library).await } async fn remove_by_file_path( location_id: location::id::Type, path: impl AsRef + Send, - file_path: &file_path::Data, + file_path: file_path_watcher_remove::Data, library: &Library, ) -> Result<(), LocationManagerError> { // check file still exists on disk @@ -1096,28 +1064,42 @@ async fn remove_by_file_path( delete_directory( library, location_id, - Some(&IsolatedFilePathData::try_from(file_path)?), + Some(&IsolatedFilePathData::try_from(&file_path)?), ) .await?; } else { sync.write_op( db, sync.shared_delete(prisma_sync::file_path::SyncId { - pub_id: file_path.pub_id.clone(), + pub_id: file_path.pub_id, }), db.file_path().delete(file_path::id::equals(file_path.id)), ) .await?; - if let Some(object_id) = file_path.object_id { - db.object() - .delete_many(vec![ - object::id::equals(object_id), + if let Some(object) = file_path.object { + // If this object doesn't have any other file paths, delete it + if db + .object() + .count(vec![ + object::id::equals(object.id), // https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#none object::file_paths::none(vec![]), ]) .exec() + .await? == 1 + { + sync.write_op( + db, + sync.shared_delete(prisma_sync::object::SyncId { + pub_id: object.pub_id, + }), + db.object() + .delete(object::id::equals(object.id)) + .select(object::select!({ id })), + ) .await?; + } } } } @@ -1186,6 +1168,7 @@ pub(super) async fn recalculate_directories_size( candidates: &mut HashMap, buffer: &mut Vec<(PathBuf, Instant)>, location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: &Library, ) -> Result<(), LocationManagerError> { let mut location_path_cache = None; @@ -1244,7 +1227,7 @@ pub(super) async fn recalculate_directories_size( } if should_update_location_size { - update_location_size(location_id, library).await?; + update_location_size(location_id, location_pub_id, library).await?; } if should_invalidate { diff --git a/core/src/location/manager/watcher/windows.rs b/core/src/location/manager/watcher/windows.rs index a9b24c54c..bd85693e8 100644 --- a/core/src/location/manager/watcher/windows.rs +++ b/core/src/location/manager/watcher/windows.rs @@ -39,6 +39,7 @@ use super::{ #[derive(Debug)] pub(super) struct EventHandler { location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: Arc, node: Arc, last_events_eviction_check: Instant, @@ -54,12 +55,18 @@ pub(super) struct EventHandler { } impl super::EventHandler for EventHandler { - fn new(location_id: location::id::Type, library: Arc, node: Arc) -> Self + fn new( + location_id: location::id::Type, + location_pub_id: location::pub_id::Type, + library: Arc, + node: Arc, + ) -> Self where Self: Sized, { Self { location_id, + location_pub_id, library, node, last_events_eviction_check: Instant::now(), @@ -277,6 +284,7 @@ impl super::EventHandler for EventHandler { &mut self.to_recalculate_size, &mut self.path_and_instant_buffer, self.location_id, + self.location_pub_id.clone(), &self.library, ) .await diff --git a/core/src/location/mod.rs b/core/src/location/mod.rs index 79ed55466..e89639285 100644 --- a/core/src/location/mod.rs +++ b/core/src/location/mod.rs @@ -13,14 +13,14 @@ use sd_core_heavy_lifting::{ use sd_core_prisma_helpers::{location_with_indexer_rules, CasId}; use sd_prisma::{ - prisma::{file_path, indexer_rules_in_location, location, PrismaClient}, + prisma::{device, file_path, indexer_rules_in_location, instance, location, PrismaClient}, prisma_sync, }; use sd_sync::*; use sd_utils::{ - db::{maybe_missing, MissingFieldError}, + db::{maybe_missing, size_in_bytes_from_db, size_in_bytes_to_db}, error::{FileIOError, NonUtf8PathError}, - msgpack, uuid_to_bytes, + uuid_to_bytes, }; use std::{ @@ -163,7 +163,7 @@ impl LocationCreateArgs { } ); - let uuid = Uuid::new_v4(); + let uuid = Uuid::now_v7(); let location = create_location( library, @@ -246,7 +246,7 @@ impl LocationCreateArgs { }, ); - let uuid = Uuid::new_v4(); + let uuid = Uuid::now_v7(); let location = create_location( library, @@ -304,63 +304,36 @@ impl LocationUpdateArgs { let name = self.name.clone(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ - self.name - .filter(|name| location.name.as_ref() != Some(name)) - .map(|v| { - ( - (location::name::NAME, msgpack!(v)), - location::name::set(Some(v)), - ) - }), - self.generate_preview_media.map(|v| { - ( - (location::generate_preview_media::NAME, msgpack!(v)), - location::generate_preview_media::set(Some(v)), - ) - }), - self.sync_preview_media.map(|v| { - ( - (location::sync_preview_media::NAME, msgpack!(v)), - location::sync_preview_media::set(Some(v)), - ) - }), - self.hidden.map(|v| { - ( - (location::hidden::NAME, msgpack!(v)), - location::hidden::set(Some(v)), - ) - }), - self.path.clone().map(|v| { - ( - (location::path::NAME, msgpack!(v)), - location::path::set(Some(v)), - ) - }), + let (sync_params, db_params) = [ + option_sync_db_entry!( + self.name + .filter(|name| location.name.as_ref() != Some(name)), + location::name + ), + option_sync_db_entry!( + self.generate_preview_media, + location::generate_preview_media + ), + option_sync_db_entry!(self.sync_preview_media, location::sync_preview_media), + option_sync_db_entry!(self.hidden, location::hidden), + option_sync_db_entry!(self.path.clone(), location::path), ] .into_iter() .flatten() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); if !sync_params.is_empty() { - sync.write_ops( + sync.write_op( db, - ( - sync_params - .into_iter() - .map(|p| { - sync.shared_update( - prisma_sync::location::SyncId { - pub_id: location.pub_id.clone(), - }, - p.0, - p.1, - ) - }) - .collect(), - db.location() - .update(location::id::equals(self.id), db_params), + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location.pub_id.clone(), + }, + sync_params, ), + db.location() + .update(location::id::equals(self.id), db_params) + .select(location::select!({ id })), ) .await?; @@ -493,6 +466,7 @@ pub async fn scan_location( ) .await? } + ScanState::Indexed => { node.job_system .dispatch( @@ -505,6 +479,7 @@ pub async fn scan_location( ) .await? } + ScanState::FilesIdentified => { node.job_system .dispatch( @@ -651,33 +626,25 @@ pub async fn relink_location( .map(str::to_string) .ok_or_else(|| NonUtf8PathError(location_path.into()))?; - sync.write_op( - db, - sync.shared_update( - prisma_sync::location::SyncId { - pub_id: pub_id.clone(), - }, - location::path::NAME, - msgpack!(path), - ), - db.location().update( - location::pub_id::equals(pub_id.clone()), - vec![location::path::set(Some(path))], - ), - ) - .await?; + let (sync_param, db_param) = sync_db_entry!(path, location::path); - let location_id = db - .location() - .find_unique(location::pub_id::equals(pub_id)) - .select(location::select!({ id })) - .exec() + let location_id = sync + .write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: pub_id.clone(), + }, + [sync_param], + ), + db.location() + .update(location::pub_id::equals(pub_id.clone()), vec![db_param]) + .select(location::select!({ id })), + ) .await? - .ok_or_else(|| { - LocationError::MissingField(MissingFieldError::new("missing id of location")) - })?; + .id; - Ok(location_id.id) + Ok(location_id) } #[derive(Debug)] @@ -766,58 +733,56 @@ async fn create_location( return Ok(None); } - let date_created = Utc::now(); - - let location = sync - .write_ops( - db, - ( - sync.shared_create( - prisma_sync::location::SyncId { - pub_id: location_pub_id.as_bytes().to_vec(), - }, - [ - (location::name::NAME, msgpack!(&name)), - (location::path::NAME, msgpack!(&path)), - (location::date_created::NAME, msgpack!(date_created)), - // ( - // location::instance::NAME, - // msgpack!(prisma_sync::instance::SyncId { - // pub_id: uuid_to_bytes(sync.instance) - // }), - // ), - ], - ), - db.location() - .create( - location_pub_id.as_bytes().to_vec(), - vec![ - location::name::set(Some(name.clone())), - location::path::set(Some(path)), - location::date_created::set(Some(date_created.into())), - location::instance_id::set(Some(library.config().await.instance_id)), - // location::instance::connect(instance::id::equals( - // library.config.instance_id.as_bytes().to_vec(), - // )), - ], - ) - .include(location_with_indexer_rules::include()), + let (sync_values, mut db_params) = [ + sync_db_entry!(&name, location::name), + sync_db_entry!(path, location::path), + sync_db_entry!(Utc::now(), location::date_created), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: sync.device_pub_id.to_db() + }, + location::device ), + location::device::connect(device::pub_id::equals(sync.device_pub_id.to_db())), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + // temporary workaround until we remove instances from locations + db_params.push(location::instance::connect(instance::id::equals( + library.config().await.instance_id, + ))); + + let location_id = sync + .write_op( + db, + sync.shared_create( + prisma_sync::location::SyncId { + pub_id: uuid_to_bytes(&location_pub_id), + }, + sync_values, + ), + db.location() + .create(uuid_to_bytes(&location_pub_id), db_params) + .select(location::select!({ id })), ) - .await?; + .await? + .id; debug!("New location created in db"); if !indexer_rules_ids.is_empty() { - link_location_and_indexer_rules(library, location.id, indexer_rules_ids).await?; + link_location_and_indexer_rules(library, location_id, indexer_rules_ids).await?; } // Updating our location variable to include information about the indexer rules - let location = find_location(library, location.id) + let location = find_location(library, location_id) .include(location_with_indexer_rules::include()) .exec() .await? - .ok_or(LocationError::IdNotFound(location.id))?; + .ok_or(LocationError::IdNotFound(location_id))?; invalidate_query!(library, "locations.list"); @@ -915,11 +880,9 @@ pub async fn delete_directory( library: &Library, location_id: location::id::Type, parent_iso_file_path: Option<&IsolatedFilePathData<'_>>, -) -> Result<(), QueryError> { +) -> Result<(), sd_core_sync::Error> { let Library { db, .. } = library; - // This is NOT sync-compatible! - // Sync requires having sync ids available. let children_params = sd_utils::chain_optional_iter( [file_path::location_id::equals(Some(location_id))], [parent_iso_file_path.and_then(|parent| { @@ -934,7 +897,39 @@ pub async fn delete_directory( })], ); - db.file_path().delete_many(children_params).exec().await?; + let pub_ids = library + .db + .file_path() + .find_many(children_params.clone()) + .select(file_path::select!({ pub_id })) + .exec() + .await? + .into_iter() + .map(|fp| fp.pub_id) + .collect::>(); + + if pub_ids.is_empty() { + debug!("No file paths to delete"); + return Ok(()); + } + + library + .sync + .write_ops( + &library.db, + ( + pub_ids + .into_iter() + .map(|pub_id| { + library + .sync + .shared_delete(prisma_sync::file_path::SyncId { pub_id }) + }) + .collect(), + db.file_path().delete_many(children_params), + ), + ) + .await?; // library.orphan_remover.invoke().await; @@ -1004,45 +999,44 @@ async fn check_nested_location( #[instrument(skip_all, err)] pub async fn update_location_size( location_id: location::id::Type, + location_pub_id: location::pub_id::Type, library: &Library, -) -> Result<(), QueryError> { - let Library { db, .. } = library; +) -> Result<(), sd_core_sync::Error> { + let Library { db, sync, .. } = library; - let total_size = db - .file_path() - .find_many(vec![ - file_path::location_id::equals(Some(location_id)), - file_path::materialized_path::equals(Some("/".to_string())), - ]) - .select(file_path::select!({ size_in_bytes_bytes })) - .exec() - .await? - .into_iter() - .filter_map(|file_path| { - file_path.size_in_bytes_bytes.map(|size_in_bytes_bytes| { - u64::from_be_bytes([ - size_in_bytes_bytes[0], - size_in_bytes_bytes[1], - size_in_bytes_bytes[2], - size_in_bytes_bytes[3], - size_in_bytes_bytes[4], - size_in_bytes_bytes[5], - size_in_bytes_bytes[6], - size_in_bytes_bytes[7], - ]) + let total_size = size_in_bytes_to_db( + db.file_path() + .find_many(vec![ + file_path::location_id::equals(Some(location_id)), + file_path::materialized_path::equals(Some("/".to_string())), + ]) + .select(file_path::select!({ size_in_bytes_bytes })) + .exec() + .await? + .into_iter() + .filter_map(|file_path| { + file_path + .size_in_bytes_bytes + .map(|size_in_bytes_bytes| size_in_bytes_from_db(&size_in_bytes_bytes)) }) - }) - .sum::(); + .sum::(), + ); - db.location() - .update( - location::id::equals(location_id), - vec![location::size_in_bytes::set(Some( - total_size.to_be_bytes().to_vec(), - ))], - ) - .exec() - .await?; + let (sync_param, db_param) = sync_db_entry!(total_size, location::size_in_bytes); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::location::SyncId { + pub_id: location_pub_id, + }, + [sync_param], + ), + db.location() + .update(location::id::equals(location_id), vec![db_param]) + .select(location::select!({ id })), + ) + .await?; invalidate_query!(library, "locations.list"); invalidate_query!(library, "locations.get"); @@ -1102,81 +1096,60 @@ pub async fn create_file_path( location_id, ))?; - let (sync_params, db_params): (Vec<_>, Vec<_>) = { - use file_path::*; + let device_pub_id = sync.device_pub_id.to_db(); - [ - ( - ( - location::NAME, - msgpack!(prisma_sync::location::SyncId { - pub_id: location.pub_id - }), - ), - location::connect(prisma::location::id::equals(location.id)), + let (sync_params, db_params) = [ + ( + sync_entry!( + prisma_sync::location::SyncId { + pub_id: location.pub_id + }, + file_path::location ), - ( - (cas_id::NAME, msgpack!(cas_id)), - cas_id::set(cas_id.map(Into::into)), + file_path::location::connect(prisma::location::id::equals(location.id)), + ), + ( + sync_entry!(cas_id, file_path::cas_id), + file_path::cas_id::set(cas_id.map(Into::into)), + ), + sync_db_entry!(materialized_path, file_path::materialized_path), + sync_db_entry!(name, file_path::name), + sync_db_entry!(extension, file_path::extension), + sync_db_entry!( + size_in_bytes_to_db(metadata.size_in_bytes), + file_path::size_in_bytes_bytes + ), + sync_db_entry!(inode_to_db(metadata.inode), file_path::inode), + sync_db_entry!(is_dir, file_path::is_dir), + sync_db_entry!(metadata.created_at, file_path::date_created), + sync_db_entry!(metadata.modified_at, file_path::date_modified), + sync_db_entry!(indexed_at, file_path::date_indexed), + sync_db_entry!(metadata.hidden, file_path::hidden), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + file_path::device ), - ( - (materialized_path::NAME, msgpack!(materialized_path)), - materialized_path::set(Some(materialized_path.into())), - ), - ((name::NAME, msgpack!(name)), name::set(Some(name.into()))), - ( - (extension::NAME, msgpack!(extension)), - extension::set(Some(extension.into())), - ), - ( - ( - size_in_bytes_bytes::NAME, - msgpack!(metadata.size_in_bytes.to_be_bytes().to_vec()), - ), - size_in_bytes_bytes::set(Some(metadata.size_in_bytes.to_be_bytes().to_vec())), - ), - ( - (inode::NAME, msgpack!(metadata.inode.to_le_bytes())), - inode::set(Some(inode_to_db(metadata.inode))), - ), - ((is_dir::NAME, msgpack!(is_dir)), is_dir::set(Some(is_dir))), - ( - (date_created::NAME, msgpack!(metadata.created_at)), - date_created::set(Some(metadata.created_at.into())), - ), - ( - (date_modified::NAME, msgpack!(metadata.modified_at)), - date_modified::set(Some(metadata.modified_at.into())), - ), - ( - (date_indexed::NAME, msgpack!(indexed_at)), - date_indexed::set(Some(indexed_at.into())), - ), - ( - (hidden::NAME, msgpack!(metadata.hidden)), - hidden::set(Some(metadata.hidden)), - ), - ] - .into_iter() - .unzip() - }; + file_path::device::connect(prisma::device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); - let pub_id = sd_utils::uuid_to_bytes(&Uuid::new_v4()); + let pub_id = sd_utils::uuid_to_bytes(&Uuid::now_v7()); - let created_path = sync - .write_ops( - db, - ( - sync.shared_create( - prisma_sync::file_path::SyncId { - pub_id: pub_id.clone(), - }, - sync_params, - ), - db.file_path().create(pub_id, db_params), - ), - ) - .await?; - - Ok(created_path) + sync.write_op( + db, + sync.shared_create( + prisma_sync::file_path::SyncId { + pub_id: pub_id.clone(), + }, + sync_params, + ), + db.file_path().create(pub_id, db_params), + ) + .await + .map_err(Into::into) } diff --git a/core/src/location/non_indexed.rs b/core/src/location/non_indexed.rs index 04d040f1c..5ad43989e 100644 --- a/core/src/location/non_indexed.rs +++ b/core/src/location/non_indexed.rs @@ -376,7 +376,9 @@ impl Entry { /// /// From my M1 Macbook Pro this: /// - takes 11ms per 10 000 files -/// and +/// +/// and +/// /// - consumes 0.16MB of RAM per 10 000 entries. /// /// The reason we collect these all up is so we can apply ordering, and then begin streaming the data as it's processed to the frontend. diff --git a/core/src/node/config.rs b/core/src/node/config.rs index 17098ca7c..28ddb7a4a 100644 --- a/core/src/node/config.rs +++ b/core/src/node/config.rs @@ -4,6 +4,8 @@ use crate::{ util::version_manager::{Kind, ManagedVersion, VersionManager, VersionManagerError}, }; +use sd_cloud_schema::devices::DeviceOS; +use sd_core_sync::DevicePubId; use sd_p2p::Identity; use sd_utils::error::FileIOError; @@ -26,6 +28,8 @@ use tokio::{ use tracing::error; use uuid::Uuid; +use super::HardwareModel; + /// NODE_STATE_CONFIG_NAME is the name of the file which stores the NodeState pub const NODE_STATE_CONFIG_NAME: &str = "node_state.sdconfig"; @@ -88,9 +92,10 @@ pub struct NodeConfigP2P { /// /// All of these are valid values: /// - `localhost` - /// - `otbeaumont.me` or `otbeaumont.me:3000` + /// - `spacedrive.com` or `spacedrive.com:3000` /// - `127.0.0.1` or `127.0.0.1:300` /// - `[::1]` or `[::1]:3000` + /// /// which is why we use `String` not `SocketAddr` #[serde(default)] pub manual_peers: HashSet, @@ -110,11 +115,12 @@ impl Default for NodeConfigP2P { } } -/// NodeConfig is the configuration for a node. This is shared between all libraries and is stored in a JSON file on disk. +/// NodeConfig is the configuration for a node. +/// This is shared between all libraries and is stored in a JSON file on disk. #[derive(Debug, Clone, Serialize, Deserialize)] // If you are adding `specta::Type` on this your probably about to leak the P2P private key pub struct NodeConfig { /// id is a unique identifier for the current node. Each node has a public identifier (this one) and is given a local id for each library (done within the library code). - pub id: Uuid, + pub id: DevicePubId, /// name is the display name of the current node. This is set by the user and is shown in the UI. // TODO: Length validation so it can fit in DNS record pub name: String, /// core level notifications @@ -123,6 +129,8 @@ pub struct NodeConfig { /// The p2p identity keypair for this node. This is used to identify the node on the network. /// This keypair does effectively nothing except for provide libp2p with a stable peer_id. #[serde(with = "identity_serde")] + // TODO(@fogodev): remove these from here, we must not store secret keys in plaintext... + // Put then on secret storage when we have a keyring compatible with all our supported platforms pub identity: Identity, /// P2P config #[serde(default)] @@ -130,15 +138,12 @@ pub struct NodeConfig { /// Feature flags enabled on the node #[serde(default)] pub features: Vec, - /// Authentication for Spacedrive Accounts - pub auth_token: Option, - /// URL of the Spacedrive API - #[serde(default, skip_serializing_if = "Option::is_none")] - pub sd_api_origin: Option, /// The aggregation of many different preferences for the node pub preferences: NodePreferences, - // Model version for the image labeler - pub image_labeler_version: Option, + /// Operating System of the node + pub os: DeviceOS, + /// Hardware model of the node + pub hardware_model: HardwareModel, version: NodeConfigVersion, } @@ -182,45 +187,44 @@ pub enum NodeConfigVersion { V1 = 1, V2 = 2, V3 = 3, + V4 = 4, + V5 = 5, } impl ManagedVersion for NodeConfig { - const LATEST_VERSION: NodeConfigVersion = NodeConfigVersion::V3; + const LATEST_VERSION: NodeConfigVersion = NodeConfigVersion::V5; const KIND: Kind = Kind::Json("version"); type MigrationError = NodeConfigError; fn from_latest_version() -> Option { - let mut name = match hostname::get() { - // SAFETY: This is just for display purposes so it doesn't matter if it's lossy - Ok(hostname) => hostname.to_string_lossy().into_owned(), - Err(e) => { - error!( - ?e, - "Falling back to default node name as an error occurred getting your systems hostname;", - ); + #[cfg(not(any(target_os = "ios", target_os = "android")))] + let mut name = whoami::devicename(); - "my-spacedrive".into() - } - }; - name.truncate(250); + #[cfg(target_os = "ios")] + let mut name = "iOS Device".to_string(); - #[cfg(feature = "ai")] - let image_labeler_version = Some(sd_ai::old_image_labeler::DEFAULT_MODEL_VERSION.to_string()); - #[cfg(not(feature = "ai"))] - let image_labeler_version = None; + #[cfg(target_os = "android")] + let mut name = "Android Device".to_string(); + + name.truncate(255); + + let os = DeviceOS::from_env(); + let hardware_model = HardwareModel::try_get().unwrap_or_else(|e| { + error!(?e, "Failed to get hardware model"); + HardwareModel::Other + }); Some(Self { - id: Uuid::new_v4(), + id: Uuid::now_v7().into(), name, identity: Identity::default(), p2p: NodeConfigP2P::default(), version: Self::LATEST_VERSION, features: vec![], notifications: vec![], - auth_token: None, - sd_api_origin: None, preferences: NodePreferences::default(), - image_labeler_version, + os, + hardware_model, }) } } @@ -313,6 +317,107 @@ impl NodeConfig { .map_err(|e| FileIOError::from((path, e)))?; } + (NodeConfigVersion::V3, NodeConfigVersion::V4) => { + let mut config: Map = + serde_json::from_slice(&fs::read(path).await.map_err(|e| { + FileIOError::from(( + path, + e, + "Failed to read node config file for migration", + )) + })?) + .map_err(VersionManagerError::SerdeJson)?; + + config.remove("id"); + config.insert( + String::from("id"), + serde_json::to_value(Uuid::now_v7()) + .map_err(VersionManagerError::SerdeJson)?, + ); + + config.remove("name"); + + #[cfg(not(any(target_os = "ios", target_os = "android")))] + config.insert( + String::from("name"), + serde_json::to_value(whoami::devicename()) + .map_err(VersionManagerError::SerdeJson)?, + ); + + #[cfg(target_os = "ios")] + config.insert( + String::from("name"), + serde_json::to_value("iOS Device") + .map_err(VersionManagerError::SerdeJson)?, + ); + + #[cfg(target_os = "android")] + config.insert( + String::from("name"), + serde_json::to_value("Android Device") + .map_err(VersionManagerError::SerdeJson)?, + ); + + config.insert( + String::from("os"), + serde_json::to_value(std::env::consts::OS) + .map_err(VersionManagerError::SerdeJson)?, + ); + + let a = + serde_json::to_vec(&config).map_err(VersionManagerError::SerdeJson)?; + + fs::write(path, a) + .await + .map_err(|e| FileIOError::from((path, e)))?; + } + + (NodeConfigVersion::V4, NodeConfigVersion::V5) => { + let mut config: Map = + serde_json::from_slice(&fs::read(path).await.map_err(|e| { + FileIOError::from(( + path, + e, + "Failed to read node config file for migration", + )) + })?) + .map_err(VersionManagerError::SerdeJson)?; + + config.insert( + String::from("os"), + serde_json::to_value(DeviceOS::from_env()) + .map_err(VersionManagerError::SerdeJson)?, + ); + config.insert( + String::from("hardware_model"), + serde_json::to_value( + HardwareModel::try_get().unwrap_or(HardwareModel::Other), + ) + .map_err(VersionManagerError::SerdeJson)?, + ); + + config.remove("features"); + config.remove("auth_token"); + config.remove("sd_api_origin"); + config.remove("image_labeler_version"); + + config.remove("id"); + config.insert( + String::from("id"), + serde_json::to_value(DevicePubId::from(Uuid::now_v7())) + .map_err(VersionManagerError::SerdeJson)?, + ); + + fs::write( + path, + serde_json::to_vec(&config).map_err(VersionManagerError::SerdeJson)?, + ) + .await + .map_err(|e| { + FileIOError::from((path, e, "Failed to write back updated config")) + })?; + } + _ => { error!(current_version = ?current, "Node config version is not handled;"); return Err(VersionManagerError::UnexpectedMigration { @@ -354,18 +459,7 @@ impl Manager { let data_directory_path = data_directory_path.as_ref().to_path_buf(); let config_file_path = data_directory_path.join(NODE_STATE_CONFIG_NAME); - let mut config = NodeConfig::load(&config_file_path).await?; - - #[cfg(feature = "ai")] - if config.image_labeler_version.is_none() { - config.image_labeler_version = - Some(sd_ai::old_image_labeler::DEFAULT_MODEL_VERSION.to_string()); - } - - #[cfg(not(feature = "ai"))] - { - config.image_labeler_version = None; - } + let config = NodeConfig::load(&config_file_path).await?; let (preferences_watcher_tx, _preferences_watcher_rx) = watch::channel(config.preferences.clone()); diff --git a/core/src/node/hardware.rs b/core/src/node/hardware.rs index be0370e00..b0d6625cf 100644 --- a/core/src/node/hardware.rs +++ b/core/src/node/hardware.rs @@ -1,150 +1,210 @@ -use std::io::Error; -use std::str; +use std::io; use serde::{Deserialize, Serialize}; use specta::Type; +use strum::IntoEnumIterator; use strum_macros::{Display, EnumIter}; #[repr(i32)] #[derive(Debug, Clone, Display, Copy, EnumIter, Type, Serialize, Deserialize, Eq, PartialEq)] +#[specta(rename = "CoreHardwareModel")] pub enum HardwareModel { - Other, - MacStudio, - MacBookAir, - MacBookPro, - MacBook, - MacMini, - MacPro, - IMac, - IMacPro, - IPad, - IPhone, - Simulator, - Android, + Other = 0, + MacStudio = 1, + MacBookAir = 2, + MacBookPro = 3, + MacBook = 4, + MacMini = 5, + MacPro = 6, + IMac = 7, + IMacPro = 8, + IPad = 9, + IPhone = 10, + Simulator = 11, + Android = 12, } -impl HardwareModel { - pub fn from_display_name(name: &str) -> Self { - use strum::IntoEnumIterator; - HardwareModel::iter() +impl From for HardwareModel { + fn from(value: i32) -> Self { + match value { + 1 => Self::MacStudio, + 2 => Self::MacBookAir, + 3 => Self::MacBookPro, + 4 => Self::MacBook, + 5 => Self::MacMini, + 6 => Self::MacPro, + 7 => Self::IMac, + 8 => Self::IMacPro, + 9 => Self::IPad, + 10 => Self::IPhone, + 11 => Self::Simulator, + 12 => Self::Android, + _ => Self::Other, + } + } +} + +impl From for sd_cloud_schema::devices::HardwareModel { + fn from(model: HardwareModel) -> Self { + match model { + HardwareModel::MacStudio => Self::MacStudio, + HardwareModel::MacBookAir => Self::MacBookAir, + HardwareModel::MacBookPro => Self::MacBookPro, + HardwareModel::MacBook => Self::MacBook, + HardwareModel::MacMini => Self::MacMini, + HardwareModel::MacPro => Self::MacPro, + HardwareModel::IMac => Self::IMac, + HardwareModel::IMacPro => Self::IMacPro, + HardwareModel::IPad => Self::IPad, + HardwareModel::IPhone => Self::IPhone, + HardwareModel::Simulator => Self::Simulator, + HardwareModel::Android => Self::Android, + HardwareModel::Other => Self::Other, + } + } +} + +impl From for HardwareModel { + fn from(model: sd_cloud_schema::devices::HardwareModel) -> Self { + match model { + sd_cloud_schema::devices::HardwareModel::MacStudio => Self::MacStudio, + sd_cloud_schema::devices::HardwareModel::MacBookAir => Self::MacBookAir, + sd_cloud_schema::devices::HardwareModel::MacBookPro => Self::MacBookPro, + sd_cloud_schema::devices::HardwareModel::MacBook => Self::MacBook, + sd_cloud_schema::devices::HardwareModel::MacMini => Self::MacMini, + sd_cloud_schema::devices::HardwareModel::MacPro => Self::MacPro, + sd_cloud_schema::devices::HardwareModel::IMac => Self::IMac, + sd_cloud_schema::devices::HardwareModel::IMacPro => Self::IMacPro, + sd_cloud_schema::devices::HardwareModel::IPad => Self::IPad, + sd_cloud_schema::devices::HardwareModel::IPhone => Self::IPhone, + sd_cloud_schema::devices::HardwareModel::Simulator => Self::Simulator, + sd_cloud_schema::devices::HardwareModel::Android => Self::Android, + sd_cloud_schema::devices::HardwareModel::Other => Self::Other, + } + } +} + +impl From<&str> for HardwareModel { + fn from(name: &str) -> Self { + Self::iter() .find(|&model| { model.to_string().to_lowercase().replace(' ', "") == name.to_lowercase().replace(' ', "") }) - .unwrap_or(HardwareModel::Other) + .unwrap_or(Self::Other) } } -pub fn get_hardware_model_name() -> Result { - #[cfg(target_os = "macos")] - { - use std::process::Command; +impl HardwareModel { + pub fn try_get() -> Result { + #[cfg(target_os = "macos")] + { + use std::process::Command; - let output = Command::new("system_profiler") - .arg("SPHardwareDataType") - .output()?; + let output = Command::new("system_profiler") + .arg("SPHardwareDataType") + .output()?; - if output.status.success() { - let output_str = std::str::from_utf8(&output.stdout).unwrap_or_default(); - let hardware_model = output_str - .lines() - .find(|line| line.to_lowercase().contains("model name")) - .and_then(|line| line.split_once(':')) - .map(|(_, model_name)| HardwareModel::from_display_name(model_name.trim())) - .unwrap_or(HardwareModel::Other); + if output.status.success() { + let output_str = std::str::from_utf8(&output.stdout).unwrap_or_default(); + let hardware_model = output_str + .lines() + .find(|line| line.to_lowercase().contains("model name")) + .and_then(|line| line.split_once(':')) + .map(|(_, model_name)| model_name.trim().into()) + .unwrap_or(Self::Other); - Ok(hardware_model) - } else { - Err(Error::new( - std::io::ErrorKind::Other, - format!( - "Failed to get hardware model name: {}", - String::from_utf8_lossy(&output.stderr) - ), - )) - } - } - #[cfg(target_os = "ios")] - { - use std::ffi::CString; - use std::ptr; - - extern "C" { - fn sysctlbyname( - name: *const libc::c_char, - oldp: *mut libc::c_void, - oldlenp: *mut usize, - newp: *mut libc::c_void, - newlen: usize, - ) -> libc::c_int; - } - - fn get_device_type() -> Option { - let mut size: usize = 0; - let name = CString::new("hw.machine").expect("CString::new failed"); - - // First, get the size of the buffer needed - unsafe { - sysctlbyname( - name.as_ptr(), - ptr::null_mut(), - &mut size, - ptr::null_mut(), - 0, - ); - } - - // Allocate a buffer with the correct size - let mut buffer: Vec = vec![0; size]; - - // Get the actual machine type - unsafe { - sysctlbyname( - name.as_ptr(), - buffer.as_mut_ptr() as *mut libc::c_void, - &mut size, - ptr::null_mut(), - 0, - ); - } - - // Convert the buffer to a String - let machine_type = String::from_utf8_lossy(&buffer).trim().to_string(); - - // Check if the device is an iPad or iPhone - if machine_type.starts_with("iPad") { - Some("iPad".to_string()) - } else if machine_type.starts_with("iPhone") { - Some("iPhone".to_string()) - } else if machine_type.starts_with("arm") { - Some("Simulator".to_string()) + Ok(hardware_model) } else { - None + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "Failed to get hardware model name: {}", + String::from_utf8_lossy(&output.stderr) + ), + )) + } + } + #[cfg(target_os = "ios")] + { + use std::ffi::CString; + use std::io::Error; + use std::ptr; + + extern "C" { + fn sysctlbyname( + name: *const libc::c_char, + oldp: *mut libc::c_void, + oldlenp: *mut usize, + newp: *mut libc::c_void, + newlen: usize, + ) -> libc::c_int; + } + + fn get_device_type() -> Option { + let mut size: usize = 0; + let name = CString::new("hw.machine").expect("CString::new failed"); + + // First, get the size of the buffer needed + unsafe { + sysctlbyname( + name.as_ptr(), + ptr::null_mut(), + &mut size, + ptr::null_mut(), + 0, + ); + } + + // Allocate a buffer with the correct size + let mut buffer: Vec = vec![0; size]; + + // Get the actual machine type + unsafe { + sysctlbyname( + name.as_ptr(), + buffer.as_mut_ptr() as *mut libc::c_void, + &mut size, + ptr::null_mut(), + 0, + ); + } + + // Convert the buffer to a String + let machine_type = String::from_utf8_lossy(&buffer).trim().to_string(); + + // Check if the device is an iPad or iPhone + if machine_type.starts_with("iPad") { + Some("iPad".to_string()) + } else if machine_type.starts_with("iPhone") { + Some("iPhone".to_string()) + } else if machine_type.starts_with("arm") { + Some("Simulator".to_string()) + } else { + None + } + } + + if let Some(device_type) = get_device_type() { + let hardware_model = HardwareModel::from(device_type.as_str()); + + Ok(hardware_model) + } else { + Err(Error::new( + std::io::ErrorKind::Other, + "Failed to get hardware model name", + )) } } - if let Some(device_type) = get_device_type() { - let hardware_model = HardwareModel::from_display_name(&device_type.as_str()); + #[cfg(target_os = "android")] + { + Ok(Self::Android) + } - Ok(hardware_model) - } else { - Err(Error::new( - std::io::ErrorKind::Other, - "Failed to get hardware model name", - )) + #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "android")))] + { + Ok(Self::Other) } } - - #[cfg(target_os = "android")] - { - Ok(HardwareModel::Android) - } - - #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "android")))] - { - Err(Error::new( - std::io::ErrorKind::Unsupported, - "Unsupported operating system", - )) - } } diff --git a/core/src/object/fs/old_copy.rs b/core/src/object/fs/old_copy.rs index 8b760b920..2d7b0fb70 100644 --- a/core/src/object/fs/old_copy.rs +++ b/core/src/object/fs/old_copy.rs @@ -323,8 +323,8 @@ impl StatefulJob for OldFileCopierJobInit { .await?; dirs.extend(more_dirs); - let (dir_source_file_data, dir_target_full_path): (Vec<_>, Vec<_>) = - dirs.into_iter().unzip(); + let (dir_source_file_data, dir_target_full_path) = + dirs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let step_files = dir_source_file_data .into_iter() diff --git a/core/src/object/tag/mod.rs b/core/src/object/tag/mod.rs index 41b4e88bd..34b609a83 100644 --- a/core/src/object/tag/mod.rs +++ b/core/src/object/tag/mod.rs @@ -21,28 +21,26 @@ impl TagCreateArgs { self, Library { db, sync, .. }: &Library, ) -> Result { - let pub_id = Uuid::new_v4().as_bytes().to_vec(); + let pub_id = Uuid::now_v7().as_bytes().to_vec(); - let (sync_params, db_params): (Vec<_>, Vec<_>) = [ + let (sync_params, db_params) = [ sync_db_entry!(self.name, tag::name), sync_db_entry!(self.color, tag::color), sync_db_entry!(false, tag::is_hidden), sync_db_entry!(Utc::now(), tag::date_created), ] .into_iter() - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); - sync.write_ops( + sync.write_op( db, - ( - sync.shared_create( - prisma_sync::tag::SyncId { - pub_id: pub_id.clone(), - }, - sync_params, - ), - db.tag().create(pub_id, db_params), + sync.shared_create( + prisma_sync::tag::SyncId { + pub_id: pub_id.clone(), + }, + sync_params, ), + db.tag().create(pub_id, db_params), ) .await } diff --git a/core/src/object/validation/old_validator_job.rs b/core/src/object/validation/old_validator_job.rs index d90fc56cb..7ddd42938 100644 --- a/core/src/object/validation/old_validator_job.rs +++ b/core/src/object/validation/old_validator_job.rs @@ -15,8 +15,8 @@ use sd_prisma::{ prisma::{file_path, location}, prisma_sync, }; -use sd_sync::OperationFactory; -use sd_utils::{db::maybe_missing, error::FileIOError, msgpack}; +use sd_sync::{sync_db_entry, OperationFactory}; +use sd_utils::{db::maybe_missing, error::FileIOError}; use std::{ hash::{Hash, Hasher}, @@ -157,19 +157,22 @@ impl StatefulJob for OldObjectValidatorJobInit { .await .map_err(|e| ValidatorError::FileIO(FileIOError::from((full_path, e))))?; + let (sync_param, db_param) = sync_db_entry!(checksum, file_path::integrity_checksum); + sync.write_op( db, sync.shared_update( prisma_sync::file_path::SyncId { pub_id: file_path.pub_id.clone(), }, - file_path::integrity_checksum::NAME, - msgpack!(&checksum), - ), - db.file_path().update( - file_path::pub_id::equals(file_path.pub_id.clone()), - vec![file_path::integrity_checksum::set(Some(checksum))], + [sync_param], ), + db.file_path() + .update( + file_path::pub_id::equals(file_path.pub_id.clone()), + vec![db_param], + ) + .select(file_path::select!({ id })), ) .await?; } diff --git a/core/src/old_job/manager.rs b/core/src/old_job/manager.rs index f47164759..c9e5cc892 100644 --- a/core/src/old_job/manager.rs +++ b/core/src/old_job/manager.rs @@ -320,6 +320,7 @@ impl OldJobs { job::id::equals(job.id.as_bytes().to_vec()), vec![job::status::set(Some(JobStatus::Canceled as i32))], ) + .select(job::select!({ id })) .exec() .await?; } diff --git a/core/src/old_job/report.rs b/core/src/old_job/report.rs index ed40df23d..af7333267 100644 --- a/core/src/old_job/report.rs +++ b/core/src/old_job/report.rs @@ -395,6 +395,7 @@ impl OldJobReport { job::date_completed::set(self.completed_at.map(Into::into)), ], ) + .select(job::select!({ id })) .exec() .await?; Ok(()) diff --git a/core/src/p2p/manager.rs b/core/src/p2p/manager.rs index 7dfcb95ea..96ee7a264 100644 --- a/core/src/p2p/manager.rs +++ b/core/src/p2p/manager.rs @@ -1,7 +1,7 @@ use crate::{ node::{ config::{self, P2PDiscoveryState}, - get_hardware_model_name, HardwareModel, + HardwareModel, }, p2p::{ libraries::libraries_hook, operations, sync::SyncMessage, Header, OperatingSystem, @@ -116,7 +116,8 @@ impl P2PManager { let client = reqwest::Client::new(); loop { match client - .get(format!("{}/api/p2p/relays", node.env.api_url.lock().await)) + // FIXME(@fogodev): hardcoded URL for now as I'm moving stuff around + .get(format!("{}/api/p2p/relays", "https://app.spacedrive.com")) .send() .await { @@ -207,7 +208,7 @@ impl P2PManager { PeerMetadata { name: config.name.clone(), operating_system: Some(OperatingSystem::get_os()), - device_model: Some(get_hardware_model_name().unwrap_or(HardwareModel::Other)), + device_model: Some(HardwareModel::try_get().unwrap_or(HardwareModel::Other)), version: Some(env!("CARGO_PKG_VERSION").to_string()), } .update(&mut self.p2p.metadata_mut()); diff --git a/core/src/p2p/metadata.rs b/core/src/p2p/metadata.rs index 5e03e9c7d..054eea0ca 100644 --- a/core/src/p2p/metadata.rs +++ b/core/src/p2p/metadata.rs @@ -47,7 +47,7 @@ impl PeerMetadata { .get("os") .map(|os| os.parse().map_err(|_| "Unable to parse 'OperationSystem'!")) .transpose()?, - device_model: Some(HardwareModel::from_display_name( + device_model: Some(HardwareModel::from( data.get("device_model") .map(|s| s.as_str()) .unwrap_or("Other"), diff --git a/core/src/p2p/sync/mod.rs b/core/src/p2p/sync/mod.rs index 8ec7c29c0..4832be93e 100644 --- a/core/src/p2p/sync/mod.rs +++ b/core/src/p2p/sync/mod.rs @@ -1,17 +1,13 @@ #![allow(clippy::panic, clippy::unwrap_used)] // TODO: Finish this -use crate::{ - library::Library, - sync::{self, GetOpsArgs}, -}; +use crate::library::Library; -use sd_p2p_proto::{decode, encode}; -use sd_sync::CompressedCRDTOperations; +// use sd_p2p_proto::{decode, encode}; +// use sd_sync::CompressedCRDTOperationsPerModelPerDevice; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::*; +use tokio::io::{AsyncRead, AsyncWrite}; use super::P2PManager; @@ -20,244 +16,246 @@ pub use proto::*; pub use originator::run as originator; mod originator { - use crate::p2p::Header; + // use crate::p2p::Header; use super::*; - use responder::tx as rx; - use sd_p2p_tunnel::Tunnel; + // use responder::tx as rx; + use sd_core_sync::SyncManager; + // use sd_p2p_tunnel::Tunnel; pub mod tx { - use super::*; + // use super::*; - #[derive(Debug, PartialEq)] - pub struct Operations(pub CompressedCRDTOperations); + // #[derive(Debug)] + // pub struct Operations(pub CompressedCRDTOperationsPerModelPerDevice); - impl Operations { - // TODO: Per field errors for better error handling - pub async fn from_stream( - stream: &mut (impl AsyncRead + Unpin), - ) -> std::io::Result { - Ok(Self( - rmp_serde::from_slice(&decode::buf(stream).await.unwrap()).unwrap(), - )) - } + // impl Operations { + // // TODO: Per field errors for better error handling + // pub async fn from_stream( + // stream: &mut (impl AsyncRead + Unpin), + // ) -> std::io::Result { + // Ok(Self( + // rmp_serde::from_slice(&decode::buf(stream).await.unwrap()).unwrap(), + // )) + // } - pub fn to_bytes(&self) -> Vec { - let Self(args) = self; - let mut buf = vec![]; + // pub fn to_bytes(&self) -> Vec { + // let Self(args) = self; + // let mut buf = vec![]; - // TODO: Error handling - encode::buf(&mut buf, &rmp_serde::to_vec_named(&args).unwrap()); - buf - } - } + // // TODO: Error handling + // encode::buf(&mut buf, &rmp_serde::to_vec_named(&args).unwrap()); + // buf + // } + // } - #[cfg(test)] - #[tokio::test] - async fn test() { - use sd_sync::CRDTOperation; - use uuid::Uuid; + // #[cfg(test)] + // #[tokio::test] + // async fn test() { + // use sd_sync::CRDTOperation; + // use uuid::Uuid; - { - let original = Operations(CompressedCRDTOperations::new(vec![])); + // { + // let original = Operations(CompressedCRDTOperationsPerModelPerDevice::new(vec![])); - let mut cursor = std::io::Cursor::new(original.to_bytes()); - let result = Operations::from_stream(&mut cursor).await.unwrap(); - assert_eq!(original, result); - } + // let mut cursor = std::io::Cursor::new(original.to_bytes()); + // let result = Operations::from_stream(&mut cursor).await.unwrap(); + // assert_eq!(original, result); + // } - { - let original = Operations(CompressedCRDTOperations::new(vec![CRDTOperation { - instance: Uuid::new_v4(), - timestamp: sync::NTP64(0), - record_id: rmpv::Value::Nil, - model: 0, - data: sd_sync::CRDTOperationData::create(), - }])); + // { + // let original = Operations(CompressedCRDTOperationsPerModelPerDevice::new(vec![ + // CRDTOperation { + // device_pub_id: Uuid::new_v4(), + // timestamp: sync::NTP64(0), + // record_id: rmpv::Value::Nil, + // model_id: 0, + // data: sd_sync::CRDTOperationData::create(), + // }, + // ])); - let mut cursor = std::io::Cursor::new(original.to_bytes()); - let result = Operations::from_stream(&mut cursor).await.unwrap(); - assert_eq!(original, result); - } - } + // let mut cursor = std::io::Cursor::new(original.to_bytes()); + // let result = Operations::from_stream(&mut cursor).await.unwrap(); + // assert_eq!(original, result); + // } + // } } - #[instrument(skip(sync, p2p))] + // #[instrument(skip(sync, p2p))] /// REMEMBER: This only syncs one direction! - pub async fn run( - library: Arc, - sync: &Arc, - p2p: &Arc, - ) { - for (remote_identity, peer) in p2p.get_library_instances(&library.id) { - if !peer.is_connected() { - continue; - }; + pub async fn run(_library: Arc, _sync: &SyncManager, _p2p: &Arc) { + // for (remote_identity, peer) in p2p.get_library_instances(&library.id) { + // if !peer.is_connected() { + // continue; + // }; - let sync = sync.clone(); + // let sync = sync.clone(); - let library = library.clone(); - tokio::spawn(async move { - debug!( - ?remote_identity, - %library.id, - "Alerting peer of new sync events for library;" - ); + // let library = library.clone(); + // tokio::spawn(async move { + // debug!( + // ?remote_identity, + // %library.id, + // "Alerting peer of new sync events for library;" + // ); - let mut stream = peer.new_stream().await.unwrap(); + // let mut stream = peer.new_stream().await.unwrap(); - stream.write_all(&Header::Sync.to_bytes()).await.unwrap(); + // stream.write_all(&Header::Sync.to_bytes()).await.unwrap(); - let mut tunnel = Tunnel::initiator(stream, &library.identity).await.unwrap(); + // let mut tunnel = Tunnel::initiator(stream, &library.identity).await.unwrap(); - tunnel - .write_all(&SyncMessage::NewOperations.to_bytes()) - .await - .unwrap(); - tunnel.flush().await.unwrap(); + // tunnel + // .write_all(&SyncMessage::NewOperations.to_bytes()) + // .await + // .unwrap(); + // tunnel.flush().await.unwrap(); - while let Ok(rx::MainRequest::GetOperations(args)) = - rx::MainRequest::from_stream(&mut tunnel).await - { - let ops = sync.get_ops(args).await.unwrap(); - - tunnel - .write_all(&tx::Operations(CompressedCRDTOperations::new(ops)).to_bytes()) - .await - .unwrap(); - tunnel.flush().await.unwrap(); - } - }); - } + // while let Ok(rx::MainRequest::GetOperations(GetOpsArgs { + // timestamp_per_device, + // count, + // })) = rx::MainRequest::from_stream(&mut tunnel).await + // { + // tunnel + // .write_all( + // &tx::Operations(CompressedCRDTOperationsPerModelPerDevice::new( + // sync.get_ops(count, timestamp_per_device).await.unwrap(), + // )) + // .to_bytes(), + // ) + // .await + // .unwrap(); + // tunnel.flush().await.unwrap(); + // } + // }); + // } } } pub use responder::run as responder; mod responder { - use std::pin::pin; use super::*; - use futures::StreamExt; - use originator::tx as rx; + // use futures::StreamExt; - pub mod tx { - use serde::{Deserialize, Serialize}; + // pub mod tx { + // use serde::{Deserialize, Serialize}; - use super::*; + // use super::*; - #[derive(Serialize, Deserialize, PartialEq, Debug)] - pub enum MainRequest { - GetOperations(GetOpsArgs), - Done, - } + // #[derive(Serialize, Deserialize, PartialEq, Debug)] + // pub enum MainRequest { + // GetOperations(GetOpsArgs), + // Done, + // } - impl MainRequest { - // TODO: Per field errors for better error handling - pub async fn from_stream( - stream: &mut (impl AsyncRead + Unpin), - ) -> std::io::Result { - Ok( - // TODO: Error handling - rmp_serde::from_slice(&decode::buf(stream).await.unwrap()).unwrap(), - ) - } + // impl MainRequest { + // // TODO: Per field errors for better error handling + // pub async fn from_stream( + // stream: &mut (impl AsyncRead + Unpin), + // ) -> std::io::Result { + // Ok( + // // TODO: Error handling + // rmp_serde::from_slice(&decode::buf(stream).await.unwrap()).unwrap(), + // ) + // } - pub fn to_bytes(&self) -> Vec { - let mut buf = vec![]; - // TODO: Error handling - encode::buf(&mut buf, &rmp_serde::to_vec_named(&self).unwrap()); - buf - } - } + // pub fn to_bytes(&self) -> Vec { + // let mut buf = vec![]; + // // TODO: Error handling + // encode::buf(&mut buf, &rmp_serde::to_vec_named(&self).unwrap()); + // buf + // } + // } - #[cfg(test)] - #[tokio::test] - async fn test() { - { - let original = MainRequest::GetOperations(GetOpsArgs { - clocks: vec![], - count: 0, - }); + // #[cfg(test)] + // #[tokio::test] + // async fn test() { + // { + // let original = MainRequest::GetOperations(GetOpsArgs { + // timestamp_per_device: vec![], + // count: 0, + // }); - let mut cursor = std::io::Cursor::new(original.to_bytes()); - let result = MainRequest::from_stream(&mut cursor).await.unwrap(); - assert_eq!(original, result); - } + // let mut cursor = std::io::Cursor::new(original.to_bytes()); + // let result = MainRequest::from_stream(&mut cursor).await.unwrap(); + // assert_eq!(original, result); + // } - { - let original = MainRequest::Done; + // { + // let original = MainRequest::Done; - let mut cursor = std::io::Cursor::new(original.to_bytes()); - let result = MainRequest::from_stream(&mut cursor).await.unwrap(); - assert_eq!(original, result); - } - } - } + // let mut cursor = std::io::Cursor::new(original.to_bytes()); + // let result = MainRequest::from_stream(&mut cursor).await.unwrap(); + // assert_eq!(original, result); + // } + // } + // } pub async fn run( - stream: &mut (impl AsyncRead + AsyncWrite + Unpin), - library: Arc, + _stream: &mut (impl AsyncRead + AsyncWrite + Unpin), + _library: Arc, ) -> Result<(), ()> { - use sync::ingest::*; + // use sync::ingest::*; - let ingest = &library.sync.ingest; + // let ingest = &library.sync.ingest; - ingest.event_tx.send(Event::Notification).await.unwrap(); + // ingest.event_tx.send(Event::Notification).await.unwrap(); - let mut rx = pin!(ingest.req_rx.clone()); + // let mut rx = pin!(ingest.req_rx.clone()); - while let Some(req) = rx.next().await { - const OPS_PER_REQUEST: u32 = 1000; + // while let Some(req) = rx.next().await { + // const OPS_PER_REQUEST: u32 = 1000; - let timestamps = match req { - Request::FinishedIngesting => break, - Request::Messages { timestamps, .. } => timestamps, - }; + // let timestamps = match req { + // Request::FinishedIngesting => break, + // Request::Messages { timestamps, .. } => timestamps, + // }; - debug!(?timestamps, "Getting ops for timestamps;"); + // debug!(?timestamps, "Getting ops for timestamps;"); - stream - .write_all( - &tx::MainRequest::GetOperations(sync::GetOpsArgs { - clocks: timestamps, - count: OPS_PER_REQUEST, - }) - .to_bytes(), - ) - .await - .unwrap(); - stream.flush().await.unwrap(); + // stream + // .write_all( + // &tx::MainRequest::GetOperations(sync::GetOpsArgs { + // timestamp_per_device: timestamps, + // count: OPS_PER_REQUEST, + // }) + // .to_bytes(), + // ) + // .await + // .unwrap(); + // stream.flush().await.unwrap(); - let rx::Operations(ops) = rx::Operations::from_stream(stream).await.unwrap(); + // let rx::Operations(ops) = rx::Operations::from_stream(stream).await.unwrap(); - let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>(); + // let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>(); - // FIXME: If there are exactly a multiple of OPS_PER_REQUEST operations, - // then this will bug, as we sent `has_more` as true, but we don't have - // more operations to send. + // // FIXME: If there are exactly a multiple of OPS_PER_REQUEST operations, + // // then this will bug, as we sent `has_more` as true, but we don't have + // // more operations to send. - ingest - .event_tx - .send(Event::Messages(MessagesEvent { - instance_id: library.sync.instance, - has_more: ops.len() == OPS_PER_REQUEST as usize, - messages: ops, - wait_tx: Some(wait_tx), - })) - .await - .expect("TODO: Handle ingest channel closed, so we don't loose ops"); + // ingest + // .event_tx + // .send(Event::Messages(MessagesEvent { + // device_pub_id: library.sync.device_pub_id.clone(), + // has_more: ops.len() == OPS_PER_REQUEST as usize, + // messages: ops, + // wait_tx: Some(wait_tx), + // })) + // .await + // .expect("TODO: Handle ingest channel closed, so we don't loose ops"); - wait_rx.await.unwrap() - } + // wait_rx.await.unwrap() + // } - debug!("Sync responder done"); + // debug!("Sync responder done"); - stream - .write_all(&tx::MainRequest::Done.to_bytes()) - .await - .unwrap(); - stream.flush().await.unwrap(); + // stream + // .write_all(&tx::MainRequest::Done.to_bytes()) + // .await + // .unwrap(); + // stream.flush().await.unwrap(); Ok(()) } diff --git a/core/src/p2p/sync/proto.rs b/core/src/p2p/sync/proto.rs index e586b631a..e3c8d60f7 100644 --- a/core/src/p2p/sync/proto.rs +++ b/core/src/p2p/sync/proto.rs @@ -19,25 +19,25 @@ impl SyncMessage { } } - pub fn to_bytes(&self) -> Vec { - match self { - Self::NewOperations => vec![b'N'], - } - } + // pub fn to_bytes(&self) -> Vec { + // match self { + // Self::NewOperations => vec![b'N'], + // } + // } } -#[cfg(test)] -mod tests { - use super::*; +// #[cfg(test)] +// mod tests { +// use super::*; - #[tokio::test] - async fn test_types() { - { - let original = SyncMessage::NewOperations; +// #[tokio::test] +// async fn test_types() { +// { +// let original = SyncMessage::NewOperations; - let mut cursor = std::io::Cursor::new(original.to_bytes()); - let result = SyncMessage::from_stream(&mut cursor).await.unwrap(); - assert_eq!(original, result); - } - } -} +// let mut cursor = std::io::Cursor::new(original.to_bytes()); +// let result = SyncMessage::from_stream(&mut cursor).await.unwrap(); +// assert_eq!(original, result); +// } +// } +// } diff --git a/core/src/util/debug_initializer.rs b/core/src/util/debug_initializer.rs index 8221aa77e..562ca7b07 100644 --- a/core/src/util/debug_initializer.rs +++ b/core/src/util/debug_initializer.rs @@ -130,7 +130,7 @@ impl InitConfig { lib } else { let library = library_manager - .create_with_uuid(lib.id, lib.name, lib.description, true, None, node, false) + .create_with_uuid(lib.id, lib.name, lib.description, true, None, node) .await?; let Some(lib) = library_manager.get_library(&library.id).await else { diff --git a/core/src/util/mpscrr.rs b/core/src/util/mpscrr.rs index 4c7826bea..72daf3441 100644 --- a/core/src/util/mpscrr.rs +++ b/core/src/util/mpscrr.rs @@ -230,7 +230,7 @@ impl<'a> Bomb<'a> { } } -impl<'a> Drop for Bomb<'a> { +impl Drop for Bomb<'_> { fn drop(&mut self) { self.0.store(false, Ordering::Relaxed); } diff --git a/core/src/volume/mod.rs b/core/src/volume/mod.rs index ada4d4ae3..9519d639a 100644 --- a/core/src/volume/mod.rs +++ b/core/src/volume/mod.rs @@ -2,13 +2,13 @@ use crate::{library::Library, Node}; -use sd_core_sync::Manager as SyncManager; +use sd_core_sync::SyncManager; use sd_prisma::{ - prisma::{storage_statistics, PrismaClient}, + prisma::{device, storage_statistics, PrismaClient}, prisma_sync, }; -use sd_sync::OperationFactory; -use sd_utils::{msgpack, uuid_to_bytes}; +use sd_sync::{sync_db_not_null_entry, sync_entry, OperationFactory}; +use sd_utils::uuid_to_bytes; use std::{ fmt::Display, @@ -515,99 +515,85 @@ fn compute_stats<'v>(volumes: impl IntoIterator) -> (u64, u64 async fn update_storage_statistics( db: &PrismaClient, sync: &SyncManager, - instance_pub_id: &Uuid, total_capacity: u64, available_capacity: u64, ) -> Result<(), VolumeError> { - let instance_pub_id = uuid_to_bytes(instance_pub_id); + let device_pub_id = sync.device_pub_id.to_db(); let storage_statistics_pub_id = db .storage_statistics() - .find_unique(storage_statistics::instance_pub_id::equals( - instance_pub_id.clone(), - )) + .find_first(vec![storage_statistics::device::is(vec![ + device::pub_id::equals(device_pub_id.clone()), + ])]) .select(storage_statistics::select!({ pub_id })) .exec() .await? .map(|s| s.pub_id); if let Some(storage_statistics_pub_id) = storage_statistics_pub_id { - sync.write_ops( - db, - ( - [ - ( - storage_statistics::total_capacity::NAME, - msgpack!(total_capacity), - ), - ( - storage_statistics::available_capacity::NAME, - msgpack!(available_capacity), - ), - ] - .into_iter() - .map(|(field, value)| { - sync.shared_update( - prisma_sync::storage_statistics::SyncId { - pub_id: storage_statistics_pub_id.clone(), - }, - field, - value, - ) - }) - .collect(), - db.storage_statistics() - .update( - storage_statistics::pub_id::equals(storage_statistics_pub_id), - vec![ - storage_statistics::total_capacity::set(total_capacity as i64), - storage_statistics::available_capacity::set(available_capacity as i64), - ], - ) - // We don't need any data here, just the id avoids receiving the entire object - // as we can't pass an empty select macro call - .select(storage_statistics::select!({ id })), + let (sync_params, db_params) = [ + sync_db_not_null_entry!(total_capacity as i64, storage_statistics::total_capacity), + sync_db_not_null_entry!( + available_capacity as i64, + storage_statistics::available_capacity ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + sync.write_op( + db, + sync.shared_update( + prisma_sync::storage_statistics::SyncId { + pub_id: storage_statistics_pub_id.clone(), + }, + sync_params, + ), + db.storage_statistics() + .update( + storage_statistics::pub_id::equals(storage_statistics_pub_id), + db_params, + ) + // We don't need any data here, just the id avoids receiving the entire object + // as we can't pass an empty select macro call + .select(storage_statistics::select!({ id })), ) .await?; } else { - let new_storage_statistics_id = uuid_to_bytes(&Uuid::new_v4()); + let new_storage_statistics_id = uuid_to_bytes(&Uuid::now_v7()); - sync.write_ops( - db, - ( - sync.shared_create( - prisma_sync::storage_statistics::SyncId { - pub_id: new_storage_statistics_id.clone(), - }, - [ - ( - storage_statistics::total_capacity::NAME, - msgpack!(total_capacity), - ), - ( - storage_statistics::available_capacity::NAME, - msgpack!(available_capacity), - ), - ( - storage_statistics::instance_pub_id::NAME, - msgpack!(instance_pub_id), - ), - ], - ), - db.storage_statistics() - .create( - new_storage_statistics_id, - vec![ - storage_statistics::total_capacity::set(total_capacity as i64), - storage_statistics::available_capacity::set(available_capacity as i64), - storage_statistics::instance_pub_id::set(Some(instance_pub_id.clone())), - ], - ) - // We don't need any data here, just the id avoids receiving the entire object - // as we can't pass an empty select macro call - .select(storage_statistics::select!({ id })), + let (sync_params, db_params) = [ + sync_db_not_null_entry!(total_capacity as i64, storage_statistics::total_capacity), + sync_db_not_null_entry!( + available_capacity as i64, + storage_statistics::available_capacity ), + ( + sync_entry!( + prisma_sync::device::SyncId { + pub_id: device_pub_id.clone() + }, + storage_statistics::device + ), + storage_statistics::device::connect(device::pub_id::equals(device_pub_id)), + ), + ] + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + sync.write_op( + db, + sync.shared_create( + prisma_sync::storage_statistics::SyncId { + pub_id: new_storage_statistics_id.clone(), + }, + sync_params, + ), + db.storage_statistics() + .create(new_storage_statistics_id, db_params) + // We don't need any data here, just the id avoids receiving the entire object + // as we can't pass an empty select macro call + .select(storage_statistics::select!({ id })), ) .await?; } @@ -626,21 +612,9 @@ pub fn save_storage_statistics(node: &Node) { .await .into_iter() .map(move |library: Arc| async move { - let Library { - db, - sync, - instance_uuid, - .. - } = &*library; + let Library { db, sync, .. } = &*library; - update_storage_statistics( - db, - sync, - instance_uuid, - total_capacity, - available_capacity, - ) - .await + update_storage_statistics(db, sync, total_capacity, available_capacity).await }) .collect::>() .join() diff --git a/core/src/volume/watcher.rs b/core/src/volume/watcher.rs index 4c71bade3..efc0556a1 100644 --- a/core/src/volume/watcher.rs +++ b/core/src/volume/watcher.rs @@ -29,7 +29,6 @@ pub fn spawn_volume_watcher(library: Arc) { if let Err(e) = super::update_storage_statistics( &library.db, &library.sync, - &library.instance_uuid, total_capacity, available_capacity, ) diff --git a/crates/actors/Cargo.toml b/crates/actors/Cargo.toml index 65d33cff7..172ac53f3 100644 --- a/crates/actors/Cargo.toml +++ b/crates/actors/Cargo.toml @@ -8,6 +8,7 @@ repository.workspace = true [dependencies] async-channel = { workspace = true } +async-trait = { workspace = true } futures = { workspace = true } pin-project-lite = { workspace = true } tokio = { workspace = true } diff --git a/crates/actors/src/lib.rs b/crates/actors/src/lib.rs index 9c9c263fb..0604bed52 100644 --- a/crates/actors/src/lib.rs +++ b/crates/actors/src/lib.rs @@ -29,7 +29,10 @@ use std::{ collections::HashMap, + fmt, future::{Future, IntoFuture}, + hash::Hash, + marker::PhantomData, panic::{panic_any, AssertUnwindSafe}, pin::Pin, sync::{ @@ -44,7 +47,7 @@ use async_channel as chan; use futures::FutureExt; use tokio::{ spawn, - sync::{broadcast, RwLock}, + sync::{broadcast, Mutex, RwLock}, task::JoinHandle, time::timeout, }; @@ -52,52 +55,122 @@ use tracing::{error, instrument, warn}; const ONE_MINUTE: Duration = Duration::from_secs(60); -type ActorFn = dyn Fn(Stopper) -> Pin + Send>> + Send + Sync; +pub trait ActorId: Hash + Eq + Send + Sync + Copy + fmt::Debug + fmt::Display + 'static {} -pub struct Actor { - spawn_fn: Arc, +impl ActorId for T {} + +pub trait Actor: Send + Sync + 'static { + const IDENTIFIER: Id; + + fn run(&mut self, stop: Stopper) -> impl Future + Send; +} + +mod sealed { + pub trait Sealed {} +} + +#[async_trait::async_trait] +pub trait DynActor: Send + Sync + sealed::Sealed + 'static { + async fn run(&mut self, stop: Stopper); +} + +pub trait IntoActor: Send + Sync { + fn into_actor(self) -> (Id, Box>); +} + +struct AnyActor> { + actor: A, + _marker: PhantomData, +} + +impl> sealed::Sealed for AnyActor {} + +#[async_trait::async_trait] +impl> DynActor for AnyActor { + async fn run(&mut self, stop: Stopper) { + self.actor.run(stop).await; + } +} + +impl> IntoActor for A { + fn into_actor(self) -> (Id, Box>) { + ( + A::IDENTIFIER, + Box::new(AnyActor { + actor: self, + _marker: PhantomData, + }), + ) + } +} + +struct ActorHandler { + actor: Arc>>>, maybe_handle: Option>, is_running: Arc, stop_tx: chan::Sender<()>, stop_rx: chan::Receiver<()>, } -pub struct Actors { +/// Actors holder, holds all actors for some generic purpose, like for cloud sync. +/// You should use an enum to identify the actors. +pub struct ActorsCollection { pub invalidate_rx: broadcast::Receiver<()>, invalidate_tx: broadcast::Sender<()>, - actors: Arc>>, + actors_map: Arc>>>, } -impl Actors { - pub async fn declare( - self: &Arc, - name: &'static str, - actor_fn: impl FnOnce(Stopper) -> Fut + Send + Sync + Clone + 'static, - autostart: bool, - ) where - Fut: Future + Send + 'static, - { - let (stop_tx, stop_rx) = chan::bounded(1); +impl ActorsCollection { + pub async fn declare(&self, actor: impl IntoActor) { + async fn inner( + this: &ActorsCollection, + identifier: Id, + actor: Box>, + ) { + let (stop_tx, stop_rx) = chan::bounded(1); - self.actors.write().await.insert( - name, - Actor { - spawn_fn: Arc::new(move |stop| Box::pin((actor_fn.clone())(stop))), - maybe_handle: None, - is_running: Arc::new(AtomicBool::new(false)), - stop_tx, - stop_rx, - }, - ); + this.actors_map.write().await.insert( + identifier, + ActorHandler { + actor: Arc::new(Mutex::new(actor)), + maybe_handle: None, + is_running: Arc::new(AtomicBool::new(false)), + stop_tx, + stop_rx, + }, + ); + } - if autostart { - self.start(name).await; + let (identifier, actor) = actor.into_actor(); + inner(self, identifier, actor).await; + } + + pub async fn declare_many_boxed( + &self, + actors: impl IntoIterator>)> + Send, + ) { + let mut actor_map = self.actors_map.write().await; + + for (id, actor) in actors { + let (stop_tx, stop_rx) = chan::bounded(1); + + actor_map.insert( + id, + ActorHandler { + actor: Arc::new(Mutex::new(actor)), + maybe_handle: None, + is_running: Arc::new(AtomicBool::new(false)), + stop_tx, + stop_rx, + }, + ); } } #[instrument(skip(self))] - pub async fn start(self: &Arc, name: &str) { - if let Some(actor) = self.actors.write().await.get_mut(name) { + pub async fn start(&self, identifier: Id) { + let mut actors_map = self.actors_map.write().await; + if let Some(actor) = actors_map.get_mut(&identifier) { if actor.is_running.load(Ordering::Acquire) { warn!("Actor already running!"); return; @@ -122,15 +195,19 @@ impl Actors { } actor.maybe_handle = Some(spawn({ - let spawn_fn = Arc::clone(&actor.spawn_fn); - let stop_actor = Stopper(actor.stop_rx.clone()); + let actor = Arc::clone(&actor.actor); async move { - if (AssertUnwindSafe((spawn_fn)(stop_actor))) - .catch_unwind() - .await - .is_err() + if (AssertUnwindSafe( + actor + .try_lock() + .expect("actors can only have a single run at a time") + .run(stop_actor), + )) + .catch_unwind() + .await + .is_err() { error!("Actor unexpectedly panicked"); } @@ -146,8 +223,9 @@ impl Actors { } #[instrument(skip(self))] - pub async fn stop(self: &Arc, name: &str) { - if let Some(actor) = self.actors.write().await.get_mut(name) { + pub async fn stop(&self, identifier: Id) { + let mut actors_map = self.actors_map.write().await; + if let Some(actor) = actors_map.get_mut(&identifier) { if !actor.is_running.load(Ordering::Acquire) { warn!("Actor already stopped!"); return; @@ -167,28 +245,43 @@ impl Actors { } } - pub async fn get_state(&self) -> HashMap { - self.actors + pub async fn get_state(&self) -> Vec<(String, bool)> { + self.actors_map .read() .await .iter() - .map(|(&name, actor)| (name.to_string(), actor.is_running.load(Ordering::Relaxed))) + .map(|(identifier, actor)| { + ( + identifier.to_string(), + actor.is_running.load(Ordering::Relaxed), + ) + }) .collect() } } -impl Default for Actors { +impl Default for ActorsCollection { fn default() -> Self { let (invalidate_tx, invalidate_rx) = broadcast::channel(1); Self { - actors: Arc::default(), + actors_map: Arc::default(), invalidate_rx, invalidate_tx, } } } +impl Clone for ActorsCollection { + fn clone(&self) -> Self { + Self { + actors_map: Arc::clone(&self.actors_map), + invalidate_rx: self.invalidate_rx.resubscribe(), + invalidate_tx: self.invalidate_tx.clone(), + } + } +} + pub struct Stopper(chan::Receiver<()>); impl Stopper { diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index b521ee93d..35f9dc198 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -40,6 +40,7 @@ uuid = { workspace = true, features = ["serde", "v4"] } # Note: half and ndarray version must be the same as used in ort half = { version = "2.4", features = ['num-traits'] } ndarray = "0.15" +ort-sys = '=2.0.0-rc.0' # lock sys crate to the same version as ort url = '2.5' # Microsoft does not provide a release for osx-gpu. See: https://github.com/microsoft/onnxruntime/releases diff --git a/crates/ai/src/old_image_labeler/old_actor.rs b/crates/ai/src/old_image_labeler/old_actor.rs index 257f69986..669d00ec6 100644 --- a/crates/ai/src/old_image_labeler/old_actor.rs +++ b/crates/ai/src/old_image_labeler/old_actor.rs @@ -1,6 +1,7 @@ use sd_core_prisma_helpers::file_path_for_media_processor; +use sd_core_sync::SyncManager; -use sd_prisma::prisma::{location, PrismaClient}; +use sd_prisma::prisma::{device, location, PrismaClient}; use sd_utils::error::FileIOError; use std::{ @@ -38,7 +39,7 @@ const PENDING_BATCHES_FILE: &str = "pending_image_labeler_batches.bin"; type ResumeBatchRequest = ( BatchToken, Arc, - Arc, + SyncManager, oneshot::Sender, ImageLabelerError>>, ); @@ -51,17 +52,19 @@ pub(super) struct Batch { pub(super) token: BatchToken, pub(super) location_id: location::id::Type, pub(super) location_path: PathBuf, + pub(super) device_id: device::id::Type, pub(super) file_paths: Vec, pub(super) output_tx: chan::Sender, pub(super) is_resumable: bool, pub(super) db: Arc, - pub(super) sync: Arc, + pub(super) sync: SyncManager, } #[derive(Serialize, Deserialize, Debug)] struct ResumableBatch { location_id: location::id::Type, location_path: PathBuf, + device_id: device::id::Type, file_paths: Vec, } @@ -162,13 +165,15 @@ impl OldImageLabeler { }) } + #[allow(clippy::too_many_arguments)] async fn new_batch_inner( &self, location_id: location::id::Type, location_path: PathBuf, + device_id: device::id::Type, file_paths: Vec, db: Arc, - sync: Arc, + sync: SyncManager, is_resumable: bool, ) -> (BatchToken, chan::Receiver) { let (tx, rx) = chan::bounded(usize::max(file_paths.len(), 1)); @@ -180,6 +185,7 @@ impl OldImageLabeler { token, location_id, location_path, + device_id, file_paths, output_tx: tx, is_resumable, @@ -203,14 +209,23 @@ impl OldImageLabeler { pub async fn new_batch( &self, location_id: location::id::Type, + device_id: device::id::Type, location_path: PathBuf, file_paths: Vec, db: Arc, - sync: Arc, + sync: SyncManager, ) -> chan::Receiver { - self.new_batch_inner(location_id, location_path, file_paths, db, sync, false) - .await - .1 + self.new_batch_inner( + location_id, + location_path, + device_id, + file_paths, + db, + sync, + false, + ) + .await + .1 } /// Resumable batches have lower priority than normal batches @@ -218,12 +233,21 @@ impl OldImageLabeler { &self, location_id: location::id::Type, location_path: PathBuf, + device_id: device::id::Type, file_paths: Vec, db: Arc, - sync: Arc, + sync: SyncManager, ) -> (BatchToken, chan::Receiver) { - self.new_batch_inner(location_id, location_path, file_paths, db, sync, true) - .await + self.new_batch_inner( + location_id, + location_path, + device_id, + file_paths, + db, + sync, + true, + ) + .await } pub async fn change_model(&self, model: Box) -> Result<(), ImageLabelerError> { @@ -291,7 +315,7 @@ impl OldImageLabeler { &self, token: BatchToken, db: Arc, - sync: Arc, + sync: SyncManager, ) -> Result, ImageLabelerError> { let (tx, rx) = oneshot::channel(); @@ -344,7 +368,7 @@ async fn actor_loop( ResumeBatch( BatchToken, Arc, - Arc, + SyncManager, oneshot::Sender, ImageLabelerError>>, ), UpdateModel( @@ -393,6 +417,7 @@ async fn actor_loop( to_resume_batches.write().await.remove(&token).map( |ResumableBatch { location_id, + device_id, location_path, file_paths, }| { @@ -403,6 +428,7 @@ async fn actor_loop( token, db, sync, + device_id, output_tx, location_id, location_path, @@ -529,6 +555,7 @@ async fn actor_loop( token, location_id, location_path, + device_id, file_paths, is_resumable, .. @@ -538,6 +565,7 @@ async fn actor_loop( ResumableBatch { location_id, location_path, + device_id, file_paths, }, )) @@ -554,6 +582,7 @@ async fn actor_loop( token, location_id, location_path, + device_id, file_paths, is_resumable, .. @@ -563,6 +592,7 @@ async fn actor_loop( ResumableBatch { location_id, location_path, + device_id, file_paths, }, )) diff --git a/crates/ai/src/old_image_labeler/process.rs b/crates/ai/src/old_image_labeler/process.rs index 125dbe21c..0232afa79 100644 --- a/crates/ai/src/old_image_labeler/process.rs +++ b/crates/ai/src/old_image_labeler/process.rs @@ -1,8 +1,9 @@ use sd_core_file_path_helper::IsolatedFilePathData; use sd_core_prisma_helpers::file_path_for_media_processor; +use sd_core_sync::SyncManager; use sd_prisma::{ - prisma::{file_path, label, label_on_object, object, PrismaClient}, + prisma::{device, file_path, label, label_on_object, object, PrismaClient}, prisma_sync, }; use sd_sync::OperationFactory; @@ -69,6 +70,7 @@ pub(super) async fn spawned_processing( token, location_id, location_path, + device_id, file_paths, output_tx, db, @@ -202,6 +204,7 @@ pub(super) async fn spawned_processing( let ids = ( file_path.id, file_path.object.as_ref().expect("already checked above").id, + device_id, ); if output_tx.is_closed() { @@ -250,6 +253,7 @@ pub(super) async fn spawned_processing( token, location_id, location_path, + device_id, file_paths: on_flight .into_values() .chain(queue.into_iter().map(|(file_path, _, _)| file_path)) @@ -292,7 +296,7 @@ pub(super) async fn spawned_processing( #[allow(clippy::too_many_arguments)] async fn spawned_process_single_file( model_and_session: Arc>, - (file_path_id, object_id): (file_path::id::Type, object::id::Type), + (file_path_id, object_id, device_id): (file_path::id::Type, object::id::Type, device::id::Type), path: PathBuf, format: ImageFormat, (output_tx, completed_tx): ( @@ -300,7 +304,7 @@ async fn spawned_process_single_file( chan::Sender, ), db: Arc, - sync: Arc, + sync: SyncManager, _permit: OwnedSemaphorePermit, ) { let image = @@ -350,10 +354,11 @@ async fn spawned_process_single_file( } }; - let (has_new_labels, result) = match assign_labels(object_id, labels, &db, &sync).await { - Ok(has_new_labels) => (has_new_labels, Ok(())), - Err(e) => (false, Err(e)), - }; + let (has_new_labels, result) = + match assign_labels(object_id, device_id, labels, &db, &sync).await { + Ok(has_new_labels) => (has_new_labels, Ok(())), + Err(e) => (false, Err(e)), + }; if output_tx .send(LabelerOutput { @@ -396,9 +401,10 @@ async fn extract_file_data( pub async fn assign_labels( object_id: object::id::Type, + device_id: device::id::Type, mut labels: HashSet, db: &PrismaClient, - sync: &sd_core_sync::Manager, + sync: &SyncManager, ) -> Result { let object = db .object() @@ -432,7 +438,7 @@ pub async fn assign_labels( let db_params = labels .into_iter() .map(|name| { - sync_params.extend(sync.shared_create( + sync_params.push(sync.shared_create( prisma_sync::label::SyncId { name: name.clone() }, [(label::date_created::NAME, msgpack!(&date_created))], )); @@ -455,37 +461,47 @@ pub async fn assign_labels( let mut sync_params = Vec::with_capacity(labels_ids.len() * 2); - let db_params: Vec<_> = labels_ids - .into_iter() - .map(|(label_id, name)| { - sync_params.extend(sync.relation_create( - prisma_sync::label_on_object::SyncId { - label: prisma_sync::label::SyncId { name }, - object: prisma_sync::object::SyncId { - pub_id: object.pub_id.clone(), + if !labels_ids.is_empty() { + let db_params: Vec<_> = labels_ids + .into_iter() + .map(|(label_id, name)| { + sync_params.push(sync.relation_create( + prisma_sync::label_on_object::SyncId { + label: prisma_sync::label::SyncId { name }, + object: prisma_sync::object::SyncId { + pub_id: object.pub_id.clone(), + }, }, - }, - [], - )); + [( + label_on_object::device::NAME, + msgpack!(prisma_sync::device::SyncId { + pub_id: sync.device_pub_id.to_db(), + }), + )], + )); - label_on_object::create_unchecked( - label_id, - object_id, - vec![label_on_object::date_created::set(date_created)], - ) - }) - .collect(); + label_on_object::create_unchecked( + label_id, + object_id, + vec![ + label_on_object::date_created::set(date_created), + label_on_object::device_id::set(Some(device_id)), + ], + ) + }) + .collect(); - sync.write_ops( - db, - ( - sync_params, - db.label_on_object() - .create_many(db_params) - .skip_duplicates(), - ), - ) - .await?; + sync.write_ops( + db, + ( + sync_params, + db.label_on_object() + .create_many(db_params) + .skip_duplicates(), + ), + ) + .await?; + } Ok(has_new_labels) } diff --git a/crates/cloud-api/Cargo.toml b/crates/cloud-api/Cargo.toml deleted file mode 100644 index 491b2fe7c..000000000 --- a/crates/cloud-api/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "sd-cloud-api" -version = "0.1.0" - -edition.workspace = true -license.workspace = true -repository.workspace = true - -[dependencies] -# Spacedrive Sub-crates -sd-p2p = { path = "../p2p" } - -# Workspace dependencies -reqwest = { workspace = true, features = ["native-tls-vendored"] } -rspc = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -specta = { workspace = true } -thiserror = { workspace = true } -tracing = { workspace = true } -uuid = { workspace = true } diff --git a/crates/cloud-api/src/auth.rs b/crates/cloud-api/src/auth.rs deleted file mode 100644 index f8d879641..000000000 --- a/crates/cloud-api/src/auth.rs +++ /dev/null @@ -1,17 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct OAuthToken { - pub access_token: String, - pub refresh_token: String, - pub token_type: String, - pub expires_in: i32, -} - -impl OAuthToken { - pub fn to_header(&self) -> String { - format!("{} {}", self.token_type, self.access_token) - } -} - -pub const DEVICE_CODE_URN: &str = "urn:ietf:params:oauth:grant-type:device_code"; diff --git a/crates/cloud-api/src/lib.rs b/crates/cloud-api/src/lib.rs deleted file mode 100644 index b505dfa23..000000000 --- a/crates/cloud-api/src/lib.rs +++ /dev/null @@ -1,635 +0,0 @@ -pub mod auth; - -use std::{collections::HashMap, future::Future, sync::Arc}; - -use auth::OAuthToken; -use sd_p2p::RemoteIdentity; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use specta::Type; -use uuid::Uuid; - -pub struct RequestConfig { - pub client: reqwest::Client, - pub api_url: String, - pub auth_token: Option, -} - -pub trait RequestConfigProvider { - fn get_request_config(self: &Arc) -> impl Future + Send; -} - -#[derive(thiserror::Error, Debug)] -#[error("{0}")] -pub struct Error(String); - -impl From for rspc::Error { - fn from(e: Error) -> rspc::Error { - rspc::Error::new(rspc::ErrorCode::InternalServerError, e.0) - } -} - -#[derive(Serialize, Deserialize, Debug, Type)] -#[serde(rename_all = "camelCase")] -#[specta(rename = "CloudLibrary")] -pub struct Library { - pub id: String, - pub uuid: Uuid, - pub name: String, - pub instances: Vec, - pub owner_id: String, -} - -#[derive(Serialize, Deserialize, Debug, Type)] -#[serde(rename_all = "camelCase")] -#[specta(rename = "CloudInstance")] -pub struct Instance { - pub id: String, - pub uuid: Uuid, - pub identity: RemoteIdentity, - #[serde(rename = "nodeId")] - pub node_id: Uuid, - pub node_remote_identity: String, - pub metadata: HashMap, -} - -#[derive(Serialize, Deserialize, Debug, Type)] -#[serde(rename_all = "camelCase")] -#[specta(rename = "CloudMessageCollection")] -pub struct MessageCollection { - pub instance_uuid: Uuid, - pub start_time: String, - pub end_time: String, - pub contents: String, -} - -trait WithAuth { - fn with_auth(self, token: OAuthToken) -> Self; -} - -impl WithAuth for reqwest::RequestBuilder { - fn with_auth(self, token: OAuthToken) -> Self { - self.header( - "authorization", - format!("{} {}", token.token_type, token.access_token), - ) - } -} - -pub mod feedback { - use super::*; - - pub use send::exec as send; - pub mod send { - use super::*; - - pub async fn exec(config: RequestConfig, message: String, emoji: u8) -> Result<(), Error> { - let mut req = config - .client - .post(format!("{}/api/v1/feedback", config.api_url)) - .json(&json!({ - "message": message, - "emoji": emoji, - })); - - if let Some(auth_token) = config.auth_token { - req = req.with_auth(auth_token); - } - - req.send() - .await - .and_then(|r| r.error_for_status()) - .map_err(|e| Error(e.to_string()))?; - - Ok(()) - } - } -} - -pub mod user { - use super::*; - - pub use me::exec as me; - pub mod me { - use super::*; - - #[derive(Serialize, Deserialize, Type)] - #[specta(inline)] - pub struct Response { - id: String, - email: String, - } - - pub async fn exec(config: RequestConfig) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .get(format!("{}/api/v1/user/me", config.api_url)) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - } -} - -pub mod library { - use super::*; - - pub use get::exec as get; - pub mod get { - use super::*; - - pub async fn exec(config: RequestConfig, library_id: Uuid) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .get(format!( - "{}/api/v1/libraries/{}", - config.api_url, library_id - )) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - pub type Response = Option; - } - - pub use list::exec as list; - pub mod list { - use super::*; - - pub async fn exec(config: RequestConfig) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .get(format!("{}/api/v1/libraries", config.api_url)) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - pub type Response = Vec; - } - - pub use create::exec as create; - pub mod create { - use super::*; - - #[derive(Debug, Deserialize)] - pub struct CreateResult { - pub id: String, - } - - #[allow(clippy::too_many_arguments)] - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - name: &str, - instance_uuid: Uuid, - instance_identity: RemoteIdentity, - node_id: Uuid, - node_remote_identity: RemoteIdentity, - metadata: &HashMap, - ) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!( - "{}/api/v1/libraries/{}", - config.api_url, library_id - )) - .json(&json!({ - "name":name, - "instanceUuid": instance_uuid, - "instanceIdentity": instance_identity, - "nodeId": node_id, - "nodeRemoteIdentity": node_remote_identity, - "metadata": metadata, - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - } - - pub use update::exec as update; - pub mod update { - use super::*; - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - name: Option, - ) -> Result<(), Error> { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .patch(format!( - "{}/api/v1/libraries/{}", - config.api_url, library_id - )) - .json(&json!({ - "name":name - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string())) - .map(|_| ()) - } - } - - pub use update_instance::exec as update_instance; - pub mod update_instance { - use super::*; - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - instance_id: Uuid, - node_id: Option, - node_remote_identity: Option, - metadata: Option>, - ) -> Result<(), Error> { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .patch(format!( - "{}/api/v1/libraries/{}/{}", - config.api_url, library_id, instance_id - )) - .json(&json!({ - "nodeId": node_id, - "nodeRemoteIdentity": node_remote_identity, - "metadata": metadata, - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string())) - .map(|_| ()) - } - } - - pub use join::exec as join; - pub mod join { - use super::*; - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - instance_uuid: Uuid, - instance_identity: RemoteIdentity, - node_id: Uuid, - node_remote_identity: RemoteIdentity, - metadata: HashMap, - ) -> Result, Error> { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!( - "{}/api/v1/libraries/{library_id}/instances/{instance_uuid}", - config.api_url - )) - .json(&json!({ - "instanceIdentity": instance_identity, - "nodeId": node_id, - "nodeRemoteIdentity": node_remote_identity, - "metadata": metadata, - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - } - - pub mod message_collections { - use super::*; - - pub use get::exec as get; - pub mod get { - use super::*; - use tracing::debug; - - #[derive(Serialize)] - #[serde(rename_all = "camelCase")] - pub struct InstanceTimestamp { - pub instance_uuid: Uuid, - pub from_time: String, - } - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - this_instance_uuid: Uuid, - timestamps: Vec, - ) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - let res = config - .client - .post(format!( - "{}/api/v1/libraries/{}/messageCollections/get", - config.api_url, library_id - )) - .json(&json!({ - "instanceUuid": this_instance_uuid, - "timestamps": timestamps - })) - .with_auth(auth_token) - .send() - .await; - - debug!("get message collections response: {:?}", res); - - match res { - Ok(response) => { - let status = response.status(); - let body = response.text().await.map_err(|e| Error(e.to_string()))?; - debug!("Response status: {}", status); - debug!("Response body: {}", body); - - // Attempt to parse the body as JSON - match serde_json::from_str::(&body) { - Ok(json) => Ok(json), - Err(e) => Err(Error(format!( - "error decoding response body: {}. Body: {}", - e, body - ))), - } - } - Err(e) => Err(Error(e.to_string())), - } - } - - pub type Response = Vec; - } - - pub use request_add::exec as request_add; - pub mod request_add { - use super::*; - use tracing::debug; - - #[derive(Deserialize, Debug)] - #[serde(rename_all = "camelCase")] - pub struct RequestAdd { - pub instance_uuid: Uuid, - pub from_time: Option, - // mutex key on the instance - pub key: String, - } - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - instances: Vec, - ) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - let instances = instances - .into_iter() - .map(|i| json!({"instanceUuid": i })) - .collect::>(); - - let res = config - .client - .post(format!( - "{}/api/v1/libraries/{}/messageCollections/requestAdd", - config.api_url, library_id - )) - .json(&json!({ "instances": instances })) - .with_auth(auth_token) - .send() - .await; - - debug!("request add response: {:?}", res); - - match res { - Ok(response) => { - let status = response.status(); - let body = response.text().await.map_err(|e| Error(e.to_string()))?; - debug!("Response status: {}", status); - debug!("Response body: {}", body); - - // Attempt to parse the body as JSON - match serde_json::from_str::(&body) { - Ok(json) => Ok(json), - Err(e) => Err(Error(format!( - "error decoding response body: {}. Body: {}", - e, body - ))), - } - } - Err(e) => Err(Error(e.to_string())), - } - } - - pub type Response = Vec; - } - - pub use do_add::exec as do_add; - pub mod do_add { - use super::*; - - #[derive(Serialize, Debug)] - #[serde(rename_all = "camelCase")] - pub struct Input { - pub uuid: Uuid, - pub key: String, - pub start_time: String, - pub end_time: String, - pub contents: String, - pub ops_count: usize, - } - - pub async fn exec( - config: RequestConfig, - library_id: Uuid, - instances: Vec, - ) -> Result<(), Error> { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!( - "{}/api/v1/libraries/{}/messageCollections/doAdd", - config.api_url, library_id - )) - .json(&json!({ "instances": instances })) - .with_auth(auth_token) - .send() - .await - .and_then(|r| r.error_for_status()) - .map_err(|e| Error(e.to_string()))?; - - Ok(()) - } - } - } -} - -#[derive(Type, Serialize, Deserialize)] -pub struct CloudLocation { - id: String, - name: String, -} - -pub mod locations { - use super::*; - - pub use list::exec as list; - pub mod list { - use super::*; - - pub async fn exec(config: RequestConfig) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .get(format!("{}/api/v1/locations", config.api_url)) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - pub type Response = Vec; - } - - pub use create::exec as create; - pub mod create { - use super::*; - - pub async fn exec(config: RequestConfig, name: String) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!("{}/api/v1/locations", config.api_url)) - .json(&json!({ - "name": name, - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - pub type Response = CloudLocation; - } - - pub use remove::exec as remove; - pub mod remove { - use super::*; - - pub async fn exec(config: RequestConfig, id: String) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!("{}/api/v1/locations/delete", config.api_url)) - .json(&json!({ - "id": id, - })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - pub type Response = CloudLocation; - } - - pub use authorize::exec as authorize; - pub mod authorize { - use super::*; - - pub async fn exec(config: RequestConfig, id: String) -> Result { - let Some(auth_token) = config.auth_token else { - return Err(Error("Authentication required".to_string())); - }; - - config - .client - .post(format!("{}/api/v1/locations/authorize", config.api_url)) - .json(&json!({ "id": id })) - .with_auth(auth_token) - .send() - .await - .map_err(|e| Error(e.to_string()))? - .json() - .await - .map_err(|e| Error(e.to_string())) - } - - #[derive(Debug, Clone, Type, Deserialize)] - pub struct Response { - pub access_key_id: String, - pub secret_access_key: String, - pub session_token: String, - } - } -} diff --git a/crates/crypto/Cargo.toml b/crates/crypto/Cargo.toml index 3305a1222..376769023 100644 --- a/crates/crypto/Cargo.toml +++ b/crates/crypto/Cargo.toml @@ -24,6 +24,7 @@ rand = { workspace = true } serde = { workspace = true, features = ["derive"] } thiserror = { workspace = true } tokio = { workspace = true, features = ["io-util", "macros", "rt-multi-thread", "sync"] } +zeroize = { workspace = true, features = ["derive"] } # External dependencies aead = { version = "0.6.0-rc.0", default-features = false, features = ["stream"] } @@ -35,7 +36,8 @@ rand_chacha = "0.9.0-alpha.2" rand_core = "0.9.0-alpha.2" serdect = "0.3.0-pre.0" typenum = "1.17" -zeroize = { version = "1.7", features = ["aarch64", "derive"] } + +old-rand-core = { package = "rand_core", version = "0.6.4" } [dev-dependencies] paste = "1.0" diff --git a/crates/crypto/src/cloud/decrypt.rs b/crates/crypto/src/cloud/decrypt.rs index 1ba41f35b..94913f64b 100644 --- a/crates/crypto/src/cloud/decrypt.rs +++ b/crates/crypto/src/cloud/decrypt.rs @@ -1,5 +1,5 @@ use crate::{ - primitives::{EncryptedBlock, StreamNonce}, + primitives::{EncryptedBlock, EncryptedBlockRef, StreamNonce}, Error, }; @@ -12,7 +12,8 @@ use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader use super::secret_key::SecretKey; pub trait OneShotDecryption { - fn decrypt(&self, cipher_text: &EncryptedBlock) -> Result, Error>; + fn decrypt(&self, cipher_text: EncryptedBlockRef<'_>) -> Result, Error>; + fn decrypt_owned(&self, cipher_text: &EncryptedBlock) -> Result, Error>; } pub trait StreamDecryption { @@ -26,6 +27,15 @@ pub trait StreamDecryption { impl OneShotDecryption for SecretKey { fn decrypt( + &self, + EncryptedBlockRef { nonce, cipher_text }: EncryptedBlockRef<'_>, + ) -> Result, Error> { + XChaCha20Poly1305::new(&self.0) + .decrypt(nonce, cipher_text) + .map_err(|aead::Error| Error::Decrypt) + } + + fn decrypt_owned( &self, EncryptedBlock { nonce, cipher_text }: &EncryptedBlock, ) -> Result, Error> { diff --git a/crates/crypto/src/cloud/encrypt.rs b/crates/crypto/src/cloud/encrypt.rs index efdcd670a..096c8c928 100644 --- a/crates/crypto/src/cloud/encrypt.rs +++ b/crates/crypto/src/cloud/encrypt.rs @@ -1,11 +1,11 @@ use crate::{ - primitives::{EncryptedBlock, StreamNonce}, + primitives::{EncryptedBlock, OneShotNonce, StreamNonce}, Error, }; use aead::{stream::EncryptorLE31, Aead, KeyInit}; use async_stream::stream; -use chacha20poly1305::{XChaCha20Poly1305, XNonce}; +use chacha20poly1305::{Tag, XChaCha20Poly1305, XNonce}; use futures::Stream; use rand::CryptoRng; use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader}; @@ -14,6 +14,10 @@ use super::secret_key::SecretKey; pub trait OneShotEncryption { fn encrypt(&self, plaintext: &[u8], rng: &mut impl CryptoRng) -> Result; + + fn cipher_text_size(&self, plain_text_size: usize) -> usize { + size_of::() + plain_text_size + size_of::() + } } pub trait StreamEncryption { @@ -25,6 +29,13 @@ pub trait StreamEncryption { StreamNonce, impl Stream, Error>> + Send, ); + + fn cipher_text_size(&self, plain_text_size: usize) -> usize { + size_of::() + + (plain_text_size / EncryptedBlock::PLAIN_TEXT_SIZE * EncryptedBlock::CIPHER_TEXT_SIZE) + + plain_text_size % EncryptedBlock::PLAIN_TEXT_SIZE + + size_of::() + } } impl OneShotEncryption for SecretKey { diff --git a/crates/crypto/src/cloud/mod.rs b/crates/crypto/src/cloud/mod.rs index 4a09a47d8..6fc76f574 100644 --- a/crates/crypto/src/cloud/mod.rs +++ b/crates/crypto/src/cloud/mod.rs @@ -1,3 +1,7 @@ pub mod decrypt; pub mod encrypt; pub mod secret_key; + +pub use decrypt::{OneShotDecryption, StreamDecryption}; +pub use encrypt::{OneShotEncryption, StreamEncryption}; +pub use secret_key::SecretKey; diff --git a/crates/crypto/src/cloud/secret_key.rs b/crates/crypto/src/cloud/secret_key.rs index 9fbf78b21..2477684ad 100644 --- a/crates/crypto/src/cloud/secret_key.rs +++ b/crates/crypto/src/cloud/secret_key.rs @@ -1,6 +1,7 @@ use crate::{ ct::{Choice, ConstantTimeEq, ConstantTimeEqNull}, rng::CryptoRng, + Error, }; use std::fmt; @@ -9,7 +10,7 @@ use aead::array::Array; use blake3::{Hash, Hasher}; use generic_array::GenericArray; use serde::{Deserialize, Serialize}; -use typenum::consts::U32; +use typenum::{consts::U32, U64}; use zeroize::{Zeroize, ZeroizeOnDrop}; /// This should be used for encrypting and decrypting data. @@ -30,8 +31,8 @@ impl fmt::Debug for SecretKey { impl SecretKey { #[inline] #[must_use] - pub fn new(v: impl Into>) -> Self { - Self(v.into()) + pub const fn new(v: Array) -> Self { + Self(v) } #[inline] @@ -75,6 +76,12 @@ impl Serialize for SecretKey { } } +impl AsRef<[u8]> for SecretKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + impl<'de> Deserialize<'de> for SecretKey { fn deserialize(deserializer: D) -> Result where @@ -82,7 +89,7 @@ impl<'de> Deserialize<'de> for SecretKey { { let mut buf = [0u8; 32]; serdect::array::deserialize_hex_or_bin(&mut buf, deserializer)?; - Ok(Self::new(buf)) + Ok(Self::new(buf.into())) } } @@ -92,6 +99,35 @@ impl From<&SecretKey> for Array { } } +impl From<&SecretKey> for Vec { + fn from(SecretKey(key): &SecretKey) -> Self { + key.to_vec() + } +} + +impl From for Vec { + fn from(SecretKey(key): SecretKey) -> Self { + key.to_vec() + } +} + +impl TryFrom<&[u8]> for SecretKey { + type Error = Error; + + fn try_from(key: &[u8]) -> Result { + if key.len() != 32 { + return Err(Error::InvalidKeySize(key.len())); + } + + Ok(Self(Array([ + key[0], key[1], key[2], key[3], key[4], key[5], key[6], key[7], key[8], key[9], + key[10], key[11], key[12], key[13], key[14], key[15], key[16], key[17], key[18], + key[19], key[20], key[21], key[22], key[23], key[24], key[25], key[26], key[27], + key[28], key[29], key[30], key[31], + ]))) + } +} + impl From> for SecretKey { fn from(key: GenericArray) -> Self { Self(Array([ @@ -103,6 +139,18 @@ impl From> for SecretKey { } } +/// We take only the first 32 bytes of the key, since the rest doesn't fit +impl From> for SecretKey { + fn from(key: GenericArray) -> Self { + Self(Array([ + key[0], key[1], key[2], key[3], key[4], key[5], key[6], key[7], key[8], key[9], + key[10], key[11], key[12], key[13], key[14], key[15], key[16], key[17], key[18], + key[19], key[20], key[21], key[22], key[23], key[24], key[25], key[26], key[27], + key[28], key[29], key[30], key[31], + ])) + } +} + #[cfg(test)] mod tests { use std::pin::pin; @@ -125,7 +173,33 @@ mod tests { let key = SecretKey::generate(&mut rng); let encrypted_block = key.encrypt(message, &mut rng).unwrap(); - let decrypted_message = key.decrypt(&encrypted_block).unwrap(); + let decrypted_message = key.decrypt_owned(&encrypted_block).unwrap(); + + assert_eq!(message, decrypted_message.as_slice()); + } + + #[test] + fn one_shot_ref_test() { + use super::super::{decrypt::OneShotDecryption, encrypt::OneShotEncryption}; + let mut rng = CryptoRng::new().unwrap(); + + let message = b"Eu queria um apartamento no Guarujah; \ + Mas o melhor que eu consegui foi um barraco em Itaquah."; + + let key = SecretKey::generate(&mut rng); + + let EncryptedBlock { nonce, cipher_text } = key.encrypt(message, &mut rng).unwrap(); + + let mut bytes = Vec::with_capacity(nonce.len() + cipher_text.len()); + bytes.extend_from_slice(nonce.as_slice()); + bytes.extend(cipher_text); + + assert_eq!( + bytes.len(), + OneShotEncryption::cipher_text_size(&key, message.len()) + ); + + let decrypted_message = key.decrypt(bytes.as_slice().into()).unwrap(); assert_eq!(message, decrypted_message.as_slice()); } diff --git a/crates/crypto/src/crypto/mod.rs b/crates/crypto/src/crypto/mod.rs deleted file mode 100644 index 6ab19cca1..000000000 --- a/crates/crypto/src/crypto/mod.rs +++ /dev/null @@ -1,821 +0,0 @@ -//! This module contains all encryption and decryption items. These are used throughout the crate for all encryption/decryption needs. - -mod stream; - -pub use self::stream::{Decryptor, Encryptor}; - -#[cfg(test)] -mod tests { - use std::io::Cursor; - - use crate::{ - crypto::{Decryptor, Encryptor}, - primitives::{ - AAD_LEN, AEAD_TAG_LEN, AES_256_GCM_SIV_NONCE_LEN, BLOCK_LEN, KEY_LEN, - XCHACHA20_POLY1305_NONCE_LEN, - }, - rng::CryptoRng, - types::{Aad, Algorithm, EncryptedKey, Key, Nonce}, - }; - - // const KEY: Key = Key::new([0x23; KEY_LEN]); - - const XCHACHA20_POLY1305_NONCE: Nonce = - Nonce::XChaCha20Poly1305([0xE9; XCHACHA20_POLY1305_NONCE_LEN]); - - const AES_256_GCM_SIV_NONCE: Nonce = Nonce::Aes256GcmSiv([0xE9; AES_256_GCM_SIV_NONCE_LEN]); - - const PLAINTEXT: [u8; 32] = [0x5A; 32]; - // const PLAINTEXT_KEY: Key = Key::new([1u8; KEY_LEN]); - - const AAD: Aad = Aad::Standard([0x92; AAD_LEN]); - - // for the `const` arrays below, [0] is without AAD, [1] is with AAD - - const AES_256_GCM_SIV_BYTES_EXPECTED: [[u8; 48]; 2] = [ - [ - 41, 231, 183, 92, 73, 104, 69, 207, 245, 250, 21, 50, 145, 41, 104, 165, 130, 59, 70, - 185, 65, 77, 215, 15, 131, 214, 183, 47, 166, 223, 185, 181, 117, 138, 62, 204, 246, - 227, 198, 32, 132, 5, 97, 120, 15, 70, 229, 218, - ], - [ - 3, 180, 75, 64, 231, 67, 228, 189, 149, 69, 47, 83, 8, 214, 103, 12, 21, 11, 39, 108, - 7, 142, 10, 169, 85, 163, 76, 53, 53, 69, 160, 134, 2, 87, 72, 121, 75, 186, 102, 176, - 163, 170, 81, 101, 242, 237, 173, 133, - ], - ]; - - const XCHACHA20_POLY1305_BYTES_EXPECTED: [[u8; 48]; 2] = [ - [ - 35, 174, 252, 59, 215, 65, 5, 237, 198, 2, 51, 72, 239, 88, 36, 177, 136, 252, 64, 157, - 141, 53, 138, 98, 185, 2, 75, 173, 253, 99, 133, 207, 145, 54, 100, 51, 44, 230, 60, 5, - 157, 70, 110, 145, 166, 41, 215, 95, - ], - [ - 35, 174, 252, 59, 215, 65, 5, 237, 198, 2, 51, 72, 239, 88, 36, 177, 136, 252, 64, 157, - 141, 53, 138, 98, 185, 2, 75, 173, 253, 99, 133, 207, 125, 139, 247, 158, 207, 216, 60, - 114, 72, 44, 6, 212, 233, 141, 251, 239, - ], - ]; - - const XCHACHA20_POLY1305_ENCRYPTED_KEY: EncryptedKey = EncryptedKey::new( - [ - 120, 245, 167, 96, 140, 26, 94, 182, 157, 89, 104, 19, 180, 3, 127, 234, 211, 167, 27, - 198, 214, 110, 209, 57, 226, 89, 16, 246, 166, 56, 222, 148, 40, 198, 237, 205, 45, 49, - 205, 18, 69, 102, 16, 78, 199, 141, 246, 165, - ], - XCHACHA20_POLY1305_NONCE, - ); - - const AES_256_GCM_ENCRYPTED_KEY: EncryptedKey = EncryptedKey::new( - [ - 227, 231, 27, 182, 122, 118, 64, 35, 125, 176, 152, 244, 156, 26, 234, 96, 178, 121, - 73, 213, 228, 189, 45, 152, 189, 68, 214, 187, 123, 182, 91, 83, 216, 50, 174, 13, 157, - 121, 165, 129, 227, 220, 139, 166, 9, 71, 215, 145, - ], - AES_256_GCM_SIV_NONCE, - ); - - #[test] - fn aes_256_gcm_siv_encrypt_bytes() { - let output = Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, AES_256_GCM_SIV_BYTES_EXPECTED[0]); - } - - #[test] - fn aes_256_gcm_siv_encrypt_bytes_with_aad() { - let output = Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &PLAINTEXT, - AAD, - ) - .unwrap(); - - assert_eq!(output, AES_256_GCM_SIV_BYTES_EXPECTED[1]); - } - - #[test] - fn aes_256_gcm_siv_decrypt_bytes() { - let output = Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &AES_256_GCM_SIV_BYTES_EXPECTED[0], - Aad::Null, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - fn aes_256_gcm_siv_decrypt_bytes_with_aad() { - let output = Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &AES_256_GCM_SIV_BYTES_EXPECTED[1], - AAD, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - fn aes_256_gcm_siv_encrypt_key() { - let output = Encryptor::encrypt_key( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &Key::new([1u8; KEY_LEN]), - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, AES_256_GCM_ENCRYPTED_KEY); - } - - #[test] - fn aes_256_gcm_siv_decrypt_key() { - let output = Decryptor::decrypt_key( - &Key::new([0x23; KEY_LEN]), - Algorithm::Aes256GcmSiv, - &AES_256_GCM_ENCRYPTED_KEY, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, Key::new([1u8; KEY_LEN])); - } - - #[test] - fn aes_256_gcm_siv_encrypt_tiny() { - let output = Encryptor::encrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, AES_256_GCM_SIV_BYTES_EXPECTED[0]); - } - - #[test] - fn aes_256_gcm_siv_decrypt_tiny() { - let output = Decryptor::decrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &AES_256_GCM_SIV_BYTES_EXPECTED[0], - Aad::Null, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - #[should_panic(expected = "LengthMismatch")] - fn aes_256_gcm_siv_encrypt_tiny_too_large() { - Encryptor::encrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &vec![0u8; BLOCK_LEN], - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "LengthMismatch")] - fn aes_256_gcm_siv_decrypt_tiny_too_large() { - Decryptor::decrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &vec![0u8; BLOCK_LEN + AEAD_TAG_LEN], - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "Decrypt")] - fn aes_256_gcm_siv_decrypt_bytes_missing_aad() { - Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - &AES_256_GCM_SIV_BYTES_EXPECTED[1], - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[cfg_attr(miri, ignore)] - fn aes_256_gcm_siv_encrypt_and_decrypt_5_blocks() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[test] - #[ignore] - fn aes_256_gcm_siv_encrypt_and_decrypt_128mib() { - let buf = vec![1u8; BLOCK_LEN * 128].into_boxed_slice(); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let output = writer.into_inner().into_boxed_slice(); - - assert_eq!(buf, output); - } - - #[test] - #[cfg_attr(miri, ignore)] - fn aes_256_gcm_siv_encrypt_and_decrypt_5_blocks_with_aad() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, AAD) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, AAD) - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[tokio::test] - #[cfg_attr(miri, ignore)] - async fn aes_256_gcm_siv_encrypt_and_decrypt_5_blocks_async() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - encryptor - .encrypt_streams_async(&mut reader, &mut writer, Aad::Null) - .await - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - decryptor - .decrypt_streams_async(&mut reader, &mut writer, Aad::Null) - .await - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[tokio::test] - #[cfg(feature = "tokio")] - #[cfg_attr(miri, ignore)] - async fn aes_256_gcm_siv_encrypt_and_decrypt_5_blocks_with_aad_async() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - encryptor - .encrypt_streams_async(&mut reader, &mut writer, AAD) - .await - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::Aes256GcmSiv, - ) - .unwrap(); - - decryptor - .decrypt_streams_async(&mut reader, &mut writer, AAD) - .await - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[test] - fn xchacha20_poly1305_encrypt_bytes() { - let output = Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, XCHACHA20_POLY1305_BYTES_EXPECTED[0]); - } - - #[test] - fn xchacha20_poly1305_encrypt_key() { - let output = Encryptor::encrypt_key( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &Key::new([1u8; KEY_LEN]), - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, XCHACHA20_POLY1305_ENCRYPTED_KEY); - } - - #[test] - fn xchacha20_poly1305_decrypt_key() { - let output = Decryptor::decrypt_key( - &Key::new([0x23; KEY_LEN]), - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_ENCRYPTED_KEY, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, Key::new([1u8; KEY_LEN])); - } - - #[test] - fn xchacha20_poly1305_encrypt_tiny() { - let output = Encryptor::encrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - - assert_eq!(output, XCHACHA20_POLY1305_BYTES_EXPECTED[0]); - } - - #[test] - fn xchacha20_poly1305_decrypt_tiny() { - let output = Decryptor::decrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_BYTES_EXPECTED[0], - Aad::Null, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - #[should_panic(expected = "LengthMismatch")] - fn xchacha20_poly1305_encrypt_tiny_too_large() { - Encryptor::encrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &vec![0u8; BLOCK_LEN], - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "LengthMismatch")] - fn xchacha20_poly1305_decrypt_tiny_too_large() { - Decryptor::decrypt_tiny( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &vec![0u8; BLOCK_LEN + AEAD_TAG_LEN], - Aad::Null, - ) - .unwrap(); - } - - #[test] - fn xchacha20_poly1305_encrypt_bytes_with_aad() { - let output = Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - AAD, - ) - .unwrap(); - - assert_eq!(output, XCHACHA20_POLY1305_BYTES_EXPECTED[1]); - } - - #[test] - fn xchacha20_poly1305_decrypt_bytes() { - let output = Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_BYTES_EXPECTED[0], - Aad::Null, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - fn xchacha20_poly1305_decrypt_bytes_with_aad() { - let output = Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_BYTES_EXPECTED[1], - AAD, - ) - .unwrap(); - - assert_eq!(output.expose(), &PLAINTEXT); - } - - #[test] - #[should_panic(expected = "Decrypt")] - fn xchacha20_poly1305_decrypt_bytes_missing_aad() { - Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_BYTES_EXPECTED[1], - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[cfg_attr(miri, ignore)] - fn xchacha20_poly1305_encrypt_and_decrypt_5_blocks() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[test] - #[ignore] - fn xchacha20_poly1305_encrypt_and_decrypt_128mib() { - let buf = vec![1u8; BLOCK_LEN * 128].into_boxed_slice(); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, Aad::Null) - .unwrap(); - - let output = writer.into_inner().into_boxed_slice(); - - assert_eq!(buf, output); - } - - #[test] - #[cfg_attr(miri, ignore)] - fn xchacha20_poly1305_encrypt_and_decrypt_5_blocks_with_aad() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - encryptor - .encrypt_streams(&mut reader, &mut writer, AAD) - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - decryptor - .decrypt_streams(&mut reader, &mut writer, AAD) - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[tokio::test] - #[cfg(feature = "tokio")] - #[cfg_attr(miri, ignore)] - async fn xchacha20_poly1305_encrypt_and_decrypt_5_blocks_async() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - encryptor - .encrypt_streams_async(&mut reader, &mut writer, Aad::Null) - .await - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - decryptor - .decrypt_streams_async(&mut reader, &mut writer, Aad::Null) - .await - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[tokio::test] - #[cfg(feature = "tokio")] - #[cfg_attr(miri, ignore)] - async fn xchacha20_poly1305_encrypt_and_decrypt_5_blocks_with_aad_async() { - let buf = CryptoRng::generate_vec(BLOCK_LEN * 5); - - let mut reader = Cursor::new(&buf); - let mut writer = Cursor::new(Vec::new()); - - let encryptor = Encryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - encryptor - .encrypt_streams_async(&mut reader, &mut writer, AAD) - .await - .unwrap(); - - let mut reader = Cursor::new(writer.into_inner()); - let mut writer = Cursor::new(Vec::new()); - - let decryptor = Decryptor::new( - &Key::new([0x23; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - ) - .unwrap(); - - decryptor - .decrypt_streams_async(&mut reader, &mut writer, AAD) - .await - .unwrap(); - - let output = writer.into_inner(); - - assert_eq!(buf, output); - } - - #[test] - #[should_panic(expected = "Validity")] - fn encrypt_with_invalid_nonce() { - Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "Validity")] - fn encrypt_with_null_nonce() { - Encryptor::encrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &Nonce::XChaCha20Poly1305([0u8; 20]), - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "Validity")] - fn encrypt_with_null_key() { - Encryptor::encrypt_bytes( - &Key::new([0u8; KEY_LEN]), - &XCHACHA20_POLY1305_NONCE, - Algorithm::XChaCha20Poly1305, - &PLAINTEXT, - Aad::Null, - ) - .unwrap(); - } - - #[test] - #[should_panic(expected = "Validity")] - fn decrypt_with_invalid_nonce() { - Decryptor::decrypt_bytes( - &Key::new([0x23; KEY_LEN]), - &AES_256_GCM_SIV_NONCE, - Algorithm::XChaCha20Poly1305, - &XCHACHA20_POLY1305_BYTES_EXPECTED[0], - Aad::Null, - ) - .unwrap(); - } -} diff --git a/crates/crypto/src/crypto/stream.rs b/crates/crypto/src/crypto/stream.rs deleted file mode 100644 index dce48d57b..000000000 --- a/crates/crypto/src/crypto/stream.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::io::{Cursor, Read, Write}; - -use crate::{ - primitives::{AEAD_TAG_LEN, BLOCK_LEN}, - types::{Aad, Algorithm, EncryptedKey, Key, Nonce}, - utils::ToArray, - Error, Protected, Result, -}; -use aead::{ - stream::{DecryptorLE31, EncryptorLE31}, - Payload, -}; -use aes_gcm_siv::Aes256GcmSiv; -use chacha20poly1305::XChaCha20Poly1305; - -#[cfg(feature = "tokio")] -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -macro_rules! impl_stream { - ( - $name:ident, // "Decryptor", "Encryptor" - $error:expr, - $next_fn:ident, // "encrypt_next" - $last_fn:ident, // "encrypt_last" - $last_in_place_fn:ident, - $stream_primitive:ident, // "DecryptorLE31" - $streams_fn:ident, // "encrypt_streams" - $streams_fn_async:ident, // "encrypt_streams_async" - $bytes_fn:ident, // "encrypt_bytes" - $bytes_return:ty, - $size:expr, - $($algorithm:tt),* -) => { - pub enum $name { - $( - $algorithm(Box<$stream_primitive<$algorithm>>), - )* - } - - impl $name { - /// This should be used to initialize a stream object. - /// - /// The desired master key, nonce and algorithm should be provided. - /// - /// This function ensures that both the nonce and key are *valid*. - /// For more information, view `Key::validate()` and `Nonce::validate()` - pub fn new(key: &Key, nonce: &Nonce, algorithm: Algorithm) -> Result { - nonce.validate(algorithm)?; - - let s = match algorithm { - $( - Algorithm::$algorithm => Self::$algorithm(Box::new($stream_primitive::new(&key.into(), &nonce.into()))), - )* - }; - - Ok(s) - } - - fn $next_fn<'msg, 'aad>( - &mut self, - payload: impl Into>, - ) -> Result> { - match self { - $( - Self::$algorithm(s) => s.$next_fn(payload), - )* - } - .map_err(|_| $error) - } - - fn $last_fn<'msg, 'aad>(self, payload: impl Into>) -> Result> { - match self { - $( - Self::$algorithm(s) => s.$last_fn(payload), - )* - } - .map_err(|_| $error) - } - - fn $last_in_place_fn(self, aad: Aad, buf: &mut dyn aead::Buffer) -> Result<()> { - match self { - $( - Self::$algorithm(s) => s.$last_in_place_fn(aad.inner(), buf), - )* - } - .map_err(|_| $error) - } - - /// This function should be used for large amounts of data. - /// - /// The streaming implementation reads blocks of data in `BLOCK_LEN`, encrypts/decrypts, and writes to the writer. - /// - /// It requires a reader, a writer, and any relevant AAD. - /// - /// The AAD will be authenticated with every block of data. - pub fn $streams_fn( - mut self, - mut reader: R, - mut writer: W, - aad: Aad, - ) -> Result<()> - where - R: Read, - W: Write, - { - let mut buffer = vec![0u8; $size].into_boxed_slice(); - - loop { - let count = reader.read(&mut buffer)?; - - let payload = Payload { - aad: aad.inner(), - msg: &buffer[..count], - }; - - if count == $size { - let data = self.$next_fn(payload)?; - writer.write_all(&data)?; - } else { - let data = self.$last_fn(payload)?; - writer.write_all(&data)?; - break; - } - } - - writer.flush()?; - - Ok(()) - } - - /// This function should be used for large amounts of data. - /// - /// The streaming implementation reads blocks of data in `BLOCK_LEN`, encrypts/decrypts, and writes to the writer. - /// - /// It requires a reader, a writer, and any relevant AAD. - /// - /// The AAD will be authenticated with every block of data. - #[cfg(feature = "tokio")] - pub async fn $streams_fn_async( - mut self, - mut reader: R, - mut writer: W, - aad: Aad, - ) -> Result<()> - where - R: AsyncReadExt + Unpin + Send, - W: AsyncWriteExt + Unpin + Send, - { - let mut buffer = vec![0u8; $size].into_boxed_slice(); - - loop { - let count = reader.read(&mut buffer).await?; - - // TODO(brxken128): block on `next_fn` and `last_fn` exclusively - - let payload = Payload { - aad: aad.inner(), - msg: &buffer[..count], - }; - - if count == $size { - let data = self.$next_fn(payload)?; - writer.write_all(&data).await?; - } else { - let data = self.$last_fn(payload)?; - writer.write_all(&data).await?; - break; - } - } - - writer.flush().await?; - - Ok(()) - } - - /// This should ideally only be used for small amounts of data. - /// - /// It is just a thin wrapper around the associated `encrypt/decrypt_streams` function. - pub fn $bytes_fn( - key: &Key, - nonce: &Nonce, - algorithm: Algorithm, - bytes: &[u8], - aad: Aad, - ) -> Result<$bytes_return> { - let mut writer = Cursor::new(Vec::new()); - let s = Self::new(key, nonce, algorithm)?; - - s - .$streams_fn(bytes, &mut writer, aad) - .map(|()| writer.into_inner().into()) - } - } - }; -} - -impl Encryptor { - pub fn encrypt_key( - key: &Key, - nonce: &Nonce, - algorithm: Algorithm, - key_to_encrypt: &Key, - aad: Aad, - ) -> Result { - Self::encrypt_tiny(key, nonce, algorithm, key_to_encrypt.expose(), aad) - .map(|b| Ok(EncryptedKey::new(b.to_array()?, *nonce))) - .map_err(|_| Error::Encrypt)? - } - - /// This is only for encrypting inputs < `BLOCK_LEN`. For anything larger, - /// see [`Encryptor::encrypt_bytes`] or [`Encryptor::encrypt_streams`]. - /// - /// It uses `encrypt_last_in_place` under the hood due to the input always being less than `BLOCK_LEN`. - /// - /// It's faster than the alternatives (for small sizes) as we don't need to allocate the - /// full buffer - we only allocate what is required. - pub fn encrypt_tiny( - key: &Key, - nonce: &Nonce, - algorithm: Algorithm, - bytes: &[u8], - aad: Aad, - ) -> Result> { - if bytes.len() >= BLOCK_LEN { - return Err(Error::LengthMismatch); - } - - let s = Self::new(key, nonce, algorithm)?; - let mut buffer = Vec::with_capacity(bytes.len() + AEAD_TAG_LEN); - buffer.extend_from_slice(bytes); - s.encrypt_last_in_place(aad, &mut buffer)?; - - Ok(buffer) - } -} - -impl Decryptor { - pub fn decrypt_key( - key: &Key, - algorithm: Algorithm, - encrypted_key: &EncryptedKey, - aad: Aad, - ) -> Result { - Self::decrypt_tiny( - key, - encrypted_key.nonce(), - algorithm, - encrypted_key.inner(), - aad, - ) - .map(Key::try_from) - .map_err(|_| Error::Decrypt)? - } - - /// This is only for decrypting inputs < `BLOCK_LEN + AEAD_TAG_LEN`. For anything larger, - /// see [`Decryptor::decrypt_bytes`] or [`Decryptor::decrypt_streams`]. - /// - /// It uses `decrypt_last_in_place` under the hood due to the input always being less than `BLOCK_LEN + AEAD_TAG_LEN`. - /// - /// It's faster than the alternatives (for small sizes) as we don't need to allocate the - /// full buffer - we only allocate what is required. - pub fn decrypt_tiny( - key: &Key, - nonce: &Nonce, - algorithm: Algorithm, - bytes: &[u8], - aad: Aad, - ) -> Result>> { - if bytes.len() >= (BLOCK_LEN + AEAD_TAG_LEN) { - return Err(Error::LengthMismatch); - } - - let s = Self::new(key, nonce, algorithm)?; - let mut buffer = Vec::with_capacity(bytes.len() + AEAD_TAG_LEN); - buffer.extend_from_slice(bytes); - s.decrypt_last_in_place(aad, &mut buffer)?; - - buffer.truncate(bytes.len() - AEAD_TAG_LEN); - - Ok(buffer.into()) - } -} - -impl_stream!( - Encryptor, - Error::Encrypt, - encrypt_next, - encrypt_last, - encrypt_last_in_place, - EncryptorLE31, - encrypt_streams, - encrypt_streams_async, - encrypt_bytes, - Vec, - BLOCK_LEN, - Aes256GcmSiv, - XChaCha20Poly1305 -); - -impl_stream!( - Decryptor, - Error::Decrypt, - decrypt_next, - decrypt_last, - decrypt_last_in_place, - DecryptorLE31, - decrypt_streams, - decrypt_streams_async, - decrypt_bytes, - Protected>, - (BLOCK_LEN + AEAD_TAG_LEN), - Aes256GcmSiv, - XChaCha20Poly1305 -); diff --git a/crates/crypto/src/ct.rs b/crates/crypto/src/ct.rs index e7edf6a89..8ce937ab9 100644 --- a/crates/crypto/src/ct.rs +++ b/crates/crypto/src/ct.rs @@ -87,7 +87,7 @@ impl ConstantTimeEq for String { } } -impl<'a> ConstantTimeEq for &'a str { +impl ConstantTimeEq for &str { fn ct_eq(&self, rhs: &Self) -> Choice { // Here we are just able to convert both values to bytes and use the // appropriate methods to compare the two in constant-time. diff --git a/crates/crypto/src/error.rs b/crates/crypto/src/error.rs index f7444b20f..c8371d1ae 100644 --- a/crates/crypto/src/error.rs +++ b/crates/crypto/src/error.rs @@ -7,6 +7,8 @@ use tokio::io; pub enum Error { #[error("Block too big for oneshot encryption: size in bytes = {0}")] BlockTooBig(usize), + #[error("Invalid key size: expected 32 bytes, got {0}")] + InvalidKeySize(usize), /// Encrypt and decrypt errors, AEAD crate doesn't provide any error context for these /// as it can be a security hazard to leak information about the error. diff --git a/crates/crypto/src/lib.rs b/crates/crypto/src/lib.rs index d5b5fee06..8238e9e21 100644 --- a/crates/crypto/src/lib.rs +++ b/crates/crypto/src/lib.rs @@ -11,7 +11,6 @@ clippy::unwrap_used, unused_qualifications, rust_2018_idioms, - clippy::expect_used, trivial_casts, trivial_numeric_casts, unused_allocation, diff --git a/crates/crypto/src/primitives.rs b/crates/crypto/src/primitives.rs index efa7bd33b..1d8335fc1 100644 --- a/crates/crypto/src/primitives.rs +++ b/crates/crypto/src/primitives.rs @@ -3,10 +3,11 @@ // DO NOT EDIT THIS FILE. IF THESE CONSTANTS CHANGE, THINGS CAN (AND PROBABLY WILL) BREAK use aead::stream::{Nonce, StreamLE31}; -use chacha20poly1305::{Tag, XChaCha20Poly1305, XNonce}; +use chacha20poly1305::{XChaCha20Poly1305, XNonce}; pub type OneShotNonce = XNonce; pub type StreamNonce = Nonce>; +pub use chacha20poly1305::Tag; #[derive(Debug, Clone)] pub struct EncryptedBlock { @@ -14,6 +15,22 @@ pub struct EncryptedBlock { pub cipher_text: Vec, } +pub struct EncryptedBlockRef<'e> { + pub nonce: &'e OneShotNonce, + pub cipher_text: &'e [u8], +} + +impl<'e> From<&'e [u8]> for EncryptedBlockRef<'e> { + fn from(cipher_text: &'e [u8]) -> Self { + let (nonce, cipher_text) = cipher_text.split_at(size_of::()); + + Self { + nonce: nonce.try_into().expect("we split the correct amount"), + cipher_text, + } + } +} + impl EncryptedBlock { /// The block size used for STREAM encryption/decryption. This size seems to offer /// the best performance compared to alternatives. diff --git a/crates/crypto/src/rng/csprng.rs b/crates/crypto/src/rng/csprng.rs index 6275aafea..2d38ef6e0 100644 --- a/crates/crypto/src/rng/csprng.rs +++ b/crates/crypto/src/rng/csprng.rs @@ -9,7 +9,7 @@ use zeroize::{Zeroize, Zeroizing}; /// /// On `Drop`, it re-seeds the inner RNG, erasing the previous state and making all future /// values unpredictable. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CryptoRng(ChaCha20Rng); impl CryptoRng { @@ -86,3 +86,25 @@ impl Drop for CryptoRng { self.zeroize(); } } + +// implementing old-rand-core traits for compatibility with old code +impl old_rand_core::CryptoRng for CryptoRng {} + +impl old_rand_core::RngCore for CryptoRng { + fn next_u32(&mut self) -> u32 { + ::next_u32(self) + } + + fn next_u64(&mut self) -> u64 { + ::next_u64(self) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + ::fill_bytes(self, dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), old_rand_core::Error> { + ::fill_bytes(self, dest); + Ok(()) + } +} diff --git a/crates/ffmpeg/src/dict.rs b/crates/ffmpeg/src/dict.rs index 7d1d5726b..feb84184f 100644 --- a/crates/ffmpeg/src/dict.rs +++ b/crates/ffmpeg/src/dict.rs @@ -87,7 +87,7 @@ pub struct FFmpegDictIter<'a> { _lifetime: std::marker::PhantomData<&'a ()>, } -impl<'a> Iterator for FFmpegDictIter<'a> { +impl Iterator for FFmpegDictIter<'_> { type Item = (String, Option); fn next(&mut self) -> Option<(String, Option)> { diff --git a/crates/ffmpeg/src/frame_decoder.rs b/crates/ffmpeg/src/frame_decoder.rs index 95516b6f2..4a98202e3 100644 --- a/crates/ffmpeg/src/frame_decoder.rs +++ b/crates/ffmpeg/src/frame_decoder.rs @@ -92,7 +92,7 @@ impl FrameDecoder { }) } - pub(crate) fn use_embedded(&mut self) -> bool { + pub(crate) const fn use_embedded(&self) -> bool { self.embedded } diff --git a/crates/images/src/consts.rs b/crates/images/src/consts.rs index d68675b69..4afd8da84 100644 --- a/crates/images/src/consts.rs +++ b/crates/images/src/consts.rs @@ -6,7 +6,7 @@ const MIB: u64 = 1_048_576; /// The maximum file size that an image can be in order to have a thumbnail generated. /// /// This value is in MiB. -pub const MAXIMUM_FILE_SIZE: u64 = MIB * 192; +pub const MAXIMUM_FILE_SIZE: u64 = MIB * 1024; /// These are roughly all extensions supported by the `image` crate, as of `v0.24.7`. /// @@ -159,7 +159,7 @@ impl serde::Serialize for ConvertibleExtension { struct ExtensionVisitor; #[cfg(feature = "serde")] -impl<'de> serde::de::Visitor<'de> for ExtensionVisitor { +impl serde::de::Visitor<'_> for ExtensionVisitor { type Value = ConvertibleExtension; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/media-metadata/src/exif/datetime.rs b/crates/media-metadata/src/exif/datetime.rs index 39c6a40b6..b238dcd4f 100644 --- a/crates/media-metadata/src/exif/datetime.rs +++ b/crates/media-metadata/src/exif/datetime.rs @@ -77,7 +77,7 @@ impl serde::Serialize for MediaDate { struct MediaDateVisitor; -impl<'de> Visitor<'de> for MediaDateVisitor { +impl Visitor<'_> for MediaDateVisitor { type Value = MediaDate; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/p2p/Cargo.toml b/crates/p2p/Cargo.toml index 4727ec7f9..d915e2024 100644 --- a/crates/p2p/Cargo.toml +++ b/crates/p2p/Cargo.toml @@ -19,6 +19,7 @@ specta = [] # Workspace dependencies base64 = { workspace = true } ed25519-dalek = { workspace = true } +flume = { workspace = true } futures = { workspace = true } rmp-serde = { workspace = true } serde = { workspace = true, features = ["derive"] } @@ -28,10 +29,10 @@ tokio = { workspace = true, features = ["fs", "io-util", "macros", "sync tokio-util = { workspace = true, features = ["compat"] } tracing = { workspace = true } uuid = { workspace = true, features = ["serde"] } +zeroize = { workspace = true, features = ["derive"] } # Specific P2P dependencies dns-lookup = "2.0" -flume = "=0.11.1" # Must match version used by `mdns-sd` hash_map_diff = "0.2.0" if-watch = { version = "=3.2.0", features = ["tokio"] } # Override features used by libp2p-quic libp2p-stream = "=0.2.0-alpha" # Update blocked due to custom patch @@ -39,7 +40,6 @@ mdns-sd = "0.11.5" rand_core = "0.6.4" stable-vec = "0.4.1" sync_wrapper = "1.0" -zeroize = { version = "1.8", features = ["derive"] } [dependencies.libp2p] features = ["autonat", "dcutr", "macros", "noise", "quic", "relay", "serde", "tokio", "yamux"] diff --git a/crates/p2p/crates/tunnel/src/lib.rs b/crates/p2p/crates/tunnel/src/lib.rs index df7706255..482c736c7 100644 --- a/crates/p2p/crates/tunnel/src/lib.rs +++ b/crates/p2p/crates/tunnel/src/lib.rs @@ -52,7 +52,7 @@ impl Tunnel { library_identity: &Identity, ) -> Result { stream - .write_all(&[b'T']) + .write_all(b"T") .await .map_err(|_| TunnelError::DiscriminatorWriteError)?; diff --git a/crates/p2p/src/smart_guards.rs b/crates/p2p/src/smart_guards.rs index 6177ed930..a920508aa 100644 --- a/crates/p2p/src/smart_guards.rs +++ b/crates/p2p/src/smart_guards.rs @@ -28,7 +28,7 @@ impl<'a, T: Clone> SmartWriteGuard<'a, T> { } } -impl<'a, T> Deref for SmartWriteGuard<'a, T> { +impl Deref for SmartWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -36,13 +36,13 @@ impl<'a, T> Deref for SmartWriteGuard<'a, T> { } } -impl<'a, T> DerefMut for SmartWriteGuard<'a, T> { +impl DerefMut for SmartWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.lock } } -impl<'a, T> Drop for SmartWriteGuard<'a, T> { +impl Drop for SmartWriteGuard<'_, T> { fn drop(&mut self) { (self.save)( self.p2p, diff --git a/crates/sync-generator/src/model.rs b/crates/sync-generator/src/model.rs index 767c1d820..e171634b8 100644 --- a/crates/sync-generator/src/model.rs +++ b/crates/sync-generator/src/model.rs @@ -46,7 +46,7 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { RefinedFieldWalker::Scalar(scalar_field) => { (!scalar_field.is_in_required_relation()).then(|| { quote! { - #model_name_snake::#field_name_snake::set(::rmpv::ext::from_value(val).unwrap()), + #model_name_snake::#field_name_snake::set(::rmpv::ext::from_value(val)?), } }) } @@ -59,11 +59,19 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { |i| { if i.count() == 1 { Some(quote! {{ - let val: std::collections::HashMap = ::rmpv::ext::from_value(val).unwrap(); - let val = val.into_iter().next().unwrap(); + + let (field, value) = ::rmpv + ::ext + ::from_value::>(val)? + .into_iter() + .next() + .ok_or(Error::MissingRelationData { + field: field.to_string(), + model: #relation_model_name_snake::NAME.to_string() + })?; #model_name_snake::#field_name_snake::connect( - #relation_model_name_snake::UniqueWhereParam::deserialize(&val.0, val.1).unwrap() + #relation_model_name_snake::UniqueWhereParam::deserialize(&field, value)? ) }}) } else { @@ -81,10 +89,13 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { } else { quote! { impl #model_name_snake::SetParam { - pub fn deserialize(field: &str, val: ::rmpv::Value) -> Option { - Some(match field { + pub fn deserialize(field: &str, val: ::rmpv::Value) -> Result { + Ok(match field { #(#field_matches)* - _ => return None + _ => return Err(Error::FieldNotFound { + field: field.to_string(), + model: #model_name_snake::NAME.to_string(), + }), }) } } @@ -97,9 +108,12 @@ pub fn module((model, sync_type): ModelWithSyncType<'_>) -> Module { Module::new( model.name(), quote! { - use super::prisma::*; + use super::Error; + use prisma_client_rust::scalar_types::*; + use super::prisma::*; + #sync_id #set_param_impl @@ -172,7 +186,7 @@ fn process_unique_params(model: Walker<'_, ModelId>, model_name_snake: &Ident) - Some(quote!(#model_name_snake::#field_name_snake::NAME => #model_name_snake::#field_name_snake::equals( - ::rmpv::ext::from_value(val).unwrap() + ::rmpv::ext::from_value(val)? ), )) } @@ -185,10 +199,13 @@ fn process_unique_params(model: Walker<'_, ModelId>, model_name_snake: &Ident) - } else { quote! { impl #model_name_snake::UniqueWhereParam { - pub fn deserialize(field: &str, val: ::rmpv::Value) -> Option { - Some(match field { + pub fn deserialize(field: &str, val: ::rmpv::Value) -> Result { + Ok(match field { #(#field_matches)* - _ => return None + _ => return Err(Error::FieldNotFound { + field: field.to_string(), + model: #model_name_snake::NAME.to_string(), + }) }) } } diff --git a/crates/sync-generator/src/sync_data.rs b/crates/sync-generator/src/sync_data.rs index 66556e752..e8ee713e6 100644 --- a/crates/sync-generator/src/sync_data.rs +++ b/crates/sync-generator/src/sync_data.rs @@ -7,7 +7,7 @@ use prisma_models::walkers::{FieldWalker, ScalarFieldWalker}; use crate::{ModelSyncType, ModelWithSyncType}; pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { - let (variants, matches): (Vec<_>, Vec<_>) = models + let (variants, matches) = models .iter() .filter_map(|(model, sync_type)| { let model_name_snake = snake_ident(model.name()); @@ -26,12 +26,12 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { quote!(#model_name_pascal(#model_name_snake::SyncId, sd_sync::CRDTOperationData)), quote! { #model_name_snake::MODEL_ID => - Self::#model_name_pascal(rmpv::ext::from_value(op.record_id).ok()?, op.data) + Self::#model_name_pascal(rmpv::ext::from_value(op.record_id)?, op.data) }, ) }) }) - .unzip(); + .unzip::<_, _, Vec<_>, Vec<_>>(); let exec_matches = models.iter().filter_map(|(model, sync_type)| { let model_name_pascal = pascal_ident(model.name()); @@ -54,20 +54,22 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { }) }); + let error_enum = declare_error_enum(); + quote! { pub enum ModelSyncData { #(#variants),* } impl ModelSyncData { - pub fn from_op(op: sd_sync::CRDTOperation) -> Option { - Some(match op.model { + pub fn from_op(op: sd_sync::CRDTOperation) -> Result { + Ok(match op.model_id { #(#matches),*, - _ => return None + _ => return Err(Error::InvalidModelId(op.model_id)), }) } - pub async fn exec(self, db: &prisma::PrismaClient) -> prisma_client_rust::Result<()> { + pub async fn exec(self, db: &prisma::PrismaClient) -> Result<(), Error> { match self { #(#exec_matches),* } @@ -75,6 +77,69 @@ pub fn enumerate(models: &[ModelWithSyncType<'_>]) -> TokenStream { Ok(()) } } + + #error_enum + } +} + +fn declare_error_enum() -> TokenStream { + quote! { + #[derive(Debug)] + pub enum Error { + Rmpv(rmpv::ext::Error), + RmpSerialize(rmp_serde::encode::Error), + Prisma(prisma_client_rust::QueryError), + InvalidModelId(sd_sync::ModelId), + FieldNotFound { field: String, model: String }, + MissingRelationData { field: String, model: String }, + RelatedEntryNotFound { field: String, model: String }, + } + + impl From for Error { + fn from(e: rmpv::ext::Error) -> Self { + Self::Rmpv(e) + } + } + + impl From for Error { + fn from(e: rmp_serde::encode::Error) -> Self { + Self::RmpSerialize(e) + } + } + + impl From for Error { + fn from(e: prisma_client_rust::QueryError) -> Self { + Self::Prisma(e) + } + } + + impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Rmpv(e) => write!(f, "Failed to serialize or deserialize rmpv data: {e}"), + Self::RmpSerialize(e) => write!(f, "Failed to serialize rmp data: {e}"), + Self::Prisma(e) => write!(f, "Prisma error: {e}"), + Self::InvalidModelId(id) => write!(f, "Invalid model id: {id}"), + Self::FieldNotFound { field, model } => { + write!(f, "Field '{field}' not found in model '{model}'") + } + Self::MissingRelationData { field, model } => { + write!( + f, + "Field '{field}' missing relation data in model '{model}'" + ) + } + Self::RelatedEntryNotFound { field, model } => { + write!( + f, + "Related entry for field '{field}' not found in table '{model}'" + ) + } + } + } + } + + impl std::error::Error for Error {} } } @@ -103,6 +168,7 @@ fn handle_crdt_ops_relation( .and_then(|(_m, sync)| sync.as_ref()) .map(|sync| snake_ident(sync.sync_id()[0].name())) .expect("missing sync id field name for relation"); + let item_model_name_snake = snake_ident(item.related_model().name()); let item_field_name_snake = snake_ident(item.name()); @@ -155,11 +221,15 @@ fn handle_crdt_ops_relation( vec![], ) .exec() - .await - .ok(); + .await?; }, - sd_sync::CRDTOperationData::Update { field, value } => { - let data = vec![prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap()]; + + sd_sync::CRDTOperationData::Update(data) => { + let data = data.into_iter() + .map(|(field, value)| { + prisma::#model_name_snake::SetParam::deserialize(&field, value) + }) + .collect::, _>>()?; db.#model_name_snake() .upsert( @@ -171,15 +241,14 @@ fn handle_crdt_ops_relation( data, ) .exec() - .await - .ok(); + .await?; }, + sd_sync::CRDTOperationData::Delete => { db.#model_name_snake() .delete(id) .exec() - .await - .ok(); + .await?; }, } } @@ -198,8 +267,10 @@ fn handle_crdt_ops_shared( .expect("missing fields") .next() .expect("empty fields"); + let id_name_snake = snake_ident(scalar_field.name()); let field_name_snake = snake_ident(rel.name()); + let opposite_model_name_snake = snake_ident( rel.opposite_relation_field() .expect("missing opposite relation field") @@ -211,12 +282,16 @@ fn handle_crdt_ops_shared( id.#field_name_snake.pub_id.clone() )); + let pub_id_field = format!("{field_name_snake}::pub_id"); + let rel_fetch = quote! { let rel = db.#opposite_model_name_snake() .find_unique(#relation_equals_condition) .exec() - .await? - .unwrap(); + .await?.ok_or_else(|| Error::RelatedEntryNotFound { + field: #pub_id_field.to_string(), + model: prisma::#opposite_model_name_snake::NAME.to_string(), + })?; }; ( @@ -226,6 +301,7 @@ fn handle_crdt_ops_shared( relation_equals_condition, ) } + RefinedFieldWalker::Scalar(s) => { let field_name_snake = snake_ident(s.name()); let thing = quote!(id.#field_name_snake.clone()); @@ -238,24 +314,12 @@ fn handle_crdt_ops_shared( #get_id match data { - sd_sync::CRDTOperationData::Create(data) => { - let data: Vec<_> = data.into_iter().map(|(field, value)| { - prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap() - }).collect(); - - db.#model_name_snake() - .upsert( - prisma::#model_name_snake::#id_name_snake::equals(#equals_value), - prisma::#model_name_snake::create(#create_id, data.clone()), - data - ) - .exec() - .await?; - }, - sd_sync::CRDTOperationData::Update { field, value } => { - let data = vec![ - prisma::#model_name_snake::SetParam::deserialize(&field, value).unwrap() - ]; + sd_sync::CRDTOperationData::Create(data) | sd_sync::CRDTOperationData::Update(data) => { + let data = data.into_iter() + .map(|(field, value)| { + prisma::#model_name_snake::SetParam::deserialize(&field, value) + }) + .collect::, _>>()?; db.#model_name_snake() .upsert( @@ -266,6 +330,7 @@ fn handle_crdt_ops_shared( .exec() .await?; }, + sd_sync::CRDTOperationData::Delete => { db.#model_name_snake() .delete(prisma::#model_name_snake::#id_name_snake::equals(#equals_value)) @@ -275,8 +340,8 @@ fn handle_crdt_ops_shared( db.crdt_operation() .delete_many(vec![ prisma::crdt_operation::model::equals(#model_id as i32), - prisma::crdt_operation::record_id::equals(rmp_serde::to_vec(&id).unwrap()), - prisma::crdt_operation::kind::equals(sd_sync::OperationKind::Create.to_string()) + prisma::crdt_operation::record_id::equals(rmp_serde::to_vec(&id)?), + prisma::crdt_operation::kind::equals(sd_sync::OperationKind::Create.to_string()), ]) .exec() .await?; diff --git a/crates/sync/Cargo.toml b/crates/sync/Cargo.toml index 8b15355ca..302b37a53 100644 --- a/crates/sync/Cargo.toml +++ b/crates/sync/Cargo.toml @@ -12,7 +12,5 @@ rmp = { workspace = true } rmp-serde = { workspace = true } rmpv = { workspace = true } serde = { workspace = true } -serde_json = { workspace = true } -specta = { workspace = true, features = ["serde_json", "uhlc", "uuid"] } uhlc = { workspace = true } -uuid = { workspace = true, features = ["serde", "v4"] } +uuid = { workspace = true, features = ["serde", "v7"] } diff --git a/crates/sync/src/compressed.rs b/crates/sync/src/compressed.rs index 0db151330..a2e3a147d 100644 --- a/crates/sync/src/compressed.rs +++ b/crates/sync/src/compressed.rs @@ -1,81 +1,105 @@ -use std::mem; +use crate::{CRDTOperation, CRDTOperationData, DevicePubId, ModelId, RecordId}; + +use std::collections::{hash_map::Entry, BTreeMap, HashMap}; use serde::{Deserialize, Serialize}; use uhlc::NTP64; -use uuid::Uuid; -use crate::{CRDTOperation, CRDTOperationData}; +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CompressedCRDTOperationsPerModel(pub Vec<(ModelId, CompressedCRDTOperationsPerRecord)>); -pub type CompressedCRDTOperationsForModel = Vec<(rmpv::Value, Vec)>; +pub type CompressedCRDTOperationsPerRecord = Vec<(RecordId, Vec)>; /// Stores a bunch of [`CRDTOperation`]s in a more memory-efficient form for sending to the cloud. -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct CompressedCRDTOperations(pub Vec<(Uuid, Vec<(u16, CompressedCRDTOperationsForModel)>)>); +#[derive(Serialize, Deserialize, Debug)] +pub struct CompressedCRDTOperationsPerModelPerDevice( + pub Vec<(DevicePubId, CompressedCRDTOperationsPerModel)>, +); -impl CompressedCRDTOperations { +impl CompressedCRDTOperationsPerModelPerDevice { + /// Creates a new [`CompressedCRDTOperationsPerModelPerDevice`] from a vector of [`CRDTOperation`]s. + /// + /// # Panics + /// + /// Will panic if for some reason `rmp_serde::to_vec` fails to serialize a `rmpv::Value` to bytes. #[must_use] pub fn new(ops: Vec) -> Self { - let mut compressed = vec![]; + let mut compressed_map = BTreeMap::< + DevicePubId, + BTreeMap, (RecordId, Vec)>>, + >::new(); - let mut ops_iter = ops.into_iter(); + for CRDTOperation { + device_pub_id, + timestamp, + model_id, + record_id, + data, + } in ops + { + let records = compressed_map + .entry(device_pub_id) + .or_default() + .entry(model_id) + .or_default(); - let Some(first) = ops_iter.next() else { - return Self(vec![]); - }; + // Can't use RecordId as a key because rmpv::Value doesn't implement Hash + Eq. + // So we use it's serialized bytes as a key. + let record_id_bytes = + rmp_serde::to_vec_named(&record_id).expect("already serialized to Value"); - let mut instance_id = first.instance; - let mut instance = vec![]; - - let mut model_str = first.model; - let mut model = vec![]; - - let mut record_id = first.record_id.clone(); - let mut record = vec![first.into()]; - - for op in ops_iter { - if instance_id != op.instance { - model.push(( - mem::replace(&mut record_id, op.record_id.clone()), - mem::take(&mut record), - )); - instance.push(( - mem::replace(&mut model_str, op.model), - mem::take(&mut model), - )); - compressed.push(( - mem::replace(&mut instance_id, op.instance), - mem::take(&mut instance), - )); - } else if model_str != op.model { - model.push(( - mem::replace(&mut record_id, op.record_id.clone()), - mem::take(&mut record), - )); - instance.push(( - mem::replace(&mut model_str, op.model), - mem::take(&mut model), - )); - } else if record_id != op.record_id { - model.push(( - mem::replace(&mut record_id, op.record_id.clone()), - mem::take(&mut record), - )); + match records.entry(record_id_bytes) { + Entry::Occupied(mut entry) => { + entry + .get_mut() + .1 + .push(CompressedCRDTOperation { timestamp, data }); + } + Entry::Vacant(entry) => { + entry.insert((record_id, vec![CompressedCRDTOperation { timestamp, data }])); + } } - - record.push(CompressedCRDTOperation::from(op)); } - model.push((record_id, record)); - instance.push((model_str, model)); - compressed.push((instance_id, instance)); + Self( + compressed_map + .into_iter() + .map(|(device_pub_id, model_map)| { + ( + device_pub_id, + CompressedCRDTOperationsPerModel( + model_map + .into_iter() + .map(|(model_id, ops_per_record_map)| { + (model_id, ops_per_record_map.into_values().collect()) + }) + .collect(), + ), + ) + }) + .collect(), + ) + } - Self(compressed) + /// Creates a new [`CompressedCRDTOperationsPerModel`] from crdt operation of a single device. + /// + /// # Panics + /// Will panic if there are more than one device. + #[must_use] + pub fn new_single_device( + ops: Vec, + ) -> (DevicePubId, CompressedCRDTOperationsPerModel) { + let Self(mut compressed) = Self::new(ops); + + assert_eq!(compressed.len(), 1, "Expected a single device"); + + compressed.remove(0) } #[must_use] - pub fn first(&self) -> Option<(Uuid, u16, &rmpv::Value, &CompressedCRDTOperation)> { + pub fn first(&self) -> Option<(DevicePubId, ModelId, &RecordId, &CompressedCRDTOperation)> { self.0.first().and_then(|(instance, data)| { - data.first().and_then(|(model, data)| { + data.0.first().and_then(|(model, data)| { data.first() .and_then(|(record, ops)| ops.first().map(|op| (*instance, *model, record, op))) }) @@ -83,9 +107,9 @@ impl CompressedCRDTOperations { } #[must_use] - pub fn last(&self) -> Option<(Uuid, u16, &rmpv::Value, &CompressedCRDTOperation)> { + pub fn last(&self) -> Option<(DevicePubId, ModelId, &RecordId, &CompressedCRDTOperation)> { self.0.last().and_then(|(instance, data)| { - data.last().and_then(|(model, data)| { + data.0.last().and_then(|(model, data)| { data.last() .and_then(|(record, ops)| ops.last().map(|op| (*instance, *model, record, op))) }) @@ -97,11 +121,12 @@ impl CompressedCRDTOperations { self.0 .iter() .map(|(_, data)| { - data.iter() + data.0 + .iter() .map(|(_, data)| data.iter().map(|(_, ops)| ops.len()).sum::()) .sum::() }) - .sum::() + .sum() } #[must_use] @@ -111,15 +136,15 @@ impl CompressedCRDTOperations { #[must_use] pub fn into_ops(self) -> Vec { - let mut ops = vec![]; + let mut ops = Vec::with_capacity(self.len()); - for (instance_id, instance) in self.0 { - for (model_str, model) in instance { - for (record_id, record) in model { + for (device_pub_id, device_messages) in self.0 { + for (model_id, model_messages) in device_messages.0 { + for (record_id, record) in model_messages { for op in record { ops.push(CRDTOperation { - instance: instance_id, - model: model_str, + device_pub_id, + model_id, record_id: record_id.clone(), timestamp: op.timestamp, data: op.data, @@ -133,6 +158,58 @@ impl CompressedCRDTOperations { } } +impl CompressedCRDTOperationsPerModel { + #[must_use] + pub fn first(&self) -> Option<(ModelId, &RecordId, &CompressedCRDTOperation)> { + self.0.first().and_then(|(model_id, data)| { + data.first() + .and_then(|(record_id, ops)| ops.first().map(|op| (*model_id, record_id, op))) + }) + } + + #[must_use] + pub fn last(&self) -> Option<(ModelId, &RecordId, &CompressedCRDTOperation)> { + self.0.last().and_then(|(model_id, data)| { + data.last() + .and_then(|(record_id, ops)| ops.last().map(|op| (*model_id, record_id, op))) + }) + } + + #[must_use] + pub fn len(&self) -> usize { + self.0 + .iter() + .map(|(_, data)| data.iter().map(|(_, ops)| ops.len()).sum::()) + .sum() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[must_use] + pub fn into_ops(self, device_pub_id: DevicePubId) -> Vec { + let mut ops = Vec::with_capacity(self.len()); + + for (model_id, model_messages) in self.0 { + for (record_id, record) in model_messages { + for op in record { + ops.push(CRDTOperation { + device_pub_id, + model_id, + record_id: record_id.clone(), + timestamp: op.timestamp, + data: op.data, + }); + } + } + } + + ops + } +} + #[derive(PartialEq, Serialize, Deserialize, Clone, Debug)] pub struct CompressedCRDTOperation { pub timestamp: NTP64, @@ -140,90 +217,91 @@ pub struct CompressedCRDTOperation { } impl From for CompressedCRDTOperation { - fn from(value: CRDTOperation) -> Self { - Self { - timestamp: value.timestamp, - data: value.data, - } + fn from( + CRDTOperation { + timestamp, data, .. + }: CRDTOperation, + ) -> Self { + Self { timestamp, data } } } #[cfg(test)] mod test { use super::*; + use uuid::Uuid; #[test] fn compress() { - let instance = Uuid::new_v4(); + let device_pub_id = Uuid::now_v7(); let uncompressed = vec![ CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 0, + model_id: 0, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 0, + model_id: 0, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 0, + model_id: 0, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 1, + model_id: 1, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 1, + model_id: 1, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 0, + model_id: 0, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, CRDTOperation { - instance, + device_pub_id, timestamp: NTP64(0), - model: 0, + model_id: 0, record_id: rmpv::Value::Nil, data: CRDTOperationData::create(), }, ]; - let CompressedCRDTOperations(compressed) = CompressedCRDTOperations::new(uncompressed); + let CompressedCRDTOperationsPerModelPerDevice(compressed) = + CompressedCRDTOperationsPerModelPerDevice::new(uncompressed); - assert_eq!(compressed[0].1[0].0, 0); - assert_eq!(compressed[0].1[1].0, 1); - assert_eq!(compressed[0].1[2].0, 0); + assert_eq!(compressed[0].1 .0[0].0, 0); + assert_eq!(compressed[0].1 .0[1].0, 1); - assert_eq!(compressed[0].1[0].1[0].1.len(), 3); - assert_eq!(compressed[0].1[1].1[0].1.len(), 2); - assert_eq!(compressed[0].1[2].1[0].1.len(), 2); + assert_eq!(compressed[0].1 .0[0].1[0].1.len(), 5); + assert_eq!(compressed[0].1 .0[1].1[0].1.len(), 2); } #[test] fn into_ops() { - let compressed = CompressedCRDTOperations(vec![( + let compressed = CompressedCRDTOperationsPerModelPerDevice(vec![( Uuid::new_v4(), - vec![ + CompressedCRDTOperationsPerModel(vec![ ( 0, vec![( @@ -241,6 +319,14 @@ mod test { timestamp: NTP64(0), data: CRDTOperationData::create(), }, + CompressedCRDTOperation { + timestamp: NTP64(0), + data: CRDTOperationData::create(), + }, + CompressedCRDTOperation { + timestamp: NTP64(0), + data: CRDTOperationData::create(), + }, ], )], ), @@ -260,30 +346,14 @@ mod test { ], )], ), - ( - 0, - vec![( - rmpv::Value::Nil, - vec![ - CompressedCRDTOperation { - timestamp: NTP64(0), - data: CRDTOperationData::create(), - }, - CompressedCRDTOperation { - timestamp: NTP64(0), - data: CRDTOperationData::create(), - }, - ], - )], - ), - ], + ]), )]); let uncompressed = compressed.into_ops(); assert_eq!(uncompressed.len(), 7); - assert_eq!(uncompressed[2].model, 0); - assert_eq!(uncompressed[4].model, 1); - assert_eq!(uncompressed[6].model, 0); + assert_eq!(uncompressed[2].model_id, 0); + assert_eq!(uncompressed[4].model_id, 0); + assert_eq!(uncompressed[6].model_id, 1); } } diff --git a/crates/sync/src/crdt.rs b/crates/sync/src/crdt.rs index 2a3872c92..3cbdf23d2 100644 --- a/crates/sync/src/crdt.rs +++ b/crates/sync/src/crdt.rs @@ -1,13 +1,13 @@ +use crate::{DevicePubId, ModelId}; + use std::{collections::BTreeMap, fmt}; use serde::{Deserialize, Serialize}; -use specta::Type; use uhlc::NTP64; -use uuid::Uuid; pub enum OperationKind<'a> { Create, - Update(&'a str), + Update(Vec<&'a str>), Delete, } @@ -15,22 +15,18 @@ impl fmt::Display for OperationKind<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { OperationKind::Create => write!(f, "c"), - OperationKind::Update(field) => write!(f, "u:{field}"), + OperationKind::Update(fields) => write!(f, "u:{}:", fields.join(":")), OperationKind::Delete => write!(f, "d"), } } } -#[derive(PartialEq, Serialize, Deserialize, Clone, Debug, Type)] +#[derive(PartialEq, Serialize, Deserialize, Clone, Debug)] pub enum CRDTOperationData { #[serde(rename = "c")] - Create(#[specta(type = BTreeMap)] BTreeMap), + Create(BTreeMap), #[serde(rename = "u")] - Update { - field: String, - #[specta(type = serde_json::Value)] - value: rmpv::Value, - }, + Update(BTreeMap), #[serde(rename = "d")] Delete, } @@ -45,19 +41,19 @@ impl CRDTOperationData { pub fn as_kind(&self) -> OperationKind<'_> { match self { Self::Create(_) => OperationKind::Create, - Self::Update { field, .. } => OperationKind::Update(field), + Self::Update(fields_and_values) => { + OperationKind::Update(fields_and_values.keys().map(String::as_str).collect()) + } Self::Delete => OperationKind::Delete, } } } -#[derive(PartialEq, Serialize, Deserialize, Clone, Type)] +#[derive(PartialEq, Serialize, Deserialize, Clone)] pub struct CRDTOperation { - pub instance: Uuid, - #[specta(type = u32)] + pub device_pub_id: DevicePubId, pub timestamp: NTP64, - pub model: u16, - #[specta(type = serde_json::Value)] + pub model_id: ModelId, pub record_id: rmpv::Value, pub data: CRDTOperationData, } @@ -73,7 +69,7 @@ impl fmt::Debug for CRDTOperation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CRDTOperation") .field("data", &self.data) - .field("model", &self.model) + .field("model", &self.model_id) .field("record_id", &self.record_id.to_string()) .finish_non_exhaustive() } diff --git a/crates/sync/src/factory.rs b/crates/sync/src/factory.rs index dd553173c..7c73f8b5f 100644 --- a/crates/sync/src/factory.rs +++ b/crates/sync/src/factory.rs @@ -1,47 +1,38 @@ -use uhlc::HLC; -use uuid::Uuid; - use crate::{ - CRDTOperation, CRDTOperationData, RelationSyncId, RelationSyncModel, SharedSyncModel, SyncId, + CRDTOperation, CRDTOperationData, DevicePubId, RelationSyncId, RelationSyncModel, + SharedSyncModel, SyncId, SyncModel, }; -macro_rules! msgpack { - (nil) => { - ::rmpv::Value::Nil - }; - ($e:expr) => {{ - let bytes = rmp_serde::to_vec_named(&$e).expect("failed to serialize msgpack"); - let value: rmpv::Value = rmp_serde::from_slice(&bytes).expect("failed to deserialize msgpack"); - - value - }} -} +use uhlc::HLC; pub trait OperationFactory { fn get_clock(&self) -> &HLC; - fn get_instance(&self) -> Uuid; - fn new_op(&self, id: &TSyncId, data: CRDTOperationData) -> CRDTOperation - where - TSyncId::Model: crate::SyncModel, - { - let timestamp = self.get_clock().new_timestamp(); + fn get_device_pub_id(&self) -> DevicePubId; + fn new_op>( + &self, + id: &SId, + data: CRDTOperationData, + ) -> CRDTOperation { CRDTOperation { - instance: self.get_instance(), - timestamp: *timestamp.get_time(), - model: ::MODEL_ID, - record_id: msgpack!(id), + device_pub_id: self.get_device_pub_id(), + timestamp: *self.get_clock().new_timestamp().get_time(), + model_id: ::MODEL_ID, + record_id: rmp_serde::from_slice::( + &rmp_serde::to_vec_named(id).expect("failed to serialize record id to msgpack"), + ) + .expect("failed to deserialize record id to msgpack value"), data, } } - fn shared_create, TModel: SharedSyncModel>( + fn shared_create( &self, - id: TSyncId, + id: impl SyncId, values: impl IntoIterator + 'static, - ) -> Vec { - vec![self.new_op( + ) -> CRDTOperation { + self.new_op( &id, CRDTOperationData::Create( values @@ -49,35 +40,35 @@ pub trait OperationFactory { .map(|(name, value)| (name.to_string(), value)) .collect(), ), - )] + ) } - fn shared_update, TModel: SharedSyncModel>( + + fn shared_update( &self, - id: TSyncId, - field: impl Into, - value: rmpv::Value, + id: impl SyncId, + values: impl IntoIterator + 'static, ) -> CRDTOperation { self.new_op( &id, - CRDTOperationData::Update { - field: field.into(), - value, - }, + CRDTOperationData::Update( + values + .into_iter() + .map(|(name, value)| (name.to_string(), value)) + .collect(), + ), ) } - fn shared_delete, TModel: SharedSyncModel>( - &self, - id: TSyncId, - ) -> CRDTOperation { + + fn shared_delete(&self, id: impl SyncId) -> CRDTOperation { self.new_op(&id, CRDTOperationData::Delete) } - fn relation_create, TModel: RelationSyncModel>( + fn relation_create( &self, - id: TSyncId, + id: impl RelationSyncId, values: impl IntoIterator + 'static, - ) -> Vec { - vec![self.new_op( + ) -> CRDTOperation { + self.new_op( &id, CRDTOperationData::Create( values @@ -85,25 +76,28 @@ pub trait OperationFactory { .map(|(name, value)| (name.to_string(), value)) .collect(), ), - )] + ) } - fn relation_update, TModel: RelationSyncModel>( + + fn relation_update( &self, - id: TSyncId, - field: impl Into, - value: rmpv::Value, + id: impl RelationSyncId, + values: impl IntoIterator + 'static, ) -> CRDTOperation { self.new_op( &id, - CRDTOperationData::Update { - field: field.into(), - value, - }, + CRDTOperationData::Update( + values + .into_iter() + .map(|(name, value)| (name.to_string(), value)) + .collect(), + ), ) } - fn relation_delete, TModel: RelationSyncModel>( + + fn relation_delete( &self, - id: TSyncId, + id: impl RelationSyncId, ) -> CRDTOperation { self.new_op(&id, CRDTOperationData::Delete) } @@ -111,29 +105,59 @@ pub trait OperationFactory { #[macro_export] macro_rules! sync_entry { - ($v:expr, $($m:tt)*) => { - ($($m)*::NAME, ::sd_utils::msgpack!($v)) - } + (nil, $($prisma_column_module:tt)+) => { + ($($prisma_column_module)+::NAME, ::sd_utils::msgpack!(nil)) + }; + + ($value:expr, $($prisma_column_module:tt)+) => { + ($($prisma_column_module)+::NAME, ::sd_utils::msgpack!($value)) + }; + } #[macro_export] macro_rules! option_sync_entry { - ($v:expr, $($m:tt)*) => { - $v.map(|v| $crate::sync_entry!(v, $($m)*)) + ($value:expr, $($prisma_column_module:tt)+) => { + $value.map(|value| $crate::sync_entry!(value, $($prisma_column_module)+)) } } #[macro_export] macro_rules! sync_db_entry { - ($v:expr, $($m:tt)*) => {{ - let v = $v.into(); - ($crate::sync_entry!(&v, $($m)*), $($m)*::set(Some(v))) + ($value:expr, $($prisma_column_module:tt)+) => {{ + let value = $value.into(); + ( + $crate::sync_entry!(&value, $($prisma_column_module)+), + $($prisma_column_module)+::set(Some(value)) + ) + }} +} + +#[macro_export] +macro_rules! sync_db_nullable_entry { + ($value:expr, $($prisma_column_module:tt)+) => {{ + let value = $value.into(); + ( + $crate::sync_entry!(&value, $($prisma_column_module)+), + $($prisma_column_module)+::set(value) + ) + }} +} + +#[macro_export] +macro_rules! sync_db_not_null_entry { + ($value:expr, $($prisma_column_module:tt)+) => {{ + let value = $value.into(); + ( + $crate::sync_entry!(&value, $($prisma_column_module)+), + $($prisma_column_module)+::set(value) + ) }} } #[macro_export] macro_rules! option_sync_db_entry { - ($v:expr, $($m:tt)*) => { - $v.map(|v| $crate::sync_db_entry!(v, $($m)*)) + ($value:expr, $($prisma_column_module:tt)+) => { + $value.map(|value| $crate::sync_db_entry!(value, $($prisma_column_module)+)) }; } diff --git a/crates/sync/src/lib.rs b/crates/sync/src/lib.rs index 3d5eac56f..239a1298d 100644 --- a/crates/sync/src/lib.rs +++ b/crates/sync/src/lib.rs @@ -38,3 +38,7 @@ pub use factory::*; pub use model_traits::*; pub use uhlc::NTP64; + +pub type DevicePubId = uuid::Uuid; +pub type ModelId = u16; +pub type RecordId = rmpv::Value; diff --git a/crates/sync/src/model_traits.rs b/crates/sync/src/model_traits.rs index b0a063f2e..48a4efacd 100644 --- a/crates/sync/src/model_traits.rs +++ b/crates/sync/src/model_traits.rs @@ -1,3 +1,5 @@ +use crate::ModelId; + use prisma_client_rust::ModelTypes; use serde::{de::DeserializeOwned, Serialize}; @@ -6,7 +8,7 @@ pub trait SyncId: Serialize + DeserializeOwned { } pub trait SyncModel: ModelTypes { - const MODEL_ID: u16; + const MODEL_ID: ModelId; } pub trait SharedSyncModel: SyncModel { diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index c535b0c0f..2960a21c4 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -2,14 +2,22 @@ name = "sd-utils" version = "0.1.0" -edition = "2021" +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true [dependencies] # Spacedrive Sub-crates sd-prisma = { path = "../prisma" } # Workspace dependencies +chrono = { workspace = true } prisma-client-rust = { workspace = true } +rmp-serde = { workspace = true } +rmpv = { workspace = true } rspc = { workspace = true, features = ["unstable"] } thiserror = { workspace = true } +tracing = { workspace = true } +uhlc = { workspace = true } uuid = { workspace = true } diff --git a/crates/utils/src/error.rs b/crates/utils/src/error.rs index 081b61be8..6bd3c08f6 100644 --- a/crates/utils/src/error.rs +++ b/crates/utils/src/error.rs @@ -1,6 +1,16 @@ use std::{io, path::Path}; use thiserror::Error; +use tracing::error; + +pub fn report_error( + message: &'static str, +) -> impl Fn(E) -> E { + move |e| { + error!(?e, "{message}"); + e + } +} #[derive(Debug, Error)] #[error("error accessing path: '{}'", .path.display())] diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index da92b7b79..09f45d609 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -27,6 +27,10 @@ #![forbid(deprecated_in_future)] #![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)] +use std::time::{SystemTime, UNIX_EPOCH}; + +use chrono::{DateTime, Utc}; +use uhlc::NTP64; use uuid::Uuid; pub mod db; @@ -104,6 +108,23 @@ macro_rules! msgpack { }} } +/// Helper function to convert a [`chrono::DateTime`] to a [`uhlc::NTP64`] +#[allow(clippy::missing_panics_doc)] // Doesn't actually panic +#[must_use] +pub fn datetime_to_timestamp(latest_time: DateTime) -> NTP64 { + NTP64::from( + SystemTime::from(latest_time) + .duration_since(UNIX_EPOCH) + .expect("hardcoded earlier time, nothing is earlier than UNIX_EPOCH"), + ) +} + +/// Helper function to convert a [`uhlc::NTP64`] to a [`chrono::DateTime`] +#[must_use] +pub fn timestamp_to_datetime(timestamp: NTP64) -> DateTime { + DateTime::from(timestamp.to_system_time()) +} + // Only used for testing purposes. Do not use in production code. use std::any::type_name; diff --git a/interface/app/$libraryId/Explorer/util.ts b/interface/app/$libraryId/Explorer/util.ts index bdae303e7..8a0e847ec 100644 --- a/interface/app/$libraryId/Explorer/util.ts +++ b/interface/app/$libraryId/Explorer/util.ts @@ -190,3 +190,12 @@ export function translateKindName(kindName: string): string { return kindName; } } + +export function fetchAccessToken(): string { + const accessToken: string = + JSON.parse(window.localStorage.getItem('frontendCookies') ?? '[]') + .find((cookie: string) => cookie.startsWith('st-access-token')) + ?.split('=')[1] + .split(';')[0] || ''; + return accessToken; +} diff --git a/interface/app/$libraryId/Layout/Sidebar/DebugPopover.tsx b/interface/app/$libraryId/Layout/Sidebar/DebugPopover.tsx index f035c5996..5fd40631f 100644 --- a/interface/app/$libraryId/Layout/Sidebar/DebugPopover.tsx +++ b/interface/app/$libraryId/Layout/Sidebar/DebugPopover.tsx @@ -62,9 +62,9 @@ export default () => { } >
- + {/* - + */} { > - + {/* */} {/* */} - - +
+ + +
{/* {platform.showDevtools && ( - - Feature Flags - - } - className="z-[999] mt-1 shadow-none data-[side=bottom]:slide-in-from-top-2 dark:divide-menu-selected/30 dark:border-sidebar-line dark:bg-sidebar-box" - alignToTrigger - > - {[...features, ...backendFeatures].map((feat) => ( - toggleFeatureFlag(feat)} - className="font-medium text-white" - icon={ - featureFlags.find((f) => feat === f) !== undefined - ? CheckSquare - : undefined - } - /> - ))} - - - ); -} - // function TestNotifications() { // const coreNotif = useBridgeMutation(['notifications.test']); // const libraryNotif = useLibraryMutation(['notifications.testLibrary']); @@ -260,33 +229,33 @@ function FeatureFlagSelector() { // ); // } -function CloudOriginSelect() { - const origin = useBridgeQuery(['cloud.getApiOrigin']); - const setOrigin = useBridgeMutation(['cloud.setApiOrigin']); +// function CloudOriginSelect() { +// const origin = useBridgeQuery(['cloud.getApiOrigin']); +// const setOrigin = useBridgeMutation(['cloud.setApiOrigin']); - const queryClient = useQueryClient(); +// const queryClient = useQueryClient(); - return ( - <> - {origin.data && ( - - )} - - ); -} +// return ( +// <> +// {origin.data && ( +// +// )} +// +// ); +// } function ExplorerBehaviorSelect() { const { explorerOperatingSystem } = useExplorerOperatingSystem(); diff --git a/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/Footer.tsx b/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/Footer.tsx index b8eb0bffc..6876bee9d 100644 --- a/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/Footer.tsx +++ b/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/Footer.tsx @@ -9,7 +9,7 @@ import { useDebugState, useLibrarySubscription } from '@sd/client'; -import { Button, ButtonLink, Tooltip } from '@sd/ui'; +import { Button, ButtonLink, Loader, Tooltip } from '@sd/ui'; import { useKeysMatcher, useLocale, useShortcut } from '~/hooks'; import { usePlatform } from '~/util/Platform'; @@ -80,5 +80,12 @@ function SyncStatusIndicator() { onData: setStatus }); - return null; + return ( +
+ {status?.cloud_ingest && } + {status?.cloud_send && } + {status?.cloud_receive && } + {status?.ingest && } +
+ ); } diff --git a/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/LibrariesDropdown.tsx b/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/LibrariesDropdown.tsx index 034be0f2f..44515a41c 100644 --- a/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/LibrariesDropdown.tsx +++ b/interface/app/$libraryId/Layout/Sidebar/SidebarLayout/LibrariesDropdown.tsx @@ -1,8 +1,7 @@ -import { CloudArrowDown, Gear, Lock, Plus } from '@phosphor-icons/react'; +import { Gear, Plus } from '@phosphor-icons/react'; import clsx from 'clsx'; import { useClientContext } from '@sd/client'; import { dialogManager, Dropdown, DropdownMenu } from '@sd/ui'; -import JoinDialog from '~/app/$libraryId/settings/node/libraries/JoinDialog'; import { useLocale } from '~/hooks'; import CreateDialog from '../../../settings/node/libraries/CreateDialog'; @@ -63,17 +62,6 @@ export default () => { onClick={() => dialogManager.create((dp) => )} className="font-medium" /> - - dialogManager.create((dp) => ( - - )) - } - className="font-medium" - /> { to="settings/library/general" className="font-medium" /> - {/* alert('TODO: Not implemented yet!')} - className="font-medium" - /> */} ); }; diff --git a/interface/app/$libraryId/Layout/Sidebar/sections/Devices/index.tsx b/interface/app/$libraryId/Layout/Sidebar/sections/Devices/index.tsx index eb9084044..a4e9ccb09 100644 --- a/interface/app/$libraryId/Layout/Sidebar/sections/Devices/index.tsx +++ b/interface/app/$libraryId/Layout/Sidebar/sections/Devices/index.tsx @@ -15,7 +15,11 @@ export default function DevicesSection() { return (
{node && ( - + {node.device_model ? ( { + const navigate = useNavigate(); + const [query] = useSearchParams(); + const { hash } = useLocation(); + + useEffect(() => { + (window.location as any).__TEMP_URL_PARAMS = query; + (window.location as any).__TEMP_URL_HASH = hash; + handleMagicLinkClicked(navigate); + }, []); + + return <>; +}; diff --git a/interface/app/$libraryId/Layout/index.tsx b/interface/app/$libraryId/Layout/index.tsx index b7c11da4f..12044df66 100644 --- a/interface/app/$libraryId/Layout/index.tsx +++ b/interface/app/$libraryId/Layout/index.tsx @@ -15,6 +15,7 @@ import { useRootContext } from '~/app/RootContext'; import { LibraryIdParamsSchema } from '~/app/route-schemas'; import ErrorFallback, { BetterErrorBoundary } from '~/ErrorFallback'; import { + useDeeplinkEventHandler, useKeybindEventHandler, useOperatingSystem, useRedirectToNewLocation, @@ -40,6 +41,7 @@ const Layout = () => { const windowState = useWindowState(); useKeybindEventHandler(library?.uuid); + useDeeplinkEventHandler(); const layoutRef = useRef(null); diff --git a/interface/app/$libraryId/debug/actors.tsx b/interface/app/$libraryId/debug/actors.tsx deleted file mode 100644 index cb14658c1..000000000 --- a/interface/app/$libraryId/debug/actors.tsx +++ /dev/null @@ -1,69 +0,0 @@ -import { inferSubscriptionResult } from '@spacedrive/rspc-client'; -import { useMemo, useState } from 'react'; -import { Procedures, useLibraryMutation, useLibrarySubscription } from '@sd/client'; -import { Button } from '@sd/ui'; -import { useRouteTitle } from '~/hooks/useRouteTitle'; - -// million-ignore -export const Component = () => { - useRouteTitle('Actors'); - - const [data, setData] = useState>({}); - - useLibrarySubscription(['library.actors'], { onData: setData }); - - const sortedData = useMemo(() => { - const sorted = Object.entries(data).sort(([a], [b]) => a.localeCompare(b)); - return sorted; - }, [data]); - - return ( -
- - - - - - {sortedData.map(([name, running]) => ( - - - - - - ))} -
NameRunning
{name} - {running ? 'Running' : 'Not Running'} - - {running ? : } -
-
- ); -}; - -function StartButton({ name }: { name: string }) { - const startActor = useLibraryMutation(['library.startActor']); - - return ( - - ); -} - -function StopButton({ name }: { name: string }) { - const stopActor = useLibraryMutation(['library.stopActor']); - - return ( - - ); -} diff --git a/interface/app/$libraryId/debug/cloud.tsx b/interface/app/$libraryId/debug/cloud.tsx deleted file mode 100644 index fcec4339f..000000000 --- a/interface/app/$libraryId/debug/cloud.tsx +++ /dev/null @@ -1,249 +0,0 @@ -import { CheckCircle, XCircle } from '@phosphor-icons/react'; -import { Suspense, useMemo } from 'react'; -import { - auth, - CloudInstance, - CloudLibrary, - HardwareModel, - useLibraryContext, - useLibraryMutation, - useLibraryQuery -} from '@sd/client'; -import { Button, Card, Loader, tw } from '@sd/ui'; -import { Icon } from '~/components'; -import { AuthRequiredOverlay } from '~/components/AuthRequiredOverlay'; -import { LoginButton } from '~/components/LoginButton'; -import { useLocale, useRouteTitle } from '~/hooks'; -import { hardwareModelToIcon } from '~/util/hardware'; - -const DataBox = tw.div`max-w-[300px] rounded-md border border-app-line/50 bg-app-lightBox/20 p-2`; -const Count = tw.div`min-w-[20px] flex h-[20px] px-1 items-center justify-center rounded-full border border-app-button/40 text-[9px]`; - -export const Component = () => { - useRouteTitle('Cloud'); - - const authState = auth.useStateSnapshot(); - - const authSensitiveChild = () => { - if (authState.status === 'loggedIn') return ; - if (authState.status === 'notLoggedIn' || authState.status === 'loggingIn') - return ( -
- -
- -

- To access cloud related features, please login -

-
- -
-
- ); - - return null; - }; - - return
{authSensitiveChild()}
; -}; - -// million-ignore -function Authenticated() { - const { library } = useLibraryContext(); - const cloudLibrary = useLibraryQuery(['cloud.library.get'], { suspense: true, retry: false }); - const createLibrary = useLibraryMutation(['cloud.library.create']); - const { t } = useLocale(); - - const thisInstance = useMemo(() => { - if (!cloudLibrary.data) return undefined; - return cloudLibrary.data.instances.find( - (instance) => instance.uuid === library.instance_id - ); - }, [cloudLibrary.data, library.instance_id]); - - return ( - - -
- } - > - {cloudLibrary.data ? ( -
- - {thisInstance && } - -
- ) : ( -
- - -
- -

- {t('cloud_connect_description')} -

-
- -
-
- )} - - ); -} - -// million-ignore -const Instances = ({ instances }: { instances: CloudInstance[] }) => { - const { library } = useLibraryContext(); - const filteredInstances = instances.filter((instance) => instance.uuid !== library.instance_id); - return ( -
-
-

Instances

- {filteredInstances.length} -
-
- {filteredInstances.map((instance) => ( - -
- -

- {instance.metadata.name} -

-
-
- -

- Id:{' '} - {instance.id} -

-
- -

- UUID:{' '} - - {instance.uuid} - -

-
- -

- Public Key:{' '} - - {instance.identity} - -

-
-
-
- ))} -
-
- ); -}; - -interface LibraryProps { - cloudLibrary: CloudLibrary; - thisInstance: CloudInstance | undefined; -} - -// million-ignore -const Library = ({ thisInstance, cloudLibrary }: LibraryProps) => { - const syncLibrary = useLibraryMutation(['cloud.library.sync']); - return ( -
-

Library

- -

- Name: {cloudLibrary.name} -

- -
-
- ); -}; - -interface ThisInstanceProps { - instance: CloudInstance; -} - -// million-ignore -const ThisInstance = ({ instance }: ThisInstanceProps) => { - return ( -
-

This Instance

- -
- -

- {instance.metadata.name} -

-
-
- -

- Id: {instance.id} -

-
- -

- UUID: {instance.uuid} -

-
- -

- Public Key:{' '} - {instance.identity} -

-
-
-
-
- ); -}; diff --git a/interface/app/$libraryId/debug/index.ts b/interface/app/$libraryId/debug/index.ts deleted file mode 100644 index 4cf60b56c..000000000 --- a/interface/app/$libraryId/debug/index.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { RouteObject } from 'react-router'; - -export const debugRoutes = [ - { path: 'cloud', lazy: () => import('./cloud') }, - { path: 'actors', lazy: () => import('./actors') } -] satisfies RouteObject[]; diff --git a/interface/app/$libraryId/index.tsx b/interface/app/$libraryId/index.tsx index e1e9a8a6f..6ff9f4b3e 100644 --- a/interface/app/$libraryId/index.tsx +++ b/interface/app/$libraryId/index.tsx @@ -3,16 +3,12 @@ import { type RouteObject } from 'react-router-dom'; import { guessOperatingSystem } from '~/hooks'; import { Platform } from '~/util/Platform'; -import { debugRoutes } from './debug'; import settingsRoutes from './settings'; // Routes that should be contained within the standard Page layout const pageRoutes: RouteObject = { lazy: () => import('./PageLayout'), - children: [ - { path: 'overview', lazy: () => import('./overview') }, - { path: 'debug', children: debugRoutes } - ] + children: [{ path: 'overview', lazy: () => import('./overview') }] }; // Routes that render the explorer and don't need padding and stuff @@ -37,8 +33,7 @@ function loadTopBarRoutes() { return [ ...explorerRoutes, pageRoutes, - { path: 'settings', lazy: () => import('./settings/Layout'), children: settingsRoutes }, - { path: 'debug', children: debugRoutes } + { path: 'settings', lazy: () => import('./settings/Layout'), children: settingsRoutes } ]; } else return [...explorerRoutes, pageRoutes]; } @@ -77,5 +72,10 @@ export default (platform: Platform) => lazy: () => import('./settings/Layout'), children: settingsRoutes }, + { + path: 'auth', + lazy: () => import('./Layout/auth'), + children: [] + }, { path: '*', lazy: () => import('./404') } ] satisfies RouteObject[]; diff --git a/interface/app/$libraryId/overview/index.tsx b/interface/app/$libraryId/overview/index.tsx index 32250a67f..2f4c6143d 100644 --- a/interface/app/$libraryId/overview/index.tsx +++ b/interface/app/$libraryId/overview/index.tsx @@ -1,7 +1,8 @@ import { keepPreviousData } from '@tanstack/react-query'; +import { Key, useEffect } from 'react'; import { Link } from 'react-router-dom'; -import { useBridgeQuery, useLibraryQuery } from '@sd/client'; -import { useLocale, useOperatingSystem } from '~/hooks'; +import { HardwareModel, useBridgeQuery, useLibraryQuery } from '@sd/client'; +import { useAccessToken, useLocale, useOperatingSystem } from '~/hooks'; import { useRouteTitle } from '~/hooks/useRouteTitle'; import { hardwareModelToIcon } from '~/util/hardware'; @@ -28,18 +29,28 @@ export const Component = () => { const os = useOperatingSystem(); const { t } = useLocale(); + const accessToken = useAccessToken(); const locationsQuery = useLibraryQuery(['locations.list'], { placeholderData: keepPreviousData }); const locations = locationsQuery.data ?? []; + // not sure if we'll need the node state in the future, as it should be returned with the cloud.devices.list query + // const { data: node } = useBridgeQuery(['nodeState']); + const cloudDevicesList = useBridgeQuery(['cloud.devices.list']); + + useEffect(() => { + const interval = setInterval(async () => { + await cloudDevicesList.refetch(); + }, 10000); + return () => clearInterval(interval); + }, []); const { data: node } = useBridgeQuery(['nodeState']); + const stats = useLibraryQuery(['library.statistics']); const search = useSearchFromSearchParams({ defaultTarget: 'paths' }); - const stats = useLibraryQuery(['library.statistics']); - return (
@@ -60,7 +71,10 @@ export const Component = () => { - + {node && ( { connectionType={null} /> )} - - {/**/} + {cloudDevicesList.data?.map((device) => ( + + ))} diff --git a/interface/app/$libraryId/saved-search/$id.tsx b/interface/app/$libraryId/saved-search/$id.tsx index 4562a96d1..d5f57c5b7 100644 --- a/interface/app/$libraryId/saved-search/$id.tsx +++ b/interface/app/$libraryId/saved-search/$id.tsx @@ -127,7 +127,6 @@ function SaveButton({ searchId }: { searchId: number }) { const updateSavedSearch = useLibraryMutation(['search.saved.update']); const search = useSearchContext(); - return (

- {t('logged_in_as', { email: me.data?.email })} + {t('logged_in_as', 'TODO')}
); } diff --git a/interface/app/$libraryId/settings/client/account.tsx b/interface/app/$libraryId/settings/client/account.tsx deleted file mode 100644 index 151a1b9b1..000000000 --- a/interface/app/$libraryId/settings/client/account.tsx +++ /dev/null @@ -1,160 +0,0 @@ -import { Envelope, User } from '@phosphor-icons/react'; -import { useEffect, useState } from 'react'; -import { auth, useBridgeMutation, useBridgeQuery, useFeatureFlag } from '@sd/client'; -import { Button, Card, Input, toast } from '@sd/ui'; -import { TruncatedText } from '~/components'; -import { AuthRequiredOverlay } from '~/components/AuthRequiredOverlay'; -import { useLocale } from '~/hooks'; - -import { Heading } from '../Layout'; - -export const Component = () => { - const { t } = useLocale(); - const me = useBridgeQuery(['auth.me'], { retry: false }); - const authStore = auth.useStateSnapshot(); - return ( - <> - - {authStore.status === 'loggedIn' && ( -
- -
- )} - - } - title={t('spacedrive_cloud')} - description={t('spacedrive_cloud_description')} - /> -
- -
- {useFeatureFlag('hostedLocations') && } - - ); -}; - -const Profile = ({ email, authStore }: { email?: string; authStore: { status: string } }) => { - const emailName = authStore.status === 'loggedIn' ? email?.split('@')[0] : 'guest user'; - return ( - - -
- -
-

- Welcome {emailName}, -

-
- -
- -
- - {authStore.status === 'loggedIn' ? email : 'guestuser@outlook.com'} - -
-
-
- ); -}; - -function HostedLocationsPlayground() { - const locations = useBridgeQuery(['cloud.locations.list'], { retry: false }); - - const [locationName, setLocationName] = useState(''); - const [path, setPath] = useState(''); - const createLocation = useBridgeMutation('cloud.locations.create', { - onSuccess(data) { - // console.log('DATA', data); // TODO: Optimistic UI - - locations.refetch(); - setLocationName(''); - } - }); - const removeLocation = useBridgeMutation('cloud.locations.remove', { - onSuccess() { - // TODO: Optimistic UI - - locations.refetch(); - } - }); - - useEffect(() => { - if (path === '' && locations.data?.[0]) { - setPath(`location/${locations.data[0].id}/hello.txt`); - } - }, [path, locations.data]); - - const isPending = createLocation.isPending || removeLocation.isPending; - - return ( - <> - - {/* TODO: We need UI for this. I wish I could use `prompt` for now but Tauri doesn't have it :( */} -
- setLocationName(e.currentTarget.value)} - placeholder="My sick location" - disabled={isPending} - /> - - -
- - } - title="Hosted Locations" - description="Augment your local storage with our cloud!" - /> - - {/* TODO: Cleanup this mess + styles */} - {locations.status === 'pending' ?
Loading!
: null} - {locations.status !== 'pending' && locations.data?.length === 0 ? ( -
Looks like you don't have any!
- ) : ( -
- {locations.data?.map((location) => ( -
-

{location.name}

- -
- ))} -
- )} - -
-

Path to save when clicking 'Do the thing':

- setPath(e.currentTarget.value)} - disabled={isPending} - /> -
- - ); -} diff --git a/interface/app/$libraryId/settings/client/account/Profile.tsx b/interface/app/$libraryId/settings/client/account/Profile.tsx new file mode 100644 index 000000000..250f3b930 --- /dev/null +++ b/interface/app/$libraryId/settings/client/account/Profile.tsx @@ -0,0 +1,191 @@ +import { Envelope } from '@phosphor-icons/react'; +import clsx from 'clsx'; +import { Dispatch, SetStateAction, useEffect, useState } from 'react'; +import { + SyncStatus, + useBridgeMutation, + useBridgeQuery, + useBridgeSubscription, + useLibraryMutation, + useLibrarySubscription +} from '@sd/client'; +import { Button, Card, tw } from '@sd/ui'; +import StatCard from '~/app/$libraryId/overview/StatCard'; +import { TruncatedText } from '~/components'; +import { getTokens } from '~/util'; +import { hardwareModelToIcon } from '~/util/hardware'; + +type User = { + email: string; + id: string; + timejoined: number; + roles: string[]; +}; + +const Pill = tw.div`px-1.5 py-[1px] rounded text-tiny font-medium text-ink-dull bg-app-box border border-app-line`; + +const Profile = ({ + user, + setReload +}: { + user: User; + setReload: Dispatch>; +}) => { + const emailName = user.email?.split('@')[0]; + const capitalizedEmailName = (emailName?.charAt(0).toUpperCase() ?? '') + emailName?.slice(1); + const { accessToken, refreshToken } = getTokens(); + + const cloudBootstrap = useBridgeMutation('cloud.bootstrap'); + const devices = useBridgeQuery(['cloud.devices.list']); + const addLibraryToCloud = useLibraryMutation('cloud.libraries.create'); + const [syncStatus, setSyncStatus] = useState(null); + useLibrarySubscription(['sync.active'], { + onData: (data) => { + console.log('sync activity', data); + setSyncStatus(data); + } + }); + const listLibraries = useBridgeQuery(['cloud.libraries.list', true]); + const createSyncGroup = useLibraryMutation('cloud.syncGroups.create'); + const listSyncGroups = useBridgeQuery(['cloud.syncGroups.list']); + const requestJoinSyncGroup = useBridgeMutation('cloud.syncGroups.request_join'); + const currentDevice = useBridgeQuery(['cloud.devices.get_current_device']); + + // Refetch libraries and devices every 10 seconds + useEffect(() => { + const interval = setInterval(async () => { + await devices.refetch(); + await listLibraries.refetch(); + }, 10000); + return () => clearInterval(interval); + }, [devices, listLibraries]); + + return ( +
+ {/* Top Section with Profile Information */} +
+ +
+

Profile Information

+
+ + {user.email} +
+
+
+
+

Joined on

+

+ {new Date(user.timejoined).toLocaleDateString()} +

+
+
+

User ID

+

{user.id}

+
+
+
+
+ + {/* Sync activity */} +
+

Sync Activity

+
+ {Object.keys(syncStatus ?? {}).map((status, index) => ( + +
+

{status}

+ + ))} +
+
+ + {/* Automatically list libraries */} +
+

Cloud Libraries

+ {listLibraries.data?.map((library) => ( + +

{library.name}

+
+ )) ||

No libraries found.

} +
+ + {/* Debug Buttons */} +
+ + + +
+ + {/* Automatically list sync groups and provide a join button */} +
+

Library Sync Groups

+ {listSyncGroups.data?.map((group) => ( + +

{group.library.name}

+ +
+ )) ||

No sync groups found.

} +
+ + {/* List all devices from const devices */} + {devices.data?.map((device) => ( + + ))} +
+ ); +}; + +export default Profile; diff --git a/interface/app/$libraryId/settings/client/account/ShowPassword.tsx b/interface/app/$libraryId/settings/client/account/ShowPassword.tsx new file mode 100644 index 000000000..d3e846da0 --- /dev/null +++ b/interface/app/$libraryId/settings/client/account/ShowPassword.tsx @@ -0,0 +1,27 @@ +import { Eye, EyeClosed } from '@phosphor-icons/react'; +import { Button, Tooltip } from '@sd/ui'; + +interface Props { + showPassword: boolean; + setShowPassword: (value: boolean) => void; +} + +const ShowPassword = ({ showPassword, setShowPassword }: Props) => { + return ( + + + + ); +}; + +export default ShowPassword; diff --git a/interface/app/$libraryId/settings/client/account/handlers/cookieHandler.ts b/interface/app/$libraryId/settings/client/account/handlers/cookieHandler.ts new file mode 100644 index 000000000..4a7362d8e --- /dev/null +++ b/interface/app/$libraryId/settings/client/account/handlers/cookieHandler.ts @@ -0,0 +1,110 @@ +import { CookieHandlerInterface } from 'supertokens-website/utils/cookieHandler/types'; + +const frontendCookiesKey = 'frontendCookies'; + +/** + * Tauri handles cookies differently than in browser environments. The SuperTokens + * SDK uses frontend cookies, to make sure everything work correctly we add custom + * cookie handling and store cookies in local storage instead (This is not a problem + * since these are frontend cookies and not server side cookies) + */ +function getCookiesFromStorage(): string { + const cookiesFromStorage = window.localStorage.getItem(frontendCookiesKey); + + if (cookiesFromStorage === null) { + window.localStorage.setItem(frontendCookiesKey, '[]'); + return ''; + } + + /** + * Because we store cookies in local storage, we need to manually check + * for expiry before returning all cookies + */ + const cookieArrayInStorage: string[] = JSON.parse(cookiesFromStorage); + const cookieArrayToReturn: string[] = []; + + for (let cookieIndex = 0; cookieIndex < cookieArrayInStorage.length; cookieIndex++) { + const currentCookieString = cookieArrayInStorage[cookieIndex]; + const parts = currentCookieString?.split(';'); + let expirationString: string = ''; + + for (let partIndex = 0; partIndex < parts!.length; partIndex++) { + const currentPart = parts![partIndex]; + + if (currentPart!.toLocaleLowerCase().includes('expires=')) { + expirationString = currentPart!; + break; + } + } + + if (expirationString !== '') { + const expirationValueString = expirationString.split('=')[1]; + const expirationDate = new Date(expirationValueString!); + const currentTimeInMillis = Date.now(); + + // if the cookie has expired, we skip it + if (expirationDate.getTime() < currentTimeInMillis) { + continue; + } + } + + cookieArrayToReturn.push(currentCookieString!); + } + + /** + * After processing and removing expired cookies we need to update the cookies + * in storage so we dont have to process the expired ones again + */ + window.localStorage.setItem(frontendCookiesKey, JSON.stringify(cookieArrayToReturn)); + + return cookieArrayToReturn.join('; '); +} + +function setCookieToStorage(cookieString: string) { + const cookieName = cookieString.split(';')[0]!.split('=')[0]; + const cookiesFromStorage = window.localStorage.getItem(frontendCookiesKey); + let cookiesArray: string[] = []; + + if (cookiesFromStorage !== null) { + const cookiesArrayFromStorage: string[] = JSON.parse(cookiesFromStorage); + cookiesArray = cookiesArrayFromStorage; + } + + let cookieIndex = -1; + + for (let i = 0; i < cookiesArray.length; i++) { + const currentCookie = cookiesArray[i]; + + if (currentCookie!.indexOf(`${cookieName}=`) !== -1) { + cookieIndex = i; + break; + } + } + + /** + * If a cookie with the same name already exists (index != -1) then we + * need to remove the old value and replace it with the new one. + * + * If it does not exist then simply add the new cookie + */ + if (cookieIndex !== -1) { + cookiesArray[cookieIndex] = cookieString; + } else { + cookiesArray.push(cookieString); + } + + window.localStorage.setItem(frontendCookiesKey, JSON.stringify(cookiesArray)); +} + +export default function getCookieHandler(original: CookieHandlerInterface): CookieHandlerInterface { + return { + ...original, + getCookie: async function () { + const cookies = getCookiesFromStorage(); + return cookies; + }, + setCookie: async function (cookieString: string) { + setCookieToStorage(cookieString); + } + }; +} diff --git a/interface/app/$libraryId/settings/client/account/handlers/windowHandler.ts b/interface/app/$libraryId/settings/client/account/handlers/windowHandler.ts new file mode 100644 index 000000000..90a75c675 --- /dev/null +++ b/interface/app/$libraryId/settings/client/account/handlers/windowHandler.ts @@ -0,0 +1,55 @@ +import { WindowHandlerInterface } from 'supertokens-website/utils/windowHandler/types'; + +/** + * This example app uses HashRouter from react-router-dom. The SuperTokens SDK relies on + * some window properties like location hash, query params etc. Because HashRouter places + * everything other than the website base in the location hash, we need to add custom + * handling for some of the properties of the Window API + */ +export default function getWindowHandler(original: WindowHandlerInterface): WindowHandlerInterface { + return { + ...original, + location: { + ...original.location, + getSearch: function () { + const params: URLSearchParams | string = + (window.location as any).__TEMP_URL_PARAMS ?? ''; + console.log('params', params); + return params.toString(); + }, + getHash: function () { + // Location hash always starts with a #, when returning we prepend it + const locationHash: string = (window.location as any).__TEMP_URL_HASH ?? ''; + console.log('locationHash', locationHash); + return locationHash; + }, + getOrigin: function () { + return 'http://localhost:8001'; + }, + getHostName: function () { + return 'localhost'; + }, + getPathName: function () { + let locationHash = window.location.hash; + + if (locationHash === '') { + return ''; + } + + if (locationHash.startsWith('#')) { + // Remove the starting pound symbol + locationHash = locationHash.substring(1); + } + + locationHash = locationHash.split('?')[0] ?? ''; + + if (locationHash.includes('#')) { + // Remove location hash + locationHash = locationHash.split('#')[0] ?? ''; + } + + return locationHash; + } + } + }; +} diff --git a/interface/app/$libraryId/settings/client/account/index.tsx b/interface/app/$libraryId/settings/client/account/index.tsx new file mode 100644 index 000000000..7037536dc --- /dev/null +++ b/interface/app/$libraryId/settings/client/account/index.tsx @@ -0,0 +1,179 @@ +import clsx from 'clsx'; +import React, { useEffect, useState } from 'react'; +import { signOut } from 'supertokens-web-js/recipe/passwordless'; +import { useBridgeMutation } from '@sd/client'; +import { Button } from '@sd/ui'; +import { Authentication } from '~/components'; +import { useLocale } from '~/hooks'; +import { AUTH_SERVER_URL } from '~/util'; + +import { Heading } from '../../Layout'; +import Profile from './Profile'; + +type User = { + email: string; + id: string; + timejoined: number; + roles: string[]; +}; + +export const Component = () => { + const { t } = useLocale(); + const [userInfo, setUserInfo] = useState(null); + const [reload, setReload] = useState(false); + + useEffect(() => { + async function _() { + const user_data = await fetch(`${AUTH_SERVER_URL}/api/user`, { + method: 'GET' + }); + + const data = await user_data.json(); + + setUserInfo(data.id ? data : null); + } + _(); + setReload(false); + }, [reload]); + + const cloudBootstrap = useBridgeMutation('cloud.bootstrap'); + + return ( + <> + + {userInfo?.id && ( +
+ +
+ )} + + } + /> +
+
+ {userInfo === null ? ( + <> + + + ) : ( + <> + + + )} +
+
+ {/* {useFeatureFlag('hostedLocations') && } */} + + ); +}; + +// Not supporting this feature for now +// function HostedLocationsPlayground() { +// const locations = useBridgeQuery(['cloud.locations.list'], { retry: false }); + +// const [locationName, setLocationName] = useState(''); +// const [path, setPath] = useState(''); +// const createLocation = useBridgeMutation('cloud.locations.create', { +// onSuccess(data) { +// // console.log('DATA', data); // TODO: Optimistic UI + +// locations.refetch(); +// setLocationName(''); +// } +// }); +// const removeLocation = useBridgeMutation('cloud.locations.remove', { +// onSuccess() { +// // TODO: Optimistic UI + +// locations.refetch(); +// } +// }); + +// useEffect(() => { +// if (path === '' && locations.data?.[0]) { +// setPath(`location/${locations.data[0].id}/hello.txt`); +// } +// }, [path, locations.data]); + +// const isLoading = createLocation.isLoading || removeLocation.isLoading; + +// return ( +// <> +// +// {/* TODO: We need UI for this. I wish I could use `prompt` for now but Tauri doesn't have it :( */} +//
+// setLocationName(e.currentTarget.value)} +// placeholder="My sick location" +// disabled={isLoading} +// /> + +// +//
+//
+// } +// title="Hosted Locations" +// description="Augment your local storage with our cloud!" +// /> + +// {/* TODO: Cleanup this mess + styles */} +// {locations.status === 'loading' ?
Loading!
: null} +// {locations.status !== 'loading' && locations.data?.length === 0 ? ( +//
Looks like you don't have any!
+// ) : ( +//
+// {locations.data?.map((location) => ( +//
+//

{location.name}

+// +//
+// ))} +//
+// )} + +//
+//

Path to save when clicking 'Do the thing':

+// setPath(e.currentTarget.value)} +// disabled={isLoading} +// /> +//
+// +// ); +// } diff --git a/interface/app/$libraryId/settings/client/general.tsx b/interface/app/$libraryId/settings/client/general.tsx index 4b42f6753..2196b0e27 100644 --- a/interface/app/$libraryId/settings/client/general.tsx +++ b/interface/app/$libraryId/settings/client/general.tsx @@ -50,7 +50,7 @@ export const Component = () => { schema: z .object({ name: z.string().min(1).max(250).optional(), - image_labeler_version: z.string().optional(), + // image_labeler_version: z.string().optional(), background_processing_percentage: z.coerce .number({ invalid_type_error: 'Must use numbers from 0 to 100' @@ -62,8 +62,8 @@ export const Component = () => { .strict(), reValidateMode: 'onChange', defaultValues: { - name: node.data?.name, - image_labeler_version: node.data?.image_labeler_version ?? undefined + name: node.data?.name + // image_labeler_version: node.data?.image_labeler_version ?? undefined // background_processing_percentage: // node.data?.preferences.thumbnailer.background_processing_percentage || 50 } diff --git a/interface/app/$libraryId/settings/library/index.tsx b/interface/app/$libraryId/settings/library/index.tsx index 3185dc91b..314890315 100644 --- a/interface/app/$libraryId/settings/library/index.tsx +++ b/interface/app/$libraryId/settings/library/index.tsx @@ -7,7 +7,6 @@ export default [ { path: 'contacts', lazy: () => import('./contacts') }, { path: 'security', lazy: () => import('./security') }, { path: 'sharing', lazy: () => import('./sharing') }, - { path: 'sync', lazy: () => import('./sync') }, { path: 'general', lazy: () => import('./general') }, { path: 'tags', lazy: () => import('./tags') }, // { path: 'saved-searches', lazy: () => import('./saved-searches') }, diff --git a/interface/app/$libraryId/settings/library/sync.tsx b/interface/app/$libraryId/settings/library/sync.tsx deleted file mode 100644 index 44593e9a6..000000000 --- a/interface/app/$libraryId/settings/library/sync.tsx +++ /dev/null @@ -1,228 +0,0 @@ -import { inferSubscriptionResult } from '@spacedrive/rspc-client'; -import clsx from 'clsx'; -import { useEffect, useState } from 'react'; -import { - Procedures, - useFeatureFlag, - useLibraryMutation, - useLibraryQuery, - useLibrarySubscription, - useZodForm -} from '@sd/client'; -import { Button, Dialog, dialogManager, useDialog, UseDialogProps, z } from '@sd/ui'; -import { useLocale } from '~/hooks'; - -import { Heading } from '../Layout'; -import Setting from '../Setting'; - -const ACTORS = { - Ingest: 'Sync Ingest', - CloudSend: 'Cloud Sync Sender', - CloudReceive: 'Cloud Sync Receiver', - CloudIngest: 'Cloud Sync Ingest' -}; - -export const Component = () => { - const { t } = useLocale(); - - const syncEnabled = useLibraryQuery(['sync.enabled']); - - const backfillSync = useLibraryMutation(['sync.backfill'], { - onSuccess: async () => { - await syncEnabled.refetch(); - } - }); - - const [data, setData] = useState>({}); - - useLibrarySubscription(['library.actors'], { onData: setData }); - - const cloudSync = useFeatureFlag('cloudSync'); - - return ( - <> - - {syncEnabled.data === false ? ( - -
- -
-
- ) : ( - <> - - {t('ingester')} - - - } - description={t('injester_description')} - > -
- {data[ACTORS.Ingest] ? ( - - ) : ( - - )} -
-
- - {cloudSync && } - - )} - - ); -}; - -function SyncBackfillDialog(props: UseDialogProps & { onEnabled: () => void }) { - const form = useZodForm({ schema: z.object({}) }); - const dialog = useDialog(props); - const { t } = useLocale(); - - const enableSync = useLibraryMutation(['sync.backfill'], {}); - - // dialog is in charge of enabling sync - useEffect(() => { - form.handleSubmit( - async () => { - await enableSync.mutateAsync(null).then(() => (dialog.state.open = false)); - await props.onEnabled(); - }, - () => {} - )(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - return ( - - ); -} - -function CloudSync({ data }: { data: inferSubscriptionResult }) { - const { t } = useLocale(); - return ( - <> -
-

{t('cloud_sync')}

-

{t('cloud_sync_description')}

-
- - {t('sender')} - - } - description={t('sender_description')} - > -
- {data[ACTORS.CloudSend] ? ( - - ) : ( - - )} -
-
- - {t('receiver')} - - - } - description={t('receiver_description')} - > -
- {data[ACTORS.CloudReceive] ? ( - - ) : ( - - )} -
-
- - {t('ingester')} - - - } - description={t('ingester_description')} - > -
- {data[ACTORS.CloudIngest] ? ( - - ) : ( - - )} -
-
- - ); -} - -function StartButton({ name }: { name: string }) { - const startActor = useLibraryMutation(['library.startActor']); - const { t } = useLocale(); - - return ( - - ); -} - -function StopButton({ name }: { name: string }) { - const stopActor = useLibraryMutation(['library.stopActor']); - const { t } = useLocale(); - - return ( - - ); -} - -function OnlineIndicator({ online }: { online: boolean }) { - return ( -
- ); -} diff --git a/interface/app/$libraryId/settings/node/libraries/DeleteDeviceDialog.tsx b/interface/app/$libraryId/settings/node/libraries/DeleteDeviceDialog.tsx new file mode 100644 index 000000000..0a6f1e51e --- /dev/null +++ b/interface/app/$libraryId/settings/node/libraries/DeleteDeviceDialog.tsx @@ -0,0 +1,91 @@ +import { useQueryClient } from '@tanstack/react-query'; +import { useEffect } from 'react'; +import { useNavigate } from 'react-router'; +import { HardwareModel, useBridgeMutation, useBridgeQuery, useZodForm } from '@sd/client'; +import { Dialog, ErrorMessage, useDialog, UseDialogProps } from '@sd/ui'; +import { Icon } from '~/components'; +import { useAccessToken, useLocale } from '~/hooks'; +import { hardwareModelToIcon } from '~/util/hardware'; +import { usePlatform } from '~/util/Platform'; + +interface Props extends UseDialogProps { + pubId: string; + name: string; + device_model: string; +} + +interface CorePubId { + Uuid: string; +} + +export default function DeleteLibraryDialog(props: Props) { + const { t } = useLocale(); + + const queryClient = useQueryClient(); + const platform = usePlatform(); + const navigate = useNavigate(); + const accessToken = useAccessToken(); + const { data: node } = useBridgeQuery(['nodeState']); + const deleteDevice = useBridgeMutation('cloud.devices.delete'); + const deviceAmount = useBridgeQuery(['cloud.devices.list']).data?.length; + + const form = useZodForm(); + + // Check if the current device matches the UUID or if it's the only device + useEffect(() => { + if (deviceAmount === 1) { + form.setError('pubId', { + type: 'manual', + message: t('error_only_device') + }); + } else if ((node?.id as CorePubId).Uuid === props.pubId) { + form.setError('pubId', { + type: 'manual', + message: t('error_current_device') + }); + } + }, [form, node, props.pubId, deviceAmount, t]); + + const onSubmit = form.handleSubmit(async () => { + try { + // Check for form errors before proceeding + if (form.formState.errors.pubId) { + return; + } + + await deleteDevice.mutateAsync(props.pubId); + queryClient.invalidateQueries({ queryKey: ['library.list'] }); + + // eslint-disable-next-line @typescript-eslint/no-unused-expressions + platform.refreshMenuBar && platform.refreshMenuBar(); + navigate('/'); + } catch (e) { + alert(`Failed to delete device: ${e}`); + } + }); + + return ( + +
+ +

{props.name}

+ +
+
+ ); +} diff --git a/interface/app/$libraryId/settings/node/libraries/DeviceItem.tsx b/interface/app/$libraryId/settings/node/libraries/DeviceItem.tsx new file mode 100644 index 000000000..8c0e734c4 --- /dev/null +++ b/interface/app/$libraryId/settings/node/libraries/DeviceItem.tsx @@ -0,0 +1,68 @@ +import { Trash } from '@phosphor-icons/react'; +import { iconNames } from '@sd/assets/util'; +import { Key } from 'react'; +import { HardwareModel, humanizeSize } from '@sd/client'; +import { Button, Card, dialogManager, Tooltip } from '@sd/ui'; +import { Icon } from '~/components'; +import { useAccessToken, useLocale } from '~/hooks'; +import { hardwareModelToIcon } from '~/util/hardware'; + +import DeleteDeviceDialog from './DeleteDeviceDialog'; + +interface DeviceItemProps { + pub_id: Key | null | undefined; + name: string; + os: string; + device_model: string; + storage_size: bigint; + used_storage: bigint; + created_at: string; +} + +// unsure where to put pub_id/if this information is important for a user? also have not included used_storage +export default (props: DeviceItemProps) => { + const { t } = useLocale(); + + return ( + + +
+

{props.name}

+

+ {props.os}, {`${t('added')}`} {new Date(props.created_at).toLocaleDateString()} +

+

+
+
{`${humanizeSize(props.storage_size)}`}
+ +
+ ); +}; diff --git a/interface/app/$libraryId/settings/node/libraries/JoinDialog.tsx b/interface/app/$libraryId/settings/node/libraries/JoinDialog.tsx deleted file mode 100644 index 759d890e7..000000000 --- a/interface/app/$libraryId/settings/node/libraries/JoinDialog.tsx +++ /dev/null @@ -1,102 +0,0 @@ -import { useQueryClient } from '@tanstack/react-query'; -import { useNavigate } from 'react-router'; -import { - LibraryConfigWrapped, - useBridgeMutation, - useBridgeQuery, - useClientContext, - useLibraryContext, - usePlausibleEvent, - useZodForm -} from '@sd/client'; -import { Button, Dialog, Select, SelectOption, toast, useDialog, UseDialogProps, z } from '@sd/ui'; -import { useLocale } from '~/hooks'; -import { usePlatform } from '~/util/Platform'; - -const schema = z.object({ - libraryId: z.string().refine((value) => value !== 'select_library', { - message: 'Please select a library' - }) -}); - -export default (props: UseDialogProps & { librariesCtx: LibraryConfigWrapped[] | undefined }) => { - const cloudLibraries = useBridgeQuery(['cloud.library.list']); - const joinLibrary = useBridgeMutation(['cloud.library.join']); - - const { t } = useLocale(); - const navigate = useNavigate(); - const platform = usePlatform(); - const queryClient = useQueryClient(); - - const form = useZodForm({ schema, defaultValues: { libraryId: 'select_library' } }); - - // const queryClient = useQueryClient(); - // const submitPlausibleEvent = usePlausibleEvent(); - // const platform = usePlatform(); - - const onSubmit = form.handleSubmit(async (data) => { - try { - const library = await joinLibrary.mutateAsync(data.libraryId); - - queryClient.setQueryData(['library.list'], (libraries: any) => { - // The invalidation system beat us to it - if ((libraries || []).find((l: any) => l.uuid === library.uuid)) return libraries; - - return [...(libraries || []), library]; - }); - - if (platform.refreshMenuBar) platform.refreshMenuBar(); - - navigate(`/${library.uuid}`, { replace: true }); - } catch (e: any) { - console.error(e); - toast.error(e); - } - }); - - return ( - -
- {cloudLibraries.isLoading && {t('loading')}...} - {cloudLibraries.data && ( - - )} -
-
- ); -}; diff --git a/interface/app/$libraryId/settings/node/libraries/ListItem.tsx b/interface/app/$libraryId/settings/node/libraries/ListItem.tsx index 06a95d443..0d1bc8550 100644 --- a/interface/app/$libraryId/settings/node/libraries/ListItem.tsx +++ b/interface/app/$libraryId/settings/node/libraries/ListItem.tsx @@ -1,8 +1,10 @@ -import { Pencil, Trash } from '@phosphor-icons/react'; +import { CaretRight, Pencil, Trash } from '@phosphor-icons/react'; +import { AnimatePresence, motion } from 'framer-motion'; +import { useState } from 'react'; import { LibraryConfigWrapped } from '@sd/client'; import { Button, ButtonLink, Card, dialogManager, Tooltip } from '@sd/ui'; import { Icon } from '~/components'; -import { useLocale } from '~/hooks'; +import { useAccessToken, useLocale } from '~/hooks'; import DeleteDialog from './DeleteDialog'; @@ -13,51 +15,78 @@ interface Props { export default (props: Props) => { const { t } = useLocale(); + const [isExpanded, setIsExpanded] = useState(false); + + const accessToken = useAccessToken(); + const toggleExpansion = () => { + setIsExpanded((prev) => !prev); + }; return ( - - {/* */} - -
-

- {props.library.config.name} - {props.current && ( - - {t('current')} - - )} -

-

{props.library.uuid}

-
-
- {/* */} - - - - - - -
-
+
+ +
+ +
+

+ {props.library.config.name} + {props.current && ( + + {t('current')} + + )} +

+

{props.library.uuid}

+
+
+
+ + + + + + + +
+
+ + + {isExpanded && ( + +
+
+ )} +
+
); }; diff --git a/interface/app/$libraryId/settings/node/libraries/index.tsx b/interface/app/$libraryId/settings/node/libraries/index.tsx index dc8ff541b..db7d6f5d8 100644 --- a/interface/app/$libraryId/settings/node/libraries/index.tsx +++ b/interface/app/$libraryId/settings/node/libraries/index.tsx @@ -1,18 +1,15 @@ -import { useBridgeQuery, useClientContext, useFeatureFlag, useLibraryContext } from '@sd/client'; +import { useBridgeQuery, useClientContext, useLibraryContext } from '@sd/client'; import { Button, dialogManager } from '@sd/ui'; import { useLocale } from '~/hooks'; import { Heading } from '../../Layout'; import CreateDialog from './CreateDialog'; -import JoinDialog from './JoinDialog'; import ListItem from './ListItem'; export const Component = () => { const librariesQuery = useBridgeQuery(['library.list']); const libraries = librariesQuery.data; - const cloudEnabled = useFeatureFlag('cloudSync'); - const { library } = useLibraryContext(); const { libraries: librariesCtx } = useClientContext(); const librariesCtxData = librariesCtx.data; @@ -35,19 +32,6 @@ export const Component = () => { > {t('add_library')} - {cloudEnabled && ( - - )}
} /> diff --git a/interface/app/onboarding/index.tsx b/interface/app/onboarding/index.tsx index f57a4e379..22ba946c3 100644 --- a/interface/app/onboarding/index.tsx +++ b/interface/app/onboarding/index.tsx @@ -4,7 +4,6 @@ import { onboardingStore } from '@sd/client'; import { useOnboardingContext } from './context'; import CreatingLibrary from './creating-library'; import { FullDisk } from './full-disk'; -import { JoinLibrary } from './join-library'; import Locations from './locations'; import NewLibrary from './new-library'; import PreRelease from './prerelease'; @@ -38,7 +37,6 @@ export default [ // path: 'login' // }, { Component: NewLibrary, path: 'new-library' }, - { Component: JoinLibrary, path: 'join-library' }, { Component: FullDisk, path: 'full-disk' }, { Component: Locations, path: 'locations' }, { Component: Privacy, path: 'privacy' }, diff --git a/interface/app/onboarding/join-library.tsx b/interface/app/onboarding/join-library.tsx deleted file mode 100644 index 1baf7f967..000000000 --- a/interface/app/onboarding/join-library.tsx +++ /dev/null @@ -1,84 +0,0 @@ -import { useQueryClient } from '@tanstack/react-query'; -import { useNavigate } from 'react-router'; -import { - resetOnboardingStore, - useBridgeMutation, - useBridgeQuery, - useLibraryMutation -} from '@sd/client'; -import { Button } from '@sd/ui'; -import { Icon } from '~/components'; -import { AuthRequiredOverlay } from '~/components/AuthRequiredOverlay'; -import { useLocale, useRouteTitle } from '~/hooks'; -import { usePlatform } from '~/util/Platform'; - -import { OnboardingContainer, OnboardingDescription, OnboardingTitle } from './components'; - -export function JoinLibrary() { - const { t } = useLocale(); - - useRouteTitle('Join Library'); - - return ( - - - {t('join_library')} - {t('join_library_description')} - -
- Cloud Libraries -
    - - -
-
-
- ); -} - -function CloudLibraries() { - const { t } = useLocale(); - - const cloudLibraries = useBridgeQuery(['cloud.library.list']); - const joinLibrary = useBridgeMutation(['cloud.library.join']); - - const navigate = useNavigate(); - const queryClient = useQueryClient(); - const platform = usePlatform(); - - if (cloudLibraries.isLoading) return {t('loading')}...; - - return ( - <> - {cloudLibraries.data?.map((cloudLibrary) => ( -
  • - {cloudLibrary.name} - -
  • - ))} - - ); -} diff --git a/interface/app/onboarding/login.tsx b/interface/app/onboarding/login.tsx index 24c794bb9..f521cdd22 100644 --- a/interface/app/onboarding/login.tsx +++ b/interface/app/onboarding/login.tsx @@ -13,7 +13,7 @@ export default function OnboardingLogin() { const authState = auth.useStateSnapshot(); const navigate = useNavigate(); - const me = useBridgeQuery(['auth.me'], { retry: false }); + // const me = useBridgeQuery(['auth.me'], { retry: false }); return ( @@ -31,7 +31,7 @@ export default function OnboardingLogin() { className="mb-3" />

    - Logged in as {me.data?.email} + Logged in as TODO

    diff --git a/interface/app/onboarding/new-library.tsx b/interface/app/onboarding/new-library.tsx index 6242a2d1e..36d6ea453 100644 --- a/interface/app/onboarding/new-library.tsx +++ b/interface/app/onboarding/new-library.tsx @@ -1,6 +1,5 @@ import { useState } from 'react'; import { useNavigate } from 'react-router'; -import { useFeatureFlag } from '@sd/client'; import { Button, Form, InputField } from '@sd/ui'; import { Icon } from '~/components'; import { useLocale, useOperatingSystem } from '~/hooks'; @@ -21,8 +20,6 @@ export default function OnboardingNewLibrary() { // TODO }; - const cloudFeatureFlag = useFeatureFlag('cloudSync'); - return (
    */} - {cloudFeatureFlag && ( - <> - {t('or')} - - - )} )} diff --git a/interface/app/onboarding/prerelease.tsx b/interface/app/onboarding/prerelease.tsx index 156d6e5a5..5432ebec5 100644 --- a/interface/app/onboarding/prerelease.tsx +++ b/interface/app/onboarding/prerelease.tsx @@ -17,7 +17,7 @@ export default function OnboardingPreRelease() {
    diff --git a/interface/components/Authentication.tsx b/interface/components/Authentication.tsx new file mode 100644 index 000000000..22c4c7269 --- /dev/null +++ b/interface/components/Authentication.tsx @@ -0,0 +1,170 @@ +import { GoogleLogo, Icon } from '@phosphor-icons/react'; +import { Apple, Github } from '@sd/assets/svgs/brands'; +import { RSPCError } from '@spacedrive/rspc-client'; +import { UseMutationResult } from '@tanstack/react-query'; +import { open } from '@tauri-apps/plugin-shell'; +import clsx from 'clsx'; +import { motion } from 'framer-motion'; +import { Dispatch, SetStateAction, useState } from 'react'; +import { getAuthorisationURLWithQueryParamsAndSetState } from 'supertokens-web-js/recipe/thirdparty'; +import { Card, toast } from '@sd/ui'; +import { Icon as Logo } from '~/components'; +import { useIsDark } from '~/hooks'; + +import Login from './Login'; +import Register from './Register'; + +export const AccountTabs = ['Login', 'Register'] as const; + +export type SocialLogin = { + name: 'Github' | 'Google' | 'Apple'; + icon: Icon; +}; + +export const SocialLogins: SocialLogin[] = [ + { name: 'Github', icon: Github }, + { name: 'Google', icon: GoogleLogo }, + { name: 'Apple', icon: Apple } +]; + +export const Authentication = ({ + reload, + cloudBootstrap +}: { + reload: Dispatch>; + cloudBootstrap: UseMutationResult; // Cloud bootstrap mutation +}) => { + const [activeTab, setActiveTab] = useState<'Login' | 'Register'>('Login'); + const isDark = useIsDark(); + + // Currently not in use due to backend issues - @Rocky43007 + const socialLoginHandlers = (name: SocialLogin['name']) => { + return { + Github: async () => { + try { + const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + thirdPartyId: 'github', + frontendRedirectURI: 'http://localhost:9420/api/auth/callback/github' + }); + await open(authUrl); + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + toast.error(err.message); + } else { + toast.error('Oops! Something went wrong.'); + } + } + }, + Google: async () => { + try { + const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + thirdPartyId: 'google', + frontendRedirectURI: 'spacedrive://-/auth' + }); + await open(authUrl); + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + toast.error(err.message); + } else { + toast.error('Oops! Something went wrong.'); + } + } + }, + Apple: async () => { + try { + const authUrl = await getAuthorisationURLWithQueryParamsAndSetState({ + thirdPartyId: 'apple', + frontendRedirectURI: 'http://localhost:9420/api/auth/callback/apple' + }); + await open(authUrl); + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + toast.error(err.message); + } else { + toast.error('Oops! Something went wrong.'); + } + } + } + }[name](); + }; + + return ( + +
    + {AccountTabs.map((text) => ( +
    setActiveTab(text)} + className={clsx( + 'relative flex-1 cursor-pointer border-b border-app-line p-3 text-center transition-colors duration-200', + text === 'Login' ? 'rounded-tl-lg' : 'rounded-tr-lg', + text === activeTab ? 'bg-app-background-alt' : '' + )} + > +

    + {text} +

    + {text === activeTab && ( + + )} +
    + ))} +
    +
    +
    + +

    + Spacedrive Cloud +

    +
    + {activeTab === 'Login' ? ( + + ) : ( + + )} +
    + Social auth and SSO (Single Sign On) available soon! +
    + {/* Optionally, uncomment the social login block when ready */} + {/*
    + +

    OR

    + +
    +
    + {SocialLogins.map((social) => ( + + + + ))} +
    */} +
    +
    + ); +}; diff --git a/interface/components/Login.tsx b/interface/components/Login.tsx new file mode 100644 index 000000000..a4ac8bf63 --- /dev/null +++ b/interface/components/Login.tsx @@ -0,0 +1,345 @@ +import { ArrowLeft } from '@phosphor-icons/react'; +import { RSPCError } from '@spacedrive/rspc-client'; +import { UseMutationResult } from '@tanstack/react-query'; +import clsx from 'clsx'; +import { Dispatch, SetStateAction, useState } from 'react'; +import { Controller } from 'react-hook-form'; +import { signIn } from 'supertokens-web-js/recipe/emailpassword'; +import { createCode } from 'supertokens-web-js/recipe/passwordless'; +import { useZodForm } from '@sd/client'; +import { Button, Divider, Form, Input, toast, z } from '@sd/ui'; +import { useLocale } from '~/hooks'; +import { getTokens } from '~/util'; + +import ShowPassword from './ShowPassword'; + +async function signInClicked( + email: string, + password: string, + reload: Dispatch>, + cloudBootstrap: UseMutationResult // Cloud bootstrap mutation +) { + try { + const response = await signIn({ + formFields: [ + { + id: 'email', + value: email + }, + { + id: 'password', + value: password + } + ] + }); + + if (response.status === 'FIELD_ERROR') { + response.formFields.forEach((formField) => { + if (formField.id === 'email') { + toast.error(formField.error); + } + }); + } else if (response.status === 'WRONG_CREDENTIALS_ERROR') { + toast.error('Email & password combination is incorrect.'); + } else if (response.status === 'SIGN_IN_NOT_ALLOWED') { + toast.error(response.reason); + } else { + const tokens = getTokens(); + console.log(cloudBootstrap); + cloudBootstrap.mutate([tokens.accessToken, tokens.refreshToken]); + toast.success('Sign in successful'); + reload(true); + } + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + toast.error(err.message); + } else { + console.error('Error signing in', err); + toast.error('Oops! Something went wrong.'); + } + } +} + +const LoginSchema = z.object({ + email: z.string().email({ + message: 'Email is required' + }), + password: z.string().min(6, { + message: 'Password must be at least 6 characters' + }) +}); + +const ContinueWithEmailSchema = z.object({ + email: z.string().email({ + message: 'Email is required' + }) +}); + +const Login = ({ + reload, + cloudBootstrap +}: { + reload: Dispatch>; + cloudBootstrap: UseMutationResult; // Cloud bootstrap mutation +}) => { + const [continueWithEmail, setContinueWithEmail] = useState(false); + + return ( + <> + {continueWithEmail ? ( + + ) : ( + + )} + + ); +}; + +interface LoginProps { + reload: Dispatch>; + cloudBootstrap: UseMutationResult; // Cloud bootstrap mutation + setContinueWithEmail: Dispatch>; +} + +const LoginForm = ({ reload, cloudBootstrap, setContinueWithEmail }: LoginProps) => { + const { t } = useLocale(); + const [showPassword, setShowPassword] = useState(false); + const form = useZodForm({ + schema: LoginSchema, + defaultValues: { + email: '', + password: '' + } + }); + + return ( + { + await signInClicked(data.email, data.password, reload, cloudBootstrap); + })} + className="w-full" + form={form} + > +
    +
    + + ( + + )} + /> + {form.formState.errors.email && ( +

    + {form.formState.errors.email.message} +

    + )} +
    + +
    + + ( +
    + { + const pastedText = e.clipboardData.getData('text'); + field.onChange(pastedText); + }} + /> + +
    + )} + /> + {form.formState.errors.password && ( +

    + {form.formState.errors.password.message} +

    + )} +
    +
    + + + +
    + +

    Or

    + +
    + + + + ); +}; + +interface Props { + setContinueWithEmail: Dispatch>; + reload: Dispatch>; + cloudBootstrap: UseMutationResult; // Cloud bootstrap mutation +} + +const ContinueWithEmail = ({ setContinueWithEmail, reload, cloudBootstrap }: Props) => { + const { t } = useLocale(); + const ContinueWithEmailForm = useZodForm({ + schema: ContinueWithEmailSchema, + defaultValues: { + email: '' + } + }); + const [step, setStep] = useState(1); + + return ( +
    { + // send email + await sendMagicLink(data.email); + setStep((step) => step + 1); + })} + className="w-full" + form={ContinueWithEmailForm} + > + {step === 1 ? ( + <> +
    + + ( + + )} + /> + {ContinueWithEmailForm.formState.errors.email && ( +

    + {ContinueWithEmailForm.formState.errors.email.message} +

    + )} +
    + + + ) : ( +
    +

    Check your email

    +
    +

    {t('login_link_sent')}

    +

    + {t('check_your_inbox')}{' '} + + {ContinueWithEmailForm.getValues().email} + +

    +
    +
    + )} + +
    + ); +}; + +async function sendMagicLink(email: string) { + try { + const response = await createCode({ + email + }); + + if (response.status === 'SIGN_IN_UP_NOT_ALLOWED') { + // the reason string is a user friendly message + // about what went wrong. It can also contain a support code which users + // can tell you so you know why their sign in / up was not allowed. + toast.error(response.reason); + } + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + // this may be a custom error message sent from the API by you, + // or if the input email / phone number is not valid. + toast.error(err.message); + } else { + console.error(err); + toast.error('Oops! Something went wrong.'); + } + } +} + +export default Login; diff --git a/interface/components/Register.tsx b/interface/components/Register.tsx new file mode 100644 index 000000000..1262274aa --- /dev/null +++ b/interface/components/Register.tsx @@ -0,0 +1,204 @@ +import { zodResolver } from '@hookform/resolvers/zod'; +import clsx from 'clsx'; +import { useState } from 'react'; +import { Controller, useForm } from 'react-hook-form'; +import { signUp } from 'supertokens-web-js/recipe/emailpassword'; +import { Button, Form, Input, toast, z } from '@sd/ui'; +import { useLocale } from '~/hooks'; + +import ShowPassword from './ShowPassword'; + +const RegisterSchema = z + .object({ + email: z.string().email({ + message: 'Email is required' + }), + password: z.string().min(6, { + message: 'Password must be at least 6 characters' + }), + confirmPassword: z.string().min(6, { + message: 'Password must be at least 6 characters' + }) + }) + .refine((data) => data.password === data.confirmPassword, { + message: 'Passwords do not match', + path: ['confirmPassword'] + }); +type RegisterType = z.infer; + +async function signUpClicked(email: string, password: string) { + try { + const response = await signUp({ + formFields: [ + { + id: 'email', + value: email + }, + { + id: 'password', + value: password + } + ] + }); + + if (response.status === 'FIELD_ERROR') { + // one of the input formFields failed validaiton + response.formFields.forEach((formField) => { + if (formField.id === 'email') { + // Email validation failed (for example incorrect email syntax), + // or the email is not unique. + toast.error(formField.error); + } else if (formField.id === 'password') { + // Password validation failed. + // Maybe it didn't match the password strength + toast.error(formField.error); + } + }); + } else if (response.status === 'SIGN_UP_NOT_ALLOWED') { + // the reason string is a user friendly message + // about what went wrong. It can also contain a support code which users + // can tell you so you know why their sign up was not allowed. + toast.error(response.reason); + } else { + // sign up successful. The session tokens are automatically handled by + // the frontend SDK. + toast.success('Sign up successful'); + // FIXME: This is a temporary workaround. We will provide a better way to handle this. + window.location.reload(); + } + } catch (err: any) { + if (err.isSuperTokensGeneralError === true) { + // this may be a custom error message sent from the API by you. + toast.error(err.message); + } else { + toast.error('Oops! Something went wrong.'); + } + } +} + +const Register = () => { + const { t } = useLocale(); + const [showPassword, setShowPassword] = useState(false); + // useZodForm seems to be out-dated or needs + //fixing as it does not support the schema using zod.refine + const form = useForm({ + resolver: zodResolver(RegisterSchema), + defaultValues: { + email: '', + password: '', + confirmPassword: '' + } + }); + return ( +
    { + // handle sign-up submission + console.log(data); + await signUpClicked(data.email, data.password); + })} + className="w-full" + form={form} + > +
    +
    + + ( + + )} + /> + {form.formState.errors.email && ( +

    + {form.formState.errors.email.message} +

    + )} +
    + +
    + + ( +
    + { + const pastedText = e.clipboardData.getData('text'); + field.onChange(pastedText); + }} + /> + +
    + )} + /> + {form.formState.errors.password && ( +

    + {form.formState.errors.password.message} +

    + )} + ( +
    + + +
    + )} + /> + {form.formState.errors.confirmPassword && ( +

    + {form.formState.errors.confirmPassword.message} +

    + )} +
    +
    + +
    + ); +}; + +export default Register; diff --git a/interface/components/RequestAddDialog.tsx b/interface/components/RequestAddDialog.tsx new file mode 100644 index 000000000..e8d1bf035 --- /dev/null +++ b/interface/components/RequestAddDialog.tsx @@ -0,0 +1,96 @@ +import { ArrowRight } from '@phosphor-icons/react'; +import { useQueryClient } from '@tanstack/react-query'; +import { useNavigate } from 'react-router'; +import { HardwareModel, useBridgeMutation, useZodForm } from '@sd/client'; +import { Dialog, toast, useDialog, UseDialogProps, z } from '@sd/ui'; +import { Icon } from '~/components'; +import { useLocale } from '~/hooks'; +import { hardwareModelToIcon } from '~/util/hardware'; +import { usePlatform } from '~/util/Platform'; + +export default ( + props: { + device_name: string; + device_model: HardwareModel; + library_name: string; + } & UseDialogProps +) => { + // PROPS = device_name, device_model, library_name + // you will probably have to change the props to accept the library id and device id to pair them properly. Omitted for now as + // unsure what the data will look like when the backend is populated + + // const joinLibrary = useBridgeMutation(['cloud.library.join']); + + const { t } = useLocale(); + const navigate = useNavigate(); + const platform = usePlatform(); + const queryClient = useQueryClient(); + + const form = useZodForm({ defaultValues: { libraryId: 'select_library' } }); + + // adapted from another dialog - we can change the form submit/remove form if needed but didn't want to + // unnecessarily remove code + const onSubmit = form.handleSubmit(async (data) => { + try { + // const library = await joinLibrary.mutateAsync(data.libraryId); + const library = { uuid: '1234' }; // dummy data + + queryClient.setQueryData(['library.list'], (libraries: any) => { + // The invalidation system beat us to it + if ((libraries || []).find((l: any) => l.uuid === library.uuid)) return libraries; + + return [...(libraries || []), library]; + }); + + if (platform.refreshMenuBar) platform.refreshMenuBar(); + + navigate(`/${library.uuid}`, { replace: true }); + } catch (e: any) { + console.error(e); + toast.error(e); + } + }); + + return ( + + {/* device */} +
    +
    + +

    {props.device_name}

    +
    + + {/* library */} +
    + +

    {props.library_name}

    +
    +
    +
    + ); +}; diff --git a/interface/components/ShowPassword.tsx b/interface/components/ShowPassword.tsx new file mode 100644 index 000000000..d3e846da0 --- /dev/null +++ b/interface/components/ShowPassword.tsx @@ -0,0 +1,27 @@ +import { Eye, EyeClosed } from '@phosphor-icons/react'; +import { Button, Tooltip } from '@sd/ui'; + +interface Props { + showPassword: boolean; + setShowPassword: (value: boolean) => void; +} + +const ShowPassword = ({ showPassword, setShowPassword }: Props) => { + return ( + + + + ); +}; + +export default ShowPassword; diff --git a/interface/components/index.ts b/interface/components/index.ts index edc3d5dc5..2f6420216 100644 --- a/interface/components/index.ts +++ b/interface/components/index.ts @@ -13,3 +13,4 @@ export * from './TextViewer'; export * from './TrafficLights'; export * from './TruncatedText'; export * from './Accordion'; +export * from './Authentication'; diff --git a/interface/hooks/index.ts b/interface/hooks/index.ts index 6f2fbdf72..fa9720835 100644 --- a/interface/hooks/index.ts +++ b/interface/hooks/index.ts @@ -31,3 +31,5 @@ export * from './useWindowState'; export * from './useZodParams'; export * from './useZodRouteParams'; export * from './useZodSearchParams'; +export * from './useDeeplinkEventHandler'; +export * from './useAccessToken'; diff --git a/interface/hooks/useAccessToken.ts b/interface/hooks/useAccessToken.ts new file mode 100644 index 000000000..d0a20a1e9 --- /dev/null +++ b/interface/hooks/useAccessToken.ts @@ -0,0 +1,8 @@ +export function useAccessToken(): string { + const accessToken: string = + JSON.parse(window.localStorage.getItem('frontendCookies') ?? '[]') + .find((cookie: string) => cookie.startsWith('st-access-token')) + ?.split('=')[1] + .split(';')[0] || ''; + return accessToken.trim(); +} diff --git a/interface/hooks/useDeeplinkEventHandler.ts b/interface/hooks/useDeeplinkEventHandler.ts new file mode 100644 index 000000000..11556ec66 --- /dev/null +++ b/interface/hooks/useDeeplinkEventHandler.ts @@ -0,0 +1,37 @@ +import { useEffect } from 'react'; +import { useNavigate } from 'react-router'; +import { DeeplinkEvent } from '~/util/events'; + +export const useDeeplinkEventHandler = () => { + const navigate = useNavigate(); + useEffect(() => { + const handler = (e: DeeplinkEvent) => { + e.preventDefault(); + + const url = e.detail.url; + if (!url) return; + // If the URL has search params, we need to navigate to the URL with the search params + const [path, search] = url.split('?'); + // If hash is present, we need to split it from the search params, and remove it from the search value + const [searchParams, hash] = search ? search.split('#') : ['', '']; + const searchParamsObj = new URLSearchParams(searchParams); + const searchParamsString = searchParamsObj.toString(); + console.log('Navigating to', { + path, + searchParamsString, + hash + }); + + navigate({ pathname: path, search: searchParamsString, hash }); + + // if (search) { + // navigate({ pathname: path, search, hash }); + // } else { + // navigate(url); + // } + }; + + document.addEventListener('deeplink', handler); + return () => document.removeEventListener('deeplink', handler); + }, [navigate]); +}; diff --git a/interface/hooks/useKeybindEventHandler.ts b/interface/hooks/useKeybindEventHandler.ts index a515cdb18..6ac1808e2 100644 --- a/interface/hooks/useKeybindEventHandler.ts +++ b/interface/hooks/useKeybindEventHandler.ts @@ -1,7 +1,7 @@ import { useEffect } from 'react'; import { useLocation, useNavigate } from 'react-router'; -import { KeybindEvent } from '../util/keybind'; +import { KeybindEvent } from '../util/events'; import { useQuickRescan } from './useQuickRescan'; import { getWindowState } from './useWindowState'; diff --git a/interface/index.tsx b/interface/index.tsx index 1d8799098..da5e021a6 100644 --- a/interface/index.tsx +++ b/interface/index.tsx @@ -8,11 +8,13 @@ import { RouterProvider, RouterProviderProps } from 'react-router-dom'; import { InteropProviderReact, P2PContextProvider, + useBridgeMutation, useBridgeSubscription, useInvalidateQuery, useLoadBackendFeatureFlags } from '@sd/client'; -import { toast, TooltipProvider } from '@sd/ui'; +import { dialogManager, toast, TooltipProvider } from '@sd/ui'; +import RequestAddDialog from '~/components/RequestAddDialog'; import { createRoutes } from './app'; import { SpacedropProvider } from './app/$libraryId/Spacedrop'; @@ -26,7 +28,7 @@ import { RouterContext, RoutingContext } from './RoutingContext'; export * from './app'; export { ErrorPage } from './ErrorFallback'; export * from './TabsContext'; -export * from './util/keybind'; +export * from './util/events'; export * from './util/Platform'; dayjs.extend(advancedFormat); @@ -87,6 +89,42 @@ export function SpacedriveInterfaceRoot({ children }: PropsWithChildren) { } }); + const userResponse = useBridgeMutation('cloud.userResponse'); + + useBridgeSubscription(['cloud.listenCloudServicesNotifications'], { + onData: (d) => { + console.log('Received cloud service notification', d); + switch (d.kind) { + case 'ReceivedJoinSyncGroupRequest': + // WARNING: This is a debug solution to accept the device into the sync group. THIS SHOULD NOT MAKE IT TO PRODUCTION + userResponse.mutate({ + kind: 'AcceptDeviceInSyncGroup', + data: { + ticket: d.data.ticket, + accepted: { + id: d.data.sync_group.library.pub_id, + name: d.data.sync_group.library.name, + description: null + } + } + }); + // TODO: Move the code above into the dialog below (@Rocky43007) + // dialogManager.create((dp) => ( + // + // )); + break; + default: + toast({ title: 'Cloud Service Notification', body: d.kind }, { type: 'info' }); + break; + } + } + }); + return ( diff --git a/interface/locales/ar/common.json b/interface/locales/ar/common.json index 2f3172406..f2426d118 100644 --- a/interface/locales/ar/common.json +++ b/interface/locales/ar/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "معاينة الصوت غير مدعومة.", "auto": "آلي", "back": "رجوع", + "back_to_login": "العودة إلى تسجيل الدخول", "backfill_sync": "عمليات مزامنة الردم", "backfill_sync_description": "تم إيقاف المكتبة مؤقتًا حتى اكتمال عملية الردم", "backups": "نسخ احتياطية", @@ -437,6 +438,7 @@ "log_out": "تسجيل الخروج", "logged_in_as": "تم تسجيل الدخول كـ {{email}}", "logging_in": "جار تسجيل الدخول...", + "login": "تسجيل الدخول", "logout": "تسجيل الخروج", "manage_library": "إدارة المكتبة", "managed": "مُدار", diff --git a/interface/locales/be/common.json b/interface/locales/be/common.json index ac582e19d..6415b9e8f 100644 --- a/interface/locales/be/common.json +++ b/interface/locales/be/common.json @@ -49,6 +49,7 @@ "audio_preview_not_supported": "Папярэдні прагляд аўдыя не падтрымваецца.", "auto": "Аўто", "back": "Назад", + "back_to_login": "Вярнуцца да ўваходу", "backfill_sync": "Аперацыі поўнай сінхранізацыі", "backfill_sync_description": "Праца бібліятэкі прыпынена да завяршэння сінхранізацыі", "backups": "Рэз. копіі", @@ -71,6 +72,7 @@ "changelog": "Што новага", "changelog_page_description": "Даведайцеся, якія новыя магчымасці мы дадалі", "changelog_page_title": "Спіс змен", + "check_your_inbox": "Праверце паштовую скрыню па адрасе", "checksum": "Кантрольная сума", "clear_finished_jobs": "Ачысціць скончаныя заданні", "click_to_hide": "Націсніце, каб схаваць", @@ -482,6 +484,8 @@ "log_out": "Выйсці з сістэмы", "logged_in_as": "Увайшлі ў сістэму як {{email}}", "logging_in": "Уваход у сістэму...", + "login": "Увайсці", + "login_link_sent": "Мы адправілі часовую спасылку для ўваходу.", "logout": "Выйсці", "manage_library": "Кіраванне бібліятэкай", "managed": "Кіраваны", diff --git a/interface/locales/cs/common.json b/interface/locales/cs/common.json index 20b19cd00..24eafdb7f 100644 --- a/interface/locales/cs/common.json +++ b/interface/locales/cs/common.json @@ -47,6 +47,7 @@ "audio_preview_not_supported": "Náhled zvuku není podporován.", "auto": "Automaticky", "back": "Zpět", + "back_to_login": "Zpět k přihlášení", "backfill_sync": "Doplňování synchronizačních operací", "backfill_sync_description": "Knihovna je pozastavena, dokud se nedokončí doplňování", "backups": "Zálohy", @@ -458,6 +459,7 @@ "log_out": "Odhlásit se", "logged_in_as": "Přihlášen jako {{email}}", "logging_in": "Přihlašování...", + "login": "Přihlášení", "logout": "Odhlásit se", "manage_library": "Spravovat knihovnu", "managed": "Spravováno", diff --git a/interface/locales/de/common.json b/interface/locales/de/common.json index d2ee35f84..5b41b164b 100644 --- a/interface/locales/de/common.json +++ b/interface/locales/de/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "Audio-Vorschau wird nicht unterstützt.", "auto": "Auto", "back": "Zurück", + "back_to_login": "Zurück zur Anmeldung", "backfill_sync": "Synchronisierungsvorgänge auffüllen", "backfill_sync_description": "Die Bibliothek wird angehalten, bis der Backfill abgeschlossen ist", "backups": "Backups", @@ -419,6 +420,7 @@ "log_out": "Abmelden", "logged_in_as": "Angemeldet als {{email}}", "logging_in": "Einloggen...", + "login": "Login", "logout": "Abmelden", "manage_library": "Bibliothek verwalten", "managed": "Verwaltet", diff --git a/interface/locales/en/common.json b/interface/locales/en/common.json index c3570f860..fcac6fd36 100644 --- a/interface/locales/en/common.json +++ b/interface/locales/en/common.json @@ -4,6 +4,7 @@ "about_vision_title": "Vision", "accept": "Accept", "accept_files": "Accept files", + "accepting": "Accepting...", "accessed": "Accessed", "account": "Account", "actions": "Actions", @@ -47,6 +48,7 @@ "audio_preview_not_supported": "Audio preview is not supported.", "auto": "Auto", "back": "Back", + "back_to_login": "Back to login", "backfill_sync": "Backfilling Sync Operations", "backfill_sync_description": "Library is paused until backfill completes", "backups": "Backups", @@ -459,6 +461,7 @@ "log_out": "Log out", "logged_in_as": "Logged in as {{email}}", "logging_in": "Logging in...", + "login": "Login", "logout": "Logout", "manage_library": "Manage Library", "managed": "Managed", @@ -617,6 +620,8 @@ "rename": "Rename", "rename_object": "Rename object", "replica": "Replica", + "request_add_device": "Library Join Request", + "request_add_device_description": "A device is requesting to join one of your libraries. Please review the device and the library it is requesting to join below.", "rescan": "Rescan", "rescan_directory": "Rescan Directory", "rescan_location": "Rescan Location", diff --git a/interface/locales/es/common.json b/interface/locales/es/common.json index 89c24f98a..ff78c5505 100644 --- a/interface/locales/es/common.json +++ b/interface/locales/es/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "La previsualización de audio no está soportada.", "auto": "Auto", "back": "Atrás", + "back_to_login": "Volver al inicio de sesión", "backfill_sync": "Operaciones de sincronización de reabastecimiento", "backfill_sync_description": "La biblioteca está en pausa hasta que se complete el reabastecimiento", "backups": "Copias de seguridad", @@ -421,6 +422,7 @@ "log_out": "Cerrar sesión", "logged_in_as": "Conectado como {{email}}", "logging_in": "Iniciando sesión...", + "login": "Acceso", "logout": "Cerrar sesión", "manage_library": "Administrar Biblioteca", "managed": "Gestionado", diff --git a/interface/locales/fr/common.json b/interface/locales/fr/common.json index a9402af74..7dbc1f203 100644 --- a/interface/locales/fr/common.json +++ b/interface/locales/fr/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "L'aperçu audio n'est pas pris en charge.", "auto": "Auto", "back": "Retour", + "back_to_login": "Retour à la connexion", "backfill_sync": "Opérations de synchronisation de remplissage", "backfill_sync_description": "La bibliothèque est suspendue jusqu'à ce que le remplissage soit terminé", "backups": "Sauvegardes", @@ -421,6 +422,7 @@ "log_out": "Se déconnecter", "logged_in_as": "Connecté en tant que {{email}}", "logging_in": "Se connecter...", + "login": "Se connecter", "logout": "Déconnexion", "manage_library": "Gérer la bibliothèque", "managed": "Géré", diff --git a/interface/locales/it/common.json b/interface/locales/it/common.json index b54bd2a46..f9e65d437 100644 --- a/interface/locales/it/common.json +++ b/interface/locales/it/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "L'anteprima audio non è disponibile.", "auto": "Auto", "back": "Indietro", + "back_to_login": "Torna al login", "backfill_sync": "Operazioni di sincronizzazione del backfill", "backfill_sync_description": "La raccolta viene sospesa fino al completamento del recupero", "backups": "Backups", @@ -421,6 +422,7 @@ "log_out": "Disconnettiti", "logged_in_as": "Accesso effettuato come {{email}}", "logging_in": "Entrando...", + "login": "Login", "logout": "Esci", "manage_library": "Gestisci la Libreria", "managed": "Gestito", diff --git a/interface/locales/ja/common.json b/interface/locales/ja/common.json index 111090a1c..f5e840084 100644 --- a/interface/locales/ja/common.json +++ b/interface/locales/ja/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "オーディオのプレビューには対応していません。", "auto": "自動", "back": "戻る", + "back_to_login": "ログインに戻る", "backfill_sync": "同期操作のバックフィル", "backfill_sync_description": "バックフィルが完了するまでライブラリは一時停止されます", "backups": "バックアップ", @@ -413,6 +414,7 @@ "log_out": "ログアウト", "logged_in_as": "{{email}} でログイン", "logging_in": "ログイン中...", + "login": "ログイン", "logout": "ログアウト", "manage_library": "ライブラリの設定", "managed": "Managed", diff --git a/interface/locales/nl/common.json b/interface/locales/nl/common.json index 38bf0016a..f39a5cde9 100644 --- a/interface/locales/nl/common.json +++ b/interface/locales/nl/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "Audio voorvertoning wordt niet ondersteund.", "auto": "Auto", "back": "Terug", + "back_to_login": "Terug naar inloggen", "backfill_sync": "Synchronisatiebewerkingen voor opvulling", "backfill_sync_description": "De bibliotheek wordt gepauzeerd totdat het aanvullen is voltooid", "backups": "Backups", @@ -419,6 +420,7 @@ "log_out": "Uitloggen", "logged_in_as": "Ingelogd als {{email}}", "logging_in": "Inloggen...", + "login": "Login", "logout": "Uitloggen", "manage_library": "Beheer Bibliotheek", "managed": "Beheerd", diff --git a/interface/locales/ru/common.json b/interface/locales/ru/common.json index bd230704d..44f9d7c23 100644 --- a/interface/locales/ru/common.json +++ b/interface/locales/ru/common.json @@ -49,6 +49,7 @@ "audio_preview_not_supported": "Предварительный просмотр аудио не поддерживается.", "auto": "Авто", "back": "Назад", + "back_to_login": "Вернуться к входу", "backfill_sync": "Операции полной синхронизации", "backfill_sync_description": "Работа библиотеки приостановлена ​​до завершения синхронизации", "backups": "Рез. копии", @@ -71,6 +72,7 @@ "changelog": "Что нового", "changelog_page_description": "Узнайте, какие новые возможности мы добавили", "changelog_page_title": "Список изменений", + "check_your_inbox": "Пожалуйста, проверьте свой почтовый ящик по адресу", "checksum": "Контрольная сумма", "clear_finished_jobs": "Очистить законченные задачи", "click_to_hide": "Щелкните, чтобы скрыть", @@ -482,6 +484,8 @@ "log_out": "Выйти из системы", "logged_in_as": "Вошли в систему как {{email}}", "logging_in": "Вход в систему...", + "login": "Авторизоваться", + "login_link_sent": "Мы отправили временную ссылку для входа.", "logout": "Выйти", "manage_library": "Управление библиотекой", "managed": "Управляемый", diff --git a/interface/locales/tr/common.json b/interface/locales/tr/common.json index c95752eb1..505044abf 100644 --- a/interface/locales/tr/common.json +++ b/interface/locales/tr/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "Ses önizlemesi desteklenmiyor.", "auto": "Oto", "back": "Geri", + "back_to_login": "Girişe geri dön", "backfill_sync": "Dolgu Senkronizasyon İşlemleri", "backfill_sync_description": "Dolgu tamamlanana kadar kitaplık duraklatıldı", "backups": "Yedeklemeler", @@ -419,6 +420,7 @@ "log_out": "Çıkış Yap", "logged_in_as": "{{email}} olarak giriş yapıldı", "logging_in": "Giriş...", + "login": "Giriş yapmak", "logout": "Çıkış Yap", "manage_library": "Kütüphaneyi Yönet", "managed": "Yönetilen", diff --git a/interface/locales/zh-CN/common.json b/interface/locales/zh-CN/common.json index 09460240b..dc3bc2799 100644 --- a/interface/locales/zh-CN/common.json +++ b/interface/locales/zh-CN/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "不支持音频预览。", "auto": "自动", "back": "返回", + "back_to_login": "返回登录", "backfill_sync": "回填同步操作", "backfill_sync_description": "库暂停直至回填完成", "backups": "备份", @@ -412,6 +413,7 @@ "log_out": "退出登录", "logged_in_as": "已登录为 {{email}}", "logging_in": "正在登录...", + "login": "登录", "logout": "退出登录", "manage_library": "管理库", "managed": "已管理", diff --git a/interface/locales/zh-TW/common.json b/interface/locales/zh-TW/common.json index d3efb1df8..3bd1097ec 100644 --- a/interface/locales/zh-TW/common.json +++ b/interface/locales/zh-TW/common.json @@ -41,6 +41,7 @@ "audio_preview_not_supported": "不支援音頻預覽。", "auto": "汽車", "back": "返回", + "back_to_login": "返回登入", "backfill_sync": "回填同步操作", "backfill_sync_description": "庫暫停至回填完成", "backups": "備份", @@ -412,6 +413,7 @@ "log_out": "登出", "logged_in_as": "已登入為{{email}}", "logging_in": "在登入...", + "login": "登入", "logout": "登出", "manage_library": "管理圖書館", "managed": "已管理", diff --git a/interface/package.json b/interface/package.json index 47b292ae1..cbcb35c67 100644 --- a/interface/package.json +++ b/interface/package.json @@ -13,6 +13,7 @@ "@dnd-kit/utilities": "^3.2.2", "@headlessui/react": "^1.7.17", "@icons-pack/react-simple-icons": "^9.1.0", + "@spacedrive/rspc-client": "github:spacedriveapp/rspc#path:packages/client&6a77167495", "@phosphor-icons/react": "^2.0.13", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-dropdown-menu": "^2.0.6", @@ -65,10 +66,11 @@ "rooks": "^7.14.1", "solid-js": "^1.8.8", "solid-refresh": "^0.6.3", + "supertokens-web-js": "^0.13.0", "use-count-up": "^3.0.1", "use-debounce": "^9.0.4", "use-resize-observer": "^9.1.0", - "uuid": "^9.0.1", + "uuid": "^10.0.0", "valtio": "^2.0" }, "devDependencies": { @@ -81,7 +83,7 @@ "tailwindcss": "^3.4.10", "type-fest": "^4.13.0", "typescript": "^5.6.2", - "vite": "^5.2.0", + "vite": "^5.4.9", "vite-plugin-svgr": "^3.3.0" } } diff --git a/interface/util/keybind.ts b/interface/util/events.ts similarity index 58% rename from interface/util/keybind.ts rename to interface/util/events.ts index 3500f3c50..4b1e1a77c 100644 --- a/interface/util/keybind.ts +++ b/interface/util/events.ts @@ -1,6 +1,7 @@ declare global { interface GlobalEventHandlersEventMap { keybindexec: KeybindEvent; + deeplink: DeeplinkEvent; } } @@ -13,3 +14,13 @@ export class KeybindEvent extends CustomEvent<{ action: string }> { }); } } + +export class DeeplinkEvent extends CustomEvent<{ url: string }> { + constructor(url: string) { + super('deeplink', { + detail: { + url + } + }); + } +} diff --git a/interface/util/index.tsx b/interface/util/index.tsx index 34fd7b43e..81ea777a8 100644 --- a/interface/util/index.tsx +++ b/interface/util/index.tsx @@ -8,3 +8,31 @@ export type NonEmptyArray = [T, ...T[]]; export const isNonEmpty = (input: T[]): input is NonEmptyArray => input.length > 0; export const isNonEmptyObject = (input: object) => Object.keys(input).length > 0; + +export const AUTH_SERVER_URL = 'https://auth.spacedrive.com'; +// export const AUTH_SERVER_URL = 'http://localhost:9420'; + +export function getTokens() { + if (typeof window === 'undefined') { + return { + refreshToken: '', + accessToken: '' + }; + } + + const refreshToken: string = + JSON.parse(window.localStorage.getItem('frontendCookies') ?? '[]') + .find((cookie: string) => cookie.startsWith('st-refresh-token')) + ?.split('=')[1] + .split(';')[0] || ''; + const accessToken: string = + JSON.parse(window.localStorage.getItem('frontendCookies') ?? '[]') + .find((cookie: string) => cookie.startsWith('st-access-token')) + ?.split('=')[1] + .split(';')[0] || ''; + + return { + refreshToken, + accessToken + }; +} diff --git a/package.json b/package.json index 393d7f33c..ea61bc8d9 100644 --- a/package.json +++ b/package.json @@ -62,7 +62,7 @@ "turbo": "^1.12.5", "turbo-ignore": "^1.12.5", "typescript": "^5.6.2", - "vite": "^5.2.0" + "vite": "^5.4.9" }, "engines": { "pnpm": ">=9.0.0", @@ -73,5 +73,5 @@ "eslintConfig": { "root": true }, - "packageManager": "pnpm@9.12.1" + "packageManager": "pnpm@9.12.2" } diff --git a/packages/client/src/core.ts b/packages/client/src/core.ts index bc4c2fe8c..ebe660c03 100644 --- a/packages/client/src/core.ts +++ b/packages/client/src/core.ts @@ -3,13 +3,19 @@ export type Procedures = { queries: - { key: "auth.me", input: never, result: { id: string; email: string } } | { key: "backups.getAll", input: never, result: GetAll } | { key: "buildInfo", input: never, result: BuildInfo } | - { key: "cloud.getApiOrigin", input: never, result: string } | - { key: "cloud.library.get", input: LibraryArgs, result: CloudLibrary | null } | - { key: "cloud.library.list", input: never, result: CloudLibrary[] } | - { key: "cloud.locations.list", input: never, result: CloudLocation[] } | + { key: "cloud.devices.get", input: CloudDevicePubId, result: CloudDevice } | + { key: "cloud.devices.get_current_device", input: never, result: CloudDevice } | + { key: "cloud.devices.list", input: never, result: CloudDevice[] } | + { key: "cloud.hasBootstrapped", input: never, result: boolean } | + { key: "cloud.libraries.get", input: CloudGetLibraryArgs, result: CloudLibrary } | + { key: "cloud.libraries.list", input: boolean, result: CloudLibrary[] } | + { key: "cloud.locations.list", input: CloudListLocationsArgs, result: CloudLocation[] } | + { key: "cloud.syncGroups.get", input: CloudGetSyncGroupArgs, result: CloudSyncGroupGetResponseKind } | + { key: "cloud.syncGroups.leave", input: CloudSyncGroupPubId, result: null } | + { key: "cloud.syncGroups.list", input: never, result: CloudSyncGroupBaseData[] } | + { key: "cloud.syncGroups.remove_device", input: CloudSyncGroupsRemoveDeviceArgs, result: null } | { key: "ephemeralFiles.getMediaData", input: string, result: MediaData | null } | { key: "files.get", input: LibraryArgs, result: ObjectWithFilePaths2 | null } | { key: "files.getConvertibleImageExtensions", input: never, result: string[] } | @@ -50,7 +56,6 @@ export type Procedures = { { key: "search.saved.get", input: LibraryArgs, result: SavedSearch | null } | { key: "search.saved.list", input: LibraryArgs, result: SavedSearch[] } | { key: "sync.enabled", input: LibraryArgs, result: boolean } | - { key: "sync.messages", input: LibraryArgs, result: CRDTOperation[] } | { key: "tags.get", input: LibraryArgs, result: Tag | null } | { key: "tags.getForObject", input: LibraryArgs, result: Tag[] } | { key: "tags.getWithObjects", input: LibraryArgs, result: { [key in number]: ({ object: { id: number }; date_created: string | null })[] } } | @@ -58,16 +63,21 @@ export type Procedures = { { key: "volumes.list", input: never, result: Volume[] }, mutations: { key: "api.sendFeedback", input: Feedback, result: null } | - { key: "auth.logout", input: never, result: null } | { key: "backups.backup", input: LibraryArgs, result: string } | { key: "backups.delete", input: string, result: null } | { key: "backups.restore", input: string, result: null } | - { key: "cloud.library.create", input: LibraryArgs, result: null } | - { key: "cloud.library.join", input: string, result: LibraryConfigWrapped } | - { key: "cloud.library.sync", input: LibraryArgs, result: null } | - { key: "cloud.locations.create", input: string, result: CloudLocation } | - { key: "cloud.locations.remove", input: string, result: CloudLocation } | - { key: "cloud.setApiOrigin", input: string, result: null } | + { key: "cloud.bootstrap", input: [AccessToken, RefreshToken], result: null } | + { key: "cloud.devices.delete", input: CloudDevicePubId, result: null } | + { key: "cloud.devices.update", input: CloudUpdateDeviceArgs, result: null } | + { key: "cloud.libraries.create", input: LibraryArgs, result: null } | + { key: "cloud.libraries.delete", input: LibraryArgs, result: null } | + { key: "cloud.libraries.update", input: LibraryArgs, result: null } | + { key: "cloud.locations.create", input: CloudCreateLocationArgs, result: null } | + { key: "cloud.locations.delete", input: CloudLocationPubId, result: null } | + { key: "cloud.syncGroups.create", input: LibraryArgs, result: null } | + { key: "cloud.syncGroups.delete", input: CloudSyncGroupPubId, result: null } | + { key: "cloud.syncGroups.request_join", input: SyncGroupsRequestJoinArgs, result: null } | + { key: "cloud.userResponse", input: CloudP2PUserResponse, result: null } | { key: "ephemeralFiles.copyFiles", input: LibraryArgs, result: null } | { key: "ephemeralFiles.createFile", input: LibraryArgs, result: string } | { key: "ephemeralFiles.createFolder", input: LibraryArgs, result: string } | @@ -101,8 +111,6 @@ export type Procedures = { { key: "library.create", input: CreateLibraryArgs, result: LibraryConfigWrapped } | { key: "library.delete", input: string, result: null } | { key: "library.edit", input: EditLibraryArgs, result: null } | - { key: "library.startActor", input: LibraryArgs, result: null } | - { key: "library.stopActor", input: LibraryArgs, result: null } | { key: "library.vacuumDb", input: LibraryArgs, result: null } | { key: "locations.addLibrary", input: LibraryArgs, result: number | null } | { key: "locations.create", input: LibraryArgs, result: number | null } | @@ -130,22 +138,26 @@ export type Procedures = { { key: "tags.update", input: LibraryArgs, result: null } | { key: "toggleFeatureFlag", input: BackendFeature, result: null }, subscriptions: - { key: "auth.loginSession", input: never, result: Response } | + { key: "cloud.listenCloudServicesNotifications", input: never, result: CloudP2PNotifyUser } | { key: "invalidation.listen", input: never, result: InvalidateOperationEvent[] } | { key: "jobs.newFilePathIdentified", input: LibraryArgs, result: number[] } | { key: "jobs.newThumbnail", input: LibraryArgs, result: ThumbKey } | { key: "jobs.progress", input: LibraryArgs, result: JobProgressEvent } | - { key: "library.actors", input: LibraryArgs, result: { [key in string]: boolean } } | + { key: "library.actors", input: LibraryArgs, result: ([string, boolean])[] } | { key: "library.updatedKindStatistic", input: LibraryArgs, result: KindStatistic } | { key: "locations.online", input: never, result: number[][] } | { key: "locations.quickRescan", input: LibraryArgs, result: null } | { key: "notifications.listen", input: never, result: Notification } | { key: "p2p.events", input: never, result: P2PEvent } | { key: "search.ephemeralPaths", input: LibraryArgs, result: { entries: ExplorerItem[]; errors: Error[] } } | - { key: "sync.active", input: LibraryArgs, result: SyncStatus } | - { key: "sync.newMessage", input: LibraryArgs, result: null } + { key: "sync.active", input: LibraryArgs, result: SyncStatus } }; +/** + * Newtype wrapper for the access token + */ +export type AccessToken = string + export type Args = { search?: string | null; filters?: string | null; name?: string | null; icon?: string | null; description?: string | null } export type AudioProps = { delay: number; padding: number; sample_rate: number | null; sample_format: string | null; bit_per_sample: number | null; channel_layout: string | null } @@ -155,16 +167,14 @@ export type AudioProps = { delay: number; padding: number; sample_rate: number | * * If you want a variant of this to show up on the frontend it must be added to `backendFeatures` in `useFeatureFlag.tsx` */ -export type BackendFeature = "cloudSync" +export type BackendFeature = never export type Backup = ({ id: string; timestamp: string; library_id: string; library_name: string }) & { path: string } +export type BasicLibraryCreationArgs = { id: CloudLibraryPubId; name: string; description: string | null } + export type BuildInfo = { version: string; commit: string } -export type CRDTOperation = { instance: string; timestamp: number; model: number; record_id: JsonValue; data: CRDTOperationData } - -export type CRDTOperationData = { c: { [key in string]: JsonValue } } | { u: { field: string; value: JsonValue } } | "d" - export type CameraData = { device_make: string | null; device_model: string | null; color_space: string | null; color_profile: ColorProfile | null; focal_length: number | null; shutter_speed: number | null; flash: Flash | null; orientation: Orientation; lens_make: string | null; lens_model: string | null; bit_depth: number | null; zoom: number | null; iso: number | null; software: string | null; serial_number: string | null; lens_serial_number: string | null; contrast: number | null; saturation: number | null; sharpness: number | null; composite: Composite | null } export type CasId = string @@ -173,11 +183,51 @@ export type ChangeNodeNameArgs = { name: string | null; p2p_port: Port | null; p export type Chapter = { id: number; start: [number, number]; end: [number, number]; time_base_den: number; time_base_num: number; metadata: Metadata } -export type CloudInstance = { id: string; uuid: string; identity: RemoteIdentity; nodeId: string; nodeRemoteIdentity: string; metadata: { [key in string]: string } } +export type CloudCreateLocationArgs = { pub_id: CloudLocationPubId; name: string; library_pub_id: CloudLibraryPubId; device_pub_id: CloudDevicePubId } -export type CloudLibrary = { id: string; uuid: string; name: string; instances: CloudInstance[]; ownerId: string } +export type CloudDevice = { pub_id: CloudDevicePubId; name: string; os: DeviceOS; hardware_model: HardwareModel; connection_id: string; created_at: string; updated_at: string } -export type CloudLocation = { id: string; name: string } +export type CloudDevicePubId = string + +export type CloudGetLibraryArgs = { pub_id: CloudLibraryPubId; with_device: boolean } + +export type CloudGetSyncGroupArgs = { pub_id: CloudSyncGroupPubId; kind: CloudSyncGroupGetRequestKind } + +export type CloudLibrary = { pub_id: CloudLibraryPubId; name: string; original_device: CloudDevice | null; created_at: string; updated_at: string } + +export type CloudLibraryPubId = string + +export type CloudListLocationsArgs = { library_pub_id: CloudLibraryPubId; with_library: boolean; with_device: boolean } + +export type CloudLocation = { pub_id: CloudLocationPubId; name: string; device: CloudDevice | null; library: CloudLibrary | null; created_at: string; updated_at: string } + +export type CloudLocationPubId = string + +export type CloudP2PError = "Rejected" | "UnableToConnect" | "TimedOut" + +export type CloudP2PNotifyUser = { kind: "ReceivedJoinSyncGroupRequest"; data: { ticket: CloudP2PTicket; asking_device: CloudDevice; sync_group: CloudSyncGroupWithDevices } } | { kind: "ReceivedJoinSyncGroupResponse"; data: { response: JoinSyncGroupResponse; sync_group: CloudSyncGroupWithDevices } } | { kind: "SendingJoinSyncGroupResponseError"; data: { error: JoinSyncGroupError; sync_group: CloudSyncGroupWithDevices } } | { kind: "TimedOutJoinRequest"; data: { device: CloudDevice; succeeded: boolean } } + +export type CloudP2PTicket = bigint + +export type CloudP2PUserResponse = { kind: "AcceptDeviceInSyncGroup"; data: { ticket: CloudP2PTicket; accepted: BasicLibraryCreationArgs | null } } + +export type CloudSyncGroup = { pub_id: CloudSyncGroupPubId; latest_key_hash: CloudSyncKeyHash; library: CloudLibrary; devices: CloudDevice[]; total_sync_messages_bytes: bigint; total_space_files_bytes: bigint; created_at: string; updated_at: string } + +export type CloudSyncGroupBaseData = { pub_id: CloudSyncGroupPubId; latest_key_hash: CloudSyncKeyHash; library: CloudLibrary; created_at: string; updated_at: string } + +export type CloudSyncGroupGetRequestKind = "WithDevices" | "DevicesConnectionIds" | "FullData" + +export type CloudSyncGroupGetResponseKind = { kind: "WithDevices"; data: CloudSyncGroupWithDevices } | { kind: "FullData"; data: CloudSyncGroup } + +export type CloudSyncGroupPubId = string + +export type CloudSyncGroupWithDevices = { pub_id: CloudSyncGroupPubId; latest_key_hash: CloudSyncKeyHash; library: CloudLibrary; devices: CloudDevice[]; created_at: string; updated_at: string } + +export type CloudSyncGroupsRemoveDeviceArgs = { group_pub_id: CloudSyncGroupPubId; to_remove_device_pub_id: CloudDevicePubId } + +export type CloudSyncKeyHash = string + +export type CloudUpdateDeviceArgs = { pub_id: CloudDevicePubId; name: string } export type Codec = { kind: string | null; sub_kind: string | null; tag: string | null; name: string | null; profile: string | null; bit_rate: number; props: Props | null } @@ -211,6 +261,12 @@ export type ConvertImageArgs = { location_id: number; file_path_id: number; dele export type ConvertibleExtension = "bmp" | "dib" | "ff" | "gif" | "ico" | "jpg" | "jpeg" | "png" | "pnm" | "qoi" | "tga" | "icb" | "vda" | "vst" | "tiff" | "tif" | "hif" | "heif" | "heifs" | "heic" | "heics" | "avif" | "avci" | "avcs" | "svg" | "svgz" | "pdf" | "webp" +export type CoreDevicePubId = CorePubId + +export type CoreHardwareModel = "Other" | "MacStudio" | "MacBookAir" | "MacBookPro" | "MacBook" | "MacMini" | "MacPro" | "IMac" | "IMacPro" | "IPad" | "IPhone" | "Simulator" | "Android" + +export type CorePubId = { Uuid: string } | { Vec: number[] } + export type CreateEphemeralFileArgs = { path: string; context: EphemeralFileCreateContextTypes; name: string | null } export type CreateEphemeralFolderArgs = { path: string; name: string | null } @@ -225,6 +281,8 @@ export type CursorOrderItem = { order: SortOrder; data: T } export type DefaultLocations = { desktop: boolean; documents: boolean; downloads: boolean; pictures: boolean; music: boolean; videos: boolean } +export type DeviceOS = "Linux" | "Windows" | "MacOS" | "iOS" | "Android" + /** * The method used for the discovery of this peer. * *Technically* you can have multiple under the hood but this simplifies things for the UX. @@ -282,7 +340,7 @@ export type FfmpegMediaVideoProps = { id: number; pixel_format: string | null; c export type FileCreateContextTypes = "empty" | "text" -export type FilePath = { id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null } +export type FilePath = { id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null; device_id: number | null } export type FilePathCursor = { isDir: boolean; variant: FilePathCursorVariant } @@ -290,7 +348,7 @@ export type FilePathCursorVariant = "none" | { name: CursorOrderItem } | export type FilePathFilterArgs = { locations: InOrNotIn } | { path: { location_id: number; path: string; include_descendants: boolean } } | { name: TextMatch } | { extension: InOrNotIn } | { createdAt: Range } | { modifiedAt: Range } | { indexedAt: Range } | { hidden: boolean } -export type FilePathForFrontend = { id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; tags: ({ object_id: number; tag_id: number; tag: Tag; date_created: string | null })[]; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null } +export type FilePathForFrontend = { id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; tags: ({ object_id: number; tag_id: number; tag: Tag; date_created: string | null; device_id: number | null })[]; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null; device_id: number | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null; device_id: number | null } export type FilePathObjectCursor = { dateAccessed: CursorOrderItem } | { kind: CursorOrderItem } @@ -379,6 +437,10 @@ export type JobName = "Indexer" | "FileIdentifier" | "MediaProcessor" | "Copy" | export type JobProgressEvent = { id: string; library_id: string; task_count: number; completed_task_count: number; phase: string; message: string; info: string; estimated_completion: string } +export type JoinSyncGroupError = "Communication" | "InternalServer" | "Auth" + +export type JoinSyncGroupResponse = { Accepted: { authorizor_device: CloudDevice } } | { Failed: CloudP2PError } | "CriticalError" + export type JsonValue = null | boolean | number | string | JsonValue[] | { [key in string]: JsonValue } export type KindStatistic = { kind: number; name: string; count: [number, number]; total_bytes: [number, number] } @@ -430,7 +492,7 @@ export type ListenerState = { type: "Listening" } | { type: "Error"; error: stri export type Listeners = { ipv4: ListenerState; ipv6: ListenerState; relay: ListenerState } -export type Location = { id: number; pub_id: number[]; name: string | null; path: string | null; total_capacity: number | null; available_capacity: number | null; size_in_bytes: number[] | null; is_archived: boolean | null; generate_preview_media: boolean | null; sync_preview_media: boolean | null; hidden: boolean | null; date_created: string | null; scan_state: number; instance_id: number | null } +export type Location = { id: number; pub_id: number[]; name: string | null; path: string | null; total_capacity: number | null; available_capacity: number | null; size_in_bytes: number[] | null; is_archived: boolean | null; generate_preview_media: boolean | null; sync_preview_media: boolean | null; hidden: boolean | null; date_created: string | null; scan_state: number; device_id: number | null; instance_id: number | null } /** * `LocationCreateArgs` is the argument received from the client using `rspc` to create a new location. @@ -473,9 +535,10 @@ export type NodeConfigP2P = { discovery?: P2PDiscoveryState; port: Port; disable * * All of these are valid values: * - `localhost` - * - `otbeaumont.me` or `otbeaumont.me:3000` + * - `spacedrive.com` or `spacedrive.com:3000` * - `127.0.0.1` or `127.0.0.1:300` * - `[::1]` or `[::1]:3000` + * * which is why we use `String` not `SocketAddr` */ manual_peers?: string[] } @@ -486,11 +549,11 @@ export type NodeState = ({ /** * id is a unique identifier for the current node. Each node has a public identifier (this one) and is given a local id for each library (done within the library code). */ -id: string; +id: CoreDevicePubId; /** * name is the display name of the current node. This is set by the user and is shown in the UI. // TODO: Length validation so it can fit in DNS record */ -name: string; identity: RemoteIdentity; p2p: NodeConfigP2P; features: BackendFeature[]; preferences: NodePreferences; image_labeler_version: string | null }) & { data_path: string; device_model: string | null; is_in_docker: boolean } +name: string; identity: RemoteIdentity; p2p: NodeConfigP2P; features: BackendFeature[]; preferences: NodePreferences; os: DeviceOS; hardware_model: CoreHardwareModel }) & { data_path: string; device_model: string | null; is_in_docker: boolean } export type NonCriticalError = { indexer: NonCriticalIndexerError } | { file_identifier: NonCriticalFileIdentifierError } | { media_processor: NonCriticalMediaProcessorError } @@ -521,7 +584,7 @@ export type NotificationId = { type: "library"; id: [string, number] } | { type: export type NotificationKind = "info" | "success" | "error" | "warning" -export type Object = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null } +export type Object = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; device_id: number | null } export type ObjectCursor = "none" | { dateAccessed: CursorOrderItem } | { kind: CursorOrderItem } @@ -535,9 +598,9 @@ export type ObjectSearchArgs = { take: number; orderAndPagination?: OrderAndPagi export type ObjectValidatorArgs = { id: number; path: string } -export type ObjectWithFilePaths = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; file_paths: ({ id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null; ffmpeg_data: { id: number; formats: string; bit_rate: number[]; duration: number[] | null; start_time: number[] | null; chapters: FfmpegMediaChapter[]; programs: ({ program_id: number; streams: ({ stream_id: number; name: string | null; codec: { id: number; kind: string | null; sub_kind: string | null; tag: string | null; name: string | null; profile: string | null; bit_rate: number; video_props: FfmpegMediaVideoProps | null; audio_props: FfmpegMediaAudioProps | null; stream_id: number; program_id: number; ffmpeg_data_id: number } | null; aspect_ratio_num: number; aspect_ratio_den: number; frames_per_second_num: number; frames_per_second_den: number; time_base_real_den: number; time_base_real_num: number; dispositions: string | null; title: string | null; encoder: string | null; language: string | null; duration: number[] | null; metadata: number[] | null; program_id: number; ffmpeg_data_id: number })[]; name: string | null; metadata: number[] | null; ffmpeg_data_id: number })[]; title: string | null; creation_time: string | null; date: string | null; album_artist: string | null; disc: string | null; track: string | null; album: string | null; artist: string | null; metadata: number[] | null; object_id: number } | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null })[] } +export type ObjectWithFilePaths = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; file_paths: ({ id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null; ffmpeg_data: { id: number; formats: string; bit_rate: number[]; duration: number[] | null; start_time: number[] | null; chapters: FfmpegMediaChapter[]; programs: ({ program_id: number; streams: ({ stream_id: number; name: string | null; codec: { id: number; kind: string | null; sub_kind: string | null; tag: string | null; name: string | null; profile: string | null; bit_rate: number; video_props: FfmpegMediaVideoProps | null; audio_props: FfmpegMediaAudioProps | null; stream_id: number; program_id: number; ffmpeg_data_id: number } | null; aspect_ratio_num: number; aspect_ratio_den: number; frames_per_second_num: number; frames_per_second_den: number; time_base_real_den: number; time_base_real_num: number; dispositions: string | null; title: string | null; encoder: string | null; language: string | null; duration: number[] | null; metadata: number[] | null; program_id: number; ffmpeg_data_id: number })[]; name: string | null; metadata: number[] | null; ffmpeg_data_id: number })[]; title: string | null; creation_time: string | null; date: string | null; album_artist: string | null; disc: string | null; track: string | null; album: string | null; artist: string | null; metadata: number[] | null; object_id: number } | null; device_id: number | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null; device_id: number | null })[]; device_id: number | null } -export type ObjectWithFilePaths2 = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; file_paths: ({ id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null; ffmpeg_data: { id: number; formats: string; bit_rate: number[]; duration: number[] | null; start_time: number[] | null; chapters: FfmpegMediaChapter[]; programs: ({ program_id: number; streams: ({ stream_id: number; name: string | null; codec: { id: number; kind: string | null; sub_kind: string | null; tag: string | null; name: string | null; profile: string | null; bit_rate: number; video_props: FfmpegMediaVideoProps | null; audio_props: FfmpegMediaAudioProps | null; stream_id: number; program_id: number; ffmpeg_data_id: number } | null; aspect_ratio_num: number; aspect_ratio_den: number; frames_per_second_num: number; frames_per_second_den: number; time_base_real_den: number; time_base_real_num: number; dispositions: string | null; title: string | null; encoder: string | null; language: string | null; duration: number[] | null; metadata: number[] | null; program_id: number; ffmpeg_data_id: number })[]; name: string | null; metadata: number[] | null; ffmpeg_data_id: number })[]; title: string | null; creation_time: string | null; date: string | null; album_artist: string | null; disc: string | null; track: string | null; album: string | null; artist: string | null; metadata: number[] | null; object_id: number } | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null })[] } +export type ObjectWithFilePaths2 = { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; file_paths: ({ id: number; pub_id: number[]; is_dir: boolean | null; cas_id: string | null; integrity_checksum: string | null; location_id: number | null; materialized_path: string | null; name: string | null; extension: string | null; hidden: boolean | null; size_in_bytes: string | null; size_in_bytes_bytes: number[] | null; inode: number[] | null; object_id: number | null; object: { id: number; pub_id: number[]; kind: number | null; key_id: number | null; hidden: boolean | null; favorite: boolean | null; important: boolean | null; note: string | null; date_created: string | null; date_accessed: string | null; exif_data: { resolution: number[] | null; media_date: number[] | null; media_location: number[] | null; camera_data: number[] | null; artist: string | null; description: string | null; copyright: string | null; exif_version: string | null } | null; ffmpeg_data: { id: number; formats: string; bit_rate: number[]; duration: number[] | null; start_time: number[] | null; chapters: FfmpegMediaChapter[]; programs: ({ program_id: number; streams: ({ stream_id: number; name: string | null; codec: { id: number; kind: string | null; sub_kind: string | null; tag: string | null; name: string | null; profile: string | null; bit_rate: number; video_props: FfmpegMediaVideoProps | null; audio_props: FfmpegMediaAudioProps | null; stream_id: number; program_id: number; ffmpeg_data_id: number } | null; aspect_ratio_num: number; aspect_ratio_den: number; frames_per_second_num: number; frames_per_second_den: number; time_base_real_den: number; time_base_real_num: number; dispositions: string | null; title: string | null; encoder: string | null; language: string | null; duration: number[] | null; metadata: number[] | null; program_id: number; ffmpeg_data_id: number })[]; name: string | null; metadata: number[] | null; ffmpeg_data_id: number })[]; title: string | null; creation_time: string | null; date: string | null; album_artist: string | null; disc: string | null; track: string | null; album: string | null; artist: string | null; metadata: number[] | null; object_id: number } | null; device_id: number | null } | null; key_id: number | null; date_created: string | null; date_modified: string | null; date_indexed: string | null; device_id: number | null })[] } export type OldFileCopierJobInit = { source_location_id: number; target_location_id: number; sources_file_path_ids: number[]; target_location_relative_directory_path: string } @@ -561,7 +624,7 @@ export type P2PDiscoveryState = "Everyone" | "ContactsOnly" | "Disabled" export type P2PEvent = { type: "PeerChange"; identity: RemoteIdentity; connection: ConnectionMethod; discovery: DiscoveryMethod; metadata: PeerMetadata; addrs: string[] } | { type: "PeerDelete"; identity: RemoteIdentity } | { type: "SpacedropRequest"; id: string; identity: RemoteIdentity; peer_name: string; files: string[] } | { type: "SpacedropProgress"; id: string; percent: number } | { type: "SpacedropTimedOut"; id: string } | { type: "SpacedropRejected"; id: string } -export type PeerMetadata = { name: string; operating_system: OperatingSystem | null; device_model: HardwareModel | null; version: string | null } +export type PeerMetadata = { name: string; operating_system: OperatingSystem | null; device_model: CoreHardwareModel | null; version: string | null } export type PlusCode = string @@ -573,6 +636,11 @@ export type Props = { Video: VideoProps } | { Audio: AudioProps } | { Subtitle: export type Range = { from: T } | { to: T } +/** + * Newtype wrapper for the refresh token + */ +export type RefreshToken = string + export type RemoteIdentity = string export type RenameFileArgs = { location_id: number; kind: RenameKind } @@ -595,8 +663,6 @@ export type RescanArgs = { location_id: number; sub_path: string } export type Resolution = { width: number; height: number } -export type Response = { Start: { user_code: string; verification_url: string; verification_url_complete: string } } | "Complete" | { Error: string } - export type RuleKind = "AcceptFilesByGlob" | "RejectFilesByGlob" | "AcceptIfChildrenDirectoriesArePresent" | "RejectIfChildrenDirectoriesArePresent" | "IgnoredByGit" export type SavedSearch = { id: number; pub_id: number[]; target: string | null; search: string | null; filters: string | null; name: string | null; icon: string | null; description: string | null; date_created: string | null; date_modified: string | null } @@ -631,6 +697,8 @@ export type Stream = { id: number; name: string | null; codec: Codec | null; asp export type SubtitleProps = { width: number; height: number } +export type SyncGroupsRequestJoinArgs = { sync_group: CloudSyncGroupWithDevices; asking_device: CloudDevice } + export type SyncStatus = { ingest: boolean; cloud_send: boolean; cloud_receive: boolean; cloud_ingest: boolean } export type SystemLocations = { desktop: string | null; documents: string | null; downloads: string | null; pictures: string | null; music: string | null; videos: string | null } diff --git a/packages/client/src/stores/auth.ts b/packages/client/src/stores/auth.ts index a494854c0..936988544 100644 --- a/packages/client/src/stores/auth.ts +++ b/packages/client/src/stores/auth.ts @@ -24,15 +24,15 @@ export function useStateSnapshot() { return useSolidStore(store).state; } -nonLibraryClient - .query(['auth.me']) - .then(() => (store.state = { status: 'loggedIn' })) - .catch((e) => { - if (e instanceof RSPCError && e.code === 401) { - // TODO: handle error? - } - store.state = { status: 'notLoggedIn' }; - }); +// nonLibraryClient +// .query(['auth.me']) +// .then(() => (store.state = { status: 'loggedIn' })) +// .catch((e) => { +// if (e instanceof RSPCError && e.code === 401) { +// // TODO: handle error? +// } +// store.state = { status: 'notLoggedIn' }; +// }); type CallbackStatus = 'success' | { error: string } | 'cancel'; const loginCallbacks = new Set<(status: CallbackStatus) => void>(); @@ -46,28 +46,28 @@ export async function login(config: ProviderConfig) { store.state = { status: 'loggingIn' }; - let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], { - onData(data) { - if (data === 'Complete') { - config.finish?.(authCleanup); - loginCallbacks.forEach((cb) => cb('success')); - } else if ('Error' in data) { - onError(data.Error); - } else { - Promise.resolve() - .then(() => config.start(data.Start.verification_url_complete)) - .then( - (res) => { - authCleanup = res; - }, - (e) => onError(e.message) - ); - } - }, - onError(e) { - onError(e.message); - } - }); + // let authCleanup = nonLibraryClient.addSubscription(['auth.loginSession'], { + // onData(data) { + // if (data === 'Complete') { + // config.finish?.(authCleanup); + // loginCallbacks.forEach((cb) => cb('success')); + // } else if ('Error' in data) { + // onError(data.Error); + // } else { + // Promise.resolve() + // .then(() => config.start(data.Start.verification_url_complete)) + // .then( + // (res) => { + // authCleanup = res; + // }, + // (e) => onError(e.message) + // ); + // } + // }, + // onError(e) { + // onError(e.message); + // } + // }); return new Promise((res, rej) => { const cb = async (status: CallbackStatus) => { @@ -75,7 +75,7 @@ export async function login(config: ProviderConfig) { if (status === 'success') { store.state = { status: 'loggedIn' }; - nonLibraryClient.query(['auth.me']); + // nonLibraryClient.query(['auth.me']); res(); } else { store.state = { status: 'notLoggedIn' }; @@ -88,8 +88,8 @@ export async function login(config: ProviderConfig) { export async function logout() { store.state = { status: 'loggingOut' }; - await nonLibraryClient.mutation(['auth.logout']); - await nonLibraryClient.query(['auth.me']); + // await nonLibraryClient.mutation(['auth.logout']); + // await nonLibraryClient.query(['auth.me']); store.state = { status: 'notLoggedIn' }; } diff --git a/packages/client/src/stores/featureFlags.tsx b/packages/client/src/stores/featureFlags.tsx index 4001c6f23..5b58c9ee7 100644 --- a/packages/client/src/stores/featureFlags.tsx +++ b/packages/client/src/stores/featureFlags.tsx @@ -2,8 +2,8 @@ import { useEffect } from 'react'; import { createMutable } from 'solid-js/store'; import type { BackendFeature } from '../core'; -import { nonLibraryClient, useBridgeQuery } from '../rspc'; -import { createPersistedMutable, useObserver, useSolidStore } from '../solid'; +import { useBridgeQuery } from '../rspc'; +import { createPersistedMutable, useObserver } from '../solid'; export const features = [ 'backups', @@ -17,7 +17,7 @@ export const features = [ // This defines which backend feature flags show up in the UI. // This is kinda a hack to not having the runtime array of possible features as Specta only exports the types. -export const backendFeatures: BackendFeature[] = ['cloudSync']; +export const backendFeatures: BackendFeature[] = []; export type FeatureFlag = (typeof features)[number] | BackendFeature; @@ -82,7 +82,7 @@ export function toggleFeatureFlag(flags: FeatureFlag | FeatureFlag[]) { ); if (result) { - nonLibraryClient.mutation(['toggleFeatureFlag', f as any]); + // nonLibraryClient.mutation(['toggleFeatureFlag', f as any]); } })(); diff --git a/packages/config/package.json b/packages/config/package.json index 90ee1a5fc..fab5fce96 100644 --- a/packages/config/package.json +++ b/packages/config/package.json @@ -11,6 +11,7 @@ "lint": "eslint . --cache" }, "devDependencies": { + "@babel/preset-typescript": "^7.24.0", "@typescript-eslint/eslint-plugin": "^8.8.0", "@typescript-eslint/parser": "^8.8.0", "@vitejs/plugin-react-swc": "^3.6.0", @@ -26,8 +27,8 @@ "eslint-utils": "^3.0.0", "regexpp": "^3.2.0", "vite-plugin-html": "^3.2.2", - "vite-plugin-i18next-loader": "^2.0.12", - "vite-plugin-inspect": "^0.8.3", + "vite-plugin-i18next-loader": "^2.0.14", + "vite-plugin-inspect": "^0.8.7", "vite-plugin-solid": "^2.10.2", "vite-plugin-svgr": "^3.3.0" }, diff --git a/packages/ui/src/Dialog.tsx b/packages/ui/src/Dialog.tsx index 6efbfeb3b..2f7d028e9 100644 --- a/packages/ui/src/Dialog.tsx +++ b/packages/ui/src/Dialog.tsx @@ -124,6 +124,7 @@ export interface DialogProps onSubmitSecond?: ReturnType>; children?: ReactNode; ctaDanger?: boolean; + cancelDanger?: boolean; closeLabel?: string; cancelLabel?: string; cancelBtn?: boolean; @@ -167,8 +168,9 @@ export function Dialog({ @@ -180,7 +182,7 @@ export function Dialog({