recipes/backend/src/main.rs

543 lines
17 KiB
Rust

use std::{net::SocketAddr, path::Path};
use axum::{
BoxError, Router, ServiceExt,
error_handling::HandleErrorLayer,
extract::{ConnectInfo, Extension, FromRef, Request, State},
http::{StatusCode, Uri},
middleware::{self, Next},
response::Response,
routing::{delete, get, patch, post, put},
};
use axum_extra::extract::cookie::CookieJar;
use chrono::prelude::*;
use clap::Parser;
use config::Config;
use itertools::Itertools;
use tokio::signal;
use tower::layer::Layer;
use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::{Level, event};
use data::{backup, db, model};
use translation::Tr;
mod config;
mod consts;
mod data;
mod email;
mod hash;
mod html_templates;
mod ron_extractor;
mod ron_utils;
mod services;
mod translation;
mod utils;
#[derive(Clone)]
struct AppState {
config: Config,
db_connection: db::Connection,
}
impl FromRef<AppState> for Config {
fn from_ref(app_state: &AppState) -> Config {
app_state.config.clone()
}
}
impl FromRef<AppState> for db::Connection {
fn from_ref(app_state: &AppState) -> db::Connection {
app_state.db_connection.clone()
}
}
impl axum::response::IntoResponse for db::DBError {
fn into_response(self) -> Response {
ron_utils::ron_error(StatusCode::INTERNAL_SERVER_ERROR, &self.to_string()).into_response()
}
}
#[derive(Debug, thiserror::Error)]
enum AppError {
#[error("Database error: {0}")]
Database(#[from] db::DBError),
#[error("Template error: {0}")]
Render(#[from] askama::Error),
}
type Result<T> = std::result::Result<T, AppError>;
impl axum::response::IntoResponse for AppError {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
#[cfg(debug_assertions)]
const TRACING_LEVEL: tracing::Level = tracing::Level::DEBUG;
#[cfg(not(debug_assertions))]
const TRACING_LEVEL: tracing::Level = tracing::Level::INFO;
#[derive(Debug, Clone)]
pub struct Context {
pub user: Option<model::User>,
pub tr: Tr,
pub dark_theme: bool,
}
// TODO: Should main returns 'Result'?
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_max_level(TRACING_LEVEL)
.init();
if !process_args().await {
return;
}
event!(Level::INFO, "Starting Recipes as web server...");
let config = config::load();
let port = config.port;
event!(Level::INFO, "Configuration: {:?}", config);
let Ok(db_connection) = db::Connection::new().await else {
event!(Level::ERROR, "Unable to connect to the database");
return;
};
backup::start(
"data",
db_connection.clone(),
// TODO: take from config.
NaiveTime::from_hms_opt(4, 0, 0).expect("Invalid time of day"),
);
let state = AppState {
config,
db_connection,
};
let ron_api_routes = Router::new()
// Disabled: update user profile is now made with a post data ('edit_user_post').
// .route("/user/update", put(services::ron::update_user))
.route("/lang", put(services::ron::set_lang))
.route("/recipe/titles", get(services::ron::recipe::get_titles))
.route("/recipe/title", patch(services::ron::recipe::set_title))
.route(
"/recipe/description",
patch(services::ron::recipe::set_description),
)
.route(
"/recipe/servings",
patch(services::ron::recipe::set_servings),
)
.route(
"/recipe/estimated_time",
patch(services::ron::recipe::set_estimated_time),
)
.route(
"/recipe/tags",
get(services::ron::recipe::get_tags)
.post(services::ron::recipe::add_tags)
.delete(services::ron::recipe::rm_tags),
)
.route(
"/recipe/difficulty",
patch(services::ron::recipe::set_difficulty),
)
.route(
"/recipe/language",
patch(services::ron::recipe::set_language),
)
.route(
"/recipe/is_published",
patch(services::ron::recipe::set_is_published),
)
.route("/recipe", delete(services::ron::recipe::rm))
.route("/recipe/groups", get(services::ron::recipe::get_groups))
.route(
"/recipe/group",
post(services::ron::recipe::add_group).delete(services::ron::recipe::rm_group),
)
.route(
"/recipe/group_name",
patch(services::ron::recipe::set_group_name),
)
.route(
"/recipe/group_comment",
patch(services::ron::recipe::set_group_comment),
)
.route(
"/recipe/groups_order",
patch(services::ron::recipe::set_groups_order),
)
.route(
"/recipe/step",
post(services::ron::recipe::add_step).delete(services::ron::recipe::rm_step),
)
.route(
"/recipe/step_action",
patch(services::ron::recipe::set_step_action),
)
.route(
"/recipe/steps_order",
patch(services::ron::recipe::set_steps_order),
)
.route(
"/recipe/ingredient",
post(services::ron::recipe::add_ingredient)
.delete(services::ron::recipe::rm_ingredient),
)
.route(
"/recipe/ingredient_name",
patch(services::ron::recipe::set_ingredient_name),
)
.route(
"/recipe/ingredient_comment",
patch(services::ron::recipe::set_ingredient_comment),
)
.route(
"/recipe/ingredient_quantity",
patch(services::ron::recipe::set_ingredient_quantity),
)
.route(
"/recipe/ingredient_unit",
patch(services::ron::recipe::set_ingredient_unit),
)
.route(
"/recipe/ingredients_order",
patch(services::ron::recipe::set_ingredients_order),
)
.route(
"/calendar/scheduled_recipes",
get(services::ron::calendar::get_scheduled_recipes),
)
.route(
"/calendar/schedule_recipe",
post(services::ron::calendar::schedule_recipe)
.delete(services::ron::calendar::rm_scheduled_recipe),
)
.route("/shopping_list", get(services::ron::shopping_list::get))
.route(
"/shopping_list/checked",
patch(services::ron::shopping_list::set_entry_checked),
)
.fallback(services::ron::not_found);
let fragments_routes = Router::new().route(
"/recipes_list",
get(services::fragments::recipes_list_fragments),
);
let html_routes_with_rate_limit = Router::new()
.route("/signin", post(services::user::sign_in_post))
.route("/signup", post(services::user::sign_up_post))
.route(
"/ask_reset_password",
post(services::user::ask_reset_password_post),
)
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|err: BoxError| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled error: {}", err),
)
}))
.layer(BufferLayer::new(1024))
.layer(RateLimitLayer::new(
consts::NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT,
consts::DURATION_FOR_RATE_LIMIT,
)),
);
let html_routes = Router::new()
.route("/", get(services::home_page))
.route("/signup", get(services::user::sign_up_get))
.route("/validation", get(services::user::sign_up_validation))
.route("/revalidation", get(services::user::email_revalidation))
.route("/signin", get(services::user::sign_in_get))
.route("/signout", get(services::user::sign_out))
.route(
"/ask_reset_password",
get(services::user::ask_reset_password_get),
)
.route(
"/reset_password",
get(services::user::reset_password_get).post(services::user::reset_password_post),
)
// Recipes.
.route("/recipe/new", get(services::recipe::create))
.route("/recipe/edit/{id}", get(services::recipe::edit))
.route("/recipe/view/{id}", get(services::recipe::view))
// User.
.route(
"/user/edit",
get(services::user::edit_user_get).post(services::user::edit_user_post),
)
.merge(html_routes_with_rate_limit)
.nest("/fragments", fragments_routes)
.route_layer(middleware::from_fn(services::ron_error_to_html));
let app = Router::new()
.merge(html_routes)
.nest("/ron-api", ron_api_routes)
.fallback(services::not_found)
.layer(middleware::from_fn_with_state(state.clone(), context))
.with_state(state)
.nest_service("/favicon.ico", ServeFile::new("static/favicon.ico"))
.nest_service("/static", ServeDir::new("static"))
.layer(TraceLayer::new_for_http());
let url_rewriting_middleware = tower::util::MapRequestLayer::new(url_rewriting);
let app_with_url_rewriting = url_rewriting_middleware.layer(app);
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(
listener,
app_with_url_rewriting.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
event!(Level::INFO, "Recipes stopped");
}
#[derive(Debug, Clone)]
struct Lang(Option<String>);
fn url_rewriting(mut req: Request) -> Request {
// Here we are extracting the language from the url then rewriting it.
// For example:
// "/fr/recipe/view/1"
// lang = "fr" and uri rewritten as = "/recipe/view/1"
let lang_and_new_uri = 'lang_and_new_uri: {
if let Some(path_query) = req.uri().path_and_query() {
let mut parts = path_query.path().split('/');
let _ = parts.next(); // Empty part due to the first '/'.
if let Some(lang) = parts.next() {
let available_codes = translation::available_codes();
if available_codes.contains(&lang) {
let mut rest = String::from("");
for part in parts {
rest.push('/');
rest.push_str(part);
}
if let Some(query) = path_query.query() {
rest.push('?');
rest.push_str(query);
}
if let Ok(new_uri) = rest.parse::<Uri>() {
break 'lang_and_new_uri Some((lang.to_string(), new_uri));
}
}
}
}
None
};
if let Some((lang, new_uri)) = lang_and_new_uri {
*req.uri_mut() = new_uri;
req.extensions_mut().insert(Lang(Some(lang)));
} else {
req.extensions_mut().insert(Lang(None));
}
req
}
/// The language of the current HTTP request is defined in the current order:
/// - Extraction from the url: like in '/fr/recipe/view/42'
/// - Get from the user database record.
/// - Get from the cookie.
/// - Get from the HTTP header `accept-language`.
/// - Set as `translation::DEFAULT_LANGUAGE_CODE`.
async fn context(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(connection): State<db::Connection>,
Extension(lang_from_url): Extension<Lang>,
mut req: Request,
next: Next,
) -> Result<Response> {
let jar = CookieJar::from_headers(req.headers());
let (client_ip, client_user_agent) = utils::get_ip_and_user_agent(req.headers(), addr);
let user = get_current_user(connection, &jar, &client_ip, &client_user_agent).await;
let language = if let Some(lang) = lang_from_url.0 {
lang
} else if let Some(ref user) = user {
user.lang.clone()
} else {
let available_codes = translation::available_codes();
let jar = CookieJar::from_headers(req.headers());
match jar.get(consts::COOKIE_LANG_NAME) {
Some(lang) if available_codes.contains(&lang.value()) => lang.value().to_string(),
_ => {
let accept_language = req
.headers()
.get(axum::http::header::ACCEPT_LANGUAGE)
.map(|v| v.to_str().unwrap_or_default())
.unwrap_or_default()
.split(',')
.map(|l| l.split('-').next().unwrap_or_default())
.find_or_first(|l| available_codes.contains(l));
match accept_language {
Some(lang) if !lang.is_empty() => lang,
_ => translation::DEFAULT_LANGUAGE_CODE,
}
.to_string()
}
}
};
let tr = Tr::new(&language);
let dark_theme = match jar.get(common::consts::COOKIE_DARK_THEME) {
Some(dark_theme_cookie) => dark_theme_cookie.value().parse().unwrap_or_default(),
None => false,
};
req.extensions_mut().insert(Context {
user,
tr,
dark_theme,
});
Ok(next.run(req).await)
}
async fn get_current_user(
connection: db::Connection,
jar: &CookieJar,
client_ip: &str,
client_user_agent: &str,
) -> Option<model::User> {
match jar.get(consts::COOKIE_AUTH_TOKEN_NAME) {
Some(token_cookie) => match connection
.authentication(token_cookie.value(), client_ip, client_user_agent)
.await
{
Ok(db::user::AuthenticationResult::NotValidToken) => None,
Ok(db::user::AuthenticationResult::Ok(user_id)) => {
match connection.load_user(user_id).await {
Ok(user) => user,
Err(error) => {
event!(Level::WARN, "Error during authentication: {}", error);
None
}
}
}
Err(error) => {
event!(Level::WARN, "Error during authentication: {}", error);
None
}
},
None => None,
}
}
#[derive(Parser, Debug)]
#[command(
author = "Greg Burri",
version = "1.0",
about = "A little cooking recipes website"
)]
struct Args {
/// Will clear the database and insert some test data. (A backup is made first).
#[arg(long)]
dbtest: bool,
}
/// Returns `true` if the server can be started.
async fn process_args() -> bool {
let args = Args::parse();
if args.dbtest {
// Make a backup of the database.
let db_path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
if db_path.exists() {
let db_path_bckup = (1..)
.find_map(|n| {
let p = db_path.with_extension(format!("sqlite.bckup{:03}", n));
if p.exists() { None } else { Some(p) }
})
.unwrap();
std::fs::copy(&db_path, &db_path_bckup).unwrap_or_else(|error| {
panic!(
"Unable to make backup of {:?} to {:?}: {}",
&db_path, &db_path_bckup, error
)
});
std::fs::remove_file(&db_path).unwrap_or_else(|error| {
panic!("Unable to remove db file {:?}: {}", &db_path, error)
});
}
match db::Connection::new().await {
Ok(con) => {
if let Err(error) = con.execute_file("sql/data_test.sql").await {
event!(Level::ERROR, "{}", error);
}
// Set the creation datetime to 'now'.
con.execute_sql(
sqlx::query(
"UPDATE [User] SET [validation_token_datetime] = $1 WHERE [email] = 'paul@test.org'")
.bind(Utc::now())
)
.await
.unwrap();
event!(
Level::INFO,
"A new test database has been created successfully"
);
}
Err(error) => {
event!(Level::ERROR, "{}", error);
}
}
return false;
}
true
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}