211 lines
6.1 KiB
Rust
211 lines
6.1 KiB
Rust
use std::{
|
|
fmt,
|
|
fs::{self, File},
|
|
io::Read,
|
|
path::Path,
|
|
str::FromStr,
|
|
time::Duration,
|
|
};
|
|
|
|
use sqlx::{
|
|
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous},
|
|
Pool, Sqlite, Transaction,
|
|
};
|
|
use thiserror::Error;
|
|
use tracing::{event, Level};
|
|
|
|
use crate::consts;
|
|
|
|
pub mod recipe;
|
|
pub mod user;
|
|
|
|
const CURRENT_DB_VERSION: u32 = 1;
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum DBError {
|
|
#[error("Sqlx error: {0}")]
|
|
Sqlx(#[from] sqlx::Error),
|
|
|
|
#[error(
|
|
"Unsupported database version: {0} (application version: {current})",
|
|
current = CURRENT_DB_VERSION
|
|
)]
|
|
UnsupportedVersion(u32),
|
|
|
|
#[error("Unknown error: {0}")]
|
|
Other(String),
|
|
}
|
|
|
|
impl DBError {
|
|
fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
|
|
DBError::Other(error.to_string())
|
|
}
|
|
}
|
|
|
|
type Result<T> = std::result::Result<T, DBError>;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Connection {
|
|
pool: Pool<Sqlite>,
|
|
}
|
|
|
|
impl Connection {
|
|
pub async fn new() -> Result<Connection> {
|
|
let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
|
|
Self::new_from_file(path).await
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub async fn new_in_memory() -> Result<Connection> {
|
|
Self::create_connection(SqlitePoolOptions::new().connect("sqlite::memory:").await?).await
|
|
}
|
|
|
|
pub async fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
|
|
if let Some(data_dir) = file.as_ref().parent() {
|
|
if !data_dir.exists() {
|
|
fs::DirBuilder::new().create(data_dir).unwrap();
|
|
}
|
|
}
|
|
|
|
let options = SqliteConnectOptions::from_str(&format!(
|
|
"sqlite://{}",
|
|
file.as_ref().to_str().unwrap()
|
|
))?
|
|
.journal_mode(SqliteJournalMode::Wal) // TODO: use 'Wal2' when available.
|
|
.create_if_missing(true)
|
|
.busy_timeout(Duration::from_secs(10))
|
|
.foreign_keys(true)
|
|
.synchronous(SqliteSynchronous::Normal);
|
|
|
|
Self::create_connection(
|
|
SqlitePoolOptions::new()
|
|
.max_connections(consts::MAX_DB_CONNECTION)
|
|
.connect_with(options)
|
|
.await?,
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn create_connection(pool: Pool<Sqlite>) -> Result<Connection> {
|
|
let connection = Connection { pool };
|
|
connection.create_or_update_db().await?;
|
|
Ok(connection)
|
|
}
|
|
|
|
async fn tx(&self) -> Result<Transaction<Sqlite>> {
|
|
self.pool.begin().await.map_err(DBError::from)
|
|
}
|
|
|
|
/// Called after the connection has been established for creating or updating the database.
|
|
/// The 'Version' table tracks the current state of the database.
|
|
async fn create_or_update_db(&self) -> Result<()> {
|
|
let mut tx = self.tx().await?; //con.transaction()?;
|
|
|
|
// Check current database version. (Version 0 corresponds to an empty database).
|
|
let mut version = match sqlx::query(
|
|
r#"
|
|
SELECT [name] FROM [sqlite_master]
|
|
WHERE [type] = 'table' AND [name] = 'Version'
|
|
"#,
|
|
)
|
|
.fetch_one(&mut *tx)
|
|
.await
|
|
{
|
|
Ok(_) => sqlx::query_scalar("SELECT [version] FROM [Version] ORDER BY [id] DESC")
|
|
.fetch_optional(&mut *tx)
|
|
.await?
|
|
.unwrap_or(0),
|
|
Err(_) => 0, // If the database doesn't exist.
|
|
};
|
|
|
|
while Self::update_to_next_version(version, &mut tx).await? {
|
|
version += 1;
|
|
}
|
|
|
|
tx.commit().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn update_to_next_version(
|
|
current_version: u32,
|
|
tx: &mut Transaction<'_, Sqlite>,
|
|
) -> Result<bool> {
|
|
let next_version = current_version + 1;
|
|
|
|
if next_version <= CURRENT_DB_VERSION {
|
|
event!(Level::INFO, "Update to version {}...", next_version);
|
|
}
|
|
|
|
async fn update_version(to_version: u32, tx: &mut Transaction<'_, Sqlite>) -> Result<()> {
|
|
sqlx::query(
|
|
"INSERT INTO [Version] ([version], [datetime]) VALUES ($1, datetime('now'))",
|
|
)
|
|
.bind(to_version)
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn ok(updated: bool) -> Result<bool> {
|
|
if updated {
|
|
event!(Level::INFO, "Version updated");
|
|
}
|
|
Ok(updated)
|
|
}
|
|
|
|
match next_version {
|
|
1 => {
|
|
let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
|
|
sqlx::query(&load_sql_file(&sql_file)?)
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
update_version(next_version, tx).await?;
|
|
|
|
ok(true)
|
|
}
|
|
|
|
// Version 2 doesn't exist yet.
|
|
2 => ok(false),
|
|
|
|
v => Err(DBError::UnsupportedVersion(v)),
|
|
}
|
|
}
|
|
|
|
/// Execute a given SQL file.
|
|
pub async fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
|
|
let sql = load_sql_file(file)?;
|
|
sqlx::query(&sql)
|
|
.execute(&self.pool)
|
|
.await
|
|
.map(|_| ())
|
|
.map_err(DBError::from)
|
|
}
|
|
|
|
pub async fn execute_sql<'a>(
|
|
&self,
|
|
query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
|
|
) -> Result<u64> {
|
|
query
|
|
.execute(&self.pool)
|
|
.await
|
|
.map(|db_result| db_result.rows_affected())
|
|
.map_err(DBError::from)
|
|
}
|
|
|
|
// pub async fn execute_sql_and_fetch_all<'a>(
|
|
// &self,
|
|
// query: sqlx::query::Query<'a, Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
|
|
// ) -> Result<Vec<SqliteRow>> {
|
|
// query.fetch_all(&self.pool).await.map_err(DBError::from)
|
|
// }
|
|
}
|
|
|
|
fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
|
|
let mut file = File::open(&sql_file)
|
|
.map_err(|err| DBError::Other(format!("Cannot open SQL file ({}): {}", &sql_file, err)))?;
|
|
let mut sql = String::new();
|
|
file.read_to_string(&mut sql)
|
|
.map_err(|err| DBError::Other(format!("Cannot read SQL file ({}) : {}", &sql_file, err)))?;
|
|
Ok(sql)
|
|
}
|