sqlite: Return an error when DB version is invalid or missing (#2038)

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2023-06-09 13:02:33 +02:00
committed by GitHub
parent 0e30d85df4
commit ae2f834345
4 changed files with 54 additions and 58 deletions

View File

@@ -37,12 +37,14 @@ use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
use rusqlite::OptionalExtension;
use serde::{de::DeserializeOwned, Serialize};
use tokio::{fs, sync::Mutex};
use tracing::{debug, error, instrument, warn};
use tracing::{debug, instrument, warn};
use crate::{
error::{Error, Result},
get_or_create_store_cipher,
utils::{Key, SqliteConnectionExt as _, SqliteObjectExt, SqliteObjectStoreExt as _},
utils::{
load_db_version, Key, SqliteConnectionExt as _, SqliteObjectExt, SqliteObjectStoreExt as _,
},
OpenStoreError,
};
@@ -98,7 +100,8 @@ impl SqliteCryptoStore {
passphrase: Option<&str>,
) -> Result<Self, OpenStoreError> {
let conn = pool.get().await?;
run_migrations(&conn).await.map_err(OpenStoreError::Migration)?;
let version = load_db_version(&conn).await?;
run_migrations(&conn, version).await.map_err(OpenStoreError::Migration)?;
let store_cipher = match passphrase {
Some(p) => Some(Arc::new(get_or_create_store_cipher(p, &conn).await?)),
None => None,
@@ -192,36 +195,14 @@ impl SqliteCryptoStore {
const DATABASE_VERSION: u8 = 6;
async fn run_migrations(conn: &SqliteConn) -> rusqlite::Result<()> {
let kv_exists = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
(),
|row| row.get::<_, u32>(0),
)
.await?
> 0;
let version = if kv_exists {
match conn.get_kv("version").await?.as_deref() {
Some([v]) => *v,
Some(_) => {
error!("version database field has multiple bytes");
return Ok(());
}
None => {
error!("version database field is missing");
return Ok(());
}
}
} else {
0
};
/// Run migrations for the given version of the database.
async fn run_migrations(conn: &SqliteConn, version: u8) -> rusqlite::Result<()> {
if version == 0 {
debug!("Creating database");
} else if version < DATABASE_VERSION {
debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
} else {
return Ok(());
}
if version < 1 {

View File

@@ -32,6 +32,18 @@ pub enum OpenStoreError {
#[error(transparent)]
CreatePool(#[from] CreatePoolError),
/// Failed to load the database's version.
#[error("Failed to load database version")]
LoadVersion(#[source] rusqlite::Error),
/// The version of the database is missing.
#[error("Missing database version")]
MissingVersion,
/// The version of the database is invalid.
#[error("Invalid database version")]
InvalidVersion,
/// Failed to apply migrations.
#[error("Failed to run migrations")]
Migration(#[source] rusqlite::Error),

View File

@@ -30,12 +30,12 @@ use ruma::{
use rusqlite::{OptionalExtension, Transaction};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::fs;
use tracing::{debug, error, warn};
use tracing::{debug, warn};
use crate::{
error::{Error, Result},
get_or_create_store_cipher,
utils::{chain, Key, SqliteObjectExt},
utils::{chain, load_db_version, Key, SqliteObjectExt},
OpenStoreError, SqliteObjectStoreExt,
};
@@ -93,7 +93,8 @@ impl SqliteStateStore {
passphrase: Option<&str>,
) -> Result<Self, OpenStoreError> {
let conn = pool.get().await?;
run_migrations(&conn).await.map_err(OpenStoreError::Migration)?;
let version = load_db_version(&conn).await?;
run_migrations(&conn, version).await.map_err(OpenStoreError::Migration)?;
let store_cipher = match passphrase {
Some(p) => Some(Arc::new(get_or_create_store_cipher(p, &conn).await?)),
None => None,
@@ -197,36 +198,13 @@ impl SqliteStateStore {
const DATABASE_VERSION: u8 = 1;
async fn run_migrations(conn: &SqliteConn) -> rusqlite::Result<()> {
let kv_exists = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
(),
|row| row.get::<_, u32>(0),
)
.await?
> 0;
let version = if kv_exists {
match conn.get_kv("version").await?.as_deref() {
Some([v]) => *v,
Some(_) => {
error!("version database field has multiple bytes");
return Ok(());
}
None => {
error!("version database field is missing");
return Ok(());
}
}
} else {
0
};
async fn run_migrations(conn: &SqliteConn, version: u8) -> rusqlite::Result<()> {
if version == 0 {
debug!("Creating database");
} else if version < DATABASE_VERSION {
debug!(version, new_version = DATABASE_VERSION, "Upgrading database");
} else {
return Ok(());
}
if version < 1 {

View File

@@ -17,6 +17,8 @@ use std::ops::Deref;
use async_trait::async_trait;
use rusqlite::{OptionalExtension, Params, Row, Statement, Transaction};
use crate::OpenStoreError;
pub(crate) fn chain<T>(
it1: impl IntoIterator<Item = T>,
it2: impl IntoIterator<Item = T>,
@@ -181,3 +183,26 @@ impl SqliteObjectStoreExt for deadpool_sqlite::Object {
Ok(())
}
}
/// Load the version of the database with the given connection.
pub(crate) async fn load_db_version(conn: &deadpool_sqlite::Object) -> Result<u8, OpenStoreError> {
let kv_exists = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'",
(),
|row| row.get::<_, u32>(0),
)
.await
.map_err(OpenStoreError::LoadVersion)?
> 0;
if kv_exists {
match conn.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
Some([v]) => Ok(*v),
Some(_) => Err(OpenStoreError::InvalidVersion),
None => Err(OpenStoreError::MissingVersion),
}
} else {
Ok(0)
}
}