use std::{ fmt, fs::{self, File}, io::Read, path::Path, }; use chrono::{prelude::*, Duration}; use itertools::Itertools; use r2d2::{Pool, PooledConnection}; use r2d2_sqlite::SqliteConnectionManager; use rand::distributions::{Alphanumeric, DistString}; use rusqlite::{named_params, params, OptionalExtension, Params}; use crate::{ consts, hash::{hash, verify_password}, model, }; const CURRENT_DB_VERSION: u32 = 1; #[derive(Debug)] pub enum DBError { SqliteError(rusqlite::Error), R2d2Error(r2d2::Error), UnsupportedVersion(u32), Other(String), } impl fmt::Display for DBError { fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> { write!(f, "{:?}", self) } } impl std::error::Error for DBError {} impl From for DBError { fn from(error: rusqlite::Error) -> Self { DBError::SqliteError(error) } } impl From for DBError { fn from(error: r2d2::Error) -> Self { DBError::R2d2Error(error) } } impl DBError { fn from_dyn_error(error: Box) -> Self { DBError::Other(error.to_string()) } } type Result = std::result::Result; #[derive(Debug)] pub enum SignUpResult { UserAlreadyExists, UserCreatedWaitingForValidation(String), // Validation token. } #[derive(Debug)] pub enum ValidationResult { UnknownUser, ValidationExpired, Ok(String, i64), // Returns token and user id. } #[derive(Debug)] pub enum SignInResult { UserNotFound, WrongPassword, AccountNotValidated, Ok(String, i64), // Returns token and user id. } #[derive(Debug)] pub enum AuthenticationResult { NotValidToken, Ok(i64), // Returns user id. } #[derive(Clone)] pub struct Connection { pool: Pool, } impl Connection { pub fn new() -> Result { let path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME); Self::new_from_file(path) } pub fn new_in_memory() -> Result { Self::create_connection(SqliteConnectionManager::memory()) } pub fn new_from_file>(file: P) -> Result { if let Some(data_dir) = file.as_ref().parent() { if !data_dir.exists() { fs::DirBuilder::new().create(data_dir).unwrap(); } } Self::create_connection(SqliteConnectionManager::file(file)) } fn create_connection(manager: SqliteConnectionManager) -> Result { let pool = r2d2::Pool::new(manager).unwrap(); let connection = Connection { pool }; connection.create_or_update_db()?; Ok(connection) } fn get(&self) -> Result> { let con = self.pool.get()?; // ('foreign_keys' is ON by default). con.pragma_update(None, "synchronous", "NORMAL")?; Ok(con) } /// Called after the connection has been established for creating or updating the database. /// The 'Version' table tracks the current state of the database. fn create_or_update_db(&self) -> Result<()> { let mut con = self.get()?; con.pragma_update(None, "journal_mode", "WAL")?; // Note: use "WAL2" when available. let tx = con.transaction()?; // Check current database version. (Version 0 corresponds to an empty database). let mut version = { match tx.query_row( "SELECT [name] FROM [sqlite_master] WHERE [type] = 'table' AND [name] = 'Version'", [], |row| row.get::(0), ) { Ok(_) => tx .query_row( "SELECT [version] FROM [Version] ORDER BY [id] DESC", [], |row| row.get(0), ) .unwrap_or_default(), Err(_) => 0, } }; while Self::update_to_next_version(version, &tx)? { version += 1; } tx.commit()?; Ok(()) } fn update_to_next_version(current_version: u32, tx: &rusqlite::Transaction) -> Result { let next_version = current_version + 1; if next_version <= CURRENT_DB_VERSION { println!("Update to version {}...", next_version); } fn update_version(to_version: u32, tx: &rusqlite::Transaction) -> Result<()> { tx.execute( "INSERT INTO [Version] ([version], [datetime]) VALUES (?1, datetime('now'))", [to_version], ) .map(|_| ()) .map_err(DBError::from) } fn ok(updated: bool) -> Result { if updated { println!("Version updated"); } Ok(updated) } match next_version { 1 => { let sql_file = consts::SQL_FILENAME.replace("{VERSION}", &next_version.to_string()); tx.execute_batch(&load_sql_file(&sql_file)?)?; update_version(next_version, tx)?; ok(true) } // Version 1 doesn't exist yet. 2 => ok(false), v => Err(DBError::UnsupportedVersion(v)), } } pub fn get_all_recipe_titles(&self) -> Result> { let con = self.get()?; let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?; let titles: std::result::Result, rusqlite::Error> = stmt .query_map([], |row| Ok((row.get("id")?, row.get("title")?)))? .collect(); titles.map_err(DBError::from) } /* Not used for the moment. pub fn get_all_recipes(&self) -> Result> { let con = self.get()?; let mut stmt = con.prepare("SELECT [id], [title] FROM [Recipe] ORDER BY [title]")?; let recipes = stmt.query_map([], |row| { Ok(model::Recipe::new(row.get(0)?, row.get(1)?)) })?.map(|r| r.unwrap()).collect_vec(); // TODO: remove unwrap. Ok(recipes) } */ pub fn get_recipe(&self, id: i64) -> Result { let con = self.get()?; con.query_row( "SELECT [id], [user_id], [title], [description] FROM [Recipe] WHERE [id] = ?1", [id], |row| { Ok(model::Recipe::new( row.get("id")?, row.get("user_id")?, row.get("title")?, row.get("description")?, )) }, ) .map_err(DBError::from) } pub fn get_user_login_info(&self, token: &str) -> Result { let con = self.get()?; con.query_row("SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| { Ok(model::UserLoginInfo { last_login_datetime: r.get("last_login_datetime")?, ip: r.get("ip")?, user_agent: r.get("user_agent")?, }) }).map_err(DBError::from) } pub fn load_user(&self, user_id: i64) -> Result { let con = self.get()?; con.query_row( "SELECT [email] FROM [User] WHERE [id] = ?1", [user_id], |r| { Ok(model::User { id: user_id, email: r.get("email")?, }) }, ) .map_err(DBError::from) } pub fn sign_up(&self, email: &str, password: &str) -> Result { self.sign_up_with_given_time(email, password, Utc::now()) } fn sign_up_with_given_time( &self, email: &str, password: &str, datetime: DateTime, ) -> Result { let mut con = self.get()?; let tx = con.transaction()?; let token = match tx .query_row( "SELECT [id], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| { Ok(( r.get::<&str, i64>("id")?, r.get::<&str, Option>("validation_token")?, )) }, ) .optional()? { Some((id, validation_token)) => { if validation_token.is_none() { return Ok(SignUpResult::UserAlreadyExists); } let token = generate_token(); let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?; tx.execute( "UPDATE [User] SET [validation_token] = ?2, [creation_datetime] = ?3, [password] = ?4 WHERE [id] = ?1", params![id, token, datetime, hashed_password], )?; token } None => { let token = generate_token(); let hashed_password = hash(password).map_err(|e| DBError::from_dyn_error(e))?; tx.execute( "INSERT INTO [User] ([email], [validation_token], [creation_datetime], [password]) VALUES (?1, ?2, ?3, ?4)", params![email, token, datetime, hashed_password], )?; token } }; tx.commit()?; Ok(SignUpResult::UserCreatedWaitingForValidation(token)) } pub fn validation( &self, token: &str, validation_time: Duration, ip: &str, user_agent: &str, ) -> Result { let mut con = self.get()?; let tx = con.transaction()?; let user_id = match tx .query_row( "SELECT [id], [creation_datetime] FROM [User] WHERE [validation_token] = ?1", [token], |r| { Ok(( r.get::<&str, i64>("id")?, r.get::<&str, DateTime>("creation_datetime")?, )) }, ) .optional()? { Some((id, creation_datetime)) => { if Utc::now() - creation_datetime > validation_time { return Ok(ValidationResult::ValidationExpired); } tx.execute( "UPDATE [User] SET [validation_token] = NULL WHERE [id] = ?1", [id], )?; id } None => return Ok(ValidationResult::UnknownUser), }; let token = Connection::create_login_token(&tx, user_id, ip, user_agent)?; tx.commit()?; Ok(ValidationResult::Ok(token, user_id)) } pub fn sign_in( &self, email: &str, password: &str, ip: &str, user_agent: &str, ) -> Result { let mut con = self.get()?; let tx = con.transaction()?; match tx .query_row( "SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = ?1", [email], |r| { Ok(( r.get::<&str, i64>("id")?, r.get::<&str, String>("password")?, r.get::<&str, Option>("validation_token")?, )) }, ) .optional()? { Some((id, stored_password, validation_token)) => { if validation_token.is_some() { Ok(SignInResult::AccountNotValidated) } else if verify_password(password, &stored_password) .map_err(DBError::from_dyn_error)? { let token = Connection::create_login_token(&tx, id, ip, user_agent)?; tx.commit()?; Ok(SignInResult::Ok(token, id)) } else { Ok(SignInResult::WrongPassword) } } None => Ok(SignInResult::UserNotFound), } } pub fn authentication( &self, token: &str, ip: &str, user_agent: &str, ) -> Result { let mut con = self.get()?; let tx = con.transaction()?; match tx .query_row( "SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| Ok((r.get::<&str, i64>("id")?, r.get::<&str, i64>("user_id")?)), ) .optional()? { Some((login_id, user_id)) => { tx.execute( "UPDATE [UserLoginToken] SET [last_login_datetime] = ?2, [ip] = ?3, [user_agent] = ?4 WHERE [id] = ?1", params![login_id, Utc::now(), ip, user_agent], )?; tx.commit()?; Ok(AuthenticationResult::Ok(user_id)) } None => Ok(AuthenticationResult::NotValidToken), } } pub fn sign_out(&self, token: &str) -> Result<()> { let mut con = self.get()?; let tx = con.transaction()?; match tx .query_row( "SELECT [id] FROM [UserLoginToken] WHERE [token] = ?1", [token], |r| Ok(r.get::<&str, i64>("id")?), ) .optional()? { Some(login_id) => { tx.execute( "DELETE FROM [UserLoginToken] WHERE [id] = ?1", params![login_id], )?; tx.commit()? } None => (), } Ok(()) } pub fn create_recipe(&self, user_id: i64) -> Result { let con = self.get()?; // Verify if an empty recipe already exists. Returns its id if one exists. match con .query_row( "SELECT [Recipe].[id] FROM [Recipe] LEFT JOIN [Image] ON [Image].[recipe_id] = [Recipe].[id] LEFT JOIN [Group] ON [Group].[recipe_id] = [Recipe].[id] WHERE [Recipe].[user_id] = ?1 AND [Recipe].[title] = '' AND [Recipe].[estimate_time] IS NULL AND [Recipe].[description] = '' AND [Image].[id] IS NULL AND [Group].[id] IS NULL", [user_id], |r| Ok(r.get::<&str, i64>("id")?), ) .optional()? { Some(recipe_id) => Ok(recipe_id), None => { con.execute( "INSERT INTO [Recipe] ([user_id], [title]) VALUES (?1, '')", [user_id], )?; Ok(con.last_insert_rowid()) } } } pub fn set_recipe_title(&self, recipe_id: i64, title: &str) -> Result<()> { let con = self.get()?; con.execute( "UPDATE [Recipe] SET [title] = ?2 WHERE [id] = ?1", params![recipe_id, title], ) .map(|_n| ()) .map_err(DBError::from) } pub fn set_recipe_description(&self, recipe_id: i64, description: &str) -> Result<()> { let con = self.get()?; con.execute( "UPDATE [Recipe] SET [description] = ?2 WHERE [id] = ?1", params![recipe_id, description], ) .map(|_n| ()) .map_err(DBError::from) } /// Execute a given SQL file. pub fn execute_file + fmt::Display>(&self, file: P) -> Result<()> { let con = self.get()?; let sql = load_sql_file(file)?; con.execute_batch(&sql).map_err(DBError::from) } /// Execute any SQL statement. /// Mainly used for testing. pub fn execute_sql(&self, sql: &str, params: P) -> Result { let con = self.get()?; con.execute(sql, params).map_err(DBError::from) } // Return the token. fn create_login_token( tx: &rusqlite::Transaction, user_id: i64, ip: &str, user_agent: &str, ) -> Result { let token = generate_token(); tx.execute( "INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES (?1, ?2, ?3, ?4, ?5)", params![user_id, Utc::now(), token, ip, user_agent], )?; Ok(token) } } fn load_sql_file + fmt::Display>(sql_file: P) -> Result { let mut file = File::open(&sql_file).map_err(|err| { DBError::Other(format!( "Cannot open SQL file ({}): {}", &sql_file, err.to_string() )) })?; let mut sql = String::new(); file.read_to_string(&mut sql).map_err(|err| { DBError::Other(format!( "Cannot read SQL file ({}) : {}", &sql_file, err.to_string() )) })?; Ok(sql) } fn generate_token() -> String { Alphanumeric.sample_string(&mut rand::thread_rng(), consts::AUTHENTICATION_TOKEN_SIZE) } #[cfg(test)] mod tests { use super::*; use rusqlite::{ffi, types::Value, Error, ErrorCode}; #[test] fn sign_up() -> Result<()> { let connection = Connection::new_in_memory()?; match connection.sign_up("paul@atreides.com", "12345")? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_to_an_already_existing_user() -> Result<()> { let connection = Connection::new_in_memory()?; connection.execute_sql(" INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token]) VALUES ( 1, 'paul@atreides.com', 'paul', '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY', 0, NULL );", [])?; match connection.sign_up("paul@atreides.com", "12345")? { SignUpResult::UserAlreadyExists => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_and_sign_in_without_validation() -> Result<()> { let connection = Connection::new_in_memory()?; let email = "paul@atreides.com"; let password = "12345"; match connection.sign_up(email, password)? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? { SignInResult::AccountNotValidated => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> { let connection = Connection::new_in_memory()?; let token = generate_token(); connection.execute_sql(" INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token]) VALUES ( 1, 'paul@atreides.com', 'paul', '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY', 0, :token );", named_params! { ":token": token })?; match connection.sign_up("paul@atreides.com", "12345")? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_then_send_validation_at_time() -> Result<()> { let connection = Connection::new_in_memory()?; let validation_token = match connection.sign_up("paul@atreides.com", "12345")? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; match connection.validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", )? { ValidationResult::Ok(_, _) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_then_send_validation_too_late() -> Result<()> { let connection = Connection::new_in_memory()?; let validation_token = match connection.sign_up_with_given_time( "paul@atreides.com", "12345", Utc::now() - Duration::days(1), )? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; match connection.validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", )? { ValidationResult::ValidationExpired => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_then_send_validation_with_bad_token() -> Result<()> { let connection = Connection::new_in_memory()?; let _validation_token = match connection.sign_up("paul@atreides.com", "12345")? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; let random_token = generate_token(); match connection.validation( &random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", )? { ValidationResult::UnknownUser => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_then_send_validation_then_sign_in() -> Result<()> { let connection = Connection::new_in_memory()?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password)? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. match connection.validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", )? { ValidationResult::Ok(_, _) => (), other => panic!("{:?}", other), }; // Sign in. match connection.sign_in(email, password, "127.0.0.1", "Mozilla/5.0")? { SignInResult::Ok(_, _) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[test] fn sign_up_then_send_validation_then_authentication() -> Result<()> { let connection = Connection::new_in_memory()?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password)? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. let (authentication_token, user_id) = match connection.validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla", )? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; // Check user login information. let user_login_info_1 = connection.get_user_login_info(&authentication_token)?; assert_eq!(user_login_info_1.ip, "127.0.0.1"); assert_eq!(user_login_info_1.user_agent, "Mozilla"); // Authentication. let _user_id = match connection.authentication(&authentication_token, "192.168.1.1", "Chrome")? { AuthenticationResult::Ok(user_id) => user_id, // Nominal case. other => panic!("{:?}", other), }; // Check user login information. let user_login_info_2 = connection.get_user_login_info(&authentication_token)?; assert_eq!(user_login_info_2.ip, "192.168.1.1"); assert_eq!(user_login_info_2.user_agent, "Chrome"); Ok(()) } #[test] fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> { let connection = Connection::new_in_memory()?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password)? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. let (authentication_token_1, user_id_1) = match connection.validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla", )? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; // Check user login information. let user_login_info_1 = connection.get_user_login_info(&authentication_token_1)?; assert_eq!(user_login_info_1.ip, "127.0.0.1"); assert_eq!(user_login_info_1.user_agent, "Mozilla"); // Sign out. connection.sign_out(&authentication_token_1)?; // Sign in. let (authentication_token_2, user_id_2) = match connection.sign_in(email, password, "192.168.1.1", "Chrome")? { SignInResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; assert_eq!(user_id_1, user_id_2); assert_ne!(authentication_token_1, authentication_token_2); // Check user login information. let user_login_info_2 = connection.get_user_login_info(&authentication_token_2)?; assert_eq!(user_login_info_2.ip, "192.168.1.1"); assert_eq!(user_login_info_2.user_agent, "Chrome"); Ok(()) } #[test] fn create_a_new_recipe_then_update_its_title() -> Result<()> { let connection = Connection::new_in_memory()?; connection.execute_sql( "INSERT INTO [User] ([id], [email], [name], [password], [creation_datetime], [validation_token]) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ 1, "paul@atreides.com", "paul", "$argon2id$v=19$m=4096,t=3,p=1$G4fjepS05MkRbTqEImUdYg$GGziE8uVQe1L1oFHk37lBno10g4VISnVqynSkLCH3Lc", "2022-11-29 22:05:04.121407300+00:00", Value::Null, ] )?; match connection.create_recipe(2) { Err(DBError::SqliteError(Error::SqliteFailure( ffi::Error { code: ErrorCode::ConstraintViolation, extended_code: _, }, Some(_), ))) => (), // Nominal case. other => panic!( "Creating a recipe with an inexistant user must fail: {:?}", other ), } let recipe_id = connection.create_recipe(1)?; assert_eq!(recipe_id, 1); connection.set_recipe_title(recipe_id, "Crêpe")?; let recipe = connection.get_recipe(recipe_id)?; assert_eq!(recipe.title, "Crêpe".to_string()); Ok(()) } }