Refactor agent function names and streamline imports
Unified the naming convention for agent functions across modules to `agent` for consistency. Adjusted relevant imports and cleaned up unused imports in `webhooks.rs` to improve readability and maintainability.
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
use crate::utils::utils::run_agent;
|
use crate::utils::utils::run_agent;
|
||||||
use tokio::process::Child;
|
use tokio::process::Child;
|
||||||
|
|
||||||
pub async fn image_generator(stream_id: &str, input: &str) -> Result<Child, String> {
|
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Running image generator, \ninput: {}",
|
"Running image generator, \ninput: {}",
|
||||||
input
|
input
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
pub mod news;
|
pub(crate) mod news;
|
||||||
pub mod scrape;
|
pub(crate) mod scrape;
|
||||||
pub mod search;
|
pub(crate) mod search;
|
||||||
pub mod image_generator;
|
pub(crate) mod image_generator;
|
@@ -1,6 +1,6 @@
|
|||||||
use crate::utils::utils::run_agent;
|
use crate::utils::utils::run_agent;
|
||||||
use tokio::process::Child;
|
use tokio::process::Child;
|
||||||
|
|
||||||
pub async fn news_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
|
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
use crate::utils::utils::run_agent;
|
use crate::utils::utils::run_agent;
|
||||||
use tokio::process::Child;
|
use tokio::process::Child;
|
||||||
|
|
||||||
pub async fn scrape_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
|
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
|
||||||
}
|
}
|
||||||
|
@@ -3,7 +3,7 @@ use tracing;
|
|||||||
|
|
||||||
use crate::utils::utils::run_agent;
|
use crate::utils::utils::run_agent;
|
||||||
|
|
||||||
pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
|
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -11,13 +11,13 @@ pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String>
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use crate::agents::search::search_agent;
|
use crate::agents::search::agent;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_search_execution() {
|
async fn test_search_execution() {
|
||||||
let input = "Who won the 2024 presidential election?";
|
let input = "Who won the 2024 presidential election?";
|
||||||
|
|
||||||
let mut command = search_agent("test-stream", input).await.unwrap();
|
let mut command = agent("test-stream", input).await.unwrap();
|
||||||
|
|
||||||
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
||||||
// Optionally, you can capture and inspect stdout if needed:
|
// Optionally, you can capture and inspect stdout if needed:
|
||||||
|
@@ -1,7 +1,3 @@
|
|||||||
use crate::agents;
|
|
||||||
use crate::agents::news::news_agent;
|
|
||||||
use crate::agents::scrape::scrape_agent;
|
|
||||||
use crate::agents::search::search_agent;
|
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
||||||
@@ -18,7 +14,6 @@ use std::time::Duration;
|
|||||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use crate::agents::image_generator::image_generator;
|
|
||||||
|
|
||||||
// init sled
|
// init sled
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
@@ -31,7 +26,6 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
let db = DB.lock().await;
|
let db = DB.lock().await;
|
||||||
match db.get(&agent_id) {
|
match db.get(&agent_id) {
|
||||||
Ok(Some(data)) => {
|
Ok(Some(data)) => {
|
||||||
|
|
||||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||||
Ok(info) => info,
|
Ok(info) => info,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -40,7 +34,6 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Increment the call_count in the database
|
// Increment the call_count in the database
|
||||||
info.call_count += 1;
|
info.call_count += 1;
|
||||||
let updated_info_bytes = match serde_json::to_vec(&info) {
|
let updated_info_bytes = match serde_json::to_vec(&info) {
|
||||||
@@ -54,7 +47,10 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
match db.insert(&agent_id, updated_info_bytes) {
|
match db.insert(&agent_id, updated_info_bytes) {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
if let Err(e) = db.flush_async().await {
|
if let Err(e) = db.flush_async().await {
|
||||||
tracing::error!("Failed to persist updated call_count to the database: {}", e);
|
tracing::error!(
|
||||||
|
"Failed to persist updated call_count to the database: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -82,7 +78,7 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if(info.call_count > 1) {
|
if (info.call_count > 1) {
|
||||||
return StatusCode::OK.into_response();
|
return StatusCode::OK.into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,10 +92,12 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
);
|
);
|
||||||
|
|
||||||
let cmd = match resource.as_str() {
|
let cmd = match resource.as_str() {
|
||||||
"web-search" => search_agent(agent_id.as_str(), &*input).await,
|
"web-search" => crate::agents::search::agent(agent_id.as_str(), &*input).await,
|
||||||
"news-search" => news_agent(agent_id.as_str(), &*input).await,
|
"news-search" => crate::agents::news::agent(agent_id.as_str(), &*input).await,
|
||||||
"image-generator" => image_generator(agent_id.as_str(), &*input).await,
|
"image-generator" => {
|
||||||
"web-scrape" => scrape_agent(agent_id.as_str(), &*input).await,
|
crate::agents::image_generator::agent(agent_id.as_str(), &*input).await
|
||||||
|
}
|
||||||
|
"web-scrape" => crate::agents::scrape::agent(agent_id.as_str(), &*input).await,
|
||||||
_ => {
|
_ => {
|
||||||
tracing::error!("Unsupported resource type: {}", resource);
|
tracing::error!("Unsupported resource type: {}", resource);
|
||||||
return StatusCode::BAD_REQUEST.into_response();
|
return StatusCode::BAD_REQUEST.into_response();
|
||||||
@@ -131,7 +129,7 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
|
|||||||
.header("Connection", "keep-alive")
|
.header("Connection", "keep-alive")
|
||||||
.header("X-Accel-Buffering", "yes")
|
.header("X-Accel-Buffering", "yes")
|
||||||
.body(Body::from_stream(sse_stream))
|
.body(Body::from_stream(sse_stream))
|
||||||
.unwrap()
|
.unwrap();
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
tracing::error!("Stream ID not found: {}", agent_id);
|
tracing::error!("Stream ID not found: {}", agent_id);
|
||||||
@@ -183,7 +181,6 @@ struct StreamInfo {
|
|||||||
call_count: i32,
|
call_count: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug)]
|
#[derive(Deserialize, Serialize, Debug)]
|
||||||
pub struct WebhookPostRequest {
|
pub struct WebhookPostRequest {
|
||||||
id: String,
|
id: String,
|
||||||
@@ -207,7 +204,7 @@ pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> im
|
|||||||
resource: payload.resource.clone(),
|
resource: payload.resource.clone(),
|
||||||
payload: payload.payload,
|
payload: payload.payload,
|
||||||
parent: payload.parent.clone(),
|
parent: payload.parent.clone(),
|
||||||
call_count: 0
|
call_count: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let info_bytes = match serde_json::to_vec(&info) {
|
let info_bytes = match serde_json::to_vec(&info) {
|
||||||
@@ -232,19 +229,22 @@ pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> im
|
|||||||
match db.get(&stream_id) {
|
match db.get(&stream_id) {
|
||||||
Ok(Some(_)) => {
|
Ok(Some(_)) => {
|
||||||
let stream_url = format!("/webhooks/{}", stream_id);
|
let stream_url = format!("/webhooks/{}", stream_id);
|
||||||
tracing::info!("Successfully created and verified stream URL: {}", stream_url);
|
tracing::info!(
|
||||||
|
"Successfully created and verified stream URL: {}",
|
||||||
|
stream_url
|
||||||
|
);
|
||||||
Json(WebhookPostResponse { stream_url }).into_response()
|
Json(WebhookPostResponse { stream_url }).into_response()
|
||||||
},
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Error verifying stream creation: {}", e);
|
tracing::error!("Error verifying stream creation: {}", e);
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to flush DB: {}", e);
|
tracing::error!("Failed to flush DB: {}", e);
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||||
|
Reference in New Issue
Block a user