From 22ef371c5b3ae60d5554db9f9f14526d4ee4ef8c Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Thu, 5 Jun 2025 14:06:00 -0400 Subject: [PATCH] add basic router tests --- Cargo.lock | 1 + Cargo.toml | 3 +- src/routes.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddf4aac..2ca7fac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -746,6 +746,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 803ef9f..d31dbcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,8 @@ bytes = "1.8.0" lazy_static = "1.5.0" sled = "0.34.7" tower-http = { version = "0.6.2", features = ["trace", "cors"] } +tower = "0.5.2" anyhow = "1.0.97" base64 = "0.22.1" fips204 = "0.4.6" -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = ["server", "transport-streamable-http-server", "transport-sse-server", "transport-io",] } \ No newline at end of file +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = ["server", "transport-streamable-http-server", "transport-sse-server", "transport-io",] } diff --git a/src/routes.rs b/src/routes.rs index a63aecb..300ceb2 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -6,26 +6,18 @@ use tracing::Level; use rmcp::transport::streamable_http_server::{ StreamableHttpService, session::local::LocalSessionManager, }; -use crate::counter::Counter; use crate::agents::Agents; pub fn create_router() -> Router { - let counter_service = StreamableHttpService::new( - Counter::new, - LocalSessionManager::default().into(), - Default::default(), - ); - - let agents_service = StreamableHttpService::new( + let mcp_service = StreamableHttpService::new( Agents::new, LocalSessionManager::default().into(), Default::default(), ); Router::new() - .nest_service("/mcp/counter", counter_service) - .nest_service("/mcp/agents", agents_service) + .nest_service("/mcp", mcp_service) .route("/", get(serve_ui)) .route("/health", get(health)) .layer( @@ -40,3 +32,71 @@ pub fn create_router() -> Router { async fn health() -> String { return "ok".to_string(); } + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::{Body, Bytes}; + use axum::http::{Request, StatusCode}; + use axum::response::Response; + use tower::ServiceExt; + + #[tokio::test] + async fn test_health_endpoint() { + // Call the health function directly + let response = health().await; + assert_eq!(response, "ok".to_string()); + } + + #[tokio::test] + async fn test_health_route() { + // Create the router + let app = create_router(); + + // Create a request to the health endpoint + let request = Request::builder() + .uri("/health") + .method("GET") + .body(Body::empty()) + .unwrap(); + + // Process the request + let response = app.oneshot(request).await.unwrap(); + + // Check the response status + assert_eq!(response.status(), StatusCode::OK); + + // Check the response body + let body = response_body_bytes(response).await; + assert_eq!(&body[..], b"ok"); + } + + #[tokio::test] + async fn test_not_found_route() { + // Create the router + let app = create_router(); + + // Create a request to a non-existent endpoint + let request = Request::builder() + .uri("/non-existent") + .method("GET") + .body(Body::empty()) + .unwrap(); + + // Process the request + let response = app.oneshot(request).await.unwrap(); + + // Check the response status + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + + // Helper function to extract bytes from a response body + async fn response_body_bytes(response: Response) -> Bytes { + let body = response.into_body(); + // Use a reasonable size limit for the body (16MB) + let bytes = axum::body::to_bytes(body, 16 * 1024 * 1024) + .await + .expect("Failed to read response body"); + bytes + } +}