mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-06 15:04:11 -04:00
sqlite: Return an error when DB version is invalid or missing (#2038)
Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
@@ -37,12 +37,14 @@ use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
|
||||
use rusqlite::OptionalExtension;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use tokio::{fs, sync::Mutex};
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
get_or_create_store_cipher,
|
||||
utils::{Key, SqliteConnectionExt as _, SqliteObjectExt, SqliteObjectStoreExt as _},
|
||||
utils::{
|
||||
load_db_version, Key, SqliteConnectionExt as _, SqliteObjectExt, SqliteObjectStoreExt as _,
|
||||
},
|
||||
OpenStoreError,
|
||||
};
|
||||
|
||||
@@ -98,7 +100,8 @@ impl SqliteCryptoStore {
|
||||
passphrase: Option<&str>,
|
||||
) -> Result<Self, OpenStoreError> {
|
||||
let conn = pool.get().await?;
|
||||
run_migrations(&conn).await.map_err(OpenStoreError::Migration)?;
|
||||
let version = load_db_version(&conn).await?;
|
||||
run_migrations(&conn, version).await.map_err(OpenStoreError::Migration)?;
|
||||
let store_cipher = match passphrase {
|
||||
Some(p) => Some(Arc::new(get_or_create_store_cipher(p, &conn).await?)),
|
||||
None => None,
|
||||
@@ -192,36 +195,14 @@ impl SqliteCryptoStore {
|
||||
|
||||
const DATABASE_VERSION: u8 = 6;
|
||||
|
||||
async fn run_migrations(conn: &SqliteConn) -> rusqlite::Result<()> {
|
||||
let kv_exists = conn
|
||||
.query_row(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
|
||||
(),
|
||||
|row| row.get::<_, u32>(0),
|
||||
)
|
||||
.await?
|
||||
> 0;
|
||||
|
||||
let version = if kv_exists {
|
||||
match conn.get_kv("version").await?.as_deref() {
|
||||
Some([v]) => *v,
|
||||
Some(_) => {
|
||||
error!("version database field has multiple bytes");
|
||||
return Ok(());
|
||||
}
|
||||
None => {
|
||||
error!("version database field is missing");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
/// Run migrations for the given version of the database.
|
||||
async fn run_migrations(conn: &SqliteConn, version: u8) -> rusqlite::Result<()> {
|
||||
if version == 0 {
|
||||
debug!("Creating database");
|
||||
} else if version < DATABASE_VERSION {
|
||||
debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if version < 1 {
|
||||
|
||||
@@ -32,6 +32,18 @@ pub enum OpenStoreError {
|
||||
#[error(transparent)]
|
||||
CreatePool(#[from] CreatePoolError),
|
||||
|
||||
/// Failed to load the database's version.
|
||||
#[error("Failed to load database version")]
|
||||
LoadVersion(#[source] rusqlite::Error),
|
||||
|
||||
/// The version of the database is missing.
|
||||
#[error("Missing database version")]
|
||||
MissingVersion,
|
||||
|
||||
/// The version of the database is invalid.
|
||||
#[error("Invalid database version")]
|
||||
InvalidVersion,
|
||||
|
||||
/// Failed to apply migrations.
|
||||
#[error("Failed to run migrations")]
|
||||
Migration(#[source] rusqlite::Error),
|
||||
|
||||
@@ -30,12 +30,12 @@ use ruma::{
|
||||
use rusqlite::{OptionalExtension, Transaction};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::fs;
|
||||
use tracing::{debug, error, warn};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
get_or_create_store_cipher,
|
||||
utils::{chain, Key, SqliteObjectExt},
|
||||
utils::{chain, load_db_version, Key, SqliteObjectExt},
|
||||
OpenStoreError, SqliteObjectStoreExt,
|
||||
};
|
||||
|
||||
@@ -93,7 +93,8 @@ impl SqliteStateStore {
|
||||
passphrase: Option<&str>,
|
||||
) -> Result<Self, OpenStoreError> {
|
||||
let conn = pool.get().await?;
|
||||
run_migrations(&conn).await.map_err(OpenStoreError::Migration)?;
|
||||
let version = load_db_version(&conn).await?;
|
||||
run_migrations(&conn, version).await.map_err(OpenStoreError::Migration)?;
|
||||
let store_cipher = match passphrase {
|
||||
Some(p) => Some(Arc::new(get_or_create_store_cipher(p, &conn).await?)),
|
||||
None => None,
|
||||
@@ -197,36 +198,13 @@ impl SqliteStateStore {
|
||||
|
||||
const DATABASE_VERSION: u8 = 1;
|
||||
|
||||
async fn run_migrations(conn: &SqliteConn) -> rusqlite::Result<()> {
|
||||
let kv_exists = conn
|
||||
.query_row(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
|
||||
(),
|
||||
|row| row.get::<_, u32>(0),
|
||||
)
|
||||
.await?
|
||||
> 0;
|
||||
|
||||
let version = if kv_exists {
|
||||
match conn.get_kv("version").await?.as_deref() {
|
||||
Some([v]) => *v,
|
||||
Some(_) => {
|
||||
error!("version database field has multiple bytes");
|
||||
return Ok(());
|
||||
}
|
||||
None => {
|
||||
error!("version database field is missing");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
async fn run_migrations(conn: &SqliteConn, version: u8) -> rusqlite::Result<()> {
|
||||
if version == 0 {
|
||||
debug!("Creating database");
|
||||
} else if version < DATABASE_VERSION {
|
||||
debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if version < 1 {
|
||||
|
||||
@@ -17,6 +17,8 @@ use std::ops::Deref;
|
||||
use async_trait::async_trait;
|
||||
use rusqlite::{OptionalExtension, Params, Row, Statement, Transaction};
|
||||
|
||||
use crate::OpenStoreError;
|
||||
|
||||
pub(crate) fn chain<T>(
|
||||
it1: impl IntoIterator<Item = T>,
|
||||
it2: impl IntoIterator<Item = T>,
|
||||
@@ -181,3 +183,26 @@ impl SqliteObjectStoreExt for deadpool_sqlite::Object {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the version of the database with the given connection.
|
||||
pub(crate) async fn load_db_version(conn: &deadpool_sqlite::Object) -> Result<u8, OpenStoreError> {
|
||||
let kv_exists = conn
|
||||
.query_row(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
|
||||
(),
|
||||
|row| row.get::<_, u32>(0),
|
||||
)
|
||||
.await
|
||||
.map_err(OpenStoreError::LoadVersion)?
|
||||
> 0;
|
||||
|
||||
if kv_exists {
|
||||
match conn.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
|
||||
Some([v]) => Ok(*v),
|
||||
Some(_) => Err(OpenStoreError::InvalidVersion),
|
||||
None => Err(OpenStoreError::MissingVersion),
|
||||
}
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user