This commit is contained in:
geoffsee
2025-05-23 09:48:26 -04:00
commit 66d3c06230
84 changed files with 6529 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
use tokio::process::Child;
use tracing;
use crate::utils::utils::run_agent;
pub async fn finance_query_agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/finance-query.genai.mts").await
}
// #[cfg(test)]
// mod tests {
// use std::fmt::Debug;
// use crate::agents::search::search_agent;
//
// #[tokio::test] // Mark the test function as async
// async fn test_search_execution() {
// let input = "Who won the 2024 presidential election?";
//
// let mut command = search_agent("test-stream", input).await.unwrap();
//
// // command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
// // Optionally, you can capture and inspect stdout if needed:
// let output = command.wait_with_output().await.expect("Failed to wait for output");
// println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
// println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
// }
// }

View File

@@ -0,0 +1,10 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn image_generator(stream_id: &str, input: &str) -> Result<Child, String> {
tracing::debug!(
"Running image generator, \ninput: {}",
input
);
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/image-generator.genai.mts").await
}

5
src/agents/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod news;
pub mod scrape;
pub mod search;
pub mod image_generator;
pub mod crypto_market;

6
src/agents/news.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn news_agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts").await
}

6
src/agents/scrape.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn scrape_agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts").await
}

28
src/agents/search.rs Normal file
View File

@@ -0,0 +1,28 @@
use tokio::process::Child;
use tracing;
use crate::utils::utils::run_agent;
pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts").await
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use crate::agents::search::search_agent;
#[tokio::test]
async fn test_search_execution() {
let input = "Who won the 2024 presidential election?";
let mut command = search_agent("test-stream", input).await.unwrap();
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
// Optionally, you can capture and inspect stdout if needed:
let output = command.wait_with_output().await.expect("Failed to wait for output");
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
}
}

30
src/config.rs Normal file
View File

@@ -0,0 +1,30 @@
// src/config.rs
pub struct AppConfig {
pub env_vars: Vec<String>,
}
impl AppConfig {
pub fn new() -> Self {
// Load .env file if it exists
match dotenv::dotenv() {
Ok(_) => tracing::debug!("Loaded .env file successfully"),
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
}
Self {
env_vars: vec![
"OPENAI_API_KEY".to_string(),
"BING_SEARCH_API_KEY".to_string(),
"TAVILY_API_KEY".to_string(),
"GENAISCRIPT_MODEL_LARGE".to_string(),
"GENAISCRIPT_MODEL_SMALL".to_string(),
"SEARXNG_API_BASE_URL".to_string(),
],
}
}
pub fn get_env_var(&self, key: &str) -> String {
std::env::var(key).unwrap_or_default()
}
}

90
src/genaiscript.rs Normal file
View File

@@ -0,0 +1,90 @@
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::process::{Child, Command};
use tracing;
const DEFAULT_ENV_VARS: [&str; 4] = [
"OPENAI_API_KEY",
"OPENAI_API_BASE",
"GENAISCRIPT_MODEL_LARGE",
"GENAISCRIPT_MODEL_SMALL",
];
pub struct GenAIScriptConfig {
script_path: PathBuf,
output_dir: PathBuf,
stream_id: String,
user_input: String,
retry_count: u32,
env_vars: HashMap<String, String>,
}
impl GenAIScriptConfig {
pub fn new(script_path: impl Into<PathBuf>, stream_id: impl Into<String>, user_input: impl Into<String>) -> Self {
let mut env_vars = HashMap::new();
// Initialize with default environment variables
for var in DEFAULT_ENV_VARS {
if let Ok(value) = std::env::var(var) {
env_vars.insert(var.to_string(), value);
}
}
Self {
script_path: script_path.into(),
output_dir: PathBuf::from("./web-agent-rs/output"),
stream_id: stream_id.into(),
user_input: user_input.into(),
retry_count: 0,
env_vars,
}
}
pub fn with_output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.output_dir = dir.into();
self
}
pub fn with_retry_count(mut self, count: u32) -> Self {
self.retry_count = count;
self
}
pub fn with_additional_env_vars(mut self, vars: HashMap<String, String>) -> Self {
self.env_vars.extend(vars);
self
}
}
pub async fn run_genaiscript(config: GenAIScriptConfig) -> Result<Child, String> {
tracing::debug!("Initiating GenAIScript for stream {}", config.stream_id);
let output_path = config.output_dir.join(&config.stream_id);
let mut command = Command::new("bunx");
command
.arg("genaiscript")
.arg("run")
.arg(&config.script_path)
// .arg("--fail-on-errors")
.arg("—out-trace")
.arg(output_path)
.arg("--retry")
.arg(config.retry_count.to_string())
.arg("--vars")
.arg(format!("USER_INPUT='{}'", config.user_input));
// Add environment variables
for (key, value) in config.env_vars {
command.env(key, value);
}
command
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.spawn()
.map_err(|e| {
tracing::error!("Failed to spawn genaiscript process: {}", e);
e.to_string()
})
}

