configure basic MCP server

This commit is contained in:
geoffsee
2025-06-04 22:33:37 -04:00
committed by Geoff Seemueller
parent dbc8d78fb5
commit 06a233633e
10 changed files with 921 additions and 50 deletions

211
src/counter.rs Normal file
View File

@@ -0,0 +1,211 @@
use std::sync::Arc;
use rmcp::{
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
service::RequestContext, tool,
};
use serde_json::json;
use tokio::sync::Mutex;
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct StructRequest {
pub a: i32,
pub b: i32,
}
#[derive(Clone)]
pub struct Counter {
counter: Arc<Mutex<i32>>,
}
#[tool(tool_box)]
impl Counter {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
}
}
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
RawResource::new(uri, name.to_string()).no_annotation()
}
#[tool(description = "Increment the counter by 1")]
async fn increment(&self) -> Result<CallToolResult, McpError> {
let mut counter = self.counter.lock().await;
*counter += 1;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Decrement the counter by 1")]
async fn decrement(&self) -> Result<CallToolResult, McpError> {
let mut counter = self.counter.lock().await;
*counter -= 1;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Get the current counter value")]
async fn get_value(&self) -> Result<CallToolResult, McpError> {
let counter = self.counter.lock().await;
Ok(CallToolResult::success(vec![Content::text(
counter.to_string(),
)]))
}
#[tool(description = "Say hello to the client")]
fn say_hello(&self) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text("hello")]))
}
#[tool(description = "Repeat what you say")]
fn echo(
&self,
#[tool(param)]
#[schemars(description = "Repeat what you say")]
saying: String,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(saying)]))
}
#[tool(description = "Calculate the sum of two numbers")]
fn sum(
&self,
#[tool(aggr)] StructRequest { a, b }: StructRequest,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
(a + b).to_string(),
)]))
}
}
const_string!(Echo = "echo");
#[tool(tool_box)]
impl ServerHandler for Counter {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_prompts()
.enable_resources()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
}
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, McpError> {
Ok(ListResourcesResult {
resources: vec![
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
self._create_resource_text("memo://insights", "memo-name"),
],
next_cursor: None,
})
}
async fn read_resource(
&self,
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
_: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, McpError> {
match uri.as_str() {
"str:////Users/to/some/path/" => {
let cwd = "/Users/to/some/path/";
Ok(ReadResourceResult {
contents: vec![ResourceContents::text(cwd, uri)],
})
}
"memo://insights" => {
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
Ok(ReadResourceResult {
contents: vec![ResourceContents::text(memo, uri)],
})
}
_ => Err(McpError::resource_not_found(
"resource_not_found",
Some(json!({
"uri": uri
})),
)),
}
}
async fn list_prompts(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, McpError> {
Ok(ListPromptsResult {
next_cursor: None,
prompts: vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required argument, message"),
Some(vec![PromptArgument {
name: "message".to_string(),
description: Some("A message to put in the prompt".to_string()),
required: Some(true),
}]),
)],
})
}
async fn get_prompt(
&self,
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
_: RequestContext<RoleServer>,
) -> Result<GetPromptResult, McpError> {
match name.as_str() {
"example_prompt" => {
let message = arguments
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
.ok_or_else(|| {
McpError::invalid_params("No message provided to example_prompt", None)
})?;
let prompt =
format!("This is an example prompt with your message here: '{message}'");
Ok(GetPromptResult {
description: None,
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::text(prompt),
}],
})
}
_ => Err(McpError::invalid_params("prompt not found", None)),
}
}
async fn list_resource_templates(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, McpError> {
Ok(ListResourceTemplatesResult {
next_cursor: None,
resource_templates: Vec::new(),
})
}
async fn initialize(
&self,
_request: InitializeRequestParam,
context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
let initialize_headers = &http_request_part.headers;
let initialize_uri = &http_request_part.uri;
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
}
Ok(self.get_info())
}
}

View File

