Limit the number of current request for a certain amount of time for some endpoints.

This commit is contained in:
Greg Burri 2025-03-03 10:10:55 +01:00
parent b2572ebfe5
commit b40ee5f765
4 changed files with 34 additions and 12 deletions

1
Cargo.lock generated
View file

@ -2872,6 +2872,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-util",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",

View file

@ -10,7 +10,7 @@ common = { path = "../common" }
axum = { version = "0.8", features = ["macros"] } axum = { version = "0.8", features = ["macros"] }
axum-extra = { version = "0.10", features = ["cookie", "query"] } axum-extra = { version = "0.10", features = ["cookie", "query"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tower = { version = "0.5", features = ["util"] } tower = { version = "0.5", features = ["util", "limit", "buffer"] }
tower-http = { version = "0.6", features = ["fs", "trace"] } tower-http = { version = "0.6", features = ["fs", "trace"] }
tracing = "0.1" tracing = "0.1"

View file

@ -18,6 +18,9 @@ pub const TOKEN_SIZE: usize = 32;
pub const SEND_EMAIL_TIMEOUT: Duration = Duration::from_secs(60); pub const SEND_EMAIL_TIMEOUT: Duration = Duration::from_secs(60);
pub const NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT: u64 = 5;
pub const DURATION_FOR_RATE_LIMIT: Duration = Duration::from_secs(5);
// HTTP headers, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers. // HTTP headers, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers.
// Common headers can be found in 'axum::http::header' (which is a re-export of the create 'http'). // Common headers can be found in 'axum::http::header' (which is a re-export of the create 'http').
pub const REVERSE_PROXY_IP_HTTP_FIELD: &str = "x-real-ip"; // Set by the reverse proxy (Nginx). pub const REVERSE_PROXY_IP_HTTP_FIELD: &str = "x-real-ip"; // Set by the reverse proxy (Nginx).

View file

@ -1,7 +1,8 @@
use std::{net::SocketAddr, path::Path}; use std::{net::SocketAddr, path::Path};
use axum::{ use axum::{
Router, BoxError, Router,
error_handling::HandleErrorLayer,
extract::{ConnectInfo, Extension, FromRef, Request, State}, extract::{ConnectInfo, Extension, FromRef, Request, State},
http::StatusCode, http::StatusCode,
middleware::{self, Next}, middleware::{self, Next},
@ -13,6 +14,7 @@ use chrono::prelude::*;
use clap::Parser; use clap::Parser;
use config::Config; use config::Config;
use itertools::Itertools; use itertools::Itertools;
use tower::{ServiceBuilder, buffer::BufferLayer, limit::RateLimitLayer};
use tower_http::{ use tower_http::{
services::{ServeDir, ServeFile}, services::{ServeDir, ServeFile},
trace::TraceLayer, trace::TraceLayer,
@ -226,23 +228,38 @@ async fn main() {
get(services::fragments::recipes_list_fragments), get(services::fragments::recipes_list_fragments),
); );
let html_routes_with_rate_limit = Router::new()
.route("/signin", post(services::user::sign_in_post))
.route("/signup", post(services::user::sign_up_post))
.route(
"/ask_reset_password",
post(services::user::ask_reset_password_post),
)
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|err: BoxError| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled error: {}", err),
)
}))
.layer(BufferLayer::new(1024))
.layer(RateLimitLayer::new(
consts::NUMBER_OF_CONCURRENT_HTTP_REQUEST_FOR_RATE_LIMIT,
consts::DURATION_FOR_RATE_LIMIT,
)),
);
let html_routes = Router::new() let html_routes = Router::new()
.route("/", get(services::home_page)) .route("/", get(services::home_page))
.route( .route("/signup", get(services::user::sign_up_get))
"/signup",
get(services::user::sign_up_get).post(services::user::sign_up_post),
)
.route("/validation", get(services::user::sign_up_validation)) .route("/validation", get(services::user::sign_up_validation))
.route("/revalidation", get(services::user::email_revalidation)) .route("/revalidation", get(services::user::email_revalidation))
.route( .route("/signin", get(services::user::sign_in_get))
"/signin",
get(services::user::sign_in_get).post(services::user::sign_in_post),
)
.route("/signout", get(services::user::sign_out)) .route("/signout", get(services::user::sign_out))
.route( .route(
"/ask_reset_password", "/ask_reset_password",
get(services::user::ask_reset_password_get) get(services::user::ask_reset_password_get),
.post(services::user::ask_reset_password_post),
) )
.route( .route(
"/reset_password", "/reset_password",
@ -257,6 +274,7 @@ async fn main() {
"/user/edit", "/user/edit",
get(services::user::edit_user_get).post(services::user::edit_user_post), get(services::user::edit_user_get).post(services::user::edit_user_post),
) )
.merge(html_routes_with_rate_limit)
.nest("/fragments", fragments_routes) .nest("/fragments", fragments_routes)
.route_layer(middleware::from_fn(services::ron_error_to_html)); .route_layer(middleware::from_fn(services::ron_error_to_html));