17
src/handlers/error.rs Normal file
View File

@@ -0,0 +1,17 @@
// src/handlers/error.rs
use axum::{
http::StatusCode,
Json,
response::IntoResponse,
};
pub async fn handle_not_found() -> impl IntoResponse {
tracing::warn!("404 Not Found error occurred");
let error_response = serde_json::json!({
"error": "Route Not Found",
"status": 404
});
(StatusCode::NOT_FOUND, Json(error_response))
}

7
src/handlers/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
// src/handlers/mod.rs
pub mod error;
pub mod status;
pub mod stream;
pub mod ui;
pub mod webhooks;

5
src/handlers/status.rs Normal file
View File

@@ -0,0 +1,5 @@
// src/handlers/status.rs
pub async fn handle_status() -> &'static str {
tracing::debug!("Status check requested");
"Server is running"
}

82
src/handlers/stream.rs Normal file
View File

@@ -0,0 +1,82 @@
use axum::{
body::Body,
http::StatusCode,
response::{IntoResponse, Response},
};
use futures::StreamExt;
use tokio_util::io::ReaderStream;
pub async fn handle_stream() -> impl IntoResponse {
use tokio::process::Command;
let user_input = "Who won the 2024 election?";
tracing::debug!("Handling stream request with input: {}", user_input);
// Check environment variables
for env_var in ["OPENAI_API_KEY", "BING_SEARCH_API_KEY", "TAVILY_API_KEY"] {
if std::env::var(env_var).is_ok() {
tracing::debug!("{} is set", env_var);
} else {
tracing::warn!("{} is not set", env_var);
}
}
let mut cmd = match Command::new("genaiscript")
.arg("run")
.arg("genaisrc/web-search.genai.mts")
.arg("--vars")
.arg(format!("USER_INPUT='{}'", user_input))
.env("OPENAI_API_KEY", std::env::var("OPENAI_API_KEY").unwrap_or_default())
.env("BING_SEARCH_API_KEY", std::env::var("BING_SEARCH_API_KEY").unwrap_or_default())
.env("TAVILY_API_KEY", std::env::var("TAVILY_API_KEY").unwrap_or_default())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.spawn() {
Ok(cmd) => {
tracing::debug!("Successfully spawned genaiscript process");
cmd
}
Err(e) => {
tracing::error!("Failed to spawn genaiscript process: {}", e);
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Failed to start process"))
.unwrap();
}
};
let stdout = match cmd.stdout.take() {
Some(stdout) => {
tracing::debug!("Successfully captured stdout from process");
stdout
}
None => {
tracing::error!("Failed to capture stdout from process");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Failed to capture process output"))
.unwrap();
}
};
let reader = tokio::io::BufReader::new(stdout);
let stream = ReaderStream::new(reader);
let mapped_stream = stream.map(|r| {
match r {
Ok(bytes) => {
tracing::trace!("Received {} bytes from stream", bytes.len());
Ok(bytes)
}
Err(e) => {
tracing::error!("Error reading from stream: {}", e);
Err(e)
}
}
});
tracing::debug!("Setting up SSE response");
Response::builder()
.header("Content-Type", "text/event-stream")
.body(Body::from_stream(mapped_stream))
.unwrap()
}

34
src/handlers/ui.rs Normal file
View File

