update project structure

This commit is contained in:
geoffsee
2025-06-05 22:42:17 -04:00
parent 1270a6b0ba
commit c5b8bd812c
45 changed files with 4921 additions and 128 deletions

View File

@@ -0,0 +1,36 @@
[package]
name = "agent-server"
version = "0.1.0"
edition = "2021"
license = "MIT"
[[bin]]
edition = "2021"
name = "agent-server"
path = "src/main.rs"
[dependencies]
axum = { version = "0.8", features = ["multipart"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
http = "1.1.0"
tokio-stream = "0.1.16"
uuid = { version = "1.11.0", features = ["v4"] }
tokio-util = { version = "0.7", features = ["io"] }
serde_json = "1.0.133"
futures = "0.3.31"
dotenv = "0.15.0"
shell-escape = "0.1.5"
rust-embed = "8.5.0"
bytes = "1.8.0"
lazy_static = "1.5.0"
sled = "0.34.7"
tower-http = { version = "0.6.2", features = ["trace", "cors"] }
tower = "0.5.2"
anyhow = "1.0.97"
base64 = "0.22.1"
fips204 = "0.4.6"
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = ["server", "transport-streamable-http-server", "transport-sse-server", "transport-io",] }
mime_guess = "2.0.5"

View File

@@ -0,0 +1,32 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
use tracing;
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/deep-research.genai.mts", 60).await
}
#[cfg(test)]
mod tests {
use crate::agents::deep_research::agent;
use std::fmt::Debug;
#[tokio::test]
async fn test_deepresearch() {
// a really provocative question for research that generally yields infinite complexity with each run
let input = "What is a life of meaning?";
let mut command = agent("test-deepresearch-agent", input).await.unwrap();
// let mut stdout = String::new();
// command.stdout.take().unwrap().read_to_string(&mut stdout).await.unwrap();
// println!("stdout: {}", stdout);
// // Optionally, you can capture and inspect stdout if needed:
let _output = command.wait_with_output().await.expect("Failed to wait for output");
// println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
// println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
}
}

View File

@@ -0,0 +1,10 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
tracing::debug!(
"Running image generator, \ninput: {}",
input
);
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/image-generator.genai.mts", 10).await
}

View File

@@ -0,0 +1,136 @@
pub(crate) mod news;
pub(crate) mod scrape;
pub(crate) mod search;
pub(crate) mod image_generator;
pub(crate) mod deep_research;
use std::sync::Arc;
use rmcp::{
Error as McpError, RoleServer, ServerHandler, const_string, model::*,
service::RequestContext, tool,
};
use tokio::process::Child;
#[derive(Clone)]
pub struct Agents;
#[tool(tool_box)]
impl Agents {
pub fn new() -> Self {
Self {}
}
#[tool(description = "Search the web for information")]
async fn search(
&self,
#[tool(param)]
#[schemars(description = "The search query")]
query: String,
) -> Result<CallToolResult, McpError> {
match search::agent("tool-search", &query).await {
Ok(child) => handle_agent_result(child).await,
Err(e) => Err(McpError::internal_error(e.to_string(), None))
}
}
#[tool(description = "Search for news articles")]
async fn news(
&self,
#[tool(param)]
#[schemars(description = "The news search query")]
query: String,
) -> Result<CallToolResult, McpError> {
match news::agent("tool-news", &query).await {
Ok(child) => handle_agent_result(child).await,
Err(e) => Err(McpError::internal_error(e.to_string(), None))
}
}
#[tool(description = "Scrape content from a webpage")]
async fn scrape(
&self,
#[tool(param)]
#[schemars(description = "The URL to scrape")]
url: String,
) -> Result<CallToolResult, McpError> {
match scrape::agent("tool-scrape", &url).await {
Ok(child) => handle_agent_result(child).await,
Err(e) => Err(McpError::internal_error(e.to_string(), None))
}
}
#[tool(description = "Generate an image based on a description")]
async fn generate_image(
&self,
#[tool(param)]
#[schemars(description = "The image description")]
description: String,
) -> Result<CallToolResult, McpError> {
match image_generator::agent("tool-image", &description).await {
Ok(child) => handle_agent_result(child).await,
Err(e) => Err(McpError::internal_error(e.to_string(), None))
}
}
#[tool(description = "Perform deep research on a topic")]
async fn deep_research(
&self,
#[tool(param)]
#[schemars(description = "The research topic")]
topic: String,
) -> Result<CallToolResult, McpError> {
match deep_research::agent("tool-research", &topic).await {
Ok(child) => handle_agent_result(child).await,
Err(e) => Err(McpError::internal_error(e.to_string(), None))
}
}
}
#[tool(tool_box)]
impl ServerHandler for Agents {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some("This server provides various agent tools for web search, news search, web scraping, image generation, and deep research.".to_string()),
}
}
async fn initialize(
&self,
_request: InitializeRequestParam,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
let initialize_headers = &http_request_part.headers;
let initialize_uri = &http_request_part.uri;
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
}
Ok(self.get_info())
}
}
async fn handle_agent_result(mut child: Child) -> Result<CallToolResult, McpError> {
use tokio::io::AsyncReadExt;
let output = match child.wait_with_output().await {
Ok(output) => output,
Err(e) => return Err(McpError::internal_error(format!("Failed to get agent output: {}", e), None)),
};
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
return Err(McpError::internal_error(
format!("Agent failed with status {}: {}", output.status, stderr),
None,
));
}
Ok(CallToolResult::success(vec![Content::text(stdout)]))
}

