init
This commit is contained in:
17
src/handlers/error.rs
Normal file
17
src/handlers/error.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
// src/handlers/error.rs
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn handle_not_found() -> impl IntoResponse {
|
||||
tracing::warn!("404 Not Found error occurred");
|
||||
|
||||
let error_response = serde_json::json!({
|
||||
"error": "Route Not Found",
|
||||
"status": 404
|
||||
});
|
||||
|
||||
(StatusCode::NOT_FOUND, Json(error_response))
|
||||
}
|
7
src/handlers/mod.rs
Normal file
7
src/handlers/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
// src/handlers/mod.rs
|
||||
pub mod error;
|
||||
pub mod status;
|
||||
pub mod stream;
|
||||
pub mod ui;
|
||||
pub mod webhooks;
|
5
src/handlers/status.rs
Normal file
5
src/handlers/status.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// src/handlers/status.rs
|
||||
pub async fn handle_status() -> &'static str {
|
||||
tracing::debug!("Status check requested");
|
||||
"Server is running"
|
||||
}
|
82
src/handlers/stream.rs
Normal file
82
src/handlers/stream.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
pub async fn handle_stream() -> impl IntoResponse {
|
||||
use tokio::process::Command;
|
||||
|
||||
let user_input = "Who won the 2024 election?";
|
||||
tracing::debug!("Handling stream request with input: {}", user_input);
|
||||
|
||||
// Check environment variables
|
||||
for env_var in ["OPENAI_API_KEY", "BING_SEARCH_API_KEY", "TAVILY_API_KEY"] {
|
||||
if std::env::var(env_var).is_ok() {
|
||||
tracing::debug!("{} is set", env_var);
|
||||
} else {
|
||||
tracing::warn!("{} is not set", env_var);
|
||||
}
|
||||
}
|
||||
|
||||
let mut cmd = match Command::new("genaiscript")
|
||||
.arg("run")
|
||||
.arg("genaisrc/web-search.genai.mts")
|
||||
.arg("--vars")
|
||||
.arg(format!("USER_INPUT='{}'", user_input))
|
||||
.env("OPENAI_API_KEY", std::env::var("OPENAI_API_KEY").unwrap_or_default())
|
||||
.env("BING_SEARCH_API_KEY", std::env::var("BING_SEARCH_API_KEY").unwrap_or_default())
|
||||
.env("TAVILY_API_KEY", std::env::var("TAVILY_API_KEY").unwrap_or_default())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn() {
|
||||
Ok(cmd) => {
|
||||
tracing::debug!("Successfully spawned genaiscript process");
|
||||
cmd
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to spawn genaiscript process: {}", e);
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Failed to start process"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => {
|
||||
tracing::debug!("Successfully captured stdout from process");
|
||||
stdout
|
||||
}
|
||||
None => {
|
||||
tracing::error!("Failed to capture stdout from process");
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Failed to capture process output"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = tokio::io::BufReader::new(stdout);
|
||||
let stream = ReaderStream::new(reader);
|
||||
let mapped_stream = stream.map(|r| {
|
||||
match r {
|
||||
Ok(bytes) => {
|
||||
tracing::trace!("Received {} bytes from stream", bytes.len());
|
||||
Ok(bytes)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading from stream: {}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tracing::debug!("Setting up SSE response");
|
||||
Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.body(Body::from_stream(mapped_stream))
|
||||
.unwrap()
|
||||
}
|
34
src/handlers/ui.rs
Normal file
34
src/handlers/ui.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{StatusCode, header::CONTENT_TYPE},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rust_embed::RustEmbed;
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "assets/"]
|
||||
struct Asset;
|
||||
|
||||
pub async fn serve_ui() -> impl IntoResponse {
|
||||
debug!("Serving UI request");
|
||||
|
||||
// Attempt to retrieve the embedded "index.html"
|
||||
match Asset::get("index.html") {
|
||||
Some(content) => {
|
||||
debug!("Successfully retrieved index.html");
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "text/html")
|
||||
.body(Body::from(content.data))
|
||||
.unwrap()
|
||||
}
|
||||
None => {
|
||||
error!("index.html not found in embedded assets");
|
||||
Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from("404 Not Found"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
261
src/handlers/webhooks.rs
Normal file
261
src/handlers/webhooks.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use crate::agents;
|
||||
use crate::agents::news::news_agent;
|
||||
use crate::agents::scrape::scrape_agent;
|
||||
use crate::agents::search::search_agent;
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sled;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::agents::crypto_market::finance_query_agent;
|
||||
use crate::agents::image_generator::image_generator;
|
||||
|
||||
// init sled
|
||||
lazy_static! {
|
||||
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
|
||||
sled::open("./web-agent-rs/db/stream_store").expect("Failed to open sled database")
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(data)) => {
|
||||
|
||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Increment the call_count in the database
|
||||
info.call_count += 1;
|
||||
let updated_info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match db.insert(&stream_id, updated_info_bytes) {
|
||||
Ok(_) => {
|
||||
if let Err(e) = db.flush_async().await {
|
||||
tracing::error!("Failed to persist updated call_count to the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to update call_count in the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let info: StreamInfo = match db.get(&stream_id) {
|
||||
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found after update: {}", stream_id);
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch updated record from DB: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if(info.call_count > 1) {
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
let resource = info.resource;
|
||||
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
|
||||
|
||||
tracing::debug!(
|
||||
"Processing webhook - Resource: {}, Stream ID: {}",
|
||||
resource,
|
||||
stream_id
|
||||
);
|
||||
|
||||
let cmd = match resource.as_str() {
|
||||
"web-search" => search_agent(stream_id.as_str(), &*input).await,
|
||||
"news-search" => news_agent(stream_id.as_str(), &*input).await,
|
||||
"image-generator" => image_generator(stream_id.as_str(), &*input).await,
|
||||
"finance-query" => finance_query_agent(stream_id.as_str(), &*input).await,
|
||||
"web-scrape" => scrape_agent(stream_id.as_str(), &*input).await,
|
||||
_ => {
|
||||
tracing::error!("Unsupported resource type: {}", resource);
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut cmd = match cmd {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Agent execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let sse_stream = reader_to_stream(reader, stream_id.clone());
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap()
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found: {}", stream_id);
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch from DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_to_stream<R>(
|
||||
reader: BufReader<R>,
|
||||
stream_id: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold(reader, move |mut reader| async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
|
||||
reader,
|
||||
)),
|
||||
Err(e) => Some((Err(e), reader)),
|
||||
}
|
||||
});
|
||||
|
||||
let stream_with_done = stream.chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct Payload {
|
||||
input: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamInfo {
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
call_count: i32,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct WebhookPostRequest {
|
||||
id: String,
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct WebhookPostResponse {
|
||||
stream_url: String,
|
||||
}
|
||||
|
||||
pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
|
||||
tracing::info!("Received webhook post request with ID: {}", payload.id);
|
||||
|
||||
let stream_id = payload.id.clone();
|
||||
let info = StreamInfo {
|
||||
resource: payload.resource.clone(),
|
||||
payload: payload.payload,
|
||||
parent: payload.parent.clone(),
|
||||
call_count: 0
|
||||
};
|
||||
|
||||
let info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Use atomic compare_and_swap operation
|
||||
match db.compare_and_swap(
|
||||
&stream_id,
|
||||
None as Option<&[u8]>,
|
||||
Some(info_bytes.as_slice()),
|
||||
) {
|
||||
Ok(_) => {
|
||||
// Force an immediate sync to disk
|
||||
match db.flush_async().await {
|
||||
Ok(_) => {
|
||||
// Verify the write by attempting to read it back
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(_)) => {
|
||||
let stream_url = format!("/webhooks/{}", stream_id);
|
||||
tracing::info!("Successfully created and verified stream URL: {}", stream_url);
|
||||
Json(WebhookPostResponse { stream_url }).into_response()
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Error verifying stream creation: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to flush DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to insert stream info: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user