recipes/backend/src/data/db.rs
2023-04-10 09:35:10 +02:00

868 lines
28 KiB
Rust

use std::{
fmt,
fs::{self, File},
io::Read,
path::Path,
};
use chrono::{prelude::*, Duration};
use itertools::Itertools;
use r2d2::{Pool, PooledConnection};
use r2d2_sqlite::SqliteConnectionManager;
use rand::distributions::{Alphanumeric, DistString};
use rusqlite::{named_params, params, OptionalExtension, Params};
use crate::{
consts,
hash::{hash, verify_password},
model,
};
const CURRENT_DB_VERSION: u32 = 1;
#[derive(Debug)]
pub enum DBError {
SqliteError(rusqlite::Error),
R2d2Error(r2d2::Error),
UnsupportedVersion(u32),
Other(String),
}
impl fmt::Display for DBError {
fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> {
write!(f, "{:?}", self)
}
}
impl std::error::Error for DBError {}
impl From<rusqlite::Error> for DBError {
fn from(error: rusqlite::Error) -> Self {
DBError::SqliteError(error)
}
}
impl From<r2d2::Error> for DBError {
fn from(error: r2d2::Error) -> Self {
DBError::R2d2Error(error)
}
}
impl DBError {
fn from_dyn_error(error: Box<dyn std::error::Error>) -> Self {
DBError::Other(error.to_string())
}
}
type Result<T> = std::result::Result<T, DBError>;
#[derive(Debug)]
pub enum SignUpResult {
UserAlreadyExists,
UserCreatedWaitingForValidation(String), // Validation token.
}
#[derive(Debug)]
pub enum ValidationResult {
UnknownUser,
ValidationExpired,
Ok(String, i64), // Returns token and user id.
}
#[derive(Debug)]
pub enum SignInResult {
UserNotFound,
WrongPassword,
AccountNotValidated,
Ok(String, i64), // Returns token and user id.
}
#[derive(Debug)]
pub enum AuthenticationResult {
NotValidToken,
Ok(i64), // Returns user id.
}
#[derive(Clone)]
pub struct Connection {
pool: Pool<SqliteConnectionManager>,
}
impl Connection {
pub fn new() -> Result<Connection> {
let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
Self::new_from_file(path)
}
pub fn new_in_memory() -> Result<Connection> {
Self::create_connection(SqliteConnectionManager::memory())
}
pub fn new_from_file<P: AsRef<Path>>(file: P) -> Result<Connection> {
if let Some(data_dir) = file.as_ref().parent() {
if !data_dir.exists() {
fs::DirBuilder::new().create(data_dir).unwrap();
}
}
Self::create_connection(SqliteConnectionManager::file(file))
}
fn create_connection(manager: SqliteConnectionManager) -> Result<Connection> {
let pool = r2d2::Pool::new(manager).unwrap();
let connection = Connection { pool };
connection.create_or_update_db()?;
Ok(connection)
}
fn get(&self) -> Result<PooledConnection<SqliteConnectionManager>> {
let con = self.pool.get()?;
// ('foreign_keys' is ON by default).
con.pragma_update(None, "synchronous", "NORMAL")?;
Ok(con)
}
/// Called after the connection has been established for creating or updating the database.
/// The 'Version' table tracks the current state of the database.
fn create_or_update_db(&self) -> Result<()> {
let mut con = self.get()?;
con.pragma_update(None, "journal_mode", "WAL")?; // Note: use "WAL2" when available.
let tx = con.transaction()?;
// Check current database version. (Version 0 corresponds to an empty database).
let mut version = {
match tx.query_row(
"SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'",
[],
|row| row.get::<usize, String>(0),
) {
Ok(_) => tx
.query_row(
"SELECT [version] FROM [Version] ORDER BY [id] DESC",
[],
|row| row.get(0),
)
.unwrap_or_default(),
Err(_) => 0,
}
};
while Self::update_to_next_version(version, &tx)? {
version += 1;
}
tx.commit()?;
Ok(())
}
fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result<bool> {
let next_version = current_version + 1;
if next_version <= CURRENT_DB_VERSION {
println!("Update to version {}...", next_version);
}
fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> {
tx.execute(
"INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))",
[to_version],
)
.map(|_| ())
.map_err(DBError::from)
}
fn ok(updated: bool) -> Result<bool> {
if updated {
println!("Version updated");
}
Ok(updated)
}
match next_version {
1 => {
let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string());
tx.execute_batch(&load_sql_file(&sql_file)?)?;
update_version(next_version, tx)?;
ok(true)
}
// Version 1 doesn't exist yet.
2 => ok(false),
v => Err(DBError::UnsupportedVersion(v)),
}
}
pub fn get_all_recipe_titles(&self) -> Result<Vec<(i64, String)>> {
let con = self.get()?;
let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
let titles: std::result::Result<Vec<(i64, String)>, rusqlite::Error> = stmt
.query_map([], |row| Ok((row.get("id")?, row.get("title")?)))?
.collect();
titles.map_err(DBError::from)
}
/* Not used for the moment.
pub fn get_all_recipes(&self) -> Result<Vec<model::Recipe>> {
let con = self.get()?;
let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?;
let recipes =
stmt.query_map([], |row| {
Ok(model::Recipe::new(row.get(0)?, row.get(1)?))
})?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap.
Ok(recipes)
} */
pub fn get_recipe(&self, id: i64) -> Result<model::Recipe> {
let con = self.get()?;
con.query_row(
"SELECT [id], [user_id], [title], [description] FROM [Recipe] WHERE [id] = ?1",
[id],
|row| {
Ok(model::Recipe::new(
row.get("id")?,
row.get("user_id")?,
row.get("title")?,
row.get("description")?,
))
},
)
.map_err(DBError::from)
}
pub fn get_user_login_info(&self, token: &str) -> Result<model::UserLoginInfo> {
let con = self.get()?;
con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| {
Ok(model::UserLoginInfo {
last_login_datetime: r.get("last_login_datetime")?,
ip: r.get("ip")?,
user_agent: r.get("user_agent")?,
})
}).map_err(DBError::from)
}
pub fn load_user(&self, user_id: i64) -> Result<model::User> {
let con = self.get()?;
con.query_row(
"SELECT [email] FROM [User] WHERE [id] = ?1",
[user_id],
|r| {
Ok(model::User {
id: user_id,
email: r.get("email")?,
})
},
)
.map_err(DBError::from)
}
pub fn sign_up(&self, email: &str, password: &str) -> Result<SignUpResult> {
self.sign_up_with_given_time(email, password, Utc::now())
}
fn sign_up_with_given_time(
&self,
email: &str,
password: &str,
datetime: DateTime<Utc>,
) -> Result<SignUpResult> {
let mut con = self.get()?;
let tx = con.transaction()?;
let token = match tx
.query_row(
"SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1",
[email],
|r| {
Ok((
r.get::<&str, i64>("id")?,
r.get::<&str, Option<String>>("validation_token")?,
))
},
)
.optional()?
{
Some((id, validation_token)) => {
if validation_token.is_none() {
return Ok(SignUpResult::UserAlreadyExists);
}
let token = generate_token();
let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
tx.execute(
"UPDATE [User]
SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4
WHERE [id] = ?1",
params![id, token, datetime, hashed_password],
)?;
token
}
None => {
let token = generate_token();
let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?;
tx.execute(
"INSERT INTO [User]
([email], [validation_token], [creation_datetime], [password])
VALUES (?1, ?2, ?3, ?4)",
params![email, token, datetime, hashed_password],
)?;
token
}
};
tx.commit()?;
Ok(SignUpResult::UserCreatedWaitingForValidation(token))
}
pub fn validation(
&self,
token: &str,
validation_time: Duration,
ip: &str,
user_agent: &str,
) -> Result<ValidationResult> {
let mut con = self.get()?;
let tx = con.transaction()?;
let user_id = match tx
.query_row(
"SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1",
[token],
|r| {
Ok((
r.get::<&str, i64>("id")?,
r.get::<&str, DateTime<Utc>>("creation_datetime")?,
))
},
)
.optional()?
{
Some((id, creation_datetime)) => {
if Utc::now() - creation_datetime > validation_time {
return Ok(ValidationResult::ValidationExpired);
}
tx.execute(
"UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1",
[id],
)?;
id
}
None => return Ok(ValidationResult::UnknownUser),
};
let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?;
tx.commit()?;
Ok(ValidationResult::Ok(token, user_id))
}
pub fn sign_in(
&self,
email: &str,
password: &str,
ip: &str,
user_agent: &str,
) -> Result<SignInResult> {
let mut con = self.get()?;
let tx = con.transaction()?;
match tx
.query_row(
"SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1",
[email],
|r| {
Ok((
r.get::<&str, i64>("id")?,
r.get::<&str, String>("password")?,
r.get::<&str, Option<String>>("validation_token")?,
))
},
)
.optional()?
{
Some((id, stored_password, validation_token)) => {
if validation_token.is_some() {
Ok(SignInResult::AccountNotValidated)
} else if verify_password(password, &stored_password)
.map_err(DBError::from_dyn_error)?
{
let token = Connection::create_login_token(&tx, id, ip, user_agent)?;
tx.commit()?;
Ok(SignInResult::Ok(token, id))
} else {
Ok(SignInResult::WrongPassword)
}
}
None => Ok(SignInResult::UserNotFound),
}
}
pub fn authentication(
&self,
token: &str,
ip: &str,
user_agent: &str,
) -> Result<AuthenticationResult> {
let mut con = self.get()?;
let tx = con.transaction()?;
match tx
.query_row(
"SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1",
[token],
|r| Ok((r.get::<&str, i64>("id")?, r.get::<&str, i64>("user_id")?)),
)
.optional()?
{
Some((login_id, user_id)) => {
tx.execute(
"UPDATE [UserLoginToken]
SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4
WHERE [id] = ?1",
params![login_id, Utc::now(), ip, user_agent],
)?;
tx.commit()?;
Ok(AuthenticationResult::Ok(user_id))
}
None => Ok(AuthenticationResult::NotValidToken),
}
}
pub fn sign_out(&self, token: &str) -> Result<()> {
let mut con = self.get()?;
let tx = con.transaction()?;
match tx
.query_row(
"SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1",
[token],
|r| Ok(r.get::<&str, i64>("id")?),
)
.optional()?
{
Some(login_id) => {
tx.execute(
"DELETE FROM [UserLoginToken] WHERE [id] = ?1",
params![login_id],
)?;
tx.commit()?
}
None => (),
}
Ok(())
}
pub fn create_recipe(&self, user_id: i64) -> Result<i64> {
let con = self.get()?;
// Verify if an empty recipe already exists. Returns its id if one exists.
match con
.query_row(
"SELECT [Recipe].[id] FROM [Recipe]
LEFT JOIN [Image] ON [Image].[recipe_id] = [Recipe].[id]
LEFT JOIN [Group] ON [Group].[recipe_id] = [Recipe].[id]
WHERE [Recipe].[user_id] = ?1
AND [Recipe].[title] = ''
AND [Recipe].[estimate_time] IS NULL
AND [Recipe].[description] = ''
AND [Image].[id] IS NULL
AND [Group].[id] IS NULL",
[user_id],
|r| Ok(r.get::<&str, i64>("id")?),
)
.optional()?
{
Some(recipe_id) => Ok(recipe_id),
None => {
con.execute(
"INSERT INTO [Recipe] ([user_id], [title]) VALUES (?1, '')",
[user_id],
)?;
Ok(con.last_insert_rowid())
}
}
}
pub fn set_recipe_title(&self, recipe_id: i64, title: &str) -> Result<()> {
let con = self.get()?;
con.execute(
"UPDATE [Recipe] SET [title] = ?2 WHERE [id] = ?1",
params![recipe_id, title],
)
.map(|_n| ())
.map_err(DBError::from)
}
pub fn set_recipe_description(&self, recipe_id: i64, description: &str) -> Result<()> {
let con = self.get()?;
con.execute(
"UPDATE [Recipe] SET [description] = ?2 WHERE [id] = ?1",
params![recipe_id, description],
)
.map(|_n| ())
.map_err(DBError::from)
}
/// Execute a given SQL file.
pub fn execute_file<P: AsRef<Path> + fmt::Display>(&self, file: P) -> Result<()> {
let con = self.get()?;
let sql = load_sql_file(file)?;
con.execute_batch(&sql).map_err(DBError::from)
}
/// Execute any SQL statement.
/// Mainly used for testing.
pub fn execute_sql<P: Params>(&self, sql: &str, params: P) -> Result<usize> {
let con = self.get()?;
con.execute(sql, params).map_err(DBError::from)
}
// Return the token.
fn create_login_token(
tx: &rusqlite::Transaction,
user_id: i64,
ip: &str,
user_agent: &str,
) -> Result<String> {
let token = generate_token();
tx.execute(
"INSERT INTO [UserLoginToken]
([user_id], [last_login_datetime], [token], [ip], [user_agent])
VALUES (?1, ?2, ?3, ?4, ?5)",
params![user_id, Utc::now(), token, ip, user_agent],
)?;
Ok(token)
}
}
fn load_sql_file<P: AsRef<Path> + fmt::Display>(sql_file: P) -> Result<String> {
let mut file = File::open(&sql_file).map_err(|err| {
DBError::Other(format!(
"Cannot open SQL file ({}): {}",
&sql_file,
err.to_string()
))
})?;
let mut sql = String::new();
file.read_to_string(&mut sql).map_err(|err| {
DBError::Other(format!(
"Cannot read SQL file ({}) : {}",
&sql_file,
err.to_string()
))
})?;
Ok(sql)
}
fn generate_token() -> String {
Alphanumeric.sample_string(&mut rand::thread_rng(), consts::AUTHENTICATION_TOKEN_SIZE)
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::{ffi, types::Value, Error, ErrorCode};
#[test]
fn sign_up() -> Result<()> {
let connection = Connection::new_in_memory()?;
match connection.sign_up("paul@atreides.com", "12345")? {
SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_to_an_already_existing_user() -> Result<()> {
let connection = Connection::new_in_memory()?;
connection.execute_sql("
INSERT INTO
[User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
VALUES (
1,
'paul@atreides.com',
'paul',
'$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
0,
NULL
);", [])?;
match connection.sign_up("paul@atreides.com", "12345")? {
SignUpResult::UserAlreadyExists => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_and_sign_in_without_validation() -> Result<()> {
let connection = Connection::new_in_memory()?;
let email = "paul@atreides.com";
let password = "12345";
match connection.sign_up(email, password)? {
SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
other => panic!("{:?}", other),
}
match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
SignInResult::AccountNotValidated => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> {
let connection = Connection::new_in_memory()?;
let token = generate_token();
connection.execute_sql("
INSERT INTO
[User] ([id], [email], [name], [password], [creation_datetime], [validation_token])
VALUES (
1,
'paul@atreides.com',
'paul',
'$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY',
0,
:token
);", named_params! { ":token": token })?;
match connection.sign_up("paul@atreides.com", "12345")? {
SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_then_send_validation_at_time() -> Result<()> {
let connection = Connection::new_in_memory()?;
let validation_token = match connection.sign_up("paul@atreides.com", "12345")? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
match connection.validation(
&validation_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla/5.0",
)? {
ValidationResult::Ok(_, _) => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_then_send_validation_too_late() -> Result<()> {
let connection = Connection::new_in_memory()?;
let validation_token = match connection.sign_up_with_given_time(
"paul@atreides.com",
"12345",
Utc::now() - Duration::days(1),
)? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
match connection.validation(
&validation_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla/5.0",
)? {
ValidationResult::ValidationExpired => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_then_send_validation_with_bad_token() -> Result<()> {
let connection = Connection::new_in_memory()?;
let _validation_token = match connection.sign_up("paul@atreides.com", "12345")? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
let random_token = generate_token();
match connection.validation(
&random_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla/5.0",
)? {
ValidationResult::UnknownUser => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_then_send_validation_then_sign_in() -> Result<()> {
let connection = Connection::new_in_memory()?;
let email = "paul@atreides.com";
let password = "12345";
// Sign up.
let validation_token = match connection.sign_up(email, password)? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
// Validation.
match connection.validation(
&validation_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla/5.0",
)? {
ValidationResult::Ok(_, _) => (),
other => panic!("{:?}", other),
};
// Sign in.
match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? {
SignInResult::Ok(_, _) => (), // Nominal case.
other => panic!("{:?}", other),
}
Ok(())
}
#[test]
fn sign_up_then_send_validation_then_authentication() -> Result<()> {
let connection = Connection::new_in_memory()?;
let email = "paul@atreides.com";
let password = "12345";
// Sign up.
let validation_token = match connection.sign_up(email, password)? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
// Validation.
let (authentication_token, user_id) = match connection.validation(
&validation_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla",
)? {
ValidationResult::Ok(token, user_id) => (token, user_id),
other => panic!("{:?}", other),
};
// Check user login information.
let user_login_info_1 = connection.get_user_login_info(&authentication_token)?;
assert_eq!(user_login_info_1.ip, "127.0.0.1");
assert_eq!(user_login_info_1.user_agent, "Mozilla");
// Authentication.
let _user_id =
match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? {
AuthenticationResult::Ok(user_id) => user_id, // Nominal case.
other => panic!("{:?}", other),
};
// Check user login information.
let user_login_info_2 = connection.get_user_login_info(&authentication_token)?;
assert_eq!(user_login_info_2.ip, "192.168.1.1");
assert_eq!(user_login_info_2.user_agent, "Chrome");
Ok(())
}
#[test]
fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> {
let connection = Connection::new_in_memory()?;
let email = "paul@atreides.com";
let password = "12345";
// Sign up.
let validation_token = match connection.sign_up(email, password)? {
SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case.
other => panic!("{:?}", other),
};
// Validation.
let (authentication_token_1, user_id_1) = match connection.validation(
&validation_token,
Duration::hours(1),
"127.0.0.1",
"Mozilla",
)? {
ValidationResult::Ok(token, user_id) => (token, user_id),
other => panic!("{:?}", other),
};
// Check user login information.
let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?;
assert_eq!(user_login_info_1.ip, "127.0.0.1");
assert_eq!(user_login_info_1.user_agent, "Mozilla");
// Sign out.
connection.sign_out(&authentication_token_1)?;
// Sign in.
let (authentication_token_2, user_id_2) =
match connection.sign_in(email, password, "192.168.1.1", "Chrome")? {
SignInResult::Ok(token, user_id) => (token, user_id),
other => panic!("{:?}", other),
};
assert_eq!(user_id_1, user_id_2);
assert_ne!(authentication_token_1, authentication_token_2);
// Check user login information.
let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?;
assert_eq!(user_login_info_2.ip, "192.168.1.1");
assert_eq!(user_login_info_2.user_agent, "Chrome");
Ok(())
}
#[test]
fn create_a_new_recipe_then_update_its_title() -> Result<()> {
let connection = Connection::new_in_memory()?;
connection.execute_sql(
"INSERT INTO [User]
([id], [email], [name], [password], [creation_datetime], [validation_token])
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
1,
"paul@atreides.com",
"paul",
"$argon2id$v=19$m=4096,t=3,p=1$G4fjepS05MkRbTqEImUdYg$GGziE8uVQe1L1oFHk37lBno10g4VISnVqynSkLCH3Lc",
"2022-11-29 22:05:04.121407300+00:00",
Value::Null,
]
)?;
match connection.create_recipe(2) {
Err(DBError::SqliteError(Error::SqliteFailure(
ffi::Error {
code: ErrorCode::ConstraintViolation,
extended_code: _,
},
Some(_),
))) => (), // Nominal case.
other => panic!(
"Creating a recipe with an inexistant user must fail: {:?}",
other
),
}
let recipe_id = connection.create_recipe(1)?;
assert_eq!(recipe_id, 1);
connection.set_recipe_title(recipe_id, "Crêpe")?;
let recipe = connection.get_recipe(recipe_id)?;
assert_eq!(recipe.title, "Crêpe".to_string());
Ok(())
}
}