@@ -0,0 +1,185 @@
use axum::response::Response;
use axum::{
body::Body, extract::Json, http::StatusCode, response::IntoResponse,
};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::pin::Pin;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use crate::utils::utils::run_agent;
// Custom function to format streaming responses according to OpenAI API format
pub fn openai_stream_format<R>(
reader: BufReader<R>,
request_id: String,
model: String,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
{
let stream = futures::stream::unfold((reader, 0), move |(mut reader, index)| {
let request_id = request_id.clone();
let model = model.clone();
async move {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => None,
Ok(_) => {
let content = line.trim();
// Skip empty lines
if content.is_empty() {
return Some((Ok(Bytes::from("")), (reader, index)));
}
// Format as OpenAI API streaming response
let chunk = serde_json::json!({
"id": format!("chatcmpl-{}", request_id),
"object": "chat.completion.chunk",
"created": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
"model": model,
"choices": [{
"index": index,
"delta": {
"content": content
},
"finish_reason": null
}]
});
Some((
Ok(Bytes::from(format!("data: {}\n\n", chunk.to_string()))),
(reader, index),
))
}
Err(e) => Some((Err(e), (reader, index))),
}
}
});
// Add the [DONE] message at the end
let stream_with_done = stream.filter(|result| {
futures::future::ready(match result {
Ok(bytes) => !bytes.is_empty(),
Err(_) => true,
})
}).chain(futures::stream::once(async {
Ok(Bytes::from("data: [DONE]\n\n"))
}));
Box::pin(stream_with_done)
}
#[derive(Deserialize, Debug)]
pub struct ModelContextRequest {
messages: Vec<Message>,
model: Option<String>,
stream: Option<bool>,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Message {
role: String,
content: String,
}
#[derive(Serialize, Debug)]
pub struct ModelContextResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Serialize, Debug)]
pub struct Choice {
index: u32,
message: Message,
finish_reason: String,
}
pub async fn model_context(
headers: axum::http::HeaderMap,
Json(payload): Json<ModelContextRequest>
) -> impl IntoResponse {
// Generate a unique ID for this request
let request_id = uuid::Uuid::new_v4().to_string();
// Convert messages to a format that can be passed to the agent
let input = serde_json::to_string(&payload.messages).unwrap_or_default();
// Use the web-search agent for now, but this could be customized based on the model parameter
let agent_file = "./packages/genaiscript/genaisrc/web-search.genai.mts";
tracing::debug!(
"Executing model context request - Id: {}",
request_id
);
// Default timeout of 60 seconds
let mut cmd = match run_agent(&request_id, &input, agent_file, 60).await {
Ok(cmd) => cmd,
Err(e) => {
tracing::error!("Model context execution failed: {}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
// Check if streaming is requested either via the stream parameter or Accept header
let accept_header = headers.get("accept").and_then(|h| h.to_str().ok()).unwrap_or("");
let is_streaming = payload.stream.unwrap_or(false) || accept_header.contains("text/event-stream");
// If streaming is requested, return a streaming response
if is_streaming {
let stdout = match cmd.stdout.take() {
Some(stdout) => stdout,
None => {
tracing::error!("No stdout available for the command.");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
let reader = BufReader::new(stdout);
let model = payload.model.clone().unwrap_or_else(|| "default-model".to_string());
let sse_stream = openai_stream_format(reader, request_id.clone(), model);
return Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache, no-transform")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "yes")
.body(Body::from_stream(sse_stream))
.unwrap();
} else {
// For non-streaming responses, we need to collect all output and return it as a single response
// This is a simplified implementation and might need to be adjusted based on actual requirements
let response = ModelContextResponse {
id: format!("chatcmpl-{}", request_id),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
model: payload.model.unwrap_or_else(|| "default-model".to_string()),
choices: vec![Choice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: "This is a placeholder response. The actual implementation would process the agent's output.".to_string(),
},
finish_reason: "stop".to_string(),
}],
};
return Json(response).into_response();
}
}

48
src/handlers/models.rs Normal file
View File

@@ -0,0 +1,48 @@
use axum::{
extract::Json,
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Serialize, Debug)]
pub struct ModelsResponse {
object: String,
data: Vec<Model>,
}
#[derive(Serialize, Debug)]
pub struct Model {
id: String,
object: String,
created: u64,
owned_by: String,
}
pub async fn list_models() -> impl IntoResponse {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
// Create a response with a default model
let response = ModelsResponse {
object: "list".to_string(),
data: vec![
Model {
id: "gpt-3.5-turbo".to_string(),
object: "model".to_string(),
created: current_time,
owned_by: "open-web-agent-rs".to_string(),
},
Model {
id: "gpt-4".to_string(),
object: "model".to_string(),
created: current_time,
owned_by: "open-web-agent-rs".to_string(),
},
],
};
Json(response)
}

View File

@@ -8,6 +8,7 @@ mod setup;
mod handlers;
mod agents;
mod utils;
mod counter;
#[tokio::main]
async fn main() {

View File

@@ -1,17 +1,25 @@
use crate::handlers::agents::create_agent;
use crate::handlers::{not_found::handle_not_found, ui::serve_ui, agents::use_agent};
use axum::routing::post;
use crate::handlers::{not_found::handle_not_found, ui::serve_ui};
use axum::routing::{get, Router};
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
use rmcp::transport::streamable_http_server::{
StreamableHttpService, session::local::LocalSessionManager,
};
use crate::counter::Counter;
pub fn create_router() -> Router {
let service = StreamableHttpService::new(
Counter::new,
LocalSessionManager::default().into(),
Default::default(),
);
Router::new()
.nest_service("/mcp", service)
.route("/", get(serve_ui))
// create an agent
.route("/api/agents", post(create_agent))
// connect the agent
.route("/agents/:agent_id", get(use_agent))
.route("/health", get(health))
.layer(
TraceLayer::new_for_http()