use axum::{ extract::{Host, Json, OriginalUri}, http::{header::HeaderMap, Method}, Router, }; use clap::Parser; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::BTreeMap; use tower_http::trace; use tracing::{info, warn, Level}; #[derive(Parser)] #[command(author, version, about, long_about = None)] struct CliArgs { /// Port to run on. #[arg(short, long, default_value_t = 8080)] port: u16, /// Listen to IP mask #[arg(short, long, default_value = "0.0.0.0")] ips: String, } #[derive(Serialize, Deserialize, Debug, PartialEq)] struct EchoResponse { method: String, path: String, host: String, headers: BTreeMap, #[serde(skip_serializing_if = "Option::is_none")] body: Option, } async fn echo_request( method: Method, original_uri: OriginalUri, host: Host, header_map: HeaderMap, body: Option>, ) -> Json { let method = method.to_string(); let host = host.0; let path = original_uri.path().to_string(); let headers = header_map .iter() .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) .collect(); let body = match body { None => { warn!("Received a non-JSON body."); None } Some(Json(value)) => { info!("JSON request: {}", value.to_string()); Some(value) } }; let response = EchoResponse { method, host, path, headers, body, }; Json(response) } fn app() -> Router { Router::new().fallback(echo_request).layer( trace::TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(Level::INFO)), ) } #[tokio::main] async fn main() { let cli_args = CliArgs::parse(); let listen_on = (cli_args.ips, cli_args.port); let listen_on = format!("{ip}:{port}", ip = listen_on.0, port = listen_on.1); // From https://stackoverflow.com/questions/75009289/how-to-enable-logging-tracing-with-axum tracing_subscriber::fmt().with_max_level(Level::INFO).init(); info!("Starting the mirror-server to listen to {}", listen_on); let listener = tokio::net::TcpListener::bind(&listen_on) .await .unwrap_or_else(|_| panic!("Failed to binding to {}", listen_on)); axum::serve(listener, app()) .await .expect("Server should start"); } #[cfg(test)] mod tests { use super::*; use axum::http::{HeaderName, HeaderValue, Method, StatusCode}; use axum_test::TestServer; use rstest::rstest; #[tokio::test] async fn handles_simple_get_request() { let expected_headers: BTreeMap = BTreeMap::new(); let expected_response = EchoResponse { method: "GET".to_string(), path: "/".to_string(), host: "localhost".to_string(), body: None, headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server.get("/").await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } #[rstest] #[case::single_level("/test")] #[case::multiple_level("/test/multiple/levels/")] #[case::slash_ending("/test/")] #[tokio::test] async fn handles_different_urls(#[case] url: String) { let expected_headers: BTreeMap = BTreeMap::new(); let expected_response = EchoResponse { method: "GET".to_string(), path: url.clone(), host: "localhost".to_string(), body: None, headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server.get(&url).await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } #[rstest] #[case::post(Method::POST)] #[case::put(Method::PUT)] #[case::patch(Method::PATCH)] #[case::delete(Method::DELETE)] #[tokio::test] async fn handles_different_http_methods(#[case] http_method: Method) { let expected_headers: BTreeMap = BTreeMap::new(); let expected_response = EchoResponse { method: http_method.to_string(), path: "/testing".to_string(), host: "localhost".to_string(), body: None, headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server.method(http_method, "/testing").await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } #[tokio::test] async fn handle_non_json_request() { let mut expected_headers: BTreeMap = BTreeMap::new(); expected_headers.insert("content-type".to_string(), "text/plain".to_string()); let expected_response = EchoResponse { method: "POST".to_string(), path: "/".to_string(), host: "localhost".to_string(), body: None, headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server.post("/").text("hello world").await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } #[tokio::test] async fn handle_json_request() { let test_json = serde_json::json!({"hello": "world"}); let mut expected_headers: BTreeMap = BTreeMap::new(); expected_headers.insert("content-type".to_string(), "application/json".to_string()); let expected_response = EchoResponse { method: "POST".to_string(), path: "/".to_string(), host: "localhost".to_string(), body: Some(test_json.clone()), headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server.post("/").json(&test_json).await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } #[tokio::test] async fn handle_extra_headers() { let mut expected_headers: BTreeMap = BTreeMap::new(); expected_headers.insert("x-test-message".to_string(), "Howdy!".to_string()); let expected_response = EchoResponse { method: "GET".to_string(), path: "/".to_string(), host: "localhost".to_string(), body: None, headers: expected_headers, }; let server = TestServer::new(app()).unwrap(); let response = server .get("/") .add_header( HeaderName::from_static("x-test-message"), HeaderValue::from_static("Howdy!"), ) .await; response.assert_status(StatusCode::OK); response.assert_json::(&expected_response); } }