View File

@@ -0,0 +1,6 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
}

View File

@@ -0,0 +1,6 @@
use crate::utils::utils::run_agent;
use tokio::process::Child;
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
}

View File

@@ -0,0 +1,27 @@
use tokio::process::Child;
use tracing;
use crate::utils::utils::run_agent;
pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use crate::agents::search::agent;
#[tokio::test]
async fn test_search_execution() {
let input = "Who won the 2024 presidential election?";
let mut command = agent("test-stream", input).await.unwrap();
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
// Optionally, you can capture and inspect stdout if needed:
let output = command.wait_with_output().await.expect("Failed to wait for output");
println!("Stdout: {}", String::from_utf8_lossy(&output.stdout));
println!("Stderr: {}", String::from_utf8_lossy(&output.stderr));
}
}

View File

@@ -0,0 +1,27 @@
pub struct Runtime {
pub env_vars: Vec<String>,
}
impl Runtime {
pub fn configure() -> Self {
// automatic configuration between local/docker environments
match dotenv::dotenv() {
Ok(_) => tracing::debug!("Loaded .env file successfully"),
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
}
Self {
env_vars: vec![
"OPENAI_API_KEY".to_string(),
"GENAISCRIPT_MODEL_LARGE".to_string(),
"GENAISCRIPT_MODEL_SMALL".to_string(),
"SEARXNG_API_BASE_URL".to_string(),
],
}
}
pub fn get_env_var(&self, key: &str) -> String {
std::env::var(key).unwrap_or_default()
}
}

View File

