use chrono::{prelude::*, Duration}; use rand::distributions::{Alphanumeric, DistString}; use sqlx::Sqlite; use super::{Connection, DBError, Result}; use crate::{ consts, data::model, hash::{hash, verify_password}, }; #[derive(Debug)] pub enum SignUpResult { UserAlreadyExists, UserCreatedWaitingForValidation(String), // Validation token. } #[derive(Debug)] pub enum UpdateUserResult { EmailAlreadyTaken, UserUpdatedWaitingForRevalidation(String), // Validation token. Ok, } #[derive(Debug)] pub enum ValidationResult { UnknownUser, ValidationExpired, Ok(String, i64), // Returns token and user id. } #[derive(Debug)] pub enum SignInResult { UserNotFound, WrongPassword, AccountNotValidated, Ok(String, i64), // Returns token and user id. } #[derive(Debug)] pub enum AuthenticationResult { NotValidToken, Ok(i64), // Returns user id. } #[derive(Debug)] pub enum GetTokenResetPasswordResult { PasswordAlreadyReset, EmailUnknown, Ok(String), } #[derive(Debug)] pub enum ResetPasswordResult { ResetTokenExpired, Ok, } fn generate_token() -> String { Alphanumeric.sample_string(&mut rand::thread_rng(), consts::TOKEN_SIZE) } impl Connection { #[cfg(test)] pub async fn get_user_login_info(&self, token: &str) -> Result { sqlx::query_as( r#" SELECT [last_login_datetime], [ip], [user_agent] FROM [UserLoginToken] WHERE [token] = $1 "#, ) .bind(token) .fetch_one(&self.pool) .await .map_err(DBError::from) } pub async fn load_user(&self, user_id: i64) -> Result> { sqlx::query_as("SELECT [id], [email], [name], [lang] FROM [User] WHERE [id] = $1") .bind(user_id) .fetch_optional(&self.pool) .await .map_err(DBError::from) } /// If a new email is given and it doesn't match the current one then it has to be /// Revalidated. pub async fn update_user( &self, user_id: i64, new_email: Option<&str>, new_name: Option<&str>, new_password: Option<&str>, ) -> Result { let mut tx = self.tx().await?; let hashed_new_password = new_password.map(|p| hash(p).unwrap()); let (email, name, hashed_password) = sqlx::query_as::<_, (String, String, String)>( "SELECT [email], [name], [password] FROM [User] WHERE [id] = $1", ) .bind(user_id) .fetch_one(&mut *tx) .await?; let new_email = new_email.map(str::trim); let email_changed = new_email.is_some_and(|new_email| new_email != email); // Check if email not already taken. let validation_token = if email_changed { if sqlx::query_scalar( r#" SELECT COUNT(*) > 0 FROM [User] WHERE [email] = $1 "#, ) .bind(new_email.unwrap()) .fetch_one(&mut *tx) .await? { return Ok(UpdateUserResult::EmailAlreadyTaken); } let token = Some(generate_token()); sqlx::query( r#" UPDATE [User] SET [validation_token] = $2, [validation_token_datetime] = $3 WHERE [id] = $1 "#, ) .bind(user_id) .bind(&token) .bind(Utc::now()) .execute(&mut *tx) .await?; token } else { None }; sqlx::query( r#" UPDATE [User] SET [email] = $2, [name] = $3, [password] = $4 WHERE [id] = $1 "#, ) .bind(user_id) .bind(new_email.unwrap_or(&email)) .bind(new_name.map(str::trim).unwrap_or(&name)) .bind(hashed_new_password.unwrap_or(hashed_password)) .execute(&mut *tx) .await?; tx.commit().await?; Ok(if let Some(validation_token) = validation_token { UpdateUserResult::UserUpdatedWaitingForRevalidation(validation_token) } else { UpdateUserResult::Ok }) } pub async fn set_user_lang(&self, user_id: i64, lang: &str) -> Result<()> { sqlx::query("UPDATE [User] SET [lang] = $2 WHERE [id] = $1") .bind(user_id) .bind(lang) .execute(&self.pool) .await .map(|_| ()) .map_err(DBError::from) } pub async fn sign_up(&self, email: &str, password: &str) -> Result { self.sign_up_with_given_time(email, password, Utc::now()) .await } async fn sign_up_with_given_time( &self, email: &str, password: &str, datetime: DateTime, ) -> Result { let mut tx = self.tx().await?; let token = match sqlx::query_as::<_, (i64, Option)>( r#" SELECT [id], [validation_token] FROM [User] WHERE [email] = $1 "#, ) .bind(email) .fetch_optional(&mut *tx) .await? { Some((id, validation_token)) => { if validation_token.is_none() { return Ok(SignUpResult::UserAlreadyExists); } let token = generate_token(); let hashed_password = hash(password).map_err(DBError::from_dyn_error)?; sqlx::query( r#" UPDATE [User] SET [validation_token] = $2, [validation_token_datetime] = $3, [password] = $4 WHERE [id] = $1 "#, ) .bind(id) .bind(&token) .bind(datetime) .bind(hashed_password) .execute(&mut *tx) .await?; token } None => { let token = generate_token(); let hashed_password = hash(password).map_err(DBError::from_dyn_error)?; sqlx::query( r#" INSERT INTO [User] ([email], [creation_datetime], [validation_token], [validation_token_datetime], [password]) VALUES ($1, $2, $3, $4, $5) "#, ) .bind(email) .bind(Utc::now()) .bind(&token) .bind(datetime) .bind(hashed_password) .execute(&mut *tx) .await?; token } }; tx.commit().await?; Ok(SignUpResult::UserCreatedWaitingForValidation(token)) } pub async fn validation( &self, token: &str, validation_time: Duration, ip: &str, user_agent: &str, ) -> Result { let mut tx = self.tx().await?; // There is no index on [validation_token]. Is it useful? let user_id = match sqlx::query_as::<_, (i64, DateTime)>( "SELECT [id], [validation_token_datetime] FROM [User] WHERE [validation_token] = $1", ) .bind(token) .fetch_optional(&mut *tx) .await? { Some((id, validation_token_datetime)) => { if Utc::now() - validation_token_datetime > validation_time { return Ok(ValidationResult::ValidationExpired); } sqlx::query("UPDATE [User] SET [validation_token] = NULL WHERE [id] = $1") .bind(id) .execute(&mut *tx) .await?; id } None => return Ok(ValidationResult::UnknownUser), }; let token = Self::create_login_token(&mut tx, user_id, ip, user_agent).await?; tx.commit().await?; Ok(ValidationResult::Ok(token, user_id)) } pub async fn sign_in( &self, email: &str, password: &str, ip: &str, user_agent: &str, ) -> Result { let mut tx = self.tx().await?; match sqlx::query_as::<_, (i64, String, Option)>( "SELECT [id], [password], [validation_token] FROM [User] WHERE [email] = $1", ) .bind(email) .fetch_optional(&mut *tx) .await? { Some((id, stored_password, validation_token)) => { if validation_token.is_some() { Ok(SignInResult::AccountNotValidated) } else if verify_password(password, &stored_password) .map_err(DBError::from_dyn_error)? { let token = Self::create_login_token(&mut tx, id, ip, user_agent).await?; tx.commit().await?; Ok(SignInResult::Ok(token, id)) } else { Ok(SignInResult::WrongPassword) } } None => Ok(SignInResult::UserNotFound), } } pub async fn authentication( &self, token: &str, ip: &str, user_agent: &str, ) -> Result { let mut tx = self.tx().await?; match sqlx::query_as::<_, (i64, i64)>( "SELECT [id], [user_id] FROM [UserLoginToken] WHERE [token] = $1", ) .bind(token) .fetch_optional(&mut *tx) .await? { Some((login_id, user_id)) => { sqlx::query( r#" UPDATE [UserLoginToken] SET [last_login_datetime] = $2, [ip] = $3, [user_agent] = $4 WHERE [id] = $1 "#, ) .bind(login_id) .bind(Utc::now()) .bind(ip) .bind(user_agent) .execute(&mut *tx) .await?; tx.commit().await?; Ok(AuthenticationResult::Ok(user_id)) } None => Ok(AuthenticationResult::NotValidToken), } } pub async fn sign_out(&self, token: &str) -> Result<()> { let mut tx = self.tx().await?; if let Some(login_id) = sqlx::query_scalar::<_, i64>("SELECT [id] FROM [UserLoginToken] WHERE [token] = $1") .bind(token) .fetch_optional(&mut *tx) .await? { sqlx::query("DELETE FROM [UserLoginToken] WHERE [id] = $1") .bind(login_id) .execute(&mut *tx) .await?; tx.commit().await?; } Ok(()) } pub async fn get_token_reset_password( &self, email: &str, validation_time: Duration, ) -> Result { let mut tx = self.tx().await?; if let Some(db_datetime_nullable) = sqlx::query_scalar::<_, Option>>( r#" SELECT [password_reset_datetime] FROM [User] WHERE [email] = $1 "#, ) .bind(email) .fetch_optional(&mut *tx) .await? { if let Some(db_datetime) = db_datetime_nullable { if Utc::now() - db_datetime <= validation_time { return Ok(GetTokenResetPasswordResult::PasswordAlreadyReset); } } } else { return Ok(GetTokenResetPasswordResult::EmailUnknown); } let token = generate_token(); sqlx::query( r#" UPDATE [User] SET [password_reset_token] = $2, [password_reset_datetime] = $3 WHERE [email] = $1 "#, ) .bind(email) .bind(&token) .bind(Utc::now()) .execute(&mut *tx) .await?; tx.commit().await?; Ok(GetTokenResetPasswordResult::Ok(token)) } pub async fn is_reset_password_token_valid( &self, token: &str, validation_time: Duration, ) -> Result { if let Some(Some(db_datetime)) = sqlx::query_scalar::<_, Option>>( r#" SELECT [password_reset_datetime] FROM [User] WHERE [password_reset_token] = $1 "#, ) .bind(token) .fetch_optional(&self.pool) .await? { Ok(Utc::now() - db_datetime <= validation_time) } else { Ok(false) } } pub async fn reset_password( &self, new_password: &str, token: &str, validation_time: Duration, ) -> Result { let mut tx = self.tx().await?; // There is no index on [password_reset_token]. Is it useful? if let (user_id, Some(db_datetime)) = sqlx::query_as::<_, (i64, Option>)>( r#" SELECT [id], [password_reset_datetime] FROM [User] WHERE [password_reset_token] = $1 "#, ) .bind(token) .fetch_one(&mut *tx) .await? { if Utc::now() - db_datetime > validation_time { return Ok(ResetPasswordResult::ResetTokenExpired); } // Remove all login tokens (for security reasons). sqlx::query("DELETE FROM [UserLoginToken] WHERE [user_id] = $1") .bind(user_id) .execute(&mut *tx) .await?; let hashed_new_password = hash(new_password).map_err(DBError::from_dyn_error)?; sqlx::query( r#" UPDATE [User] SET [password] = $2, [password_reset_token] = NULL, [password_reset_datetime] = NULL WHERE [id] = $1 "#, ) .bind(user_id) .bind(hashed_new_password) .execute(&mut *tx) .await?; tx.commit().await?; Ok(ResetPasswordResult::Ok) } else { Err(DBError::Other( "Can't reset password: stored token or datetime not set (NULL)".to_string(), )) } } // Return the token. async fn create_login_token( tx: &mut sqlx::Transaction<'_, Sqlite>, user_id: i64, ip: &str, user_agent: &str, ) -> Result { let token = generate_token(); sqlx::query( r#" INSERT INTO [UserLoginToken] ([user_id], [last_login_datetime], [token], [ip], [user_agent]) VALUES ($1, $2, $3, $4, $5) "#, ) .bind(user_id) .bind(Utc::now()) .bind(&token) .bind(ip) .bind(user_agent) .execute(&mut **tx) .await?; Ok(token) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn sign_up() -> Result<()> { let connection = Connection::new_in_memory().await?; match connection.sign_up("paul@atreides.com", "12345").await? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_to_an_already_existing_user() -> Result<()> { let connection = Connection::new_in_memory().await?; connection.execute_sql( sqlx::query( r#" INSERT INTO [User] ([id], [email], [name], [creation_datetime], [password], [validation_token_datetime], [validation_token]) VALUES ( 1, 'paul@atreides.com', 'paul', '', '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY', 0, NULL ); "#)).await?; match connection.sign_up("paul@atreides.com", "12345").await? { SignUpResult::UserAlreadyExists => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_and_sign_in_without_validation() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; let password = "12345"; match connection.sign_up(email, password).await? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } match connection .sign_in(email, password, "127.0.0.1", "Mozilla/5.0") .await? { SignInResult::AccountNotValidated => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_to_an_unvalidated_already_existing_user() -> Result<()> { let connection = Connection::new_in_memory().await?; let token = generate_token(); connection.execute_sql( sqlx::query( r#" INSERT INTO [User] ([id], [email], [creation_datetime], [name], [password], [validation_token_datetime], [validation_token]) VALUES ( 1, 'paul@atreides.com', '', 'paul', '$argon2id$v=19$m=4096,t=3,p=1$1vtXcacYjUHZxMrN6b2Xng$wW8Z59MIoMcsIljnjHmxn3EBcc5ymEySZPUVXHlRxcY', 0, $1 ) "# ).bind(token)).await?; match connection.sign_up("paul@atreides.com", "12345").await? { SignUpResult::UserCreatedWaitingForValidation(_) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_at_time() -> Result<()> { let connection = Connection::new_in_memory().await?; let validation_token = match connection.sign_up("paul@atreides.com", "12345").await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", ) .await? { ValidationResult::Ok(_, _) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_too_late() -> Result<()> { let connection = Connection::new_in_memory().await?; let validation_token = match connection .sign_up_with_given_time("paul@atreides.com", "12345", Utc::now() - Duration::days(1)) .await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", ) .await? { ValidationResult::ValidationExpired => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_with_bad_token() -> Result<()> { let connection = Connection::new_in_memory().await?; let _validation_token = match connection.sign_up("paul@atreides.com", "12345").await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; let random_token = generate_token(); match connection .validation( &random_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", ) .await? { ValidationResult::UnknownUser => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_then_sign_in() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password).await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0", ) .await? { ValidationResult::Ok(_, _) => (), other => panic!("{:?}", other), }; // Sign in. match connection .sign_in(email, password, "127.0.0.1", "Mozilla/5.0") .await? { SignInResult::Ok(_, _) => (), // Nominal case. other => panic!("{:?}", other), } Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_then_authentication() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password).await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. let (authentication_token, _user_id) = match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla", ) .await? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; // Check user login information. let user_login_info_1 = connection .get_user_login_info(&authentication_token) .await?; assert_eq!(user_login_info_1.ip, "127.0.0.1"); assert_eq!(user_login_info_1.user_agent, "Mozilla"); // Authentication. let _user_id = match connection .authentication(&authentication_token, "192.168.1.1", "Chrome") .await? { AuthenticationResult::Ok(user_id) => user_id, // Nominal case. other => panic!("{:?}", other), }; // Check user login information. let user_login_info_2 = connection .get_user_login_info(&authentication_token) .await?; assert_eq!(user_login_info_2.ip, "192.168.1.1"); assert_eq!(user_login_info_2.user_agent, "Chrome"); Ok(()) } #[tokio::test] async fn sign_up_then_send_validation_then_sign_out_then_sign_in() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; let password = "12345"; // Sign up. let validation_token = match connection.sign_up(email, password).await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. let (authentication_token_1, user_id_1) = match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla", ) .await? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; // Check user login information. let user_login_info_1 = connection .get_user_login_info(&authentication_token_1) .await?; assert_eq!(user_login_info_1.ip, "127.0.0.1"); assert_eq!(user_login_info_1.user_agent, "Mozilla"); // Sign out. connection.sign_out(&authentication_token_1).await?; // Sign in. let (authentication_token_2, user_id_2) = match connection .sign_in(email, password, "192.168.1.1", "Chrome") .await? { SignInResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; assert_eq!(user_id_1, user_id_2); assert_ne!(authentication_token_1, authentication_token_2); // Check user login information. let user_login_info_2 = connection .get_user_login_info(&authentication_token_2) .await?; assert_eq!(user_login_info_2.ip, "192.168.1.1"); assert_eq!(user_login_info_2.user_agent, "Chrome"); Ok(()) } #[tokio::test] async fn ask_to_reset_password_for_unknown_email() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; // Ask for password reset. match connection .get_token_reset_password(email, Duration::hours(1)) .await? { GetTokenResetPasswordResult::EmailUnknown => Ok(()), // Nominal case. other => panic!("{:?}", other), } } #[tokio::test] async fn sign_up_then_send_validation_then_sign_out_then_ask_to_reset_password() -> Result<()> { let connection = Connection::new_in_memory().await?; let email = "paul@atreides.com"; let password = "12345"; let new_password = "54321"; // Sign up. let validation_token = match connection.sign_up(email, password).await? { SignUpResult::UserCreatedWaitingForValidation(token) => token, // Nominal case. other => panic!("{:?}", other), }; // Validation. let (authentication_token_1, user_id_1) = match connection .validation( &validation_token, Duration::hours(1), "127.0.0.1", "Mozilla", ) .await? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; // Check user login information. let user_login_info_1 = connection .get_user_login_info(&authentication_token_1) .await?; assert_eq!(user_login_info_1.ip, "127.0.0.1"); assert_eq!(user_login_info_1.user_agent, "Mozilla"); // Sign out. connection.sign_out(&authentication_token_1).await?; // Ask for password reset. let token = match connection .get_token_reset_password(email, Duration::hours(1)) .await? { GetTokenResetPasswordResult::Ok(token) => token, other => panic!("{:?}", other), }; connection .reset_password(new_password, &token, Duration::hours(1)) .await?; // Sign in. let (authentication_token_2, user_id_2) = match connection .sign_in(email, new_password, "192.168.1.1", "Chrome") .await? { SignInResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; assert_eq!(user_id_1, user_id_2); assert_ne!(authentication_token_1, authentication_token_2); // Check user login information. let user_login_info_2 = connection .get_user_login_info(&authentication_token_2) .await?; assert_eq!(user_login_info_2.ip, "192.168.1.1"); assert_eq!(user_login_info_2.user_agent, "Chrome"); Ok(()) } #[tokio::test] async fn update_user() -> Result<()> { let connection = Connection::new_in_memory().await?; connection.execute_sql( sqlx::query( r#" INSERT INTO [User] ([id], [email], [name], [creation_datetime], [password], [validation_token_datetime], [validation_token]) VALUES (1, 'paul@atreides.com', 'paul', '', '$argon2id$v=19$m=4096,t=3,p=1$G4fjepS05MkRbTqEImUdYg$GGziE8uVQe1L1oFHk37lBno10g4VISnVqynSkLCH3Lc', '2022-11-29 22:05:04.121407300+00:00', NULL) "# ) ).await?; let user = connection.load_user(1).await?.unwrap(); assert_eq!(user.name, "paul"); assert_eq!(user.email, "paul@atreides.com"); if let UpdateUserResult::UserUpdatedWaitingForRevalidation(token) = connection .update_user( 1, Some("muaddib@fremen.com"), Some("muaddib"), Some("Chani"), ) .await? { let (_authentication_token_1, user_id_1) = match connection .validation(&token, Duration::hours(1), "127.0.0.1", "Mozilla/5.0") .await? { ValidationResult::Ok(token, user_id) => (token, user_id), other => panic!("{:?}", other), }; assert_eq!(user_id_1, 1); } else { panic!("A revalidation token must be created when changin e-mail"); } let user = connection.load_user(1).await?.unwrap(); assert_eq!(user.name, "muaddib"); assert_eq!(user.email, "muaddib@fremen.com"); // Tests if password has been updated correctly. if let SignInResult::Ok(_token, id) = connection .sign_in("muaddib@fremen.com", "Chani", "127.0.0.1", "Mozilla/5.0") .await? { assert_eq!(id, 1); } else { panic!("Can't sign in"); } Ok(()) } }