mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-04-22 16:17:31 -04:00
feat(sqlite): SqliteCryptoStore has 1 write connection.
Similarly to #5382 and #5744, this patch introduces a write-only connection in `SqliteCryptoStore`. The idea is to get many read-only connections, and a single write-only connection behind a lock, so that there is a single writer at a time. This patch renames the `acquire` method to `read`, and it introduces a new `write` connection.
This commit is contained in:
@@ -41,7 +41,10 @@ use ruma::{
|
||||
events::secret::request::SecretName,
|
||||
};
|
||||
use rusqlite::{OptionalExtension, named_params, params_from_iter};
|
||||
use tokio::{fs, sync::Mutex};
|
||||
use tokio::{
|
||||
fs,
|
||||
sync::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use tracing::{debug, instrument, warn};
|
||||
use vodozemac::Curve25519PublicKey;
|
||||
|
||||
@@ -62,8 +65,16 @@ const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3";
|
||||
#[derive(Clone)]
|
||||
pub struct SqliteCryptoStore {
|
||||
store_cipher: Option<Arc<StoreCipher>>,
|
||||
|
||||
/// The pool of connections.
|
||||
pool: SqlitePool,
|
||||
|
||||
/// We make the difference between connections for read operations, and for
|
||||
/// write operations. We keep a single connection apart from write
|
||||
/// operations. All other connections are used for read operations. The
|
||||
/// lock is used to ensure there is one owner at a time.
|
||||
write_connection: Arc<Mutex<SqliteAsyncConn>>,
|
||||
|
||||
// DB values cached in memory
|
||||
static_account: Arc<RwLock<Option<StaticAccountData>>>,
|
||||
save_changes_lock: Arc<Mutex<()>>,
|
||||
@@ -135,6 +146,8 @@ impl SqliteCryptoStore {
|
||||
Ok(SqliteCryptoStore {
|
||||
store_cipher,
|
||||
pool,
|
||||
// Use `conn` as our selected write connections.
|
||||
write_connection: Arc::new(Mutex::new(conn)),
|
||||
static_account: Arc::new(RwLock::new(None)),
|
||||
save_changes_lock: Default::default(),
|
||||
})
|
||||
@@ -168,9 +181,17 @@ impl SqliteCryptoStore {
|
||||
self.static_account.read().unwrap().clone()
|
||||
}
|
||||
|
||||
async fn acquire(&self) -> Result<SqliteAsyncConn> {
|
||||
/// Acquire a connection for executing read operations.
|
||||
#[instrument(skip_all)]
|
||||
async fn read(&self) -> Result<SqliteAsyncConn> {
|
||||
Ok(self.pool.get().await?)
|
||||
}
|
||||
|
||||
/// Acquire a connection for executing write operations.
|
||||
#[instrument(skip_all)]
|
||||
async fn write(&self) -> OwnedMutexGuard<SqliteAsyncConn> {
|
||||
self.write_connection.clone().lock_owned().await
|
||||
}
|
||||
}
|
||||
|
||||
const DATABASE_VERSION: u8 = 14;
|
||||
@@ -855,7 +876,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
type Error = Error;
|
||||
|
||||
async fn load_account(&self) -> Result<Option<Account>> {
|
||||
let conn = self.acquire().await?;
|
||||
let conn = self.read().await?;
|
||||
if let Some(pickle) = conn.get_kv("account").await? {
|
||||
let pickle = self.deserialize_value(&pickle)?;
|
||||
|
||||
@@ -870,7 +891,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
let conn = self.acquire().await?;
|
||||
let conn = self.read().await?;
|
||||
if let Some(i) = conn.get_kv("identity").await? {
|
||||
let pickle = self.deserialize_value(&i)?;
|
||||
Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?))
|
||||
@@ -894,8 +915,8 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
};
|
||||
|
||||
let this = self.clone();
|
||||
self.acquire()
|
||||
.await?
|
||||
self.write()
|
||||
.await
|
||||
.with_transaction(move |txn| {
|
||||
if let Some(pickled_account) = pickled_account {
|
||||
let serialized_account = this.serialize_value(&pickled_account)?;
|
||||
@@ -946,8 +967,8 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
let this = self.clone();
|
||||
self.acquire()
|
||||
.await?
|
||||
self.write()
|
||||
.await
|
||||
.with_transaction(move |txn| {
|
||||
if let Some(pickled_private_identity) = &pickled_private_identity {
|
||||
let serialized_private_identity =
|
||||
@@ -1094,7 +1115,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
let device_keys = self.get_own_device().await?.as_device_keys().clone();
|
||||
|
||||
let sessions: Vec<_> = self
|
||||
.acquire()
|
||||
.read()
|
||||
.await?
|
||||
.get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes()))
|
||||
.await?
|
||||
@@ -1116,7 +1137,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
) -> Result<Option<InboundGroupSession>> {
|
||||
let session_id = self.encode_key("inbound_group_session", session_id);
|
||||
let Some((room_id_from_db, value, backed_up)) =
|
||||
self.acquire().await?.get_inbound_group_session(session_id).await?
|
||||
self.read().await?.get_inbound_group_session(session_id).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
@@ -1131,7 +1152,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_inbound_group_sessions()
|
||||
.await?
|
||||
@@ -1147,7 +1168,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
room_id: &RoomId,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
let room_id = self.encode_key("inbound_group_session", room_id.as_bytes());
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_inbound_group_sessions_by_room_id(room_id)
|
||||
.await?
|
||||
@@ -1169,7 +1190,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id));
|
||||
let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64());
|
||||
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_inbound_group_sessions_for_device_batch(
|
||||
sender_key,
|
||||
@@ -1189,7 +1210,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
&self,
|
||||
backup_version: Option<&str>,
|
||||
) -> Result<RoomKeyCounts> {
|
||||
Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?)
|
||||
Ok(self.read().await?.get_inbound_group_session_counts(backup_version).await?)
|
||||
}
|
||||
|
||||
async fn inbound_group_sessions_for_backup(
|
||||
@@ -1197,7 +1218,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
_backup_version: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<InboundGroupSession>> {
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_inbound_group_sessions_for_backup(limit)
|
||||
.await?
|
||||
@@ -1212,8 +1233,8 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<()> {
|
||||
Ok(self
|
||||
.acquire()
|
||||
.await?
|
||||
.write()
|
||||
.await
|
||||
.mark_inbound_group_sessions_as_backed_up(
|
||||
session_ids
|
||||
.iter()
|
||||
@@ -1224,11 +1245,11 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn reset_backup_state(&self) -> Result<()> {
|
||||
Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
|
||||
Ok(self.write().await.reset_inbound_group_session_backup_state().await?)
|
||||
}
|
||||
|
||||
async fn load_backup_keys(&self) -> Result<BackupKeys> {
|
||||
let conn = self.acquire().await?;
|
||||
let conn = self.read().await?;
|
||||
|
||||
let backup_version = conn
|
||||
.get_kv("backup_version_v1")
|
||||
@@ -1246,7 +1267,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn load_dehydrated_device_pickle_key(&self) -> Result<Option<DehydratedDeviceKey>> {
|
||||
let conn = self.acquire().await?;
|
||||
let conn = self.read().await?;
|
||||
|
||||
conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY)
|
||||
.await?
|
||||
@@ -1255,17 +1276,14 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> {
|
||||
let conn = self.acquire().await?;
|
||||
conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?;
|
||||
|
||||
Ok(())
|
||||
Ok(self.write().await.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?)
|
||||
}
|
||||
async fn get_outbound_group_session(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
let room_id = self.encode_key("outbound_group_session", room_id.as_bytes());
|
||||
let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else {
|
||||
let Some(value) = self.read().await?.get_outbound_group_session(room_id).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
@@ -1283,7 +1301,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_tracked_users()
|
||||
.await?
|
||||
@@ -1303,7 +1321,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
Ok(self.acquire().await?.add_tracked_users(users).await?)
|
||||
Ok(self.write().await.add_tracked_users(users).await?)
|
||||
}
|
||||
|
||||
async fn get_device(
|
||||
@@ -1314,7 +1332,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
let user_id = self.encode_key("device", user_id.as_bytes());
|
||||
let device_id = self.encode_key("device", device_id.as_bytes());
|
||||
Ok(self
|
||||
.acquire()
|
||||
.read()
|
||||
.await?
|
||||
.get_device(user_id, device_id)
|
||||
.await?
|
||||
@@ -1327,7 +1345,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
user_id: &UserId,
|
||||
) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
|
||||
let user_id = self.encode_key("device", user_id.as_bytes());
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_user_devices(user_id)
|
||||
.await?
|
||||
@@ -1351,7 +1369,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
|
||||
let user_id = self.encode_key("identity", user_id.as_bytes());
|
||||
Ok(self
|
||||
.acquire()
|
||||
.read()
|
||||
.await?
|
||||
.get_user_identity(user_id)
|
||||
.await?
|
||||
@@ -1364,7 +1382,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
|
||||
) -> Result<bool> {
|
||||
let value = rmp_serde::to_vec(message_hash)?;
|
||||
Ok(self.acquire().await?.has_olm_hash(value).await?)
|
||||
Ok(self.read().await?.has_olm_hash(value).await?)
|
||||
}
|
||||
|
||||
async fn get_outgoing_secret_requests(
|
||||
@@ -1373,7 +1391,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
) -> Result<Option<GossipRequest>> {
|
||||
let request_id = self.encode_key("key_requests", request_id.as_bytes());
|
||||
Ok(self
|
||||
.acquire()
|
||||
.read()
|
||||
.await?
|
||||
.get_outgoing_secret_request(request_id)
|
||||
.await?
|
||||
@@ -1385,7 +1403,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
&self,
|
||||
key_info: &SecretInfo,
|
||||
) -> Result<Option<GossipRequest>> {
|
||||
let requests = self.acquire().await?.get_outgoing_secret_requests().await?;
|
||||
let requests = self.read().await?.get_outgoing_secret_requests().await?;
|
||||
for (request, sent_out) in requests {
|
||||
let request = self.deserialize_key_request(&request, sent_out)?;
|
||||
if request.info == *key_info {
|
||||
@@ -1396,7 +1414,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_unsent_secret_requests()
|
||||
.await?
|
||||
@@ -1410,7 +1428,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
|
||||
let request_id = self.encode_key("key_requests", request_id.as_bytes());
|
||||
Ok(self.acquire().await?.delete_key_request(request_id).await?)
|
||||
Ok(self.write().await.delete_key_request(request_id).await?)
|
||||
}
|
||||
|
||||
async fn get_secrets_from_inbox(
|
||||
@@ -1419,7 +1437,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
) -> Result<Vec<GossippedSecret>> {
|
||||
let secret_name = self.encode_key("secrets", secret_name.to_string());
|
||||
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_secrets_from_inbox(secret_name)
|
||||
.await?
|
||||
@@ -1430,7 +1448,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
|
||||
let secret_name = self.encode_key("secrets", secret_name.to_string());
|
||||
self.acquire().await?.delete_secrets_from_inbox(secret_name).await
|
||||
self.write().await.delete_secrets_from_inbox(secret_name).await
|
||||
}
|
||||
|
||||
async fn get_withheld_info(
|
||||
@@ -1441,7 +1459,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
let room_id = self.encode_key("direct_withheld_info", room_id);
|
||||
let session_id = self.encode_key("direct_withheld_info", session_id);
|
||||
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_direct_withheld_info(session_id, room_id)
|
||||
.await?
|
||||
@@ -1458,7 +1476,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
) -> matrix_sdk_crypto::store::Result<Vec<RoomKeyWithheldEntry>, Self::Error> {
|
||||
let room_id = self.encode_key("direct_withheld_info", room_id);
|
||||
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_withheld_sessions_by_room_id(room_id)
|
||||
.await?
|
||||
@@ -1469,7 +1487,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
|
||||
let room_id = self.encode_key("room_settings", room_id.as_bytes());
|
||||
let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else {
|
||||
let Some(value) = self.read().await?.get_room_settings(room_id).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
@@ -1485,7 +1503,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
) -> Result<Option<StoredRoomKeyBundleData>> {
|
||||
let room_id = self.encode_key("received_room_key_bundle", room_id);
|
||||
let user_id = self.encode_key("received_room_key_bundle", user_id);
|
||||
self.acquire()
|
||||
self.read()
|
||||
.await?
|
||||
.get_received_room_key_bundle(room_id, user_id)
|
||||
.await?
|
||||
@@ -1495,11 +1513,11 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let room_id = self.encode_key("room_key_backups_fully_downloaded", room_id);
|
||||
self.acquire().await?.has_downloaded_all_room_keys(room_id).await
|
||||
self.read().await?.has_downloaded_all_room_keys(room_id).await
|
||||
}
|
||||
|
||||
async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
|
||||
let Some(serialized) = self.acquire().await?.get_kv(key).await? else {
|
||||
let Some(serialized) = self.read().await?.get_kv(key).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let value = if let Some(cipher) = &self.store_cipher {
|
||||
@@ -1520,14 +1538,14 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
value
|
||||
};
|
||||
|
||||
self.acquire().await?.set_kv(key, serialized).await?;
|
||||
self.write().await.set_kv(key, serialized).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_custom_value(&self, key: &str) -> Result<()> {
|
||||
let key = key.to_owned();
|
||||
self.acquire()
|
||||
.await?
|
||||
self.write()
|
||||
.await
|
||||
.interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,)))
|
||||
.await
|
||||
.unwrap()?;
|
||||
@@ -1549,8 +1567,8 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
|
||||
// Learn about the `excluded` keyword in https://sqlite.org/lang_upsert.html.
|
||||
let generation = self
|
||||
.acquire()
|
||||
.await?
|
||||
.write()
|
||||
.await
|
||||
.with_transaction(move |txn| {
|
||||
txn.query_row(
|
||||
"INSERT INTO lease_locks (key, holder, expiration)
|
||||
@@ -1581,7 +1599,7 @@ impl CryptoStore for SqliteCryptoStore {
|
||||
}
|
||||
|
||||
async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
|
||||
let conn = self.acquire().await?;
|
||||
let conn = self.read().await?;
|
||||
if let Some(token) = conn.get_kv("next_batch_token").await? {
|
||||
let maybe_token: Option<String> = self.deserialize_value(&token)?;
|
||||
Ok(maybe_token)
|
||||
|
||||
Reference in New Issue
Block a user