use std::{net::SocketAddr, path::Path}; use axum::{ Router, extract::{ConnectInfo, Extension, FromRef, Request, State}, http::StatusCode, middleware::{self, Next}, response::Response, routing::{delete, get, post, put}, }; use axum_extra::extract::cookie::CookieJar; use chrono::prelude::*; use clap::Parser; use config::Config; use itertools::Itertools; use tower_http::{services::ServeDir, trace::TraceLayer}; use tracing::{Level, event}; use data::{db, model}; use translation::Tr; mod config; mod consts; mod data; mod email; mod hash; mod html_templates; mod ron_extractor; mod ron_utils; mod services; mod translation; mod utils; #[derive(Clone)] struct AppState { config: Config, db_connection: db::Connection, } impl FromRef for Config { fn from_ref(app_state: &AppState) -> Config { app_state.config.clone() } } impl FromRef for db::Connection { fn from_ref(app_state: &AppState) -> db::Connection { app_state.db_connection.clone() } } impl axum::response::IntoResponse for db::DBError { fn into_response(self) -> Response { ron_utils::ron_error(StatusCode::INTERNAL_SERVER_ERROR, &self.to_string()).into_response() } } #[derive(Debug, thiserror::Error)] enum AppError { #[error("Database error: {0}")] Database(#[from] db::DBError), #[error("Template error: {0}")] Render(#[from] rinja::Error), } type Result = std::result::Result; impl axum::response::IntoResponse for AppError { fn into_response(self) -> Response { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } } #[cfg(debug_assertions)] const TRACING_LEVEL: tracing::Level = tracing::Level::DEBUG; #[cfg(not(debug_assertions))] const TRACING_LEVEL: tracing::Level = tracing::Level::INFO; // TODO: Should main returns 'Result'? #[tokio::main] async fn main() { tracing_subscriber::fmt() .with_max_level(TRACING_LEVEL) .init(); if process_args().await { return; } event!(Level::INFO, "Starting Recipes as web server..."); let config = config::load(); let port = config.port; event!(Level::INFO, "Configuration: {:?}", config); let db_connection = db::Connection::new().await.unwrap(); let state = AppState { config, db_connection, }; let ron_api_routes = Router::new() // Disabled: update user profile is now made with a post data ('edit_user_post'). // .route("/user/update", put(services::ron::update_user)) .route("/set_lang", put(services::ron::set_lang)) .route("/recipe/get_titles", get(services::ron::recipe::get_titles)) .route("/recipe/set_title", put(services::ron::recipe::set_title)) .route( "/recipe/set_description", put(services::ron::recipe::set_description), ) .route( "/recipe/set_servings", put(services::ron::recipe::set_servings), ) .route( "/recipe/set_estimated_time", put(services::ron::recipe::set_estimated_time), ) .route("/recipe/get_tags", get(services::ron::recipe::get_tags)) .route("/recipe/add_tags", post(services::ron::recipe::add_tags)) .route("/recipe/rm_tags", delete(services::ron::recipe::rm_tags)) .route( "/recipe/set_difficulty", put(services::ron::recipe::set_difficulty), ) .route( "/recipe/set_language", put(services::ron::recipe::set_language), ) .route( "/recipe/set_is_published", put(services::ron::recipe::set_is_published), ) .route("/recipe/remove", delete(services::ron::recipe::rm)) .route("/recipe/get_groups", get(services::ron::recipe::get_groups)) .route("/recipe/add_group", post(services::ron::recipe::add_group)) .route( "/recipe/remove_group", delete(services::ron::recipe::rm_group), ) .route( "/recipe/set_group_name", put(services::ron::recipe::set_group_name), ) .route( "/recipe/set_group_comment", put(services::ron::recipe::set_group_comment), ) .route( "/recipe/set_groups_order", put(services::ron::recipe::set_groups_order), ) .route("/recipe/add_step", post(services::ron::recipe::add_step)) .route( "/recipe/remove_step", delete(services::ron::recipe::rm_step), ) .route( "/recipe/set_step_action", put(services::ron::recipe::set_step_action), ) .route( "/recipe/set_steps_order", put(services::ron::recipe::set_steps_order), ) .route( "/recipe/add_ingredient", post(services::ron::recipe::add_ingredient), ) .route( "/recipe/remove_ingredient", delete(services::ron::recipe::rm_ingredient), ) .route( "/recipe/set_ingredient_name", put(services::ron::recipe::set_ingredient_name), ) .route( "/recipe/set_ingredient_comment", put(services::ron::recipe::set_ingredient_comment), ) .route( "/recipe/set_ingredient_quantity", put(services::ron::recipe::set_ingredient_quantity), ) .route( "/recipe/set_ingredient_unit", put(services::ron::recipe::set_ingredient_unit), ) .route( "/recipe/set_ingredients_order", put(services::ron::recipe::set_ingredients_order), ) .route( "/calendar/get_scheduled_recipes", get(services::ron::calendar::get_scheduled_recipes), ) .route( "/calendar/schedule_recipe", post(services::ron::calendar::schedule_recipe), ) .route( "/calendar/remove_scheduled_recipe", delete(services::ron::calendar::rm_scheduled_recipe), ) .route( "/shopping_list/get_list", get(services::ron::shopping_list::get), ) .route( "/shopping_list/set_checked", put(services::ron::shopping_list::set_entry_checked), ) .fallback(services::ron::not_found); let fragments_routes = Router::new().route( "/recipes_list", get(services::fragments::recipes_list_fragments), ); let html_routes = Router::new() .route("/", get(services::home_page)) .route( "/signup", get(services::user::sign_up_get).post(services::user::sign_up_post), ) .route("/validation", get(services::user::sign_up_validation)) .route("/revalidation", get(services::user::email_revalidation)) .route( "/signin", get(services::user::sign_in_get).post(services::user::sign_in_post), ) .route("/signout", get(services::user::sign_out)) .route( "/ask_reset_password", get(services::user::ask_reset_password_get) .post(services::user::ask_reset_password_post), ) .route( "/reset_password", get(services::user::reset_password_get).post(services::user::reset_password_post), ) // Recipes. .route("/recipe/new", get(services::recipe::create)) .route("/recipe/edit/{id}", get(services::recipe::edit)) .route("/recipe/view/{id}", get(services::recipe::view)) // User. .route( "/user/edit", get(services::user::edit_user_get).post(services::user::edit_user_post), ) .route_layer(middleware::from_fn(services::ron_error_to_html)); let app = Router::new() .merge(html_routes) .nest("/fragments", fragments_routes) .nest("/ron-api", ron_api_routes) .fallback(services::not_found) .layer(TraceLayer::new_for_http()) // FIXME: Should be 'route_layer' but it doesn't work for 'fallback(..)'. .layer(middleware::from_fn(translation)) .layer(middleware::from_fn_with_state( state.clone(), user_authentication, )) .nest_service("/static", ServeDir::new("static")) .with_state(state) .into_make_service_with_connect_info::(); let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } async fn user_authentication( ConnectInfo(addr): ConnectInfo, State(connection): State, mut req: Request, next: Next, ) -> Result { let jar = CookieJar::from_headers(req.headers()); let (client_ip, client_user_agent) = utils::get_ip_and_user_agent(req.headers(), addr); let user = get_current_user(connection, &jar, &client_ip, &client_user_agent).await; req.extensions_mut().insert(user); Ok(next.run(req).await) } async fn translation( Extension(user): Extension>, mut req: Request, next: Next, ) -> Result { let language = if let Some(user) = user { user.lang } else { let available_codes = translation::available_codes(); let jar = CookieJar::from_headers(req.headers()); match jar.get(consts::COOKIE_LANG_NAME) { Some(lang) if available_codes.contains(&lang.value()) => lang.value().to_string(), _ => { let accept_language = req .headers() .get(axum::http::header::ACCEPT_LANGUAGE) .map(|v| v.to_str().unwrap_or_default()) .unwrap_or_default() .split(',') .map(|l| l.split('-').next().unwrap_or_default()) .find_or_first(|l| available_codes.contains(l)); match accept_language { Some(lang) if !lang.is_empty() => lang, _ => translation::DEFAULT_LANGUAGE_CODE, } .to_string() } } }; let tr = Tr::new(&language); req.extensions_mut().insert(tr); Ok(next.run(req).await) } async fn get_current_user( connection: db::Connection, jar: &CookieJar, client_ip: &str, client_user_agent: &str, ) -> Option { match jar.get(consts::COOKIE_AUTH_TOKEN_NAME) { Some(token_cookie) => match connection .authentication(token_cookie.value(), client_ip, client_user_agent) .await { Ok(db::user::AuthenticationResult::NotValidToken) => None, Ok(db::user::AuthenticationResult::Ok(user_id)) => { match connection.load_user(user_id).await { Ok(user) => user, Err(error) => { event!(Level::WARN, "Error during authentication: {}", error); None } } } Err(error) => { event!(Level::WARN, "Error during authentication: {}", error); None } }, None => None, } } #[derive(Parser, Debug)] #[command( author = "Greg Burri", version = "1.0", about = "A little cooking recipes website" )] struct Args { /// Will clear the database and insert some test data. (A backup is made first). #[arg(long)] dbtest: bool, } async fn process_args() -> bool { let args = Args::parse(); if args.dbtest { // Make a backup of the database. let db_path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME); if db_path.exists() { let db_path_bckup = (1..) .find_map(|n| { let p = db_path.with_extension(format!("sqlite.bckup{:03}", n)); if p.exists() { None } else { Some(p) } }) .unwrap(); std::fs::copy(&db_path, &db_path_bckup).unwrap_or_else(|error| { panic!( "Unable to make backup of {:?} to {:?}: {}", &db_path, &db_path_bckup, error ) }); std::fs::remove_file(&db_path).unwrap_or_else(|error| { panic!("Unable to remove db file {:?}: {}", &db_path, error) }); } match db::Connection::new().await { Ok(con) => { if let Err(error) = con.execute_file("sql/data_test.sql").await { event!(Level::ERROR, "{}", error); } // Set the creation datetime to 'now'. con.execute_sql( sqlx::query( "UPDATE [User] SET [validation_token_datetime] = $1 WHERE [email] = 'paul@test.org'") .bind(Utc::now()) ) .await .unwrap(); event!( Level::INFO, "A new test database has been created successfully" ); } Err(error) => { event!(Level::ERROR, "{}", error); } } return true; } false }