use std::net::SocketAddr; use axum::{ BoxError, Router, ServiceExt, error_handling::HandleErrorLayer, extract::{ ConnectInfo, Extension, FromRef, Request, State, connect_info::IntoMakeServiceWithConnectInfo, }, http::{StatusCode, Uri}, middleware::{self, Next}, response::Response, routing::{delete, get, patch, post, put}, }; use axum_extra::extract::cookie::CookieJar; use chrono::prelude::*; use itertools::Itertools; use tower::layer::Layer; use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer}; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; use tracing::warn; use crate::{ config::Config, consts, data::{db, model}, log::Log, ron_utils, services, translation::{self, Tr}, utils, }; #[derive(Clone)] pub struct AppState { pub config: Config, pub db_connection: db::Connection, pub log: Log, } impl FromRef for Config { fn from_ref(app_state: &AppState) -> Config { app_state.config.clone() } } impl FromRef for db::Connection { fn from_ref(app_state: &AppState) -> db::Connection { app_state.db_connection.clone() } } impl FromRef for Log { fn from_ref(app_state: &AppState) -> Log { app_state.log.clone() } } impl axum::response::IntoResponse for db::DBError { fn into_response(self) -> Response { ron_utils::ron_error(StatusCode::INTERNAL_SERVER_ERROR, &self.to_string()).into_response() } } #[derive(Debug, thiserror::Error)] pub enum AppError { #[error("Database error: {0}")] Database(#[from] db::DBError), #[error("Template error: {0}")] Render(#[from] askama::Error), } pub type Result = std::result::Result; impl axum::response::IntoResponse for AppError { fn into_response(self) -> Response { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } } #[derive(Debug, Clone)] pub struct Context { pub user: Option, pub tr: Tr, pub dark_theme: bool, } impl Context { pub fn first_day_of_the_week(&self) -> Weekday { if let Some(user) = &self.user { user.first_day_of_the_week } else { self.tr.first_day_of_week() } } } #[derive(Debug, Clone)] struct Lang(Option); // TODO: Refactor the signature into something like 'impl Service<...>'. pub fn make_service( state: AppState, ) -> IntoMakeServiceWithConnectInfo< tower::util::MapRequest) -> Request>, SocketAddr, > // ) -> impl Service, Error = core::convert::Infallible, Response = S> // where // L: axum::serve::Listener, // S: Service // + Clone // + Send // + 'static, // S::Future: Send, // T: , // std::net::SocketAddr: // axum::extract::connect_info::Connected>, { let ron_api_routes = Router::new() // Disabled: update user profile is now made with a post data ('edit_user_post'). // .route("/user/update", put(services::ron::update_user)) .route("/lang", put(services::ron::set_lang)) .route("/recipe/titles", get(services::ron::recipe::get_titles)) .route("/recipe/title", patch(services::ron::recipe::set_title)) .route( "/recipe/description", patch(services::ron::recipe::set_description), ) .route( "/recipe/servings", patch(services::ron::recipe::set_servings), ) .route( "/recipe/estimated_time", patch(services::ron::recipe::set_estimated_time), ) .route( "/recipe/tags", get(services::ron::recipe::get_tags) .post(services::ron::recipe::add_tags) .delete(services::ron::recipe::rm_tags), ) .route( "/recipe/difficulty", patch(services::ron::recipe::set_difficulty), ) .route( "/recipe/language", patch(services::ron::recipe::set_language), ) .route( "/recipe/is_public", patch(services::ron::recipe::set_is_public), ) .route("/recipe", delete(services::ron::recipe::rm)) .route("/recipe/groups", get(services::ron::recipe::get_groups)) .route( "/recipe/group", post(services::ron::recipe::add_group).delete(services::ron::recipe::rm_group), ) .route( "/recipe/group_name", patch(services::ron::recipe::set_group_name), ) .route( "/recipe/group_comment", patch(services::ron::recipe::set_group_comment), ) .route( "/recipe/groups_order", patch(services::ron::recipe::set_groups_order), ) .route( "/recipe/step", post(services::ron::recipe::add_step).delete(services::ron::recipe::rm_step), ) .route( "/recipe/step_action", patch(services::ron::recipe::set_step_action), ) .route( "/recipe/steps_order", patch(services::ron::recipe::set_steps_order), ) .route( "/recipe/ingredient", post(services::ron::recipe::add_ingredient) .delete(services::ron::recipe::rm_ingredient), ) .route( "/recipe/ingredient_name", patch(services::ron::recipe::set_ingredient_name), ) .route( "/recipe/ingredient_comment", patch(services::ron::recipe::set_ingredient_comment), ) .route( "/recipe/ingredient_quantity", patch(services::ron::recipe::set_ingredient_quantity), ) .route( "/recipe/ingredient_unit", patch(services::ron::recipe::set_ingredient_unit), ) .route( "/recipe/ingredients_order", patch(services::ron::recipe::set_ingredients_order), ) .route( "/calendar/scheduled_recipes", get(services::ron::calendar::get_scheduled_recipes), ) .route( "/calendar/scheduled_recipe", post(services::ron::calendar::add_scheduled_recipe) .delete(services::ron::calendar::rm_scheduled_recipe), ) .route("/shopping_list", get(services::ron::shopping_list::get)) .route( "/shopping_list/checked", patch(services::ron::shopping_list::set_entry_checked), ) .fallback(services::ron::not_found); let fragments_routes = Router::new().route( "/recipes_list", get(services::fragments::recipes_list_fragments), ); let html_routes_with_rate_limit = Router::new() .route("/signin", post(services::user::sign_in_post)) .route("/signup", post(services::user::sign_up_post)) .route( "/ask_reset_password", post(services::user::ask_reset_password_post), ) .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|err: BoxError| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled error: {}", err), ) })) .layer(BufferLayer::new(1024)) .layer(RateLimitLayer::new( consts::NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT, consts::DURATION_FOR_RATE_LIMIT, )), ); let html_routes = Router::new() .route("/", get(services::home_page)) .route("/dev_panel", get(services::dev_panel)) .route("/logs", get(services::logs)) .route("/signup", get(services::user::sign_up_get)) .route("/validation", get(services::user::sign_up_validation)) .route("/revalidation", get(services::user::email_revalidation)) .route("/signin", get(services::user::sign_in_get)) .route("/signout", get(services::user::sign_out)) .route( "/ask_reset_password", get(services::user::ask_reset_password_get), ) .route( "/reset_password", get(services::user::reset_password_get).post(services::user::reset_password_post), ) // Recipes. .route("/recipe/new", get(services::recipe::create)) .route("/recipe/edit/{id}", get(services::recipe::edit)) .route("/recipe/view/{id}", get(services::recipe::view)) // User. .route( "/user/edit", get(services::user::edit_user_get).post(services::user::edit_user_post), ) .merge(html_routes_with_rate_limit) .nest("/fragments", fragments_routes) .route_layer(middleware::from_fn(services::ron_error_to_html)); let app = Router::new() .merge(html_routes) .nest("/ron-api", ron_api_routes) .fallback(services::not_found) .layer(middleware::from_fn_with_state(state.clone(), context)) .with_state(state) .nest_service("/favicon.ico", ServeFile::new("static/favicon.ico")) .nest_service("/static", ServeDir::new("static")) .layer(TraceLayer::new_for_http()); let url_rewriting_middleware = tower::util::MapRequestLayer::new( url_rewriting as fn(axum::http::Request) -> axum::http::Request, ); url_rewriting_middleware .layer(app) .into_make_service_with_connect_info::() } fn url_rewriting(mut req: Request) -> Request { // Here we are extracting the language from the url then rewriting it. // For example: // "/fr/recipe/view/1" // lang = "fr" and uri rewritten as = "/recipe/view/1" let lang_and_new_uri = 'lang_and_new_uri: { if let Some(path_query) = req.uri().path_and_query() { let mut parts = path_query.path().split('/'); let _ = parts.next(); // Empty part due to the first '/'. if let Some(lang) = parts.next() { let available_codes = translation::available_codes(); if available_codes.contains(&lang) { let mut rest = String::from(""); for part in parts { rest.push('/'); rest.push_str(part); } if let Some(query) = path_query.query() { rest.push('?'); rest.push_str(query); } if let Ok(new_uri) = rest.parse::() { break 'lang_and_new_uri Some((lang.to_string(), new_uri)); } } } } None }; if let Some((lang, new_uri)) = lang_and_new_uri { *req.uri_mut() = new_uri; req.extensions_mut().insert(Lang(Some(lang))); } else { req.extensions_mut().insert(Lang(None)); } req } /// The language associated to the current HTTP request is defined in the current order: /// - Extraction from the url: like in `/fr/recipe/view/42` /// - Get from the user database record. /// - Get from the cookie. /// - Get from the HTTP header `accept-language`. /// - Set as `translation::DEFAULT_LANGUAGE_CODE`. async fn context( ConnectInfo(addr): ConnectInfo, State(connection): State, Extension(lang_from_url): Extension, mut req: Request, next: Next, ) -> Result { let jar = CookieJar::from_headers(req.headers()); let (client_ip, client_user_agent) = utils::get_ip_and_user_agent(req.headers(), addr); let user = get_current_user(connection, &jar, &client_ip, &client_user_agent).await; let language = if let Some(lang) = lang_from_url.0 { lang } else if let Some(ref user) = user { user.lang.clone() } else { let available_codes = translation::available_codes(); let jar = CookieJar::from_headers(req.headers()); match jar.get(consts::COOKIE_LANG_NAME) { Some(lang) if available_codes.contains(&lang.value()) => lang.value().to_string(), _ => { let accept_language = req .headers() .get(axum::http::header::ACCEPT_LANGUAGE) .map(|v| v.to_str().unwrap_or_default()) .unwrap_or_default() .split(',') .map(|l| l.split('-').next().unwrap_or_default()) .find_or_first(|l| available_codes.contains(l)); match accept_language { Some(lang) if !lang.is_empty() => lang, _ => translation::DEFAULT_LANGUAGE_CODE, } .to_string() } } }; let tr = Tr::new(&language); let dark_theme = match jar.get(common::consts::COOKIE_DARK_THEME) { Some(dark_theme_cookie) => dark_theme_cookie .value() .parse() .inspect_err(|err| warn!("Can't parse dark theme cookie: {}", err)) .unwrap_or_default(), None => false, }; req.extensions_mut().insert(Context { user, tr, dark_theme, }); Ok(next.run(req).await) } async fn get_current_user( connection: db::Connection, jar: &CookieJar, client_ip: &str, client_user_agent: &str, ) -> Option { match jar.get(consts::COOKIE_AUTH_TOKEN_NAME) { Some(token_cookie) => match connection .authentication(token_cookie.value(), client_ip, client_user_agent) .await { Ok(db::user::AuthenticationResult::NotValidToken) => None, Ok(db::user::AuthenticationResult::Ok(user_id)) => { match connection.load_user(user_id).await { Ok(user) => user, Err(error) => { warn!("Error during authentication: {}", error); None } } } Err(error) => { warn!("Error during authentication: {}", error); None } }, None => None, } }