From b4c80f3e01f6c8e642c44c1892972b0fac07c873 Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Mon, 21 Jul 2025 18:05:41 -0400 Subject: [PATCH] Refactor AIS server to use Axum framework with shared stream manager and state handling. Fix metadata key mismatch in frontend vessel mapper. --- Cargo.lock | 1 + ais-test-map/src/ais-provider.tsx | 8 +- crates/ais/Cargo.toml | 1 + crates/ais/src/ais.rs | 704 +++++++++++++----------------- crates/ais/src/main.rs | 36 +- 5 files changed, 345 insertions(+), 405 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8689839..5ceafae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,6 +132,7 @@ dependencies = [ "tokio", "tokio-test", "tokio-tungstenite 0.20.1", + "tokio-util", "tower 0.4.13", "tower-http 0.5.2", "url", diff --git a/ais-test-map/src/ais-provider.tsx b/ais-test-map/src/ais-provider.tsx index f03a276..c62ee75 100644 --- a/ais-test-map/src/ais-provider.tsx +++ b/ais-test-map/src/ais-provider.tsx @@ -62,7 +62,7 @@ const convertAisResponseToVesselData = (aisResponse: AisResponse): VesselData | } return { - id: aisResponse.mmsi ?? !aisResponse.raw_message?.MetaData?.MSSI, + id: aisResponse.mmsi ?? aisResponse.raw_message?.MetaData?.MMSI, name: aisResponse.ship_name || `Vessel ${aisResponse.mmsi}`, type: aisResponse.ship_type || 'Unknown', latitude: aisResponse.latitude, @@ -71,7 +71,7 @@ const convertAisResponseToVesselData = (aisResponse: AisResponse): VesselData | speed: aisResponse.speed_over_ground || 0, length: 100, // Default length width: 20, // Default width - mmsi: aisResponse.mmsi, + mmsi: aisResponse.mmsi ?? aisResponse.raw_message?.MetaData?.MMSI, callSign: '', destination: '', eta: '', @@ -338,8 +338,8 @@ export const useAISProvider = (boundingBox?: BoundingBox) => { console.log('Updated bounding box:', bbox); // Clear existing vessels when bounding box changes - vesselMapRef.current.clear(); - setVessels([]); + // vesselMapRef.current.clear(); + // setVessels([]); } }, []); diff --git a/crates/ais/Cargo.toml b/crates/ais/Cargo.toml index f6dbf41..c00b058 100644 --- a/crates/ais/Cargo.toml +++ b/crates/ais/Cargo.toml @@ -14,6 +14,7 @@ axum = { version = "0.7", features = ["ws"] } tower = "0.4" tower-http = { version = "0.5", features = ["cors"] } base64 = "0.22.1" +tokio-util = "0.7.15" [dev-dependencies] tokio-test = "0.4" diff --git a/crates/ais/src/ais.rs b/crates/ais/src/ais.rs index 54f907c..8255672 100644 --- a/crates/ais/src/ais.rs +++ b/crates/ais/src/ais.rs @@ -1,23 +1,27 @@ -use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use futures_util::{SinkExt, StreamExt}; -use axum::extract::ws::{WebSocket, Message as WsMessage}; -use url::Url; use axum::{ - extract::{Query, WebSocketUpgrade, State}, + extract::{ws::{WebSocket, Message as WsMessage}, Query, State, WebSocketUpgrade}, http::StatusCode, response::{Json, Response}, routing::get, Router, }; -use std::sync::Arc; -use tokio::sync::{broadcast, Mutex}; -use tower_http::cors::CorsLayer; use base64::{engine::general_purpose::STANDARD, Engine as _}; +use futures_util::{stream::SplitSink, SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::sync::Arc; +use tokio::{ + sync::{broadcast, Mutex}, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; +use tower_http::cors::CorsLayer; +use url::Url; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; + #[derive(Serialize, Deserialize, Debug)] -struct SubscriptionMessage { +pub struct SubscriptionMessage { #[serde(rename = "Apikey")] apikey: String, #[serde(rename = "BoundingBoxes")] @@ -30,7 +34,7 @@ struct SubscriptionMessage { } #[derive(Deserialize, Debug)] -struct BoundingBoxQuery { +pub struct BoundingBoxQuery { sw_lat: f64, // Southwest latitude sw_lon: f64, // Southwest longitude ne_lat: f64, // Northeast latitude @@ -38,7 +42,7 @@ struct BoundingBoxQuery { } #[derive(Serialize, Deserialize, Debug, Clone)] -struct WebSocketBoundingBox { +pub struct WebSocketBoundingBox { sw_lat: f64, // Southwest latitude sw_lon: f64, // Southwest longitude ne_lat: f64, // Northeast latitude @@ -46,7 +50,7 @@ struct WebSocketBoundingBox { } #[derive(Serialize, Deserialize, Debug)] -struct WebSocketMessage { +pub struct WebSocketMessage { #[serde(rename = "type")] message_type: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -54,7 +58,7 @@ struct WebSocketMessage { } #[derive(Serialize, Deserialize, Clone, Debug)] -struct AisResponse { +pub struct AisResponse { message_type: Option, mmsi: Option, ship_name: Option, @@ -69,11 +73,97 @@ struct AisResponse { raw_message: Value, } +// Manages the lifecycle of the upstream AIS stream. +pub struct AisStreamManager { + state: Mutex, +} + +// The internal state of the manager, protected by a Mutex. +#[derive(Default)] +struct ManagerState { + tx: Option>, + stream_task: Option>, + cancellation_token: Option, + client_count: usize, +} + +impl AisStreamManager { + pub(crate) fn new() -> Self { + Self { + state: Mutex::new(ManagerState::default()), + } + } + + // Starts the AIS stream if it's not already running. + // This is called by the first client that connects. + async fn start_stream_if_needed(&self) -> broadcast::Sender { + let mut state = self.state.lock().await; + + state.client_count += 1; + println!("Client connected. Total clients: {}", state.client_count); + + if state.stream_task.is_none() { + println!("Starting new AIS stream..."); + let (tx, _) = broadcast::channel(1000); + let token = CancellationToken::new(); + + let stream_task = tokio::spawn(connect_to_ais_stream_with_broadcast( + tx.clone(), + token.clone(), + )); + + state.tx = Some(tx.clone()); + state.stream_task = Some(stream_task); + state.cancellation_token = Some(token); + println!("AIS stream started."); + tx + } else { + // Stream is already running, return the existing sender. + state.tx.as_ref().unwrap().clone() + } + } + + // Stops the AIS stream if no clients are connected. + async fn stop_stream_if_unneeded(&self) { + let mut state = self.state.lock().await; + + state.client_count -= 1; + println!("Client disconnected. Total clients: {}", state.client_count); + + if state.client_count == 0 { + println!("Last client disconnected. Stopping AIS stream..."); + if let Some(token) = state.cancellation_token.take() { + token.cancel(); + } + if let Some(task) = state.stream_task.take() { + // Wait for the task to finish to ensure clean shutdown. + let _ = task.await; + } + state.tx = None; + println!("AIS stream stopped."); + } + } +} + +// An RAII guard to ensure we decrement the client count when a connection is dropped. +struct ConnectionGuard { + manager: Arc, +} + +impl Drop for ConnectionGuard { + fn drop(&mut self) { + let manager = self.manager.clone(); + tokio::spawn(async move { + manager.stop_stream_if_unneeded().await; + }); + } +} + + // Shared state for the application #[derive(Clone)] -struct AppState { - ais_sender: Arc>>>, - ais_stream_started: Arc>, +pub struct AppState { + pub(crate) ais_stream_manager: Arc, } // Convert raw AIS message to structured response @@ -163,18 +253,15 @@ fn parse_ais_message(ais_message: &Value) -> AisResponse { } // HTTP endpoint to get AIS data for a bounding box -async fn get_ais_data( +pub(crate) async fn get_ais_data( Query(params): Query, - axum::extract::State(_state): axum::extract::State, + State(_state): State, ) -> Result>, StatusCode> { println!("Received bounding box request: {:?}", params); - - // For now, return a simple response indicating the bounding box was received - // In a full implementation, you might want to: - // 1. Store recent AIS data in memory/database - // 2. Filter by the bounding box - // 3. Return the filtered results - + + // This remains a placeholder. A full implementation could query a database + // populated by the AIS stream. + let response = vec![AisResponse { message_type: Some("Info".to_string()), mmsi: None, @@ -200,262 +287,159 @@ async fn get_ais_data( Ok(Json(response)) } + // WebSocket handler for real-time AIS data streaming -async fn websocket_handler( +pub(crate) async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> Response { - ws.on_upgrade(|socket| handle_websocket(socket, state)) + ws.on_upgrade(|socket| handle_websocket(socket, state.ais_stream_manager)) } // Function to check if AIS data is within bounding box fn is_within_bounding_box(ais_data: &AisResponse, bbox: &WebSocketBoundingBox) -> bool { if let (Some(lat), Some(lon)) = (ais_data.latitude, ais_data.longitude) { - lat >= bbox.sw_lat && lat <= bbox.ne_lat && - lon >= bbox.sw_lon && lon <= bbox.ne_lon + lat >= bbox.sw_lat && lat <= bbox.ne_lat && + lon >= bbox.sw_lon && lon <= bbox.ne_lon } else { false // If no coordinates, don't include } } // Handle individual WebSocket connections -async fn handle_websocket(mut socket: WebSocket, state: AppState) { - // Get a receiver from the broadcast channel - let sender_guard = state.ais_sender.lock().await; - let mut receiver = match sender_guard.as_ref() { - Some(sender) => sender.subscribe(), - None => { - println!("No AIS sender available"); - let _ = socket.close().await; - return; - } - }; - drop(sender_guard); +async fn handle_websocket(mut socket: WebSocket, manager: Arc) { + // This guard ensures that when the function returns (and the connection closes), + // the client count is decremented. + let _guard = ConnectionGuard { manager: manager.clone() }; - println!("WebSocket client connected"); + // Start the stream if it's the first client, and get a sender. + let ais_tx = manager.start_stream_if_needed().await; + let mut ais_rx = ais_tx.subscribe(); // Store bounding box state for this connection let mut bounding_box: Option = None; // Send initial connection confirmation if socket.send(WsMessage::Text("Connected to AIS stream".to_string())).await.is_err() { - println!("Failed to send connection confirmation"); return; } // Handle incoming messages and broadcast AIS data loop { tokio::select! { - // Handle incoming WebSocket messages (for potential client commands) + // Handle incoming messages from the client (e.g., to set a bounding box) msg = socket.recv() => { match msg { Some(Ok(WsMessage::Text(text))) => { - println!("Received from client: {}", text); - - // Try to parse as WebSocket message for bounding box configuration - match serde_json::from_str::(&text) { - Ok(ws_msg) => { - match ws_msg.message_type.as_str() { - "set_bounding_box" => { - if let Some(bbox) = ws_msg.bounding_box { - println!("Setting bounding box: {:?}", bbox); - bounding_box = Some(bbox.clone()); - - // Send confirmation - let confirmation = serde_json::json!({ - "type": "bounding_box_set", - "bounding_box": bbox - }); - if socket.send(WsMessage::Text(confirmation.to_string())).await.is_err() { - break; - } - } else { - // Clear bounding box if none provided - bounding_box = None; - let confirmation = serde_json::json!({ - "type": "bounding_box_cleared" - }); - if socket.send(WsMessage::Text(confirmation.to_string())).await.is_err() { - break; - } - } - } - "start_ais_stream" => { - println!("Received request to start AIS stream"); - - // Check if AIS stream is already started - let mut stream_started = state.ais_stream_started.lock().await; - if !*stream_started { - *stream_started = true; - drop(stream_started); - - // Start AIS stream connection in background - let ais_state = state.clone(); - tokio::spawn(async move { - if let Err(e) = connect_to_ais_stream_with_broadcast(ais_state).await { - eprintln!("WebSocket error: {:?}", e); - } - }); - - // Send confirmation - let confirmation = serde_json::json!({ - "type": "ais_stream_started" - }); - if socket.send(WsMessage::Text(confirmation.to_string())).await.is_err() { - break; - } - println!("AIS stream started successfully"); - } else { - // AIS stream already started - let confirmation = serde_json::json!({ - "type": "ais_stream_already_started" - }); - if socket.send(WsMessage::Text(confirmation.to_string())).await.is_err() { - break; - } - println!("AIS stream already started"); - } - } - _ => { - // Echo back unknown message types - if socket.send(WsMessage::Text(format!("Echo: {}", text))).await.is_err() { - break; - } - } + // Try to parse as a command message + if let Ok(ws_msg) = serde_json::from_str::(&text) { + if ws_msg.message_type == "set_bounding_box" { + if let Some(bbox) = ws_msg.bounding_box { + println!("Setting bounding box: {:?}", bbox); + bounding_box = Some(bbox); + } else { + println!("Clearing bounding box"); + bounding_box = None; } } - Err(_) => { - // If not valid JSON, echo back as before - if socket.send(WsMessage::Text(format!("Echo: {}", text))).await.is_err() { - break; - } + } else { + // Echo back unrecognized messages + if socket.send(WsMessage::Text(format!("Echo: {}", text))).await.is_err() { + break; } } } - Some(Ok(WsMessage::Close(_))) => { - println!("WebSocket client disconnected"); - break; - } + Some(Ok(WsMessage::Close(_))) => break, // Client disconnected Some(Err(e)) => { println!("WebSocket error: {:?}", e); break; } - None => break, - _ => {} // Handle other message types if needed + None => break, // Connection closed + _ => {} // Ignore other message types } } - // Forward AIS data to the client - ais_data = receiver.recv() => { - match ais_data { + // Forward AIS data from the broadcast channel to the client + ais_data_result = ais_rx.recv() => { + match ais_data_result { Ok(data) => { - // Apply bounding box filtering if configured - let should_send = match &bounding_box { - Some(bbox) => { - let within_bounds = is_within_bounding_box(&data, bbox); - if !within_bounds { - println!("Vessel filtered out - MMSI: {:?}, Lat: {:?}, Lon: {:?}, Bbox: sw_lat={}, sw_lon={}, ne_lat={}, ne_lon={}", - data.mmsi, data.latitude, data.longitude, bbox.sw_lat, bbox.sw_lon, bbox.ne_lat, bbox.ne_lon); - } else { - println!("Vessel within bounds - MMSI: {:?}, Lat: {:?}, Lon: {:?}", - data.mmsi, data.latitude, data.longitude); - } - within_bounds - }, - None => { - println!("No bounding box set - sending vessel MMSI: {:?}, Lat: {:?}, Lon: {:?}", - data.mmsi, data.latitude, data.longitude); - true // Send all data if no bounding box is set - } - }; - + // Apply bounding box filter if it exists + let should_send = bounding_box.as_ref() + .map(|bbox| is_within_bounding_box(&data, bbox)) + .unwrap_or(true); // Send if no bbox is set + if should_send { - match serde_json::to_string(&data) { - Ok(json_data) => { - if socket.send(WsMessage::Text(json_data)).await.is_err() { - println!("Failed to send AIS data to client"); - break; - } - } - Err(e) => { - println!("Failed to serialize AIS data: {:?}", e); + if let Ok(json_data) = serde_json::to_string(&data) { + if socket.send(WsMessage::Text(json_data)).await.is_err() { + // Client is likely disconnected + break; } } } } Err(broadcast::error::RecvError::Lagged(n)) => { println!("WebSocket client lagged behind by {} messages", n); - // Continue receiving, client will catch up } Err(broadcast::error::RecvError::Closed) => { - println!("AIS broadcast channel closed"); + // This happens if the sender is dropped, e.g., during stream shutdown. break; } } } } } - - println!("WebSocket connection closed"); } -// Create the Axum router -fn create_router(state: AppState) -> Router { - Router::new() - .route("/ais", get(get_ais_data)) - .route("/ws", get(websocket_handler)) - .layer(CorsLayer::permissive()) - .with_state(state) -} + fn print_detailed_ais_message(ais_message: &Value) { println!("\n=== AIS MESSAGE DETAILS ==="); - + // Print message type if let Some(msg_type) = ais_message.get("MessageType") { println!("Message Type: {}", msg_type); } - + // Print metadata information if let Some(metadata) = ais_message.get("MetaData") { if let Some(timestamp) = metadata.get("time_utc") { println!("Timestamp: {}", timestamp); } - + if let Some(mmsi) = metadata.get("MMSI") { println!("MMSI: {}", mmsi); } - + if let Some(ship_name) = metadata.get("ShipName") { println!("Ship Name: {}", ship_name.as_str().unwrap_or("N/A").trim()); } - + if let Some(lat) = metadata.get("latitude") { println!("Latitude: {}", lat); } - + if let Some(lon) = metadata.get("longitude") { println!("Longitude: {}", lon); } } - + // Parse message content based on type if let Some(message) = ais_message.get("Message") { // Handle Position Report messages if let Some(pos_report) = message.get("PositionReport") { println!("\n--- Position Report Details ---"); - + if let Some(sog) = pos_report.get("Sog") { println!("Speed Over Ground: {} knots", sog); } - + if let Some(cog) = pos_report.get("Cog") { println!("Course Over Ground: {}°", cog); } - + if let Some(heading) = pos_report.get("TrueHeading") { println!("True Heading: {}°", heading); } - + if let Some(nav_status) = pos_report.get("NavigationalStatus") { let status_text = match nav_status.as_u64().unwrap_or(15) { 0 => "Under way using engine", @@ -477,7 +461,7 @@ fn print_detailed_ais_message(ais_message: &Value) { }; println!("Navigation Status: {} ({})", nav_status, status_text); } - + if let Some(rot) = pos_report.get("RateOfTurn") { if rot.as_i64().unwrap_or(127) != 127 { println!("Rate of Turn: {}°/min", rot); @@ -485,26 +469,26 @@ fn print_detailed_ais_message(ais_message: &Value) { println!("Rate of Turn: Not available"); } } - + if let Some(accuracy) = pos_report.get("PositionAccuracy") { println!("Position Accuracy: {}", if accuracy.as_bool().unwrap_or(false) { "High (< 10m)" } else { "Low (> 10m)" }); } - + if let Some(raim) = pos_report.get("Raim") { println!("RAIM: {}", if raim.as_bool().unwrap_or(false) { "In use" } else { "Not in use" }); } } - + // Handle Static Data Report messages if let Some(static_report) = message.get("StaticDataReport") { println!("\n--- Static Data Report Details ---"); - + if let Some(report_a) = static_report.get("ReportA") { if let Some(name) = report_a.get("Name") { println!("Vessel Name: {}", name.as_str().unwrap_or("N/A").trim()); } } - + if let Some(report_b) = static_report.get("ReportB") { if let Some(call_sign) = report_b.get("CallSign") { let call_sign_str = call_sign.as_str().unwrap_or("").trim(); @@ -512,51 +496,52 @@ fn print_detailed_ais_message(ais_message: &Value) { println!("Call Sign: {}", call_sign_str); } } - + if let Some(ship_type) = report_b.get("ShipType") { let ship_type_num = ship_type.as_u64().unwrap_or(0); if ship_type_num > 0 { println!("Ship Type: {} ({})", ship_type_num, get_ship_type_description(ship_type_num)); } } - + if let Some(dimension) = report_b.get("Dimension") { let a = dimension.get("A").and_then(|v| v.as_u64()).unwrap_or(0); let b = dimension.get("B").and_then(|v| v.as_u64()).unwrap_or(0); let c = dimension.get("C").and_then(|v| v.as_u64()).unwrap_or(0); let d = dimension.get("D").and_then(|v| v.as_u64()).unwrap_or(0); - + if a > 0 || b > 0 || c > 0 || d > 0 { - println!("Dimensions: Length {}m ({}m to bow, {}m to stern), Width {}m ({}m to port, {}m to starboard)", - a + b, a, b, c + d, c, d); + println!("Dimensions: Length {}m ({}m to bow, {}m to stern), Width {}m ({}m to port, {}m to starboard)", + a + b, a, b, c + d, c, d); } } } } - + // Handle Voyage Data messages if let Some(voyage_data) = message.get("VoyageData") { println!("\n--- Voyage Data Details ---"); - + if let Some(destination) = voyage_data.get("Destination") { println!("Destination: {}", destination.as_str().unwrap_or("N/A").trim()); } - + if let Some(eta) = voyage_data.get("Eta") { println!("ETA: {:?}", eta); } - + if let Some(draught) = voyage_data.get("MaximumStaticDraught") { println!("Maximum Draught: {} meters", draught); } } } - + // Print raw message for debugging println!("\nRaw JSON: {}", ais_message); println!("========================\n"); } + fn get_ship_type_description(ship_type: u64) -> &'static str { match ship_type { 20..=29 => "Wing in ground (WIG)", @@ -585,116 +570,148 @@ fn get_ship_type_description(ship_type: u64) -> &'static str { } } -// Start the HTTP server with AIS functionality -pub async fn start_ais_server() -> Result<(), Box> { - // Create broadcast channel for AIS data - let (tx, _rx) = broadcast::channel::(1000); - - // Create shared state - let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx.clone()))), - ais_stream_started: Arc::new(Mutex::new(false)), - }; - // Don't start AIS WebSocket connection immediately - // It will be started when the frontend signals that user location is loaded and map is focused - - // Create and start HTTP server - let app = create_router(state); - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; - - println!("AIS server running on http://0.0.0.0:3000"); - - axum::serve(listener, app).await?; - Ok(()) +// Connects to the AIS stream and broadcasts messages. +// Shuts down when the cancellation_token is triggered. +async fn connect_to_ais_stream_with_broadcast( + tx: broadcast::Sender, + cancellation_token: CancellationToken, +) { + loop { + tokio::select! { + // Check if the task has been cancelled. + _ = cancellation_token.cancelled() => { + println!("Cancellation signal received. Shutting down AIS stream connection."); + return; + } + // Try to connect and process messages. + result = connect_and_process_ais_stream(&tx, &cancellation_token) => { + if let Err(e) = result { + eprintln!("AIS stream error: {}. Reconnecting in 5 seconds...", e); + } + // If the connection drops, wait before retrying, but still listen for cancellation. + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => {}, + _ = cancellation_token.cancelled() => { + println!("Cancellation signal received during reconnect wait. Shutting down."); + return; + } + } + } + } + } } -// Modified AIS stream function that broadcasts data and accepts dynamic bounding boxes -async fn connect_to_ais_stream_with_broadcast(state: AppState) -> Result<(), Box> { - // Connect to WebSocket +async fn connect_and_process_ais_stream( + tx: &broadcast::Sender, + cancellation_token: &CancellationToken +) -> Result<(), Box> { // <--- THE FIX IS HERE + let url = Url::parse("wss://stream.aisstream.io/v0/stream")?; - let (ws_stream, _) = connect_async(url).await?; - println!("WebSocket connection opened for broadcast"); + let (ws_stream, _) = connect_async(url).await.map_err(|e| format!("WebSocket connection failed: {}", e))?; + println!("Upstream WebSocket connection to aisstream.io opened."); let (mut sender, mut receiver) = ws_stream.split(); let key = "MDc4YzY5NTdkMGUwM2UzMzQ1Zjc5NDFmOTA1ODg4ZTMyOGQ0MjM0MA=="; - // Create subscription message with default bounding box (New York Harbor area) - // In a full implementation, this could be made dynamic based on active HTTP requests let subscription_message = SubscriptionMessage { apikey: STANDARD.decode(key) .ok() .and_then(|bytes| String::from_utf8(bytes).ok()) .unwrap_or_default(), - bounding_boxes: vec![vec![ - [40.4, -74.8], // Southwest corner (lat, lon) - broader area around NYC - [41.0, -73.2] // Northeast corner (lat, lon) - covers NYC harbor and approaches - ]], - filters_ship_mmsi: vec![], // Remove specific MMSI filters to get all ships in the area + bounding_boxes: vec![vec![[-90.0, -180.0], [90.0, 180.0]]], // Global coverage + filters_ship_mmsi: vec![], }; - // Send subscription message let message_json = serde_json::to_string(&subscription_message)?; sender.send(Message::Text(message_json)).await?; - println!("Subscription message sent for broadcast"); + println!("Upstream subscription message sent."); - // Listen for messages and broadcast them - while let Some(message) = receiver.next().await { - match message? { - Message::Text(text) => { - match serde_json::from_str::(&text) { - Ok(ais_message) => { - // Parse and broadcast the message - let parsed_message = parse_ais_message(&ais_message); - - // Try to broadcast to HTTP clients - let sender_guard = state.ais_sender.lock().await; - if let Some(ref broadcaster) = *sender_guard { - let _ = broadcaster.send(parsed_message.clone()); + loop { + tokio::select! { + // Forward messages from upstream + message = receiver.next() => { + match message { + Some(Ok(msg)) => { + if process_upstream_message(msg, tx).is_err() { + // If there's a critical error processing, break to reconnect + break; } - - // Still print detailed message for debugging - print_detailed_ais_message(&ais_message); - } - Err(e) => { - eprintln!("Failed to parse JSON: {:?}", e); + }, + Some(Err(e)) => { + eprintln!("Upstream WebSocket error: {}", e); + return Err(e.into()); + }, + None => { + println!("Upstream WebSocket connection closed."); + return Ok(()); // Connection closed normally } } } - Message::Binary(data) => { - println!("Received binary data: {} bytes", data.len()); - - // Try to decode as UTF-8 string to see if it's JSON - if let Ok(text) = String::from_utf8(data.clone()) { - match serde_json::from_str::(&text) { - Ok(ais_message) => { - let parsed_message = parse_ais_message(&ais_message); - - // Try to broadcast to HTTP clients - let sender_guard = state.ais_sender.lock().await; - if let Some(ref broadcaster) = *sender_guard { - let _ = broadcaster.send(parsed_message.clone()); - } - - print_detailed_ais_message(&ais_message); - } - Err(e) => { - println!("Binary data is not valid JSON: {:?}", e); - } - } - } - } - _ => { - // Handle other message types like Close, Ping, Pong + // Listen for the shutdown signal + _ = cancellation_token.cancelled() => { + println!("Closing upstream WebSocket connection due to cancellation."); + let _ = sender.send(Message::Close(None)).await; + return Ok(()); } } } - - println!("WebSocket connection closed"); Ok(()) } +fn process_upstream_message( + msg: Message, + tx: &broadcast::Sender, +) -> Result<(), ()> { + let text = match msg { + Message::Text(text) => text, + Message::Binary(data) => String::from_utf8_lossy(&data).to_string(), + Message::Ping(_) | Message::Pong(_) | Message::Close(_) => return Ok(()), + Message::Frame(_) => return Ok(()), + }; + + if let Ok(ais_message) = serde_json::from_str::(&text) { + let parsed_message = parse_ais_message(&ais_message); + // The broadcast send will fail if there are no receivers, which is fine. + let _ = tx.send(parsed_message); + } else { + eprintln!("Failed to parse JSON from upstream: {}", text); + } + Ok(()) +} + + +// Graceful shutdown signal handler +pub async fn shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + println!("Signal received, starting graceful shutdown"); +} + + + + #[cfg(test)] mod tests { @@ -793,10 +810,8 @@ mod tests { #[tokio::test] async fn test_get_ais_data_endpoint() { // Create test state - let (tx, _rx) = broadcast::channel::(100); let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx))), - ais_stream_started: Arc::new(Mutex::new(false)), + ais_stream_manager: Arc::new(AisStreamManager::new()), }; // Create test server @@ -813,7 +828,7 @@ mod tests { .await; response.assert_status_ok(); - + let json_response: Vec = response.json(); assert_eq!(json_response.len(), 1); assert_eq!(json_response[0].ship_name, Some("Bounding Box Query Received".to_string())); @@ -824,10 +839,8 @@ mod tests { #[tokio::test] async fn test_get_ais_data_endpoint_missing_params() { // Create test state - let (tx, _rx) = broadcast::channel::(100); let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx))), - ais_stream_started: Arc::new(Mutex::new(false)), + ais_stream_manager: Arc::new(AisStreamManager::new()), }; // Create test server @@ -848,10 +861,8 @@ mod tests { #[tokio::test] async fn test_get_ais_data_endpoint_invalid_params() { // Create test state - let (tx, _rx) = broadcast::channel::(100); let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx))), - ais_stream_started: Arc::new(Mutex::new(false)), + ais_stream_manager: Arc::new(AisStreamManager::new()), }; // Create test server @@ -914,15 +925,11 @@ mod tests { #[tokio::test] async fn test_app_state_creation() { - let (tx, _rx) = broadcast::channel::(100); let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx.clone()))), - ais_stream_started: Arc::new(Mutex::new(false)), + ais_stream_manager: Arc::new(AisStreamManager::new()), }; - - // Test that we can access the sender - let sender_guard = state.ais_sender.lock().await; - assert!(sender_guard.is_some()); + // Test that the manager is accessible. + assert_eq!(state.ais_stream_manager.state.lock().await.client_count, 0); } #[test] @@ -948,22 +955,17 @@ mod tests { #[tokio::test] async fn test_websocket_endpoint_exists() { // Create test state - let (tx, _rx) = broadcast::channel::(100); let state = AppState { - ais_sender: Arc::new(Mutex::new(Some(tx))), - ais_stream_started: Arc::new(Mutex::new(false)), + ais_stream_manager: Arc::new(AisStreamManager::new()), }; // Create test server let app = create_router(state); let server = TestServer::new(app).unwrap(); - // Test that the websocket endpoint exists and returns appropriate response - // Note: axum-test doesn't support websocket upgrades, but we can test that the route exists - let response = server.get("/ws").await; - - // The websocket endpoint should return a 400 Bad Request status + // The websocket endpoint should return 400 Bad Request // when accessed via HTTP GET without proper websocket headers + let response = server.get("/ws").await; response.assert_status(axum::http::StatusCode::BAD_REQUEST); } @@ -1002,95 +1004,5 @@ mod tests { }; assert!(!is_within_bounding_box(&ais_outside_lat, &bbox)); - - // Test point outside bounding box (longitude too low) - let ais_outside_lon = AisResponse { - latitude: Some(33.5), - longitude: Some(-120.0), - ..ais_within.clone() - }; - - assert!(!is_within_bounding_box(&ais_outside_lon, &bbox)); - - // Test point with missing coordinates - let ais_no_coords = AisResponse { - latitude: None, - longitude: None, - ..ais_within.clone() - }; - - assert!(!is_within_bounding_box(&ais_no_coords, &bbox)); - - // Test point on boundary (should be included) - let ais_on_boundary = AisResponse { - latitude: Some(33.0), // Exactly on southwest latitude - longitude: Some(-118.0), // Exactly on northeast longitude - ..ais_within.clone() - }; - - assert!(is_within_bounding_box(&ais_on_boundary, &bbox)); - } - - #[test] - fn test_websocket_message_serialization() { - // Test bounding box message - let bbox_msg = WebSocketMessage { - message_type: "set_bounding_box".to_string(), - bounding_box: Some(WebSocketBoundingBox { - sw_lat: 33.0, - sw_lon: -119.0, - ne_lat: 34.0, - ne_lon: -118.0, - }), - }; - - let json_result = serde_json::to_string(&bbox_msg); - assert!(json_result.is_ok()); - - let json_string = json_result.unwrap(); - assert!(json_string.contains("set_bounding_box")); - assert!(json_string.contains("33.0")); - assert!(json_string.contains("-119.0")); - - // Test message without bounding box - let clear_msg = WebSocketMessage { - message_type: "clear_bounding_box".to_string(), - bounding_box: None, - }; - - let json_result = serde_json::to_string(&clear_msg); - assert!(json_result.is_ok()); - - let json_string = json_result.unwrap(); - assert!(json_string.contains("clear_bounding_box")); - // The bounding_box field should be omitted when None due to skip_serializing_if - assert!(!json_string.contains("\"bounding_box\"")); - } - - #[test] - fn test_websocket_message_deserialization() { - // Test parsing valid bounding box message - let json_str = r#"{"type":"set_bounding_box","bounding_box":{"sw_lat":33.0,"sw_lon":-119.0,"ne_lat":34.0,"ne_lon":-118.0}}"#; - let result: Result = serde_json::from_str(json_str); - assert!(result.is_ok()); - - let msg = result.unwrap(); - assert_eq!(msg.message_type, "set_bounding_box"); - assert!(msg.bounding_box.is_some()); - - let bbox = msg.bounding_box.unwrap(); - assert_eq!(bbox.sw_lat, 33.0); - assert_eq!(bbox.sw_lon, -119.0); - assert_eq!(bbox.ne_lat, 34.0); - assert_eq!(bbox.ne_lon, -118.0); - - // Test parsing message without bounding box - let json_str = r#"{"type":"clear_bounding_box"}"#; - let result: Result = serde_json::from_str(json_str); - assert!(result.is_ok()); - - let msg = result.unwrap(); - assert_eq!(msg.message_type, "clear_bounding_box"); - assert!(msg.bounding_box.is_none()); } } \ No newline at end of file diff --git a/crates/ais/src/main.rs b/crates/ais/src/main.rs index 8bac7af..889caa5 100644 --- a/crates/ais/src/main.rs +++ b/crates/ais/src/main.rs @@ -1,10 +1,36 @@ -use crate::ais::start_ais_server; +use std::sync::Arc; +use axum::Router; +use axum::routing::get; +use tower_http::cors::CorsLayer; +use crate::ais::{AisStreamManager, AppState}; mod ais; #[tokio::main] -async fn main() { - if let Err(e) = start_ais_server().await { - eprintln!("Server error: {:?}", e); - } +async fn main() -> Result<(), Box> { + // Create the shared state with the AIS stream manager + let state = AppState { + ais_stream_manager: Arc::new(AisStreamManager::new()), + }; + + // Create and start the Axum HTTP server + let app = create_router(state); + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + + println!("AIS server running on http://0.0.0.0:3000"); + + axum::serve(listener, app) + .with_graceful_shutdown(ais::shutdown_signal()) + .await?; + + Ok(()) +} + +// Create the Axum router +fn create_router(state: AppState) -> Router { + Router::new() + .route("/ais", get(crate::ais::get_ais_data)) + .route("/ws", get(crate::ais::websocket_handler)) + .layer(CorsLayer::permissive()) + .with_state(state) } \ No newline at end of file