recipes/backend/src/main.rs
2024-12-21 23:13:06 +01:00

252 lines
7.5 KiB
Rust

use std::{net::SocketAddr, path::Path};
use axum::{
extract::{ConnectInfo, FromRef, Request, State},
http::StatusCode,
middleware::{self, Next},
response::{Response, Result},
routing::get,
Router,
};
use axum_extra::extract::cookie::CookieJar;
use chrono::prelude::*;
use clap::Parser;
use config::Config;
use tower_http::{services::ServeDir, trace::TraceLayer};
use tracing::{event, Level};
use data::{db, model};
mod config;
mod consts;
mod data;
mod email;
mod hash;
mod html_templates;
mod ron_extractor;
mod ron_utils;
mod services;
mod utils;
#[derive(Clone)]
struct AppState {
config: Config,
db_connection: db::Connection,
}
impl FromRef<AppState> for Config {
fn from_ref(app_state: &AppState) -> Config {
app_state.config.clone()
}
}
impl FromRef<AppState> for db::Connection {
fn from_ref(app_state: &AppState) -> db::Connection {
app_state.db_connection.clone()
}
}
impl axum::response::IntoResponse for db::DBError {
fn into_response(self) -> Response {
ron_utils::ron_error(StatusCode::INTERNAL_SERVER_ERROR, &self.to_string()).into_response()
}
}
#[cfg(debug_assertions)]
const TRACING_LEVEL: tracing::Level = tracing::Level::DEBUG;
#[cfg(not(debug_assertions))]
const TRACING_LEVEL: tracing::Level = tracing::Level::INFO;
// TODO: Should main returns 'Result'?
#[tokio::main]
async fn main() {
if process_args().await {
return;
}
tracing_subscriber::fmt()
.with_max_level(TRACING_LEVEL)
.init();
event!(Level::INFO, "Starting Recipes as web server...");
let config = config::load();
let port = config.port;
event!(Level::INFO, "Configuration: {:?}", config);
let db_connection = db::Connection::new().await.unwrap();
let state = AppState {
config,
db_connection,
};
let ron_api_routes = Router::new()
// Disabled: update user profile is now made with a post data ('edit_user_post').
// .route("/user/update", put(services::ron::update_user))
.fallback(services::ron::not_found);
let html_routes = Router::new()
.route("/", get(services::home_page))
.route(
"/signup",
get(services::user::sign_up_get).post(services::user::sign_up_post),
)
.route("/validation", get(services::user::sign_up_validation))
.route("/revalidation", get(services::user::email_revalidation))
.route(
"/signin",
get(services::user::sign_in_get).post(services::user::sign_in_post),
)
.route("/signout", get(services::user::sign_out))
.route(
"/ask_reset_password",
get(services::user::ask_reset_password_get)
.post(services::user::ask_reset_password_post),
)
.route(
"/reset_password",
get(services::user::reset_password_get).post(services::user::reset_password_post),
)
// Recipes.
.route("/recipe/new", get(services::recipe::create))
.route("/recipe/edit/:id", get(services::recipe::edit_recipe))
.route("/recipe/view/:id", get(services::recipe::view))
// User.
.route(
"/user/edit",
get(services::user::edit_user_get).post(services::user::edit_user_post),
)
.route_layer(middleware::from_fn(services::ron_error_to_html));
let app = Router::new()
.merge(html_routes)
.nest("/ron-api", ron_api_routes)
.fallback(services::not_found)
.layer(TraceLayer::new_for_http())
// FIXME: Should be 'route_layer' but it doesn't work for 'fallback(..)'.
.layer(middleware::from_fn_with_state(
state.clone(),
user_authentication,
))
.nest_service("/static", ServeDir::new("static"))
.with_state(state)
.into_make_service_with_connect_info::<SocketAddr>();
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn user_authentication(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(connection): State<db::Connection>,
mut req: Request,
next: Next,
) -> Result<Response> {
let jar = CookieJar::from_headers(req.headers());
let (client_ip, client_user_agent) = utils::get_ip_and_user_agent(req.headers(), addr);
let user = get_current_user(connection, &jar, &client_ip, &client_user_agent).await;
req.extensions_mut().insert(user);
Ok(next.run(req).await)
}
async fn get_current_user(
connection: db::Connection,
jar: &CookieJar,
client_ip: &str,
client_user_agent: &str,
) -> Option<model::User> {
match jar.get(consts::COOKIE_AUTH_TOKEN_NAME) {
Some(token_cookie) => match connection
.authentication(token_cookie.value(), &client_ip, &client_user_agent)
.await
{
Ok(db::user::AuthenticationResult::NotValidToken) => None,
Ok(db::user::AuthenticationResult::Ok(user_id)) => {
match connection.load_user(user_id).await {
Ok(user) => user,
Err(error) => {
event!(Level::WARN, "Error during authentication: {}", error);
None
}
}
}
Err(error) => {
event!(Level::WARN, "Error during authentication: {}", error);
None
}
},
None => None,
}
}
#[derive(Parser, Debug)]
#[command(
author = "Greg Burri",
version = "1.0",
about = "A little cooking recipes website"
)]
struct Args {
/// Will clear the database and insert some test data. (A backup is made first).
#[arg(long)]
dbtest: bool,
}
async fn process_args() -> bool {
let args = Args::parse();
if args.dbtest {
// Make a backup of the database.
let db_path = Path::new(consts::DB_DIRECTORY).join(consts::DB_FILENAME);
if db_path.exists() {
let db_path_bckup = (1..)
.find_map(|n| {
let p = db_path.with_extension(format!("sqlite.bckup{:03}", n));
if p.exists() {
None
} else {
Some(p)
}
})
.unwrap();
std::fs::copy(&db_path, &db_path_bckup).expect(&format!(
"Unable to make backup of {:?} to {:?}",
&db_path, &db_path_bckup
));
std::fs::remove_file(&db_path)
.expect(&format!("Unable to remove db file: {:?}", &db_path));
}
match db::Connection::new().await {
Ok(con) => {
if let Err(error) = con.execute_file("sql/data_test.sql").await {
event!(Level::ERROR, "{}", error);
}
// Set the creation datetime to 'now'.
con.execute_sql(
sqlx::query(
"UPDATE [User] SET [validation_token_datetime] = $1 WHERE [email] = 'paul@test.org'")
.bind(Utc::now())
)
.await
.unwrap();
event!(
Level::INFO,
"A new test database has been created successfully"
);
}
Err(error) => {
event!(Level::ERROR, "{}", error);
}
}
return true;
}
false
}