@@ -0,0 +1,211 @@
use std::sync::Arc;
use rmcp::{
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
service::RequestContext, tool,
};
use serde_json::json;
use tokio::sync::Mutex;
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct StructRequest {
pub a: i32,
pub b: i32,
}
#[derive(Clone)]
pub struct Counter {
counter: Arc<Mutex<i32>>,
}
#[tool(tool_box)]
impl Counter {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
}
}
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
RawResource::new(uri, name.to_string()).no_annotation()
}
#[tool(description = "Increment the counter by 1")]
async fn increment(&self) -> Result<CallToolResult, McpError> {
let mut counter = self.counter.lock().await;
*counter += 1;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Decrement the counter by 1")]
async fn decrement(&self) -> Result<CallToolResult, McpError> {
let mut counter = self.counter.lock().await;
*counter -= 1;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Get the current counter value")]
async fn get_value(&self) -> Result<CallToolResult, McpError> {
let counter = self.counter.lock().await;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Say hello to the client")]
fn say_hello(&self) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text("hello")]))
}
#[tool(description = "Repeat what you say")]
fn echo(
&self,
#[tool(param)]
#[schemars(description = "Repeat what you say")]
saying: String,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(saying)]))
}
#[tool(description = "Calculate the sum of two numbers")]
fn sum(
&self,
#[tool(aggr)] StructRequest { a, b }: StructRequest,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
(a + b).to_string(),
)]))
}
}
const_string!(Echo = "echo");
#[tool(tool_box)]
impl ServerHandler for Counter {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_prompts()
.enable_resources()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
}
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, McpError> {
Ok(ListResourcesResult {
resources: vec![
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
self._create_resource_text("memo://insights", "memo-name"),
],
next_cursor: None,
})
}
async fn read_resource(
&self,
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
_: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, McpError> {
match uri.as_str() {
"str:////Users/to/some/path/" => {
let cwd = "/Users/to/some/path/";
Ok(ReadResourceResult {
contents: vec![ResourceContents::text(cwd, uri)],
})
}
"memo://insights" => {
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
Ok(ReadResourceResult {
contents: vec![ResourceContents::text(memo, uri)],
})
}
_ => Err(McpError::resource_not_found(
"resource_not_found",
Some(json!({
"uri": uri
})),
)),
}
}
async fn list_prompts(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, McpError> {
Ok(ListPromptsResult {
next_cursor: None,
prompts: vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required argument, message"),
Some(vec![PromptArgument {
name: "message".to_string(),
description: Some("A message to put in the prompt".to_string()),
required: Some(true),
}]),
)],
})
}
async fn get_prompt(
&self,
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
_: RequestContext<RoleServer>,
) -> Result<GetPromptResult, McpError> {
match name.as_str() {
"example_prompt" => {
let message = arguments
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
.ok_or_else(|| {
McpError::invalid_params("No message provided to example_prompt", None)
})?;
let prompt =
format!("This is an example prompt with your message here: '{message}'");
Ok(GetPromptResult {
description: None,
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(prompt),
}],
})
}
_ => Err(McpError::invalid_params("prompt not found", None)),
}
}
async fn list_resource_templates(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, McpError> {
Ok(ListResourceTemplatesResult {
next_cursor: None,
resource_templates: Vec::new(),
})
}
async fn initialize(
&self,
_request: InitializeRequestParam,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
let initialize_headers = &http_request_part.headers;
let initialize_uri = &http_request_part.uri;
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
}
Ok(self.get_info())
}
}

View File