@@ -0,0 +1,34 @@
use axum::{
body::Body,
http::{StatusCode, header::CONTENT_TYPE},
response::{IntoResponse, Response},
};
use rust_embed::RustEmbed;
use tracing::{debug, error};
#[derive(RustEmbed)]
#[folder = "assets/"]
struct Asset;
pub async fn serve_ui() -> impl IntoResponse {
debug!("Serving UI request");
// Attempt to retrieve the embedded "index.html"
match Asset::get("index.html") {
Some(content) => {
debug!("Successfully retrieved index.html");
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(content.data))
.unwrap()
}
None => {
error!("index.html not found in embedded assets");
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("404 Not Found"))
.unwrap()
}
}
}

261
src/handlers/webhooks.rs Normal file
View File

@@ -0,0 +1,261 @@
use crate::agents;
use crate::agents::news::news_agent;
use crate::agents::scrape::scrape_agent;
use crate::agents::search::search_agent;
use axum::response::Response;
use axum::{
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sled;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::Mutex;
use crate::agents::crypto_market::finance_query_agent;
use crate::agents::image_generator::image_generator;
// init sled
lazy_static! {
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
sled::open("./web-agent-rs/db/stream_store").expect("Failed to open sled database")
));
}
pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse {
let db = DB.lock().await;
match db.get(&stream_id) {
Ok(Some(data)) => {
let mut info: StreamInfo = match serde_json::from_slice(&data) {
Ok(info) => info,
Err(e) => {
tracing::error!("Failed to deserialize StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Increment the call_count in the database
info.call_count += 1;
let updated_info_bytes = match serde_json::to_vec(&info) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
match db.insert(&stream_id, updated_info_bytes) {
Ok(_) => {
if let Err(e) = db.flush_async().await {
tracing::error!("Failed to persist updated call_count to the database: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
Err(e) => {
tracing::error!("Failed to update call_count in the database: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let info: StreamInfo = match db.get(&stream_id) {
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
Ok(info) => info,
Err(e) => {
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
},
Ok(None) => {
tracing::error!("Stream ID not found after update: {}", stream_id);
return StatusCode::NOT_FOUND.into_response();
}
Err(e) => {
tracing::error!("Failed to fetch updated record from DB: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
if(info.call_count > 1) {
return StatusCode::OK.into_response();
}
let resource = info.resource;
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
tracing::debug!(
"Processing webhook - Resource: {}, Stream ID: {}",
resource,
stream_id
);
let cmd = match resource.as_str() {
"web-search" => search_agent(stream_id.as_str(), &*input).await,
"news-search" => news_agent(stream_id.as_str(), &*input).await,
"image-generator" => image_generator(stream_id.as_str(), &*input).await,
"finance-query" => finance_query_agent(stream_id.as_str(), &*input).await,
"web-scrape" => scrape_agent(stream_id.as_str(), &*input).await,
_ => {
tracing::error!("Unsupported resource type: {}", resource);
return StatusCode::BAD_REQUEST.into_response();
}
};
let mut cmd = match cmd {
Ok(cmd) => cmd,
Err(e) => {
tracing::error!("Agent execution failed: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let stdout = match cmd.stdout.take() {
Some(stdout) => stdout,
None => {
tracing::error!("No stdout available for the command.");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let reader = BufReader::new(stdout);
let sse_stream = reader_to_stream(reader, stream_id.clone());
return Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache, no-transform")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "yes")
.body(Body::from_stream(sse_stream))
.unwrap()
}
Ok(None) => {
tracing::error!("Stream ID not found: {}", stream_id);
StatusCode::NOT_FOUND.into_response()
}
Err(e) => {
tracing::error!("Failed to fetch from DB: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
fn reader_to_stream<R>(
reader: BufReader<R>,
stream_id: String,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let stream = futures::stream::unfold(reader, move |mut reader| async move {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => None,
Ok(_) => Some((
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
reader,
)),
Err(e) => Some((Err(e), reader)),
}
});
let stream_with_done = stream.chain(futures::stream::once(async {
Ok(Bytes::from("data: [DONE]\n\n"))
}));
Box::pin(stream_with_done)
}
#[derive(Deserialize, Serialize, Debug)]
struct Payload {
input: Value,
}
#[derive(Serialize, Deserialize, Debug)]
struct StreamInfo {
resource: String,
payload: Payload,
parent: String,
call_count: i32,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct WebhookPostRequest {
id: String,
resource: String,
payload: Payload,
parent: String,
}
#[derive(Deserialize, Serialize, Debug)]
struct WebhookPostResponse {
stream_url: String,
}
pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
let db = DB.lock().await;
tracing::info!("Received webhook post request with ID: {}", payload.id);
let stream_id = payload.id.clone();
let info = StreamInfo {
resource: payload.resource.clone(),
payload: payload.payload,
parent: payload.parent.clone(),
call_count: 0
};
let info_bytes = match serde_json::to_vec(&info) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to serialize StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Use atomic compare_and_swap operation
match db.compare_and_swap(
&stream_id,
None as Option<&[u8]>,
Some(info_bytes.as_slice()),
) {
Ok(_) => {
// Force an immediate sync to disk
match db.flush_async().await {
Ok(_) => {
// Verify the write by attempting to read it back
match db.get(&stream_id) {
Ok(Some(_)) => {
let stream_url = format!("/webhooks/{}", stream_id);
tracing::info!("Successfully created and verified stream URL: {}", stream_url);
Json(WebhookPostResponse { stream_url }).into_response()
},
Ok(None) => {
tracing::error!("Failed to verify stream creation: {}", stream_id);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
},
Err(e) => {
tracing::error!("Error verifying stream creation: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
},
Err(e) => {
tracing::error!("Failed to flush DB: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(e) => {
tracing::error!("Failed to insert stream info: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}

43
src/main.rs Normal file
View File

@@ -0,0 +1,43 @@
// src/main.rs
use crate::config::AppConfig;
use crate::routes::create_router;
use crate::setup::init_logging;
mod config;
mod routes;
mod setup;
mod handlers;
mod agents;
mod genaiscript;
mod utils;
mod session_identify;
#[tokio::main]
async fn main() {
// Initialize logging
init_logging();
// Load configuration
let config = AppConfig::new();
// Create router with all routes
let app = create_router();
// Start core
let addr = "0.0.0.0:3006";
tracing::info!("Attempting to bind core to {}", addr);
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => {
tracing::info!("Successfully bound to {}", l.local_addr().unwrap());
l
}
Err(e) => {
tracing::error!("Failed to bind to {}: {}", addr, e);
panic!("Server failed to start");
}
};
tracing::info!("Server starting on {}", listener.local_addr().unwrap());
axum::serve(listener, app.into_make_service()).await.unwrap();
}

105
src/routes.rs Normal file
View File

@@ -0,0 +1,105 @@
use crate::handlers::webhooks::handle_webhooks_post;
use crate::handlers::{
error::handle_not_found,
ui::serve_ui
,
webhooks::handle_webhooks,
};
use crate::session_identify::session_identify;
use axum::extract::Request;
use axum::response::Response;
use axum::routing::post;
// src/routes.rs
use axum::routing::{get, Router};
use http::header::AUTHORIZATION;
use http::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Number;
use std::fmt;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CurrentUser {
pub(crate) sub: String,
pub name: String,
pub email: String,
pub exp: Number,
pub id: String,
pub aud: String,
}
impl fmt::Display for CurrentUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CurrentUser {{ id: {}, name: {}, email: {}, sub: {}, aud: {}, exp: {} }}",
self.id, self.name, self.email, self.sub, self.aud, self.exp
)
}
}
pub fn create_router() -> Router {
Router::new()
.route("/", get(serve_ui))
// request a stream resource
.route("/api/webhooks", post(handle_webhooks_post))
// consume a stream resource
.route("/webhooks/:stream_id", get(handle_webhooks))
.route_layer(axum::middleware::from_fn(auth))
.route("/health", get(health))
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
)
// left for smoke testing
// .route("/api/status", get(handle_status))
.fallback(handle_not_found)
}
async fn health() -> String {
return "ok".to_string();
}
async fn auth(mut req: Request, next: axum::middleware::Next) -> Result<Response, StatusCode> {
let session_token_header = req
.headers()
.get(AUTHORIZATION)
.and_then(|header_value| header_value.to_str().ok());
let session_token_parts= session_token_header.expect("No credentials").split(" ").collect::<Vec<&str>>();
let session_token = session_token_parts.get(1);
// log::info!("session_token: {:?}", session_token);
let session_token = session_token.expect("Unauthorized: No credentials supplied");
let result =
if let Some(current_user) = authorize_current_user(&*session_token).await {
// info!("current user: {}", current_user);
// insert the current user into a request extension so the handler can
// extract it
req.extensions_mut().insert(current_user);
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
};
result
}
async fn authorize_current_user(
session_token: &str,
) -> Option<CurrentUser> {
let session_identity = session_identify(session_token)
.await
.unwrap();
// println!("current_user: {:?}", session_identity.user);
Some(serde_json::from_value::<CurrentUser>(session_identity.user).unwrap())
}

55
src/session_identify.rs Normal file
View File

@@ -0,0 +1,55 @@
use anyhow::Result;
use serde_json::Value;
use serde_json::json;
use base64::Engine;
use fips204::ml_dsa_44::{PrivateKey, PublicKey};
use fips204::traits::{SerDes, Signer, Verifier};
use crate::utils::base64::B64_ENCODER;
pub struct SessionIdentity {
pub message: String,
pub signature: String,
pub target: String,
pub session_id: String,
pub user: Value
}
pub async fn session_identify(session_token: &str) -> Result<SessionIdentity> {
let session_data_base64 = session_token.split('.').nth(0).ok_or_else(|| anyhow::anyhow!("Invalid session data format"))?;
// println!("session_data_base64: {}", session_data_base64);
let session_data: Value = serde_json::de::from_slice(&*B64_ENCODER.b64_decode_payload(session_data_base64).map_err(|e| anyhow::anyhow!("Failed to decode session data: {}", e))?).map_err(|e| anyhow::anyhow!("Failed to parse session data: {}", e))?;
// println!("session_data: {:?}", session_data);
let signature_base64 = session_token.split('.').nth(1).ok_or_else(|| anyhow::anyhow!("Invalid session token format"))?;
// println!("signature_base64: {}", signature_base64);
let target = session_data.get("aud")
.and_then(|e| e.as_str())
.ok_or_else(|| anyhow::anyhow!("Session data missing audience"))?;
let target = target.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse target to String: {}", e))?;
let session_id = session_data.get("id")
.and_then(|e| e.as_str())
.ok_or_else(|| anyhow::anyhow!("Session data missing id"))?;
let session_id = session_id.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse session_id to String: {}", e))?;
// let request_payload: Value = json!({
// "message": session_data_base64,
// "signature": signature_base64,
// "target": target,
// "session_id": session_id,
// });
let result = SessionIdentity {
message: session_data_base64.to_string(),
signature: signature_base64.to_string(),
target,
session_id,
user: session_data.clone()
};
Ok(result)
}

10
src/setup.rs Normal file
View File

@@ -0,0 +1,10 @@
// src/setup.rs
pub fn init_logging() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(true)
.with_thread_ids(true)
.with_file(true)
.with_line_number(true)
.init();
}

65
src/utils/base64.rs Normal file
View File

@@ -0,0 +1,65 @@
use base64::Engine;
use base64::engine::GeneralPurpose;
use base64::engine::general_purpose::STANDARD;
use base64::engine::general_purpose::STANDARD_NO_PAD;
pub struct Base64Encoder {
payload_engine: GeneralPurpose,
signature_engine: GeneralPurpose,
public_key_engine: GeneralPurpose,
secret_key_engine: GeneralPurpose,
}
impl Base64Encoder {
pub(crate) fn b64_encode(&self, p0: &[u8]) -> String {
self.payload_engine.encode(p0)
}
pub(crate) fn b64_decode(&self, p0: String) -> Result<Vec<u8>, base64::DecodeError> {
self.payload_engine.decode(p0)
}
}
pub const B64_ENCODER: &Base64Encoder = &Base64Encoder::new();
impl Base64Encoder {
pub const fn new() -> Self { // Made new() a const fn
Base64Encoder {
payload_engine: STANDARD,
signature_engine: STANDARD,
public_key_engine: STANDARD,
secret_key_engine: STANDARD,
}
}
pub fn b64_encode_payload<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.payload_engine.encode(input)
}
pub fn b64_decode_payload<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.payload_engine.decode(input)
}
pub fn b64_decode_signature<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.signature_engine.decode(input)
}
pub fn b64_encode_signature<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.signature_engine.encode(input)
}
pub fn b64_encode_public_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.public_key_engine.encode(input)
}
pub fn b64_decode_public_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.public_key_engine.decode(input)
}
pub fn b64_encode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.secret_key_engine.encode(input)
}
pub fn b64_decode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.secret_key_engine.decode(input)
}
}

2
src/utils/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod utils;
pub mod base64;

80
src/utils/utils.rs Normal file
View File

@@ -0,0 +1,80 @@
// utils.rs
use tokio::process::{Child, Command}; // Use tokio::process::Child and Command
use std::env;
use tokio::time::{timeout, Duration};
use tracing;
pub struct ShimBinding {
user_input: String,
file_path: String, // Add new field for the file path
openai_api_key: String,
openai_api_base: String,
bing_search_api_key: String,
perigon_api_key: String,
tavily_api_key: String,
genaiscript_model_large: String,
genaiscript_model_small: String,
searxng_api_base_url: String,
searxng_password: String,
}
impl ShimBinding {
pub fn new(user_input: String, file_path: String) -> Self { // Update constructor to take file path
Self {
user_input,
file_path, // Initialize the new field
openai_api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
openai_api_base: env::var("OPENAI_API_BASE").unwrap_or_default(),
bing_search_api_key: env::var("BING_SEARCH_API_KEY").unwrap_or_default(),
tavily_api_key: env::var("TAVILY_API_KEY").unwrap_or_default(),
genaiscript_model_large: env::var("GENAISCRIPT_MODEL_LARGE").unwrap_or_default(),
genaiscript_model_small: env::var("GENAISCRIPT_MODEL_SMALL").unwrap_or_default(),
perigon_api_key: env::var("PERIGON_API_KEY").unwrap_or_default(),
searxng_api_base_url: env::var("SEARXNG_API_BASE_URL").unwrap_or_default(),
searxng_password: env::var("SEARXNG_PASSWORD").unwrap_or_default(),
}
}
pub fn execute(&self) -> std::io::Result<Child> {
let mut command = Command::new("./dist/genaiscript-rust-shim.js");
command
.arg("--file")
.arg(&self.file_path) // Use the file_path field instead of hardcoded value
.arg(format!("USER_INPUT={}", self.user_input))
.env("OPENAI_API_KEY", &self.openai_api_key)
.env("OPENAI_API_BASE", &self.openai_api_base)
.env("BING_SEARCH_API_KEY", &self.bing_search_api_key)
.env("TAVILY_API_KEY", &self.tavily_api_key)
.env("GENAISCRIPT_MODEL_LARGE", &self.genaiscript_model_large)
.env("GENAISCRIPT_MODEL_SMALL", &self.genaiscript_model_small)
.env("PERIGON_API_KEY", &self.perigon_api_key)
.env("SEARXNG_API_BASE_URL", &self.searxng_api_base_url)
.env("SEARXNG_PASSWORD", &self.searxng_password)
.stdout(std::process::Stdio::piped()) // Use tokio::io::Stdio::piped()
.stderr(std::process::Stdio::inherit()); // Use tokio::io::Stdio::piped()
command.spawn()
}
}
/// Generic helper to execute a ShimBinding-based agent with a timeout
pub async fn run_agent(stream_id: &str, input: &str, file_path: &str) -> Result<Child, String> {
tracing::debug!("Initiating agent for stream {} with file path {}", stream_id, file_path);
let shim_binding = ShimBinding::new(input.to_string(), file_path.to_string());
let spawn_future = async move {
match shim_binding.execute() {
Ok(child) => Ok(child),
Err(e) => {
tracing::error!("Failed to spawn shim process: {}", e);
Err(e.to_string())
}
}
};
timeout(Duration::from_secs(10), spawn_future)
.await
.unwrap_or_else(|_| Err("Command timed out after 10 seconds".to_string()))
}