diff --git a/README.md b/README.md index 7b60301..f920655 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,18 @@ # axum-tower-sessions-edge +[![Rust](https://github.com/seemueller-io/axum-tower-sessions-edge/actions/workflows/test.yaml/badge.svg)](https://github.com/seemueller-io/axum-tower-sessions-edge/actions/workflows/test.yaml) +[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) -> OAuth 2.0 Proxy built with Axum and Tower. Targets `wasm32-unknown-unknown` +Warning: This API may be unstable. -Proxies incoming requests for defined routes and forwards traffic to the service defined as `PROXY_TARGET`. -Configuration is modified by changing `.dev.vars`, `wrangler.jsonc`, or `secrets.json`. It's not perfect yet, but it's powerful. +Validates incoming requests for defined routes and forwards traffic to the service defined as `PROXY_TARGET`. + +> Targets `wasm32-unknown-unknown` ## Features - [OAuth 2.0](https://datatracker.ietf.org/doc/html/rfc6749) - [Proof Key for Code Exchange (PKCE)](https://datatracker.ietf.org/doc/html/rfc7636) - [OAuth 2.0 Token Introspection](https://datatracker.ietf.org/doc/html/rfc7662) -## Todo -- Proof compliance -- Expand configuration interface -- Zero-config development environment - ## Quickstart ```bash git clone https://github.com/seemueller-io/axum-tower-sessions-edge.git @@ -27,7 +25,6 @@ bun install #ZITADEL_ORG_ID="your-organization-id" #ZITADEL_PROJECT_ID="your-project-id" #APP_URL="http://localhost:3000" -#DEV_MODE="true" # Update the wrangler.jsonc and replace the value of PROXY_TARGET with a worker script name. @@ -41,7 +38,7 @@ Run your own Zitadel: `docker compose up -d` > You will need to configure: > - Organization > - Project -> - Application. _Choose PKCE (with code)_ +> - Application - _Choose PKCE (with code)_ ### Building @@ -57,16 +54,6 @@ cargo clean && cargo install -q worker-build && worker-build --release cargo build --release --target wasm32-unknown-unknown ``` -## Project Structure -- `src/` - Rust source code - - `api/` - API endpoints and routing - - `axum_introspector/` - Axum framework integration for token introspection - - `credentials/` - Credential management - - `oidc/` - OpenID Connect implementation - - `session_storage/` - Session storage implementations - - `utilities.rs` - Common utility functions - - `lib.rs` - Server - ## Acknowledgements This project is made possible thanks to: diff --git a/src/api/mod.rs b/src/api/mod.rs index 16ed456..f5ef7f6 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,2 +1,3 @@ pub mod public; -pub mod authenticated; \ No newline at end of file +pub mod authenticated; +pub mod router; diff --git a/src/api/router.rs b/src/api/router.rs new file mode 100644 index 0000000..5e2162f --- /dev/null +++ b/src/api/router.rs @@ -0,0 +1,188 @@ +use crate::api::authenticated::AuthenticatedApi; +use crate::api::public::PublicApi; +use crate::axum_introspector::introspection::{IntrospectionState, IntrospectionStateBuilder}; +use crate::oidc::introspection::cache::in_memory::InMemoryIntrospectionCache; +use crate::session_storage::in_memory::MemoryStore; +use axum::response::{IntoResponse, Redirect}; +use axum::routing::{any, get}; +use axum::{Router, ServiceExt}; +use http::HeaderName; +use std::iter::once; +use std::sync::Arc; +use tower_cookies::CookieManagerLayer; +use tower_http::cors::CorsLayer; +use tower_http::propagate_header::PropagateHeaderLayer; +use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer; +use tower_sessions::cookie::{Key, SameSite}; +use tower_sessions::SessionManagerLayer; +use tower_sessions_core::Expiry; + +// Test configuration struct +#[derive(Clone)] +pub struct TestConfig { + pub auth_server_url: String, + pub client_id: String, + pub client_secret: String, + pub app_url: String, + pub dev_mode: bool, +} + +impl Default for TestConfig { + fn default() -> Self { + Self { + auth_server_url: "https://test-auth-server.example.com".to_string(), + client_id: "test-client-id".to_string(), + client_secret: "test-client-secret".to_string(), + app_url: "http://localhost:3000".to_string(), + dev_mode: true, + } + } +} + +// App state for testing +#[derive(Clone)] +pub struct TestAppState { + pub introspection_state: IntrospectionState, + pub session_store: MemoryStore, +} + +impl From for IntrospectionState { + fn from(state: TestAppState) -> Self { + state.introspection_state + } +} + +// Create a router for testing +pub async fn create_router(config: TestConfig) -> Router { + // Create a memory-based introspection cache for testing + let cache = InMemoryIntrospectionCache::new(); + + // Create introspection state + let introspection_state = IntrospectionStateBuilder::new(&config.auth_server_url) + .with_basic_auth(&config.client_id, &config.client_secret) + .with_introspection_cache(cache) + .build() + .await + .unwrap(); + + // Create a memory-based session store for testing + let session_store = MemoryStore::default(); + + // Create app state + let state = TestAppState { + introspection_state, + session_store: session_store.clone(), + }; + + // Generate keys for session encryption and signing + let signing_key = Key::generate(); + let encryption_key = Key::generate(); + + // Parse the app URL to get the host for cookies + let cookie_host_uri = config.app_url.parse::().unwrap(); + let mut cookie_host = cookie_host_uri.authority().unwrap().to_string(); + + if cookie_host.starts_with("localhost:") { + cookie_host = "localhost".to_string(); + } + + // Create session layer + let session_layer = SessionManagerLayer::new(session_store) + .with_name("session") + .with_expiry(Expiry::OnSessionEnd) + .with_domain(cookie_host) + .with_same_site(SameSite::Lax) + .with_signed(signing_key) + .with_private(encryption_key) + .with_path("/") + .with_secure(!config.dev_mode) + .with_always_save(false); + + // Error handling middleware + async fn handle_introspection_errors( + mut response: axum_core::response::Response, + ) -> axum_core::response::Response { + let x_error_header_value = response + .headers() + .get("x-introspection-error") + .and_then(|header_value| header_value.to_str().ok()); + + match response.status() { + http::StatusCode::UNAUTHORIZED => { + if let Some(x_error) = x_error_header_value { + if x_error == "unauthorized" { + return Redirect::to("/login").into_response(); + } + } + response + } + http::StatusCode::BAD_REQUEST => { + if let Some(x_error) = x_error_header_value { + if x_error == "invalid schema" + || x_error == "invalid header" + || x_error == "introspection error" + { + return Redirect::to("/login").into_response(); + } + } + response + } + http::StatusCode::FORBIDDEN => { + if let Some(x_error) = x_error_header_value { + if x_error == "user is inactive" { + return Redirect::to("/login").into_response(); + } + } + response + } + http::StatusCode::NOT_FOUND => { + if let Some(x_error) = x_error_header_value { + if x_error == "user was not found" { + return Redirect::to("/login").into_response(); + } + } + response + } + http::StatusCode::INTERNAL_SERVER_ERROR => { + if let Some(x_error) = x_error_header_value { + if x_error == "missing config" { + return Redirect::to("/login").into_response(); + } + } + response + } + _ => response, + } + } + + // Create the router with test-specific routes + Router::new() + .route("/api/whoami", get(whoami)) + .route("/public", get(public_test_route)) + .route("/protected", get(protected_test_route)) + .layer(PropagateHeaderLayer::new(HeaderName::from_static( + "x-request-id", + ))) + .layer(axum::middleware::map_response(handle_introspection_errors)) + .with_state(state) + .layer(session_layer) + .layer(CookieManagerLayer::new()) + .layer(CorsLayer::very_permissive()) + .layer(SetSensitiveRequestHeadersLayer::new(once( + http::header::AUTHORIZATION, + ))) +} + +// Test routes +async fn whoami() -> impl IntoResponse { + "test user" +} + +async fn public_test_route() -> impl IntoResponse { + "public route" +} + +async fn protected_test_route() -> impl IntoResponse { + "protected route" +} + diff --git a/src/api/tests/middleware.rs b/src/api/tests/middleware.rs new file mode 100644 index 0000000..1f9b56d --- /dev/null +++ b/src/api/tests/middleware.rs @@ -0,0 +1,96 @@ +use super::*; +use axum::http::Method; + +#[tokio::test] +async fn test_auth_middleware_rejects_invalid_token() { + let app = test_app().await; + + let (status, _) = make_request( + app, + Method::GET, + "/protected", + None, + Some(vec![("Authorization".to_string(), "Bearer invalid-token".to_string())]), + ).await; + + // Should redirect to login or return unauthorized + assert!(status == StatusCode::UNAUTHORIZED || status == StatusCode::FOUND); +} + +#[tokio::test] +async fn test_auth_middleware_accepts_valid_token() { + let app = test_app().await; + + // Create a valid token for testing + let token = create_test_token(); + + let (status, _) = make_request( + app, + Method::GET, + "/protected", + None, + Some(vec![("Authorization".to_string(), format!("Bearer {}", token))]), + ).await; + + assert_eq!(status, StatusCode::OK); +} + +#[tokio::test] +async fn test_session_middleware_creates_session() { + let app = test_app().await; + + let (status, headers) = make_request_with_response_headers( + app, + Method::GET, + "/login", + None, + None, + ).await; + + assert_eq!(status, StatusCode::OK); + + // Check that a session cookie was set + let has_session_cookie = headers.iter() + .any(|(name, value)| name.to_lowercase() == "set-cookie" && value.contains("session=")); + + assert!(has_session_cookie); +} + +#[tokio::test] +async fn test_error_handling_middleware_redirects_to_login() { + let app = test_app().await; + + // Make a request that will trigger an unauthorized error with the specific header + let (status, _) = make_request( + app, + Method::GET, + "/protected", + None, + Some(vec![ + ("Authorization".to_string(), "Bearer invalid-token".to_string()), + ("X-Introspection-Error".to_string(), "unauthorized".to_string()), + ]), + ).await; + + // Should redirect to login + assert_eq!(status, StatusCode::FOUND); +} + +#[tokio::test] +async fn test_cors_middleware() { + let app = test_app().await; + + let (_, headers) = make_request_with_response_headers( + app, + Method::GET, + "/public", + None, + Some(vec![("Origin".to_string(), "http://example.com".to_string())]), + ).await; + + // Check that CORS headers were set + let has_cors_headers = headers.iter() + .any(|(name, _)| name.to_lowercase() == "access-control-allow-origin"); + + assert!(has_cors_headers); +} \ No newline at end of file diff --git a/src/api/tests/mod.rs b/src/api/tests/mod.rs new file mode 100644 index 0000000..e9270c8 --- /dev/null +++ b/src/api/tests/mod.rs @@ -0,0 +1,126 @@ +use axum::{ + body::Body, + http::{Request, StatusCode}, + Router, +}; +use tower::ServiceExt; + +// Import your API router +use crate::api::router; + +// Helper function to create a test app +async fn test_app() -> Router { + // Create a test configuration + let config = TestConfig::default(); + + // Create the router with test configuration + router::create_router(config).await +} + +// Helper function to make a test request +async fn make_request( + app: Router, + method: http::Method, + uri: &str, + body: Option, + headers: Option>, +) -> (StatusCode, String) { + let mut req_builder = Request::builder() + .method(method) + .uri(uri); + + // Add headers if provided + if let Some(headers) = headers { + for (name, value) in headers { + req_builder = req_builder.header(name, value); + } + } + + // Add body if provided + let body = match body { + Some(b) => Body::from(b), + None => Body::empty(), + }; + + let req = req_builder.body(Body::from(body)).unwrap(); + + // Process the request + let response = app.oneshot(req).await.unwrap(); + + // Extract status code + let status = response.status(); + + // Extract body + let body = hyper::body::to_bytes(response.into_body()) + .await + .unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + (status, body) +} + +// Helper function to make a request and return headers +async fn make_request_with_response_headers( + app: Router, + method: http::Method, + uri: &str, + body: Option, + headers: Option>, +) -> (StatusCode, Vec<(String, String)>) { + let mut req_builder = Request::builder() + .method(method) + .uri(uri); + + // Add headers if provided + if let Some(headers) = headers { + for (name, value) in headers { + req_builder = req_builder.header(name, value); + } + } + + // Add body if provided + let body = match body { + Some(b) => Body::from(b), + None => Body::empty(), + }; + + let req = req_builder.body(Body::from(body)).unwrap(); + + // Process the request + let response = app.oneshot(req).await.unwrap(); + + // Extract status code + let status = response.status(); + + // Extract headers + let headers = response.headers().iter() + .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) + .collect(); + + (status, headers) +} + +// Helper function to create a test token +fn create_test_token() -> String { + // In a real implementation, this would create a valid JWT token + // For testing purposes, we can use a placeholder + "test-token".to_string() +} + +// Helper struct for test configuration +#[derive(Clone)] +struct TestConfig { + // Add fields as needed for your tests +} + +impl Default for TestConfig { + fn default() -> Self { + Self { + // Initialize with default values + } + } +} + +// Export the test modules +pub mod routes; +pub mod middleware; \ No newline at end of file diff --git a/src/api/tests/routes.rs b/src/api/tests/routes.rs new file mode 100644 index 0000000..bdf0785 --- /dev/null +++ b/src/api/tests/routes.rs @@ -0,0 +1,87 @@ +use super::*; +use axum::http::Method; + +#[tokio::test] +async fn test_public_route_accessible() { + let app = test_app().await; + + let (status, body) = make_request( + app, + Method::GET, + "/public", + None, + None, + ).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "public route"); +} + +#[tokio::test] +async fn test_protected_route_requires_auth() { + let app = test_app().await; + + let (status, _) = make_request( + app, + Method::GET, + "/protected", + None, + None, + ).await; + + // Should redirect to login or return unauthorized + assert!(status == StatusCode::UNAUTHORIZED || status == StatusCode::FOUND); +} + +#[tokio::test] +async fn test_protected_route_with_valid_token() { + let app = test_app().await; + + // Create a valid token for testing + let token = create_test_token(); + + let (status, body) = make_request( + app, + Method::GET, + "/protected", + None, + Some(vec![("Authorization".to_string(), format!("Bearer {}", token))]), + ).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "protected route"); +} + +#[tokio::test] +async fn test_login_page_accessible() { + let app = test_app().await; + + let (status, _) = make_request( + app, + Method::GET, + "/login", + None, + None, + ).await; + + assert_eq!(status, StatusCode::OK); +} + +#[tokio::test] +async fn test_whoami_endpoint() { + let app = test_app().await; + + // Create a valid token for testing + let token = create_test_token(); + + let (status, body) = make_request( + app, + Method::GET, + "/api/whoami", + None, + Some(vec![("Authorization".to_string(), format!("Bearer {}", token))]), + ).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "test user"); +} \ No newline at end of file