configure basic MCP server
This commit is contained in:

committed by
Geoff Seemueller

parent
dbc8d78fb5
commit
06a233633e
211
src/counter.rs
Normal file
211
src/counter.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
|
||||
service::RequestContext, tool,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct StructRequest {
|
||||
pub a: i32,
|
||||
pub b: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Counter {
|
||||
counter: Arc<Mutex<i32>>,
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl Counter {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
|
||||
RawResource::new(uri, name.to_string()).no_annotation()
|
||||
}
|
||||
|
||||
#[tool(description = "Increment the counter by 1")]
|
||||
async fn increment(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter += 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Decrement the counter by 1")]
|
||||
async fn decrement(&self) -> Result<CallToolResult, McpError> {
|
||||
let mut counter = self.counter.lock().await;
|
||||
*counter -= 1;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Get the current counter value")]
|
||||
async fn get_value(&self) -> Result<CallToolResult, McpError> {
|
||||
let counter = self.counter.lock().await;
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
counter.to_string(),
|
||||
)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Say hello to the client")]
|
||||
fn say_hello(&self) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text("hello")]))
|
||||
}
|
||||
|
||||
#[tool(description = "Repeat what you say")]
|
||||
fn echo(
|
||||
&self,
|
||||
#[tool(param)]
|
||||
#[schemars(description = "Repeat what you say")]
|
||||
saying: String,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(saying)]))
|
||||
}
|
||||
|
||||
#[tool(description = "Calculate the sum of two numbers")]
|
||||
fn sum(
|
||||
&self,
|
||||
#[tool(aggr)] StructRequest { a, b }: StructRequest,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(
|
||||
(a + b).to_string(),
|
||||
)]))
|
||||
}
|
||||
}
|
||||
const_string!(Echo = "echo");
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for Counter {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_prompts()
|
||||
.enable_resources()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourcesResult, McpError> {
|
||||
Ok(ListResourcesResult {
|
||||
resources: vec![
|
||||
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
|
||||
self._create_resource_text("memo://insights", "memo-name"),
|
||||
],
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
match uri.as_str() {
|
||||
"str:////Users/to/some/path/" => {
|
||||
let cwd = "/Users/to/some/path/";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(cwd, uri)],
|
||||
})
|
||||
}
|
||||
"memo://insights" => {
|
||||
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::text(memo, uri)],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({
|
||||
"uri": uri
|
||||
})),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_prompts(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListPromptsResult, McpError> {
|
||||
Ok(ListPromptsResult {
|
||||
next_cursor: None,
|
||||
prompts: vec![Prompt::new(
|
||||
"example_prompt",
|
||||
Some("This is an example prompt that takes one required argument, message"),
|
||||
Some(vec![PromptArgument {
|
||||
name: "message".to_string(),
|
||||
description: Some("A message to put in the prompt".to_string()),
|
||||
required: Some(true),
|
||||
}]),
|
||||
)],
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_prompt(
|
||||
&self,
|
||||
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<GetPromptResult, McpError> {
|
||||
match name.as_str() {
|
||||
"example_prompt" => {
|
||||
let message = arguments
|
||||
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
|
||||
.ok_or_else(|| {
|
||||
McpError::invalid_params("No message provided to example_prompt", None)
|
||||
})?;
|
||||
|
||||
let prompt =
|
||||
format!("This is an example prompt with your message here: '{message}'");
|
||||
Ok(GetPromptResult {
|
||||
description: None,
|
||||
messages: vec![PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: PromptMessageContent::text(prompt),
|
||||
}],
|
||||
})
|
||||
}
|
||||
_ => Err(McpError::invalid_params("prompt not found", None)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_: RequestContext<RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
next_cursor: None,
|
||||
resource_templates: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
||||
let initialize_headers = &http_request_part.headers;
|
||||
let initialize_uri = &http_request_part.uri;
|
||||
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
||||
}
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
185
src/handlers/model_context.rs
Normal file
185
src/handlers/model_context.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
body::Body, extract::Json, http::StatusCode, response::IntoResponse,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::utils::utils::run_agent;
|
||||
|
||||
// Custom function to format streaming responses according to OpenAI API format
|
||||
pub fn openai_stream_format<R>(
|
||||
reader: BufReader<R>,
|
||||
request_id: String,
|
||||
model: String,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let stream = futures::stream::unfold((reader, 0), move |(mut reader, index)| {
|
||||
let request_id = request_id.clone();
|
||||
let model = model.clone();
|
||||
async move {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => None,
|
||||
Ok(_) => {
|
||||
let content = line.trim();
|
||||
// Skip empty lines
|
||||
if content.is_empty() {
|
||||
return Some((Ok(Bytes::from("")), (reader, index)));
|
||||
}
|
||||
|
||||
// Format as OpenAI API streaming response
|
||||
let chunk = serde_json::json!({
|
||||
"id": format!("chatcmpl-{}", request_id),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": index,
|
||||
"delta": {
|
||||
"content": content
|
||||
},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
Some((
|
||||
Ok(Bytes::from(format!("data: {}\n\n", chunk.to_string()))),
|
||||
(reader, index),
|
||||
))
|
||||
}
|
||||
Err(e) => Some((Err(e), (reader, index))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add the [DONE] message at the end
|
||||
let stream_with_done = stream.filter(|result| {
|
||||
futures::future::ready(match result {
|
||||
Ok(bytes) => !bytes.is_empty(),
|
||||
Err(_) => true,
|
||||
})
|
||||
}).chain(futures::stream::once(async {
|
||||
Ok(Bytes::from("data: [DONE]\n\n"))
|
||||
}));
|
||||
|
||||
Box::pin(stream_with_done)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ModelContextRequest {
|
||||
messages: Vec<Message>,
|
||||
model: Option<String>,
|
||||
stream: Option<bool>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelContextResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
message: Message,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
pub async fn model_context(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ModelContextRequest>
|
||||
) -> impl IntoResponse {
|
||||
// Generate a unique ID for this request
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Convert messages to a format that can be passed to the agent
|
||||
let input = serde_json::to_string(&payload.messages).unwrap_or_default();
|
||||
|
||||
// Use the web-search agent for now, but this could be customized based on the model parameter
|
||||
let agent_file = "./packages/genaiscript/genaisrc/web-search.genai.mts";
|
||||
|
||||
tracing::debug!(
|
||||
"Executing model context request - Id: {}",
|
||||
request_id
|
||||
);
|
||||
|
||||
// Default timeout of 60 seconds
|
||||
let mut cmd = match run_agent(&request_id, &input, agent_file, 60).await {
|
||||
Ok(cmd) => cmd,
|
||||
Err(e) => {
|
||||
tracing::error!("Model context execution failed: {}", e);
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Check if streaming is requested either via the stream parameter or Accept header
|
||||
let accept_header = headers.get("accept").and_then(|h| h.to_str().ok()).unwrap_or("");
|
||||
let is_streaming = payload.stream.unwrap_or(false) || accept_header.contains("text/event-stream");
|
||||
|
||||
// If streaming is requested, return a streaming response
|
||||
if is_streaming {
|
||||
let stdout = match cmd.stdout.take() {
|
||||
Some(stdout) => stdout,
|
||||
None => {
|
||||
tracing::error!("No stdout available for the command.");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let reader = BufReader::new(stdout);
|
||||
let model = payload.model.clone().unwrap_or_else(|| "default-model".to_string());
|
||||
let sse_stream = openai_stream_format(reader, request_id.clone(), model);
|
||||
|
||||
return Response::builder()
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache, no-transform")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("X-Accel-Buffering", "yes")
|
||||
.body(Body::from_stream(sse_stream))
|
||||
.unwrap();
|
||||
} else {
|
||||
// For non-streaming responses, we need to collect all output and return it as a single response
|
||||
// This is a simplified implementation and might need to be adjusted based on actual requirements
|
||||
let response = ModelContextResponse {
|
||||
id: format!("chatcmpl-{}", request_id),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: payload.model.unwrap_or_else(|| "default-model".to_string()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "This is a placeholder response. The actual implementation would process the agent's output.".to_string(),
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
return Json(response).into_response();
|
||||
}
|
||||
}
|
48
src/handlers/models.rs
Normal file
48
src/handlers/models.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{
|
||||
extract::Json,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ModelsResponse {
|
||||
object: String,
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Model {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
owned_by: String,
|
||||
}
|
||||
|
||||
pub async fn list_models() -> impl IntoResponse {
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
// Create a response with a default model
|
||||
let response = ModelsResponse {
|
||||
object: "list".to_string(),
|
||||
data: vec![
|
||||
Model {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gpt-4".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: current_time,
|
||||
owned_by: "open-web-agent-rs".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
Json(response)
|
||||
}
|
@@ -8,6 +8,7 @@ mod setup;
|
||||
mod handlers;
|
||||
mod agents;
|
||||
mod utils;
|
||||
mod counter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@@ -1,17 +1,25 @@
|
||||
use crate::handlers::agents::create_agent;
|
||||
use crate::handlers::{not_found::handle_not_found, ui::serve_ui, agents::use_agent};
|
||||
use axum::routing::post;
|
||||
use crate::handlers::{not_found::handle_not_found, ui::serve_ui};
|
||||
use axum::routing::{get, Router};
|
||||
use tower_http::trace::{self, TraceLayer};
|
||||
use tracing::Level;
|
||||
|
||||
use rmcp::transport::streamable_http_server::{
|
||||
StreamableHttpService, session::local::LocalSessionManager,
|
||||
};
|
||||
use crate::counter::Counter;
|
||||
|
||||
pub fn create_router() -> Router {
|
||||
|
||||
let service = StreamableHttpService::new(
|
||||
Counter::new,
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
|
||||
Router::new()
|
||||
.nest_service("/mcp", service)
|
||||
.route("/", get(serve_ui))
|
||||
// create an agent
|
||||
.route("/api/agents", post(create_agent))
|
||||
// connect the agent
|
||||
.route("/agents/:agent_id", get(use_agent))
|
||||
.route("/health", get(health))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
|
Reference in New Issue
Block a user