use std::{ fmt, fs::{self, File}, io::Read, path::Path, str::FromStr, time::Duration, }; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous}, Pool, Sqlite, Transaction, }; use thiserror::Error; use tracing::{event, Level}; use crate::consts; pub mod recipe; pub mod user; const CURRENT_DB_VERSION: u32 = 1; #[derive(Error, Debug)] pub enum DBError { #[error("Sqlx error: {0}")] Sqlx(#[from] sqlx::Error), #[error( "Unsupported database version: {0} (application version: {current})", current = CURRENT_DB_VERSION )] UnsupportedVersion(u32), #[error("Unknown error: {0}")] Other(String), } impl DBError { fn from_dyn_error(error: Box) -> Self { DBError::Other(error.to_string()) } } type Result = std::result::Result; #[derive(Clone)] pub struct Connection { pool: Pool, } impl Connection { pub async fn new() -> Result { let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME); Self::new_from_file(path).await } #[cfg(test)] pub async fn new_in_memory() -> Result { Self::create_connection(SqlitePoolOptions::new().connect("sqlite::memory:").await?).await } pub async fn new_from_file>(file: P) -> Result { if let Some(data_dir) = file.as_ref().parent() { if !data_dir.exists() { fs::DirBuilder::new().create(data_dir).unwrap(); } } let options = SqliteConnectOptions::from_str(&format!( "sqlite://{}", file.as_ref().to_str().unwrap() ))? .journal_mode(SqliteJournalMode::Wal) // TODO: use 'Wal2' when available. .create_if_missing(true) .busy_timeout(Duration::from_secs(10)) .foreign_keys(true) .synchronous(SqliteSynchronous::Normal); Self::create_connection( SqlitePoolOptions::new() .max_connections(consts::MAX_DB_CONNECTION) .connect_with(options) .await?, ) .await } async fn create_connection(pool: Pool) -> Result { let connection = Connection { pool }; connection.create_or_update_db().await?; Ok(connection) } async fn tx(&self) -> Result> { self.pool.begin().await.map_err(DBError::from) } /// Called after the connection has been established for creating or updating the database. /// The 'Version' table tracks the current state of the database. async fn create_or_update_db(&self) -> Result<()> { let mut tx = self.tx().await?; //con.transaction()?; // Check current database version. (Version 0 corresponds to an empty database). let mut version = match sqlx::query( r#" SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version' "#, ) .fetch_one(&mut *tx) .await { Ok(_) => sqlx::query_scalar("SELECT [version] FROM [Version] ORDER BY [id] DESC") .fetch_optional(&mut *tx) .await? .unwrap_or(0), Err(_) => 0, // If the database doesn't exist. }; while Self::update_to_next_version(version, &mut tx).await? { version += 1; } tx.commit().await?; Ok(()) } async fn update_to_next_version( current_version: u32, tx: &mut Transaction<'_, Sqlite>, ) -> Result { let next_version = current_version + 1; if next_version <= CURRENT_DB_VERSION { event!(Level::INFO, "Update to version {}...", next_version); } async fn update_version(to_version: u32, tx: &mut Transaction<'_, Sqlite>) -> Result<()> { sqlx::query( "INSERT INTO [Version] ([version], [datetime]) VALUES ($1, datetime('now'))", ) .bind(to_version) .execute(&mut **tx) .await?; Ok(()) } fn ok(updated: bool) -> Result { if updated { event!(Level::INFO, "Version updated"); } Ok(updated) } match next_version { 1 => { let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string()); sqlx::query(&load_sql_file(&sql_file)?) .execute(&mut **tx) .await?; update_version(next_version, tx).await?; ok(true) } // Version 2 doesn't exist yet. 2 => ok(false), v => Err(DBError::UnsupportedVersion(v)), } } /// Execute a given SQL file. pub async fn execute_file + fmt::Display>(&self, file: P) -> Result<()> { let sql = load_sql_file(file)?; sqlx::query(&sql) .execute(&self.pool) .await .map(|_| ()) .map_err(DBError::from) } pub async fn execute_sql<'a>( &self, query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>, ) -> Result { query .execute(&self.pool) .await .map(|db_result| db_result.rows_affected()) .map_err(DBError::from) } // pub async fn execute_sql_and_fetch_all<'a>( // &self, // query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>, // ) -> Result> { // query.fetch_all(&self.pool).await.map_err(DBError::from) // } } fn load_sql_file + fmt::Display>(sql_file: P) -> Result { let mut file = File::open(&sql_file) .map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err)))?; let mut sql = String::new(); file.read_to_string(&mut sql) .map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err)))?; Ok(sql) }