init
This commit is contained in:
28
src/agents/crypto_market.rs
Normal file
28
src/agents/crypto_market.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
pub async fn finance_query_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/finance-query.genai.mts").await
|
||||
}
|
||||
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use std::fmt::Debug;
|
||||
// use crate::agents::search::search_agent;
|
||||
//
|
||||
// #[tokio::test] // Mark the test function as async
|
||||
// async fn test_search_execution() {
|
||||
// let input = "Who won the 2024 presidential election?";
|
||||
//
|
||||
// let mut command = search_agent("test-stream", input).await.unwrap();
|
||||
//
|
||||
// // command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
||||
// // Optionally, you can capture and inspect stdout if needed:
|
||||
// let output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
// println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
// println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
// }
|
||||
// }
|
10
src/agents/image_generator.rs
Normal file
10
src/agents/image_generator.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn image_generator(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
tracing::debug!(
|
||||
"Running image generator, \ninput: {}",
|
||||
input
|
||||
);
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/image-generator.genai.mts").await
|
||||
}
|
5
src/agents/mod.rs
Normal file
5
src/agents/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod news;
|
||||
pub mod scrape;
|
||||
pub mod search;
|
||||
pub mod image_generator;
|
||||
pub mod crypto_market;
|
6
src/agents/news.rs
Normal file
6
src/agents/news.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn news_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts").await
|
||||
}
|
6
src/agents/scrape.rs
Normal file
6
src/agents/scrape.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn scrape_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts").await
|
||||
}
|
28
src/agents/search.rs
Normal file
28
src/agents/search.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts").await
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fmt::Debug;
|
||||
use crate::agents::search::search_agent;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_execution() {
|
||||
let input = "Who won the 2024 presidential election?";
|
||||
|
||||
let mut command = search_agent("test-stream", input).await.unwrap();
|
||||
|
||||
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
||||
// Optionally, you can capture and inspect stdout if needed:
|
||||
let output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
}
|
30
src/config.rs
Normal file
30
src/config.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
// src/config.rs
|
||||
pub struct AppConfig {
|
||||
pub env_vars: Vec<String>,
|
||||
}
|
||||
|
||||
|
||||
impl AppConfig {
|
||||
pub fn new() -> Self {
|
||||
// Load .env file if it exists
|
||||
match dotenv::dotenv() {
|
||||
Ok(_) => tracing::debug!("Loaded .env file successfully"),
|
||||
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
|
||||
}
|
||||
|
||||
Self {
|
||||
env_vars: vec![
|
||||
"OPENAI_API_KEY".to_string(),
|
||||
"BING_SEARCH_API_KEY".to_string(),
|
||||
"TAVILY_API_KEY".to_string(),
|
||||
"GENAISCRIPT_MODEL_LARGE".to_string(),
|
||||
"GENAISCRIPT_MODEL_SMALL".to_string(),
|
||||
"SEARXNG_API_BASE_URL".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_env_var(&self, key: &str) -> String {
|
||||
std::env::var(key).unwrap_or_default()
|
||||
}
|
||||
}
|
90
src/genaiscript.rs
Normal file
90
src/genaiscript.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tokio::process::{Child, Command};
|
||||
use tracing;
|
||||
|
||||
const DEFAULT_ENV_VARS: [&str; 4] = [
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_API_BASE",
|
||||
"GENAISCRIPT_MODEL_LARGE",
|
||||
"GENAISCRIPT_MODEL_SMALL",
|
||||
];
|
||||
|
||||
pub struct GenAIScriptConfig {
|
||||
script_path: PathBuf,
|
||||
output_dir: PathBuf,
|
||||
stream_id: String,
|
||||
user_input: String,
|
||||
retry_count: u32,
|
||||
env_vars: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl GenAIScriptConfig {
|
||||
pub fn new(script_path: impl Into<PathBuf>, stream_id: impl Into<String>, user_input: impl Into<String>) -> Self {
|
||||
let mut env_vars = HashMap::new();
|
||||
|
||||
// Initialize with default environment variables
|
||||
for var in DEFAULT_ENV_VARS {
|
||||
if let Ok(value) = std::env::var(var) {
|
||||
env_vars.insert(var.to_string(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
script_path: script_path.into(),
|
||||
output_dir: PathBuf::from("./web-agent-rs/output"),
|
||||
stream_id: stream_id.into(),
|
||||
user_input: user_input.into(),
|
||||
retry_count: 0,
|
||||
env_vars,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
|
||||
self.output_dir = dir.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_retry_count(mut self, count: u32) -> Self {
|
||||
self.retry_count = count;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_additional_env_vars(mut self, vars: HashMap<String, String>) -> Self {
|
||||
self.env_vars.extend(vars);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_genaiscript(config: GenAIScriptConfig) -> Result<Child, String> {
|
||||
tracing::debug!("Initiating GenAIScript for stream {}", config.stream_id);
|
||||
|
||||
let output_path = config.output_dir.join(&config.stream_id);
|
||||
|
||||
let mut command = Command::new("bunx");
|
||||
command
|
||||
.arg("genaiscript")
|
||||
.arg("run")
|
||||
.arg(&config.script_path)
|
||||
// .arg("--fail-on-errors")
|
||||
.arg("—out-trace")
|
||||
.arg(output_path)
|
||||
.arg("--retry")
|
||||
.arg(config.retry_count.to_string())
|
||||
.arg("--vars")
|
||||
.arg(format!("USER_INPUT='{}'", config.user_input));
|
||||
|
||||
// Add environment variables
|
||||
for (key, value) in config.env_vars {
|
||||
command.env(key, value);
|
||||
}
|
||||
|
||||
command
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.spawn()
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to spawn genaiscript process: {}", e);
|
||||
e.to_string()
|
||||
})
|
||||
}
|
17
src/handlers/error.rs
Normal file
17
src/handlers/error.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
// src/handlers/error.rs
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn handle_not_found() -> impl IntoResponse {
|
||||
tracing::warn!("404 Not Found error occurred");
|
||||
|
||||
let error_response = serde_json::json!({
|
||||
"error": "Route Not Found",
|
||||
"status": 404
|
||||
});
|
||||
|
||||
(StatusCode::NOT_FOUND, Json(error_response))
|
||||
}
|
7
src/handlers/mod.rs
Normal file
7
src/handlers/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
// src/handlers/mod.rs
|
||||
pub mod error;
|
||||
pub mod status;
|
||||
pub mod stream;
|
||||
pub mod ui;
|
||||
pub mod webhooks;
|
5
src/handlers/status.rs
Normal file
5
src/handlers/status.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
// src/handlers/status.rs
|
||||
pub async fn handle_status() -> &'static str {
|
||||
tracing::debug!("Status check requested");
|
||||
"Server is running"
|
||||
}
|
82
src/handlers/stream.rs
Normal file
82
src/handlers/stream.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
pub async fn handle_stream() -> impl IntoResponse {
|
||||
use tokio::process::Command;
|
||||
|
||||
let user_input = "Who won the 2024 election?";
|
||||
tracing::debug!("Handling stream request with input: {}", user_input);
|
||||
|
||||
// Check environment variables
|
||||
for env_var in ["OPENAI_API_KEY", "BING_SEARCH_API_KEY", "TAVILY_API_KEY"] {
|
||||
if std::env::var(env_var).is_ok() {
|
||||
tracing::debug!("{} is set", env_var);
|
||||
} else {
|
||||
tracing::warn!("{} is not set", env_var);
|
||||
}
|
||||
}
|
||||
|
||||
let mut cmd = match Command::new("genaiscript")
|
||||
.arg("run")
|
||||
.arg("genaisrc/web-search.genai.mts")
|
||||
.arg("--vars")
|
||||
.arg(format!("USER_INPUT='{}'", user_input))
|
||||
.env("OPENAI_API_KEY", std::env::var("OPENAI_API_KEY").unwrap_or_default())
|
||||
.env("BING_SEARCH_API_KEY", std::env::var("BING_SEARCH_API_KEY").unwrap_or_default())
|
||||
.env("TAVILY_API_KEY", std::env::var("TAVILY_API_KEY").unwrap_or_default())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn() {
|
||||
Ok(cmd) => {
|
||||
tracing::debug!("Successfully spawned genaiscript process");
|
||||
cmd
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to spawn genaiscript process: {}", e);
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Failed to start process"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => {
|
||||
tracing::debug!("Successfully captured stdout from process");
|
||||
stdout
|
||||
}
|
||||
None => {
|
||||
tracing::error!("Failed to capture stdout from process");
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Failed to capture process output"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = tokio::io::BufReader::new(stdout);
|
||||
let stream = ReaderStream::new(reader);
|
||||
let mapped_stream = stream.map(|r| {
|
||||
match r {
|
||||
Ok(bytes) => {
|
||||
tracing::trace!("Received {} bytes from stream", bytes.len());
|
||||
Ok(bytes)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading from stream: {}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tracing::debug!("Setting up SSE response");
|
||||
Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.body(Body::from_stream(mapped_stream))
|
||||
.unwrap()
|
||||
}
|
34
src/handlers/ui.rs
Normal file
34
src/handlers/ui.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{StatusCode, header::CONTENT_TYPE},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use rust_embed::RustEmbed;
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "assets/"]
|
||||
struct Asset;
|
||||
|
||||
pub async fn serve_ui() -> impl IntoResponse {
|
||||
debug!("Serving UI request");
|
||||
|
||||
// Attempt to retrieve the embedded "index.html"
|
||||
match Asset::get("index.html") {
|
||||
Some(content) => {
|
||||
debug!("Successfully retrieved index.html");
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "text/html")
|
||||
.body(Body::from(content.data))
|
||||
.unwrap()
|
||||
}
|
||||
None => {
|
||||
error!("index.html not found in embedded assets");
|
||||
Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from("404 Not Found"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
261
src/handlers/webhooks.rs
Normal file
261
src/handlers/webhooks.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use crate::agents;
|
||||
use crate::agents::news::news_agent;
|
||||
use crate::agents::scrape::scrape_agent;
|
||||
use crate::agents::search::search_agent;
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sled;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::agents::crypto_market::finance_query_agent;
|
||||
use crate::agents::image_generator::image_generator;
|
||||
|
||||
// init sled
|
||||
lazy_static! {
|
||||
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
|
||||
sled::open("./web-agent-rs/db/stream_store").expect("Failed to open sled database")
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(data)) => {
|
||||
|
||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Increment the call_count in the database
|
||||
info.call_count += 1;
|
||||
let updated_info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match db.insert(&stream_id, updated_info_bytes) {
|
||||
Ok(_) => {
|
||||
if let Err(e) = db.flush_async().await {
|
||||
tracing::error!("Failed to persist updated call_count to the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to update call_count in the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let info: StreamInfo = match db.get(&stream_id) {
|
||||
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found after update: {}", stream_id);
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch updated record from DB: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if(info.call_count > 1) {
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
let resource = info.resource;
|
||||
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
|
||||
|
||||
tracing::debug!(
|
||||
"Processing webhook - Resource: {}, Stream ID: {}",
|
||||
resource,
|
||||
stream_id
|
||||
);
|
||||
|
||||
let cmd = match resource.as_str() {
|
||||
"web-search" => search_agent(stream_id.as_str(), &*input).await,
|
||||
"news-search" => news_agent(stream_id.as_str(), &*input).await,
|
||||
"image-generator" => image_generator(stream_id.as_str(), &*input).await,
|
||||
"finance-query" => finance_query_agent(stream_id.as_str(), &*input).await,
|
||||
"web-scrape" => scrape_agent(stream_id.as_str(), &*input).await,
|
||||
_ => {
|
||||
tracing::error!("Unsupported resource type: {}", resource);
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut cmd = match cmd {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Agent execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let sse_stream = reader_to_stream(reader, stream_id.clone());
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap()
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found: {}", stream_id);
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch from DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_to_stream<R>(
|
||||
reader: BufReader<R>,
|
||||
stream_id: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold(reader, move |mut reader| async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
|
||||
reader,
|
||||
)),
|
||||
Err(e) => Some((Err(e), reader)),
|
||||
}
|
||||
});
|
||||
|
||||
let stream_with_done = stream.chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct Payload {
|
||||
input: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamInfo {
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
call_count: i32,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct WebhookPostRequest {
|
||||
id: String,
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct WebhookPostResponse {
|
||||
stream_url: String,
|
||||
}
|
||||
|
||||
pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
|
||||
tracing::info!("Received webhook post request with ID: {}", payload.id);
|
||||
|
||||
let stream_id = payload.id.clone();
|
||||
let info = StreamInfo {
|
||||
resource: payload.resource.clone(),
|
||||
payload: payload.payload,
|
||||
parent: payload.parent.clone(),
|
||||
call_count: 0
|
||||
};
|
||||
|
||||
let info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Use atomic compare_and_swap operation
|
||||
match db.compare_and_swap(
|
||||
&stream_id,
|
||||
None as Option<&[u8]>,
|
||||
Some(info_bytes.as_slice()),
|
||||
) {
|
||||
Ok(_) => {
|
||||
// Force an immediate sync to disk
|
||||
match db.flush_async().await {
|
||||
Ok(_) => {
|
||||
// Verify the write by attempting to read it back
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(_)) => {
|
||||
let stream_url = format!("/webhooks/{}", stream_id);
|
||||
tracing::info!("Successfully created and verified stream URL: {}", stream_url);
|
||||
Json(WebhookPostResponse { stream_url }).into_response()
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Error verifying stream creation: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to flush DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to insert stream info: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
43
src/main.rs
Normal file
43
src/main.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
// src/main.rs
|
||||
use crate::config::AppConfig;
|
||||
use crate::routes::create_router;
|
||||
use crate::setup::init_logging;
|
||||
|
||||
mod config;
|
||||
mod routes;
|
||||
mod setup;
|
||||
mod handlers;
|
||||
mod agents;
|
||||
mod genaiscript;
|
||||
mod utils;
|
||||
mod session_identify;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize logging
|
||||
init_logging();
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::new();
|
||||
|
||||
// Create router with all routes
|
||||
let app = create_router();
|
||||
|
||||
// Start core
|
||||
let addr = "0.0.0.0:3006";
|
||||
tracing::info!("Attempting to bind core to {}", addr);
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => {
|
||||
tracing::info!("Successfully bound to {}", l.local_addr().unwrap());
|
||||
l
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to bind to {}: {}", addr, e);
|
||||
panic!("Server failed to start");
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!("Server starting on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app.into_make_service()).await.unwrap();
|
||||
}
|
105
src/routes.rs
Normal file
105
src/routes.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use crate::handlers::webhooks::handle_webhooks_post;
|
||||
use crate::handlers::{
|
||||
error::handle_not_found,
|
||||
ui::serve_ui
|
||||
,
|
||||
webhooks::handle_webhooks,
|
||||
};
|
||||
use crate::session_identify::session_identify;
|
||||
use axum::extract::Request;
|
||||
use axum::response::Response;
|
||||
use axum::routing::post;
|
||||
// src/routes.rs
|
||||
use axum::routing::{get, Router};
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Number;
|
||||
use std::fmt;
|
||||
use tower_http::trace::{self, TraceLayer};
|
||||
use tracing::Level;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CurrentUser {
|
||||
pub(crate) sub: String,
|
||||
pub name: String,
|
||||
pub email: String,
|
||||
pub exp: Number,
|
||||
pub id: String,
|
||||
pub aud: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for CurrentUser {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"CurrentUser {{ id: {}, name: {}, email: {}, sub: {}, aud: {}, exp: {} }}",
|
||||
self.id, self.name, self.email, self.sub, self.aud, self.exp
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_router() -> Router {
|
||||
|
||||
Router::new()
|
||||
.route("/", get(serve_ui))
|
||||
// request a stream resource
|
||||
.route("/api/webhooks", post(handle_webhooks_post))
|
||||
// consume a stream resource
|
||||
.route("/webhooks/:stream_id", get(handle_webhooks))
|
||||
.route_layer(axum::middleware::from_fn(auth))
|
||||
.route("/health", get(health))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
|
||||
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
|
||||
)
|
||||
// left for smoke testing
|
||||
// .route("/api/status", get(handle_status))
|
||||
.fallback(handle_not_found)
|
||||
}
|
||||
|
||||
async fn health() -> String {
|
||||
return "ok".to_string();
|
||||
}
|
||||
|
||||
async fn auth(mut req: Request, next: axum::middleware::Next) -> Result<Response, StatusCode> {
|
||||
let session_token_header = req
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|header_value| header_value.to_str().ok());
|
||||
|
||||
let session_token_parts= session_token_header.expect("No credentials").split(" ").collect::<Vec<&str>>();
|
||||
|
||||
let session_token = session_token_parts.get(1);
|
||||
|
||||
|
||||
// log::info!("session_token: {:?}", session_token);
|
||||
|
||||
let session_token = session_token.expect("Unauthorized: No credentials supplied");
|
||||
|
||||
let result =
|
||||
if let Some(current_user) = authorize_current_user(&*session_token).await {
|
||||
// info!("current user: {}", current_user);
|
||||
// insert the current user into a request extension so the handler can
|
||||
// extract it
|
||||
req.extensions_mut().insert(current_user);
|
||||
Ok(next.run(req).await)
|
||||
} else {
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
};
|
||||
result
|
||||
}
|
||||
|
||||
|
||||
async fn authorize_current_user(
|
||||
session_token: &str,
|
||||
) -> Option<CurrentUser> {
|
||||
let session_identity = session_identify(session_token)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// println!("current_user: {:?}", session_identity.user);
|
||||
|
||||
Some(serde_json::from_value::<CurrentUser>(session_identity.user).unwrap())
|
||||
}
|
55
src/session_identify.rs
Normal file
55
src/session_identify.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use base64::Engine;
|
||||
use fips204::ml_dsa_44::{PrivateKey, PublicKey};
|
||||
use fips204::traits::{SerDes, Signer, Verifier};
|
||||
use crate::utils::base64::B64_ENCODER;
|
||||
|
||||
pub struct SessionIdentity {
|
||||
pub message: String,
|
||||
pub signature: String,
|
||||
pub target: String,
|
||||
pub session_id: String,
|
||||
pub user: Value
|
||||
}
|
||||
|
||||
pub async fn session_identify(session_token: &str) -> Result<SessionIdentity> {
|
||||
let session_data_base64 = session_token.split('.').nth(0).ok_or_else(|| anyhow::anyhow!("Invalid session data format"))?;
|
||||
// println!("session_data_base64: {}", session_data_base64);
|
||||
let session_data: Value = serde_json::de::from_slice(&*B64_ENCODER.b64_decode_payload(session_data_base64).map_err(|e| anyhow::anyhow!("Failed to decode session data: {}", e))?).map_err(|e| anyhow::anyhow!("Failed to parse session data: {}", e))?;
|
||||
// println!("session_data: {:?}", session_data);
|
||||
|
||||
|
||||
let signature_base64 = session_token.split('.').nth(1).ok_or_else(|| anyhow::anyhow!("Invalid session token format"))?;
|
||||
// println!("signature_base64: {}", signature_base64);
|
||||
|
||||
let target = session_data.get("aud")
|
||||
.and_then(|e| e.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Session data missing audience"))?;
|
||||
|
||||
let target = target.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse target to String: {}", e))?;
|
||||
|
||||
let session_id = session_data.get("id")
|
||||
.and_then(|e| e.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Session data missing id"))?;
|
||||
|
||||
let session_id = session_id.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse session_id to String: {}", e))?;
|
||||
|
||||
// let request_payload: Value = json!({
|
||||
// "message": session_data_base64,
|
||||
// "signature": signature_base64,
|
||||
// "target": target,
|
||||
// "session_id": session_id,
|
||||
// });
|
||||
|
||||
let result = SessionIdentity {
|
||||
message: session_data_base64.to_string(),
|
||||
signature: signature_base64.to_string(),
|
||||
target,
|
||||
session_id,
|
||||
user: session_data.clone()
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
10
src/setup.rs
Normal file
10
src/setup.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
// src/setup.rs
|
||||
pub fn init_logging() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
}
|
65
src/utils/base64.rs
Normal file
65
src/utils/base64.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::GeneralPurpose;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::engine::general_purpose::STANDARD_NO_PAD;
|
||||
|
||||
pub struct Base64Encoder {
|
||||
payload_engine: GeneralPurpose,
|
||||
signature_engine: GeneralPurpose,
|
||||
public_key_engine: GeneralPurpose,
|
||||
secret_key_engine: GeneralPurpose,
|
||||
}
|
||||
|
||||
impl Base64Encoder {
|
||||
pub(crate) fn b64_encode(&self, p0: &[u8]) -> String {
|
||||
self.payload_engine.encode(p0)
|
||||
}
|
||||
pub(crate) fn b64_decode(&self, p0: String) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
self.payload_engine.decode(p0)
|
||||
}
|
||||
}
|
||||
|
||||
pub const B64_ENCODER: &Base64Encoder = &Base64Encoder::new();
|
||||
|
||||
impl Base64Encoder {
|
||||
pub const fn new() -> Self { // Made new() a const fn
|
||||
Base64Encoder {
|
||||
payload_engine: STANDARD,
|
||||
signature_engine: STANDARD,
|
||||
public_key_engine: STANDARD,
|
||||
secret_key_engine: STANDARD,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b64_encode_payload<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.payload_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_payload<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.payload_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_signature<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.signature_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_signature<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.signature_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_public_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.public_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_public_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.public_key_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.secret_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.secret_key_engine.decode(input)
|
||||
}
|
||||
}
|
2
src/utils/mod.rs
Normal file
2
src/utils/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod utils;
|
||||
pub mod base64;
|
80
src/utils/utils.rs
Normal file
80
src/utils/utils.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
// utils.rs
|
||||
use tokio::process::{Child, Command}; // Use tokio::process::Child and Command
|
||||
use std::env;
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing;
|
||||
|
||||
|
||||
pub struct ShimBinding {
|
||||
user_input: String,
|
||||
file_path: String, // Add new field for the file path
|
||||
openai_api_key: String,
|
||||
openai_api_base: String,
|
||||
bing_search_api_key: String,
|
||||
perigon_api_key: String,
|
||||
tavily_api_key: String,
|
||||
genaiscript_model_large: String,
|
||||
genaiscript_model_small: String,
|
||||
searxng_api_base_url: String,
|
||||
searxng_password: String,
|
||||
}
|
||||
|
||||
impl ShimBinding {
|
||||
pub fn new(user_input: String, file_path: String) -> Self { // Update constructor to take file path
|
||||
Self {
|
||||
user_input,
|
||||
file_path, // Initialize the new field
|
||||
openai_api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
|
||||
openai_api_base: env::var("OPENAI_API_BASE").unwrap_or_default(),
|
||||
bing_search_api_key: env::var("BING_SEARCH_API_KEY").unwrap_or_default(),
|
||||
tavily_api_key: env::var("TAVILY_API_KEY").unwrap_or_default(),
|
||||
genaiscript_model_large: env::var("GENAISCRIPT_MODEL_LARGE").unwrap_or_default(),
|
||||
genaiscript_model_small: env::var("GENAISCRIPT_MODEL_SMALL").unwrap_or_default(),
|
||||
perigon_api_key: env::var("PERIGON_API_KEY").unwrap_or_default(),
|
||||
searxng_api_base_url: env::var("SEARXNG_API_BASE_URL").unwrap_or_default(),
|
||||
searxng_password: env::var("SEARXNG_PASSWORD").unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> std::io::Result<Child> {
|
||||
let mut command = Command::new("./dist/genaiscript-rust-shim.js");
|
||||
command
|
||||
.arg("--file")
|
||||
.arg(&self.file_path) // Use the file_path field instead of hardcoded value
|
||||
.arg(format!("USER_INPUT={}", self.user_input))
|
||||
.env("OPENAI_API_KEY", &self.openai_api_key)
|
||||
.env("OPENAI_API_BASE", &self.openai_api_base)
|
||||
.env("BING_SEARCH_API_KEY", &self.bing_search_api_key)
|
||||
.env("TAVILY_API_KEY", &self.tavily_api_key)
|
||||
.env("GENAISCRIPT_MODEL_LARGE", &self.genaiscript_model_large)
|
||||
.env("GENAISCRIPT_MODEL_SMALL", &self.genaiscript_model_small)
|
||||
.env("PERIGON_API_KEY", &self.perigon_api_key)
|
||||
.env("SEARXNG_API_BASE_URL", &self.searxng_api_base_url)
|
||||
.env("SEARXNG_PASSWORD", &self.searxng_password)
|
||||
.stdout(std::process::Stdio::piped()) // Use tokio::io::Stdio::piped()
|
||||
.stderr(std::process::Stdio::inherit()); // Use tokio::io::Stdio::piped()
|
||||
|
||||
command.spawn()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Generic helper to execute a ShimBinding-based agent with a timeout
|
||||
pub async fn run_agent(stream_id: &str, input: &str, file_path: &str) -> Result<Child, String> {
|
||||
tracing::debug!("Initiating agent for stream {} with file path {}", stream_id, file_path);
|
||||
|
||||
let shim_binding = ShimBinding::new(input.to_string(), file_path.to_string());
|
||||
let spawn_future = async move {
|
||||
match shim_binding.execute() {
|
||||
Ok(child) => Ok(child),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to spawn shim process: {}", e);
|
||||
Err(e.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
timeout(Duration::from_secs(10), spawn_future)
|
||||
.await
|
||||
.unwrap_or_else(|_| Err("Command timed out after 10 seconds".to_string()))
|
||||
}
|
Reference in New Issue
Block a user