@@ -0,0 +1,262 @@
use axum::response::Response;
use axum::{
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sled;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::Mutex;
// init sled
lazy_static! {
static ref DB: Arc<Mutex<sled::Db>> = Arc::new(Mutex::new(
sled::open("./open-web-agent-rs/db/stream_store").expect("Failed to open sled database")
));
}
pub async fn use_agent(Path(agent_id): Path<String>) -> impl IntoResponse {
let db = DB.lock().await;
match db.get(&agent_id) {
Ok(Some(data)) => {
let mut info: StreamInfo = match serde_json::from_slice(&data) {
Ok(info) => info,
Err(e) => {
tracing::error!("Failed to deserialize StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Increment the call_count in the database
info.call_count += 1;
let updated_info_bytes = match serde_json::to_vec(&info) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to serialize updated StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
match db.insert(&agent_id, updated_info_bytes) {
Ok(_) => {
if let Err(e) = db.flush_async().await {
tracing::error!(
"Failed to persist updated call_count to the database: {}",
e
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
Err(e) => {
tracing::error!("Failed to update call_count in the database: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let info: StreamInfo = match db.get(&agent_id) {
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
Ok(info) => info,
Err(e) => {
tracing::error!("Failed to deserialize updated StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
},
Ok(None) => {
tracing::error!("Stream ID not found after update: {}", agent_id);
return StatusCode::NOT_FOUND.into_response();
}
Err(e) => {
tracing::error!("Failed to fetch updated record from DB: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
if (info.call_count > 1) {
return StatusCode::OK.into_response();
}
let resource = info.resource;
let input = serde_json::to_string(&info.payload.input).unwrap_or_default();
tracing::debug!(
"Executing agent - Type: {}, Id: {}",
resource,
agent_id
);
let cmd = match resource.as_str() {
"web-search" => crate::agents::search::agent(agent_id.as_str(), &*input).await,
"news-search" => crate::agents::news::agent(agent_id.as_str(), &*input).await,
"image-generator" => {
crate::agents::image_generator::agent(agent_id.as_str(), &*input).await
}
"deep-research" => {
crate::agents::deep_research::agent(agent_id.as_str(), &*input).await
}
"web-scrape" => crate::agents::scrape::agent(agent_id.as_str(), &*input).await,
_ => {
tracing::error!("Unsupported resource type: {}", resource);
return StatusCode::BAD_REQUEST.into_response();
}
};
let mut cmd = match cmd {
Ok(cmd) => cmd,
Err(e) => {
tracing::error!("Agent execution failed: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let stdout = match cmd.stdout.take() {
Some(stdout) => stdout,
None => {
tracing::error!("No stdout available for the command.");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let reader = BufReader::new(stdout);
let sse_stream = reader_to_stream(reader, agent_id.clone());
return Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache, no-transform")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "yes")
.body(Body::from_stream(sse_stream))
.unwrap();
}
Ok(None) => {
tracing::error!("Stream ID not found: {}", agent_id);
StatusCode::NOT_FOUND.into_response()
}
Err(e) => {
tracing::error!("Failed to fetch from DB: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
fn reader_to_stream<R>(
reader: BufReader<R>,
stream_id: String,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let stream = futures::stream::unfold(reader, move |mut reader| async move {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => None,
Ok(_) => Some((
Ok(Bytes::from(format!("data: {}\n\n", line.trim()))),
reader,
)),
Err(e) => Some((Err(e), reader)),
}
});
let stream_with_done = stream.chain(futures::stream::once(async {
Ok(Bytes::from("data: [DONE]\n\n"))
}));
Box::pin(stream_with_done)
}
#[derive(Deserialize, Serialize, Debug)]
struct Payload {
input: Value,
}
#[derive(Serialize, Deserialize, Debug)]
struct StreamInfo {
resource: String,
payload: Payload,
parent: String,
call_count: i32,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct WebhookPostRequest {
id: String,
resource: String,
payload: Payload,
parent: String,
}
#[derive(Deserialize, Serialize, Debug)]
struct WebhookPostResponse {
stream_url: String,
}
pub async fn create_agent(Json(payload): Json<WebhookPostRequest>) -> impl IntoResponse {
let db = DB.lock().await;
tracing::info!("Received webhook post request with ID: {}", payload.id);
let stream_id = payload.id.clone();
let info = StreamInfo {
resource: payload.resource.clone(),
payload: payload.payload,
parent: payload.parent.clone(),
call_count: 0,
};
let info_bytes = match serde_json::to_vec(&info) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to serialize StreamInfo: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Use atomic compare_and_swap operation
match db.compare_and_swap(
&stream_id,
None as Option<&[u8]>,
Some(info_bytes.as_slice()),
) {
Ok(_) => {
// Force an immediate sync to disk
match db.flush_async().await {
Ok(_) => {
// Verify the write by attempting to read it back
match db.get(&stream_id) {
Ok(Some(_)) => {
let stream_url = format!("/agents/{}", stream_id);
tracing::info!(
"Successfully created and verified stream URL: {}",
stream_url
);
Json(WebhookPostResponse { stream_url }).into_response()
}
Ok(None) => {
tracing::error!("Failed to verify stream creation: {}", stream_id);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
Err(e) => {
tracing::error!("Error verifying stream creation: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(e) => {
tracing::error!("Failed to flush DB: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(e) => {
tracing::error!("Failed to insert stream info: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}

View File

@@ -0,0 +1,3 @@
pub mod not_found;
pub mod ui;
pub mod agents;

View File

@@ -0,0 +1,185 @@
use axum::response::Response;
use axum::{
body::Body, extract::Json, http::StatusCode, response::IntoResponse,
};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::pin::Pin;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use crate::utils::utils::run_agent;
// Custom function to format streaming responses according to OpenAI API format
pub fn openai_stream_format<R>(
reader: BufReader<R>,
request_id: String,
model: String,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let stream = futures::stream::unfold((reader, 0), move |(mut reader, index)| {
let request_id = request_id.clone();
let model = model.clone();
async move {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => None,
Ok(_) => {
let content = line.trim();
// Skip empty lines
if content.is_empty() {
return Some((Ok(Bytes::from("")), (reader, index)));
}
// Format as OpenAI API streaming response
let chunk = serde_json::json!({
"id": format!("chatcmpl-{}", request_id),
"object": "chat.completion.chunk",
"created": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
"model": model,
"choices": [{
"index": index,
"delta": {
"content": content
},
"finish_reason": null
}]
});
Some((
Ok(Bytes::from(format!("data: {}\n\n", chunk.to_string()))),
(reader, index),
))
}
Err(e) => Some((Err(e), (reader, index))),
}
}
});
// Add the [DONE] message at the end
let stream_with_done = stream.filter(|result| {
futures::future::ready(match result {
Ok(bytes) => !bytes.is_empty(),
Err(_) => true,
})
}).chain(futures::stream::once(async {
Ok(Bytes::from("data: [DONE]\n\n"))
}));
Box::pin(stream_with_done)
}
#[derive(Deserialize, Debug)]
pub struct ModelContextRequest {
messages: Vec<Message>,
model: Option<String>,
stream: Option<bool>,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Debug)]
pub struct ModelContextResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Serialize, Debug)]
pub struct Choice {
index: u32,
message: Message,
finish_reason: String,
}
pub async fn model_context(
headers: axum::http::HeaderMap,
Json(payload): Json<ModelContextRequest>
) -> impl IntoResponse {
// Generate a unique ID for this request
let request_id = uuid::Uuid::new_v4().to_string();
// Convert messages to a format that can be passed to the agent
let input = serde_json::to_string(&payload.messages).unwrap_or_default();
// Use the web-search agent for now, but this could be customized based on the model parameter
let agent_file = "./packages/genaiscript/genaisrc/web-search.genai.mts";
tracing::debug!(
"Executing model context request - Id: {}",
request_id
);
// Default timeout of 60 seconds
let mut cmd = match run_agent(&request_id, &input, agent_file, 60).await {
Ok(cmd) => cmd,
Err(e) => {
tracing::error!("Model context execution failed: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Check if streaming is requested either via the stream parameter or Accept header
let accept_header = headers.get("accept").and_then(|h| h.to_str().ok()).unwrap_or("");
let is_streaming = payload.stream.unwrap_or(false) || accept_header.contains("text/event-stream");
// If streaming is requested, return a streaming response
if is_streaming {
let stdout = match cmd.stdout.take() {
Some(stdout) => stdout,
None => {
tracing::error!("No stdout available for the command.");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let reader = BufReader::new(stdout);
let model = payload.model.clone().unwrap_or_else(|| "default-model".to_string());
let sse_stream = openai_stream_format(reader, request_id.clone(), model);
return Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache, no-transform")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "yes")
.body(Body::from_stream(sse_stream))
.unwrap();
} else {
// For non-streaming responses, we need to collect all output and return it as a single response
// This is a simplified implementation and might need to be adjusted based on actual requirements
let response = ModelContextResponse {
id: format!("chatcmpl-{}", request_id),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
model: payload.model.unwrap_or_else(|| "default-model".to_string()),
choices: vec![Choice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: "This is a placeholder response. The actual implementation would process the agent's output.".to_string(),
},
finish_reason: "stop".to_string(),
}],
};
return Json(response).into_response();
}
}

View File

@@ -0,0 +1,48 @@
use axum::{
extract::Json,
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Serialize, Debug)]
pub struct ModelsResponse {
object: String,
data: Vec<Model>,
}
#[derive(Serialize, Debug)]
pub struct Model {
id: String,
object: String,
created: u64,
owned_by: String,
}
pub async fn list_models() -> impl IntoResponse {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
// Create a response with a default model
let response = ModelsResponse {
object: "list".to_string(),
data: vec![
Model {
id: "gpt-3.5-turbo".to_string(),
object: "model".to_string(),
created: current_time,
owned_by: "open-web-agent-rs".to_string(),
},
Model {
id: "gpt-4".to_string(),
object: "model".to_string(),
created: current_time,
owned_by: "open-web-agent-rs".to_string(),
},
],
};
Json(response)
}

View File

@@ -0,0 +1,16 @@
use axum::{
http::StatusCode,
Json,
response::IntoResponse,
};
pub async fn handle_not_found() -> impl IntoResponse {
tracing::warn!("404 Not Found error occurred");
let error_response = serde_json::json!({
"error": "Route Not Found",
"status": 404
});
(StatusCode::NOT_FOUND, Json(error_response))
}

View File

View File

@@ -0,0 +1,37 @@
use crate::config::{Runtime};
use crate::routes::create_router;
use crate::setup::init_logging;
mod config;
mod routes;
mod setup;
mod handlers;
mod agents;
mod utils;
mod counter;
#[tokio::main]
async fn main() {
init_logging();
Runtime::configure();
let router = create_router();
let addr = "0.0.0.0:3006";
tracing::info!("Attempting to bind server to {}", addr);
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => {
tracing::info!("Successfully bound to {}", l.local_addr().unwrap());
l
}
Err(e) => {
tracing::error!("Failed to bind to {}: {}", addr, e);
panic!("Server failed to start");
}
};
tracing::info!("Server starting on {}", listener.local_addr().unwrap());
axum::serve(listener, router.into_make_service()).await.unwrap();
}

View File

@@ -0,0 +1,140 @@
use axum::response::Response;
use crate::handlers::{not_found::handle_not_found};
use axum::routing::{get, Router};
use http::StatusCode;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
use rmcp::transport::streamable_http_server::{
StreamableHttpService, session::local::LocalSessionManager,
};
use rust_embed::Embed;
use crate::agents::Agents;
#[derive(Embed)]
#[folder = "../../node_modules/@modelcontextprotocol/inspector-client/dist"]
struct Asset;
pub struct StaticFile<T>(pub T);
impl<T> axum::response::IntoResponse for StaticFile<T>
where
T: Into<String>,
{
fn into_response(self) -> Response {
let path = self.0.into();
match Asset::get(path.as_str()) {
Some(content) => {
let mime = mime_guess::from_path(path).first_or_octet_stream();
([(http::header::CONTENT_TYPE, mime.as_ref())], content.data).into_response()
}
None => (StatusCode::NOT_FOUND, "404 Not Found").into_response(),
}
}
}
async fn ui_index_handler() -> impl axum::response::IntoResponse {
StaticFile("index.html")
}
async fn static_handler(uri: http::Uri) -> impl axum::response::IntoResponse {
let path = uri.path().trim_start_matches("/").to_string();
StaticFile(path)
}
pub fn create_router() -> Router {
let mcp_service = StreamableHttpService::new(
Agents::new,
LocalSessionManager::default().into(),
Default::default(),
);
Router::new()
.nest_service("/mcp", mcp_service)
.route("/health", get(health))
.route("/", get(ui_index_handler))
.route("/index.html", get(ui_index_handler))
.route("/{*path}", get(static_handler))
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
)
.layer(tower_http::cors::CorsLayer::very_permissive())
.fallback(handle_not_found)
}
async fn health() -> String {
return "ok".to_string();
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::{Body, Bytes};
use axum::http::{Request, StatusCode};
use axum::response::Response;
use tower::ServiceExt;
#[tokio::test]
async fn test_health_endpoint() {
// Call the health function directly
let response = health().await;
assert_eq!(response, "ok".to_string());
}
#[tokio::test]
async fn test_health_route() {
// Create the router
let app = create_router();
// Create a request to the health endpoint
let request = Request::builder()
.uri("/health")
.method("GET")
.body(Body::empty())
.unwrap();
// Process the request
let response = app.oneshot(request).await.unwrap();
// Check the response status
assert_eq!(response.status(), StatusCode::OK);
// Check the response body
let body = response_body_bytes(response).await;
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn test_not_found_route() {
// Create the router
let app = create_router();
// Create a request to a non-existent endpoint
let request = Request::builder()
.uri("/non-existent")
.method("GET")
.body(Body::empty())
.unwrap();
// Process the request
let response = app.oneshot(request).await.unwrap();
// Check the response status
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
// Helper function to extract bytes from a response body
async fn response_body_bytes(response: Response) -> Bytes {
let body = response.into_body();
// Use a reasonable size limit for the body (16MB)
let bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
.await
.expect("Failed to read response body");
bytes
}
}

View File

@@ -0,0 +1,10 @@
// src/setup.rs
pub fn init_logging() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(true)
.with_thread_ids(true)
.with_file(true)
.with_line_number(true)
.init();
}

View File

@@ -0,0 +1,65 @@
use base64::Engine;
use base64::engine::GeneralPurpose;
use base64::engine::general_purpose::STANDARD;
use base64::engine::general_purpose::STANDARD_NO_PAD;
pub struct Base64Encoder {
payload_engine: GeneralPurpose,
signature_engine: GeneralPurpose,
public_key_engine: GeneralPurpose,
secret_key_engine: GeneralPurpose,
}
impl Base64Encoder {
pub(crate) fn b64_encode(&self, p0: &[u8]) -> String {
self.payload_engine.encode(p0)
}
pub(crate) fn b64_decode(&self, p0: String) -> Result<Vec<u8>, base64::DecodeError> {
self.payload_engine.decode(p0)
}
}
pub const B64_ENCODER: &Base64Encoder = &Base64Encoder::new();
impl Base64Encoder {
pub const fn new() -> Self { // Made new() a const fn
Base64Encoder {
payload_engine: STANDARD,
signature_engine: STANDARD,
public_key_engine: STANDARD,
secret_key_engine: STANDARD,
}
}
pub fn b64_encode_payload<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.payload_engine.encode(input)
}
pub fn b64_decode_payload<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.payload_engine.decode(input)
}
pub fn b64_decode_signature<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.signature_engine.decode(input)
}
pub fn b64_encode_signature<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.signature_engine.encode(input)
}
pub fn b64_encode_public_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.public_key_engine.encode(input)
}
pub fn b64_decode_public_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.public_key_engine.decode(input)
}
pub fn b64_encode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> String { // Added trait bound
self.secret_key_engine.encode(input)
}
pub fn b64_decode_secret_key<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, base64::DecodeError> { // Added trait bound
self.secret_key_engine.decode(input)
}
}

View File

@@ -0,0 +1,2 @@
pub mod utils;
pub mod base64;

View File

@@ -0,0 +1,71 @@
// utils.rs
use tokio::process::{Child, Command}; // Use tokio::process::Child and Command
use std::env;
use tokio::time::{timeout, Duration};
use tracing;
pub struct ShimBinding {
user_input: String,
file_path: String,
openai_api_key: String,
openai_api_base: String,
genaiscript_model_large: String,
genaiscript_model_small: String,
searxng_api_base_url: String,
searxng_password: String,
}
impl ShimBinding {
pub fn new(user_input: String, file_path: String) -> Self {
Self {
user_input,
file_path, // Initialize the new field
openai_api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
openai_api_base: env::var("OPENAI_API_BASE").unwrap_or_default(),
genaiscript_model_large: env::var("GENAISCRIPT_MODEL_LARGE").unwrap_or_default(),
genaiscript_model_small: env::var("GENAISCRIPT_MODEL_SMALL").unwrap_or_default(),
searxng_api_base_url: env::var("SEARXNG_API_BASE_URL").unwrap_or_default(),
searxng_password: env::var("SEARXNG_PASSWORD").unwrap_or_default(),
}
}
pub fn execute(&self) -> std::io::Result<Child> {
let mut command = Command::new("./dist/genaiscript-rust-shim.js");
command
.arg("--file")
.arg(&self.file_path) // Use the file_path field instead of hardcoded value
.arg(format!("USER_INPUT={}", self.user_input))
.env("OPENAI_API_KEY", &self.openai_api_key)
.env("OPENAI_API_BASE", &self.openai_api_base)
.env("GENAISCRIPT_MODEL_LARGE", &self.genaiscript_model_large)
.env("GENAISCRIPT_MODEL_SMALL", &self.genaiscript_model_small)
.env("SEARXNG_API_BASE_URL", &self.searxng_api_base_url)
.env("SEARXNG_PASSWORD", &self.searxng_password)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit());
command.spawn()
}
}
/// wrapper executes an agent with a timeout
pub async fn run_agent(stream_id: &str, input: &str, file_path: &str, timeout_seconds: u64 ) -> Result<Child, String> {
tracing::debug!("Initiating agent for stream {} with file path {}", stream_id, file_path);
let shim_binding = ShimBinding::new(input.to_string(), file_path.to_string());
let spawn_future = async move {
match shim_binding.execute() {
Ok(child) => Ok(child),
Err(e) => {
tracing::error!("Failed to spawn shim process: {}", e);
Err(e.to_string())
}
}
};
timeout(Duration::from_secs(timeout_seconds), spawn_future)
.await
.unwrap_or_else(|_| Err("Command timed out after 10 seconds".to_string()))
}