update project structure
This commit is contained in:
36
crates/agent-server/Cargo.toml
Normal file
36
crates/agent-server/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "agent-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
edition = "2021"
|
||||
name = "agent-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.8", features = ["multipart"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
http = "1.1.0"
|
||||
tokio-stream = "0.1.16"
|
||||
uuid = { version = "1.11.0", features = ["v4"] }
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
serde_json = "1.0.133"
|
||||
futures = "0.3.31"
|
||||
dotenv = "0.15.0"
|
||||
shell-escape = "0.1.5"
|
||||
rust-embed = "8.5.0"
|
||||
bytes = "1.8.0"
|
||||
lazy_static = "1.5.0"
|
||||
sled = "0.34.7"
|
||||
tower-http = { version = "0.6.2", features = ["trace", "cors"] }
|
||||
tower = "0.5.2"
|
||||
anyhow = "1.0.97"
|
||||
base64 = "0.22.1"
|
||||
fips204 = "0.4.6"
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = ["server", "transport-streamable-http-server", "transport-sse-server", "transport-io",] }
|
||||
mime_guess = "2.0.5"
|
32
crates/agent-server/src/agents/deep_research.rs
Normal file
32
crates/agent-server/src/agents/deep_research.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/deep-research.genai.mts", 60).await
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::agents::deep_research::agent;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepresearch() {
|
||||
// a really provocative question for research that generally yields infinite complexity with each run
|
||||
let input = "What is a life of meaning?";
|
||||
|
||||
let mut command = agent("test-deepresearch-agent", input).await.unwrap();
|
||||
|
||||
// let mut stdout = String::new();
|
||||
|
||||
// command.stdout.take().unwrap().read_to_string(&mut stdout).await.unwrap();
|
||||
|
||||
// println!("stdout: {}", stdout);
|
||||
// // Optionally, you can capture and inspect stdout if needed:
|
||||
let _output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
// println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
// println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
}
|
10
crates/agent-server/src/agents/image_generator.rs
Normal file
10
crates/agent-server/src/agents/image_generator.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
tracing::debug!(
|
||||
"Running image generator, \ninput: {}",
|
||||
input
|
||||
);
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/image-generator.genai.mts", 10).await
|
||||
}
|
136
crates/agent-server/src/agents/mod.rs
Normal file
136
crates/agent-server/src/agents/mod.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
pub(crate) mod news;
|
||||
pub(crate) mod scrape;
|
||||
pub(crate) mod search;
|
||||
pub(crate) mod image_generator;
|
||||
pub(crate) mod deep_research;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError, RoleServer, ServerHandler, const_string, model::*,
|
||||
service::RequestContext, tool,
|
||||
};
|
||||
use tokio::process::Child;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Agents;
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl Agents {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
#[tool(description = "Search the web for information")]
|
||||
async fn search(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The search query")]
|
||||
query: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match search::agent("tool-search", &query).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Search for news articles")]
|
||||
async fn news(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The news search query")]
|
||||
query: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match news::agent("tool-news", &query).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Scrape content from a webpage")]
|
||||
async fn scrape(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The URL to scrape")]
|
||||
url: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match scrape::agent("tool-scrape", &url).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Generate an image based on a description")]
|
||||
async fn generate_image(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The image description")]
|
||||
description: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match image_generator::agent("tool-image", &description).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(description = "Perform deep research on a topic")]
|
||||
async fn deep_research(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "The research topic")]
|
||||
topic: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match deep_research::agent("tool-research", &topic).await {
|
||||
Ok(child) => handle_agent_result(child).await,
|
||||
Err(e) => Err(McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for Agents {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("This server provides various agent tools for web search, news search, web scraping, image generation, and deep research.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
||||
let initialize_headers = &http_request_part.headers;
|
||||
let initialize_uri = &http_request_part.uri;
|
||||
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
||||
}
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_agent_result(mut child: Child) -> Result<CallToolResult, McpError> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
let output = match child.wait_with_output().await {
|
||||
Ok(output) => output,
|
||||
Err(e) => return Err(McpError::internal_error(format!("Failed to get agent output: {}", e), None)),
|
||||
};
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
return Err(McpError::internal_error(
|
||||
format!("Agent failed with status {}: {}", output.status, stderr),
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(stdout)]))
|
||||
}
|
6
crates/agent-server/src/agents/news.rs
Normal file
6
crates/agent-server/src/agents/news.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
|
||||
}
|
6
crates/agent-server/src/agents/scrape.rs
Normal file
6
crates/agent-server/src/agents/scrape.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use crate::utils::utils::run_agent;
|
||||
use tokio::process::Child;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
|
||||
}
|
27
crates/agent-server/src/agents/search.rs
Normal file
27
crates/agent-server/src/agents/search.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use tokio::process::Child;
|
||||
use tracing;
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
|
||||
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fmt::Debug;
|
||||
use crate::agents::search::agent;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_execution() {
|
||||
let input = "Who won the 2024 presidential election?";
|
||||
|
||||
let mut command = agent("test-stream", input).await.unwrap();
|
||||
|
||||
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
|
||||
// Optionally, you can capture and inspect stdout if needed:
|
||||
let output = command.wait_with_output().await.expect("Failed to wait for output");
|
||||
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
|
||||
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
|
||||
}
|
||||
}
|
27
crates/agent-server/src/config.rs
Normal file
27
crates/agent-server/src/config.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub struct Runtime {
|
||||
pub env_vars: Vec<String>,
|
||||
}
|
||||
|
||||
|
||||
impl Runtime {
|
||||
pub fn configure() -> Self {
|
||||
// automatic configuration between local/docker environments
|
||||
match dotenv::dotenv() {
|
||||
Ok(_) => tracing::debug!("Loaded .env file successfully"),
|
||||
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
|
||||
}
|
||||
|
||||
Self {
|
||||
env_vars: vec![
|
||||
"OPENAI_API_KEY".to_string(),
|
||||
"GENAISCRIPT_MODEL_LARGE".to_string(),
|
||||
"GENAISCRIPT_MODEL_SMALL".to_string(),
|
||||
"SEARXNG_API_BASE_URL".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_env_var(&self, key: &str) -> String {
|
||||
std::env::var(key).unwrap_or_default()
|
||||
}
|
||||
}
|
211
crates/agent-server/src/counter.rs
Normal file
211
crates/agent-server/src/counter.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
|
||||
service::RequestContext, tool,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct StructRequest {
|
||||
pub a: i32,
|
||||
pub b: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Counter {
|
||||
counter: Arc<Mutex<i32>>,
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl Counter {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
|
||||
RawResource::new(uri, name.to_string()).no_annotation()
|
||||
}
|
||||
|
||||
#[tool(description = "Increment the counter by 1")]
|
||||
async fn increment(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter += 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Decrement the counter by 1")]
|
||||
async fn decrement(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter -= 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Get the current counter value")]
|
||||
async fn get_value(&self) -> Result<CallToolResult, McpError> {
|
||||
let counter = self.counter.lock().await;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Say hello to the client")]
|
||||
fn say_hello(&self) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text("hello")]))
|
||||
}
|
||||
|
||||
#[tool(description = "Repeat what you say")]
|
||||
fn echo(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "Repeat what you say")]
|
||||
saying: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(saying)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Calculate the sum of two numbers")]
|
||||
fn sum(
|
||||
&self,
|
||||
#[tool(aggr)] StructRequest { a, b }: StructRequest,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
(a + b).to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
const_string!(Echo = "echo");
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for Counter {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_prompts()
|
||||
.enable_resources()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourcesResult, McpError> {
|
||||
Ok(ListResourcesResult {
|
||||
resources: vec![
|
||||
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
|
||||
self._create_resource_text("memo://insights", "memo-name"),
|
||||
],
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
match uri.as_str() {
|
||||
"str:////Users/to/some/path/" => {
|
||||
let cwd = "/Users/to/some/path/";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(cwd, uri)],
|
||||
})
|
||||
}
|
||||
"memo://insights" => {
|
||||
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(memo, uri)],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({
|
||||
"uri": uri
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_prompts(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListPromptsResult, McpError> {
|
||||
Ok(ListPromptsResult {
|
||||
next_cursor: None,
|
||||
prompts: vec![Prompt::new(
|
||||
"example_prompt",
|
||||
Some("This is an example prompt that takes one required argument, message"),
|
||||
Some(vec![PromptArgument {
|
||||
name: "message".to_string(),
|
||||
description: Some("A message to put in the prompt".to_string()),
|
||||
required: Some(true),
|
||||
}]),
|
||||
)],
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_prompt(
|
||||
&self,
|
||||
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<GetPromptResult, McpError> {
|
||||
match name.as_str() {
|
||||
"example_prompt" => {
|
||||
let message = arguments
|
||||
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
|
||||
.ok_or_else(|| {
|
||||
McpError::invalid_params("No message provided to example_prompt", None)
|
||||
})?;
|
||||
|
||||
let prompt =
|
||||
format!("This is an example prompt with your message here: '{message}'");
|
||||
Ok(GetPromptResult {
|
||||
description: None,
|
||||
messages: vec![PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: PromptMessageContent::text(prompt),
|
||||
}],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::invalid_params("prompt not found", None)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
||||
let initialize_headers = &http_request_part.headers;
|
||||
let initialize_uri = &http_request_part.uri;
|
||||
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
||||
}
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
262
crates/agent-server/src/handlers/agents.rs
Normal file
262
crates/agent-server/src/handlers/agents.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sled;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// init sled
|
||||
lazy_static! {
|
||||
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
|
||||
sled::open("./open-web-agent-rs/db/stream_store").expect("Failed to open sled database")
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn use_agent(Path(agent_id): Path<String>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
match db.get(&agent_id) {
|
||||
Ok(Some(data)) => {
|
||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Increment the call_count in the database
|
||||
info.call_count += 1;
|
||||
let updated_info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match db.insert(&agent_id, updated_info_bytes) {
|
||||
Ok(_) => {
|
||||
if let Err(e) = db.flush_async().await {
|
||||
tracing::error!(
|
||||
"Failed to persist updated call_count to the database: {}",
|
||||
e
|
||||
);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to update call_count in the database: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let info: StreamInfo = match db.get(&agent_id) {
|
||||
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found after update: {}", agent_id);
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch updated record from DB: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if (info.call_count > 1) {
|
||||
return StatusCode::OK.into_response();
|
||||
}
|
||||
|
||||
let resource = info.resource;
|
||||
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
|
||||
|
||||
tracing::debug!(
|
||||
"Executing agent - Type: {}, Id: {}",
|
||||
resource,
|
||||
agent_id
|
||||
);
|
||||
|
||||
let cmd = match resource.as_str() {
|
||||
"web-search" => crate::agents::search::agent(agent_id.as_str(), &*input).await,
|
||||
"news-search" => crate::agents::news::agent(agent_id.as_str(), &*input).await,
|
||||
"image-generator" => {
|
||||
crate::agents::image_generator::agent(agent_id.as_str(), &*input).await
|
||||
}
|
||||
"deep-research" => {
|
||||
crate::agents::deep_research::agent(agent_id.as_str(), &*input).await
|
||||
}
|
||||
"web-scrape" => crate::agents::scrape::agent(agent_id.as_str(), &*input).await,
|
||||
_ => {
|
||||
tracing::error!("Unsupported resource type: {}", resource);
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let mut cmd = match cmd {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Agent execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let sse_stream = reader_to_stream(reader, agent_id.clone());
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap();
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Stream ID not found: {}", agent_id);
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to fetch from DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_to_stream<R>(
|
||||
reader: BufReader<R>,
|
||||
stream_id: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold(reader, move |mut reader| async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
|
||||
reader,
|
||||
)),
|
||||
Err(e) => Some((Err(e), reader)),
|
||||
}
|
||||
});
|
||||
|
||||
let stream_with_done = stream.chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct Payload {
|
||||
input: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamInfo {
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
call_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct WebhookPostRequest {
|
||||
id: String,
|
||||
resource: String,
|
||||
payload: Payload,
|
||||
parent: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct WebhookPostResponse {
|
||||
stream_url: String,
|
||||
}
|
||||
|
||||
pub async fn create_agent(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
|
||||
let db = DB.lock().await;
|
||||
|
||||
tracing::info!("Received webhook post request with ID: {}", payload.id);
|
||||
|
||||
let stream_id = payload.id.clone();
|
||||
let info = StreamInfo {
|
||||
resource: payload.resource.clone(),
|
||||
payload: payload.payload,
|
||||
parent: payload.parent.clone(),
|
||||
call_count: 0,
|
||||
};
|
||||
|
||||
let info_bytes = match serde_json::to_vec(&info) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to serialize StreamInfo: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Use atomic compare_and_swap operation
|
||||
match db.compare_and_swap(
|
||||
&stream_id,
|
||||
None as Option<&[u8]>,
|
||||
Some(info_bytes.as_slice()),
|
||||
) {
|
||||
Ok(_) => {
|
||||
// Force an immediate sync to disk
|
||||
match db.flush_async().await {
|
||||
Ok(_) => {
|
||||
// Verify the write by attempting to read it back
|
||||
match db.get(&stream_id) {
|
||||
Ok(Some(_)) => {
|
||||
let stream_url = format!("/agents/{}", stream_id);
|
||||
tracing::info!(
|
||||
"Successfully created and verified stream URL: {}",
|
||||
stream_url
|
||||
);
|
||||
Json(WebhookPostResponse { stream_url }).into_response()
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::error!("Failed to verify stream creation: {}", stream_id);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error verifying stream creation: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to flush DB: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to insert stream info: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
3
crates/agent-server/src/handlers/mod.rs
Normal file
3
crates/agent-server/src/handlers/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod not_found;
|
||||
pub mod ui;
|
||||
pub mod agents;
|
185
crates/agent-server/src/handlers/model_context.rs
Normal file
185
crates/agent-server/src/handlers/model_context.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Json, http::StatusCode, response::IntoResponse,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
// Custom function to format streaming responses according to OpenAI API format
|
||||
pub fn openai_stream_format<R>(
|
||||
reader: BufReader<R>,
|
||||
request_id: String,
|
||||
model: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold((reader, 0), move |(mut reader, index)| {
|
||||
let request_id = request_id.clone();
|
||||
let model = model.clone();
|
||||
async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => {
|
||||
let content = line.trim();
|
||||
// Skip empty lines
|
||||
if content.is_empty() {
|
||||
return Some((Ok(Bytes::from("")), (reader, index)));
|
||||
}
|
||||
|
||||
// Format as OpenAI API streaming response
|
||||
let chunk = serde_json::json!({
|
||||
"id": format!("chatcmpl-{}", request_id),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": index,
|
||||
"delta": {
|
||||
"content": content
|
||||
},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", chunk.to_string()))),
|
||||
(reader, index),
|
||||
))
|
||||
}
|
||||
Err(e) => Some((Err(e), (reader, index))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add the [DONE] message at the end
|
||||
let stream_with_done = stream.filter(|result| {
|
||||
futures::future::ready(match result {
|
||||
Ok(bytes) => !bytes.is_empty(),
|
||||
Err(_) => true,
|
||||
})
|
||||
}).chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ModelContextRequest {
|
||||
messages: Vec<Message>,
|
||||
model: Option<String>,
|
||||
stream: Option<bool>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelContextResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
message: Message,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
pub async fn model_context(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ModelContextRequest>
|
||||
) -> impl IntoResponse {
|
||||
// Generate a unique ID for this request
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Convert messages to a format that can be passed to the agent
|
||||
let input = serde_json::to_string(&payload.messages).unwrap_or_default();
|
||||
|
||||
// Use the web-search agent for now, but this could be customized based on the model parameter
|
||||
let agent_file = "./packages/genaiscript/genaisrc/web-search.genai.mts";
|
||||
|
||||
tracing::debug!(
|
||||
"Executing model context request - Id: {}",
|
||||
request_id
|
||||
);
|
||||
|
||||
// Default timeout of 60 seconds
|
||||
let mut cmd = match run_agent(&request_id, &input, agent_file, 60).await {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Model context execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Check if streaming is requested either via the stream parameter or Accept header
|
||||
let accept_header = headers.get("accept").and_then(|h| h.to_str().ok()).unwrap_or("");
|
||||
let is_streaming = payload.stream.unwrap_or(false) || accept_header.contains("text/event-stream");
|
||||
|
||||
// If streaming is requested, return a streaming response
|
||||
if is_streaming {
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let model = payload.model.clone().unwrap_or_else(|| "default-model".to_string());
|
||||
let sse_stream = openai_stream_format(reader, request_id.clone(), model);
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap();
|
||||
} else {
|
||||
// For non-streaming responses, we need to collect all output and return it as a single response
|
||||
// This is a simplified implementation and might need to be adjusted based on actual requirements
|
||||
let response = ModelContextResponse {
|
||||
id: format!("chatcmpl-{}", request_id),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: payload.model.unwrap_or_else(|| "default-model".to_string()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a placeholder response. The actual implementation would process the agent's output.".to_string(),
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
return Json(response).into_response();
|
||||
}
|
||||
}
|
48
crates/agent-server/src/handlers/models.rs
Normal file
48
crates/agent-server/src/handlers/models.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{
|
||||
extract::Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelsResponse {
|
||||
object: String,
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Model {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
owned_by: String,
|
||||
}
|
||||
|
||||
pub async fn list_models() -> impl IntoResponse {
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
// Create a response with a default model
|
||||
let response = ModelsResponse {
|
||||
object: "list".to_string(),
|
||||
data: vec![
|
||||
Model {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gpt-4".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
Json(response)
|
||||
}
|
16
crates/agent-server/src/handlers/not_found.rs
Normal file
16
crates/agent-server/src/handlers/not_found.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn handle_not_found() -> impl IntoResponse {
|
||||
tracing::warn!("404 Not Found error occurred");
|
||||
|
||||
let error_response = serde_json::json!({
|
||||
"error": "Route Not Found",
|
||||
"status": 404
|
||||
});
|
||||
|
||||
(StatusCode::NOT_FOUND, Json(error_response))
|
||||
}
|
0
crates/agent-server/src/handlers/ui.rs
Normal file
0
crates/agent-server/src/handlers/ui.rs
Normal file
37
crates/agent-server/src/main.rs
Normal file
37
crates/agent-server/src/main.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::config::{Runtime};
|
||||
use crate::routes::create_router;
|
||||
use crate::setup::init_logging;
|
||||
|
||||
mod config;
|
||||
mod routes;
|
||||
mod setup;
|
||||
mod handlers;
|
||||
mod agents;
|
||||
mod utils;
|
||||
mod counter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
init_logging();
|
||||
|
||||
Runtime::configure();
|
||||
|
||||
let router = create_router();
|
||||
|
||||
let addr = "0.0.0.0:3006";
|
||||
tracing::info!("Attempting to bind server to {}", addr);
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => {
|
||||
tracing::info!("Successfully bound to {}", l.local_addr().unwrap());
|
||||
l
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to bind to {}: {}", addr, e);
|
||||
panic!("Server failed to start");
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!("Server starting on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, router.into_make_service()).await.unwrap();
|
||||
}
|
140
crates/agent-server/src/routes.rs
Normal file
140
crates/agent-server/src/routes.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use axum::response::Response;
|
||||
use crate::handlers::{not_found::handle_not_found};
|
||||
use axum::routing::{get, Router};
|
||||
use http::StatusCode;
|
||||
use tower_http::trace::{self, TraceLayer};
|
||||
use tracing::Level;
|
||||
|
||||
use rmcp::transport::streamable_http_server::{
|
||||
StreamableHttpService, session::local::LocalSessionManager,
|
||||
};
|
||||
use rust_embed::Embed;
|
||||
use crate::agents::Agents;
|
||||
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "../../node_modules/@modelcontextprotocol/inspector-client/dist"]
|
||||
struct Asset;
|
||||
|
||||
pub struct StaticFile<T>(pub T);
|
||||
|
||||
impl<T> axum::response::IntoResponse for StaticFile<T>
|
||||
where
|
||||
T: Into<String>,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
let path = self.0.into();
|
||||
|
||||
match Asset::get(path.as_str()) {
|
||||
Some(content) => {
|
||||
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
||||
([(http::header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
|
||||
}
|
||||
None => (StatusCode::NOT_FOUND, "404 Not Found").into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn ui_index_handler() -> impl axum::response::IntoResponse {
|
||||
StaticFile("index.html")
|
||||
}
|
||||
|
||||
async fn static_handler(uri: http::Uri) -> impl axum::response::IntoResponse {
|
||||
let path = uri.path().trim_start_matches("/").to_string();
|
||||
StaticFile(path)
|
||||
}
|
||||
|
||||
pub fn create_router() -> Router {
|
||||
|
||||
let mcp_service = StreamableHttpService::new(
|
||||
Agents::new,
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
.nest_service("/mcp", mcp_service)
|
||||
.route("/health", get(health))
|
||||
.route("/", get(ui_index_handler))
|
||||
.route("/index.html", get(ui_index_handler))
|
||||
.route("/{*path}", get(static_handler))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
|
||||
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
|
||||
)
|
||||
.layer(tower_http::cors::CorsLayer::very_permissive())
|
||||
.fallback(handle_not_found)
|
||||
}
|
||||
|
||||
async fn health() -> String {
|
||||
return "ok".to_string();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::{Body, Bytes};
|
||||
use axum::http::{Request, StatusCode};
|
||||
use axum::response::Response;
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoint() {
|
||||
// Call the health function directly
|
||||
let response = health().await;
|
||||
assert_eq!(response, "ok".to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_route() {
|
||||
// Create the router
|
||||
let app = create_router();
|
||||
|
||||
// Create a request to the health endpoint
|
||||
let request = Request::builder()
|
||||
.uri("/health")
|
||||
.method("GET")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
// Process the request
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Check the response status
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Check the response body
|
||||
let body = response_body_bytes(response).await;
|
||||
assert_eq!(&body[..], b"ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found_route() {
|
||||
// Create the router
|
||||
let app = create_router();
|
||||
|
||||
// Create a request to a non-existent endpoint
|
||||
let request = Request::builder()
|
||||
.uri("/non-existent")
|
||||
.method("GET")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
// Process the request
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Check the response status
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
// Helper function to extract bytes from a response body
|
||||
async fn response_body_bytes(response: Response) -> Bytes {
|
||||
let body = response.into_body();
|
||||
// Use a reasonable size limit for the body (16MB)
|
||||
let bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
|
||||
.await
|
||||
.expect("Failed to read response body");
|
||||
bytes
|
||||
}
|
||||
}
|
10
crates/agent-server/src/setup.rs
Normal file
10
crates/agent-server/src/setup.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
// src/setup.rs
|
||||
pub fn init_logging() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::DEBUG)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
}
|
65
crates/agent-server/src/utils/base64.rs
Normal file
65
crates/agent-server/src/utils/base64.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::GeneralPurpose;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::engine::general_purpose::STANDARD_NO_PAD;
|
||||
|
||||
pub struct Base64Encoder {
|
||||
payload_engine: GeneralPurpose,
|
||||
signature_engine: GeneralPurpose,
|
||||
public_key_engine: GeneralPurpose,
|
||||
secret_key_engine: GeneralPurpose,
|
||||
}
|
||||
|
||||
impl Base64Encoder {
|
||||
pub(crate) fn b64_encode(&self, p0: &[u8]) -> String {
|
||||
self.payload_engine.encode(p0)
|
||||
}
|
||||
pub(crate) fn b64_decode(&self, p0: String) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
self.payload_engine.decode(p0)
|
||||
}
|
||||
}
|
||||
|
||||
pub const B64_ENCODER: &Base64Encoder = &Base64Encoder::new();
|
||||
|
||||
impl Base64Encoder {
|
||||
pub const fn new() -> Self { // Made new() a const fn
|
||||
Base64Encoder {
|
||||
payload_engine: STANDARD,
|
||||
signature_engine: STANDARD,
|
||||
public_key_engine: STANDARD,
|
||||
secret_key_engine: STANDARD,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b64_encode_payload<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.payload_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_payload<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.payload_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_signature<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.signature_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_signature<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.signature_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_public_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.public_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_public_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.public_key_engine.decode(input)
|
||||
}
|
||||
|
||||
pub fn b64_encode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
|
||||
self.secret_key_engine.encode(input)
|
||||
}
|
||||
|
||||
pub fn b64_decode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
|
||||
self.secret_key_engine.decode(input)
|
||||
}
|
||||
}
|
2
crates/agent-server/src/utils/mod.rs
Normal file
2
crates/agent-server/src/utils/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod utils;
|
||||
pub mod base64;
|
71
crates/agent-server/src/utils/utils.rs
Normal file
71
crates/agent-server/src/utils/utils.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// utils.rs
|
||||
use tokio::process::{Child, Command}; // Use tokio::process::Child and Command
|
||||
use std::env;
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing;
|
||||
|
||||
|
||||
pub struct ShimBinding {
|
||||
user_input: String,
|
||||
file_path: String,
|
||||
openai_api_key: String,
|
||||
openai_api_base: String,
|
||||
genaiscript_model_large: String,
|
||||
genaiscript_model_small: String,
|
||||
searxng_api_base_url: String,
|
||||
searxng_password: String,
|
||||
}
|
||||
|
||||
impl ShimBinding {
|
||||
pub fn new(user_input: String, file_path: String) -> Self {
|
||||
Self {
|
||||
user_input,
|
||||
file_path, // Initialize the new field
|
||||
openai_api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
|
||||
openai_api_base: env::var("OPENAI_API_BASE").unwrap_or_default(),
|
||||
genaiscript_model_large: env::var("GENAISCRIPT_MODEL_LARGE").unwrap_or_default(),
|
||||
genaiscript_model_small: env::var("GENAISCRIPT_MODEL_SMALL").unwrap_or_default(),
|
||||
searxng_api_base_url: env::var("SEARXNG_API_BASE_URL").unwrap_or_default(),
|
||||
searxng_password: env::var("SEARXNG_PASSWORD").unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> std::io::Result<Child> {
|
||||
let mut command = Command::new("./dist/genaiscript-rust-shim.js");
|
||||
command
|
||||
.arg("--file")
|
||||
.arg(&self.file_path) // Use the file_path field instead of hardcoded value
|
||||
.arg(format!("USER_INPUT={}", self.user_input))
|
||||
.env("OPENAI_API_KEY", &self.openai_api_key)
|
||||
.env("OPENAI_API_BASE", &self.openai_api_base)
|
||||
.env("GENAISCRIPT_MODEL_LARGE", &self.genaiscript_model_large)
|
||||
.env("GENAISCRIPT_MODEL_SMALL", &self.genaiscript_model_small)
|
||||
.env("SEARXNG_API_BASE_URL", &self.searxng_api_base_url)
|
||||
.env("SEARXNG_PASSWORD", &self.searxng_password)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit());
|
||||
|
||||
command.spawn()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// wrapper executes an agent with a timeout
|
||||
pub async fn run_agent(stream_id: &str, input: &str, file_path: &str, timeout_seconds: u64 ) -> Result<Child, String> {
|
||||
tracing::debug!("Initiating agent for stream {} with file path {}", stream_id, file_path);
|
||||
|
||||
let shim_binding = ShimBinding::new(input.to_string(), file_path.to_string());
|
||||
let spawn_future = async move {
|
||||
match shim_binding.execute() {
|
||||
Ok(child) => Ok(child),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to spawn shim process: {}", e);
|
||||
Err(e.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
timeout(Duration::from_secs(timeout_seconds), spawn_future)
|
||||
.await
|
||||
.unwrap_or_else(|_| Err("Command timed out after 10 seconds".to_string()))
|
||||
}
|
Reference in New Issue
Block a user