recipes/backend/src/data/db.rs

211 lines
6.1 KiB
Rust

use std::{
fmt,
fs::{self, File},
io::Read,
path::Path,
str::FromStr,
time::Duration,
};
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous},
Pool, Sqlite, Transaction,
};
use thiserror::Error;
use tracing::{event, Level};
use crate::consts;
pub mod recipe;
pub mod user;
const CURRENT_DB_VERSION: u32 = 1;
#[derive(Error, Debug)]
pub enum DBError {
#[error("Sqlx error: {0}")]
Sqlx(#[from] sqlx::Error),
#[error(
"Unsupported database version: {0} (application version: {current})",
current = CURRENT_DB_VERSION
)]
UnsupportedVersion(u32),
#[error("Unknown error: {0}")]
Other(String),
}
impl DBError {
fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
DBError::Other(error.to_string())
}
}
type Result<T> = std::result::Result<T, DBError>;
#[derive(Clone)]
pub struct Connection {
pool: Pool<Sqlite>,
}
impl Connection {
pub async fn new() -> Result<Connection> {
let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
Self::new_from_file(path).await
}
#[cfg(test)]
pub async fn new_in_memory() -> Result<Connection> {
Self::create_connection(SqlitePoolOptions::new().connect("sqlite::memory:").await?).await
}
pub async fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
if let Some(data_dir) = file.as_ref().parent() {
if !data_dir.exists() {
fs::DirBuilder::new().create(data_dir).unwrap();
}
}
let options = SqliteConnectOptions::from_str(&format!(
"sqlite://{}",
file.as_ref().to_str().unwrap()
))?
.journal_mode(SqliteJournalMode::Wal) // TODO: use 'Wal2' when available.
.create_if_missing(true)
.busy_timeout(Duration::from_secs(10))
.foreign_keys(true)
.synchronous(SqliteSynchronous::Normal);
Self::create_connection(
SqlitePoolOptions::new()
.max_connections(consts::MAX_DB_CONNECTION)
.connect_with(options)
.await?,
)
.await
}
async fn create_connection(pool: Pool<Sqlite>) -> Result<Connection> {
let connection = Connection { pool };
connection.create_or_update_db().await?;
Ok(connection)
}
async fn tx(&self) -> Result<Transaction<Sqlite>> {
self.pool.begin().await.map_err(DBError::from)
}
/// Called after the connection has been established for creating or updating the database.
/// The 'Version' table tracks the current state of the database.
async fn create_or_update_db(&self) -> Result<()> {
let mut tx = self.tx().await?; //con.transaction()?;
// Check current database version. (Version 0 corresponds to an empty database).
let mut version = match sqlx::query(
r#"
SELECT [name] FROM [sqlite_master]
WHERE [type] = 'table' AND [name] = 'Version'
"#,
)
.fetch_one(&mut *tx)
.await
{
Ok(_) => sqlx::query_scalar("SELECT [version] FROM [Version] ORDER BY [id] DESC")
.fetch_optional(&mut *tx)
.await?
.unwrap_or(0),
Err(_) => 0, // If the database doesn't exist.
};
while Self::update_to_next_version(version, &mut tx).await? {
version += 1;
}
tx.commit().await?;
Ok(())
}
async fn update_to_next_version(
current_version: u32,
tx: &mut Transaction<'_, Sqlite>,
) -> Result<bool> {
let next_version = current_version + 1;
if next_version <= CURRENT_DB_VERSION {
event!(Level::INFO, "Update to version {}...", next_version);
}
async fn update_version(to_version: u32, tx: &mut Transaction<'_, Sqlite>) -> Result<()> {
sqlx::query(
"INSERT INTO [Version] ([version], [datetime]) VALUES ($1, datetime('now'))",
)
.bind(to_version)
.execute(&mut **tx)
.await?;
Ok(())
}
fn ok(updated: bool) -> Result<bool> {
if updated {
event!(Level::INFO, "Version updated");
}
Ok(updated)
}
match next_version {
1 => {
let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
sqlx::query(&load_sql_file(&sql_file)?)
.execute(&mut **tx)
.await?;
update_version(next_version, tx).await?;
ok(true)
}
// Version 2 doesn't exist yet.
2 => ok(false),
v => Err(DBError::UnsupportedVersion(v)),
}
}
/// Execute a given SQL file.
pub async fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
let sql = load_sql_file(file)?;
sqlx::query(&sql)
.execute(&self.pool)
.await
.map(|_| ())
.map_err(DBError::from)
}
pub async fn execute_sql<'a>(
&self,
query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
) -> Result<u64> {
query
.execute(&self.pool)
.await
.map(|db_result| db_result.rows_affected())
.map_err(DBError::from)
}
// pub async fn execute_sql_and_fetch_all<'a>(
// &self,
// query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
// ) -> Result<Vec<SqliteRow>> {
// query.fetch_all(&self.pool).await.map_err(DBError::from)
// }
}
fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
let mut file = File::open(&sql_file)
.map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err)))?;
let mut sql = String::new();
file.read_to_string(&mut sql)
.map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err)))?;
Ok(sql)
}