211 lines
7.0 KiB
Rust
211 lines
7.0 KiB
Rust
use std::sync::Arc;
|
|
|
|
use rmcp::{
|
|
Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars,
|
|
service::RequestContext, tool,
|
|
};
|
|
use serde_json::json;
|
|
use tokio::sync::Mutex;
|
|
|
|
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
|
pub struct StructRequest {
|
|
pub a: i32,
|
|
pub b: i32,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Counter {
|
|
counter: Arc<Mutex<i32>>,
|
|
}
|
|
|
|
#[tool(tool_box)]
|
|
impl Counter {
|
|
#[allow(dead_code)]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
counter: Arc::new(Mutex::new(0)),
|
|
}
|
|
}
|
|
|
|
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
|
|
RawResource::new(uri, name.to_string()).no_annotation()
|
|
}
|
|
|
|
#[tool(description = "Increment the counter by 1")]
|
|
async fn increment(&self) -> Result<CallToolResult, McpError> {
|
|
let mut counter = self.counter.lock().await;
|
|
*counter += 1;
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Decrement the counter by 1")]
|
|
async fn decrement(&self) -> Result<CallToolResult, McpError> {
|
|
let mut counter = self.counter.lock().await;
|
|
*counter -= 1;
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Get the current counter value")]
|
|
async fn get_value(&self) -> Result<CallToolResult, McpError> {
|
|
let counter = self.counter.lock().await;
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Say hello to the client")]
|
|
fn say_hello(&self) -> Result<CallToolResult, McpError> {
|
|
Ok(CallToolResult::success(vec![Content::text("hello")]))
|
|
}
|
|
|
|
#[tool(description = "Repeat what you say")]
|
|
fn echo(
|
|
&self,
|
|
#[tool(param)]
|
|
#[schemars(description = "Repeat what you say")]
|
|
saying: String,
|
|
) -> Result<CallToolResult, McpError> {
|
|
Ok(CallToolResult::success(vec![Content::text(saying)]))
|
|
}
|
|
|
|
#[tool(description = "Calculate the sum of two numbers")]
|
|
fn sum(
|
|
&self,
|
|
#[tool(aggr)] StructRequest { a, b }: StructRequest,
|
|
) -> Result<CallToolResult, McpError> {
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
(a + b).to_string(),
|
|
)]))
|
|
}
|
|
}
|
|
const_string!(Echo = "echo");
|
|
#[tool(tool_box)]
|
|
impl ServerHandler for Counter {
|
|
fn get_info(&self) -> ServerInfo {
|
|
ServerInfo {
|
|
protocol_version: ProtocolVersion::V_2024_11_05,
|
|
capabilities: ServerCapabilities::builder()
|
|
.enable_prompts()
|
|
.enable_resources()
|
|
.enable_tools()
|
|
.build(),
|
|
server_info: Implementation::from_build_env(),
|
|
instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()),
|
|
}
|
|
}
|
|
|
|
async fn list_resources(
|
|
&self,
|
|
_request: Option<PaginatedRequestParam>,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ListResourcesResult, McpError> {
|
|
Ok(ListResourcesResult {
|
|
resources: vec![
|
|
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
|
|
self._create_resource_text("memo://insights", "memo-name"),
|
|
],
|
|
next_cursor: None,
|
|
})
|
|
}
|
|
|
|
async fn read_resource(
|
|
&self,
|
|
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ReadResourceResult, McpError> {
|
|
match uri.as_str() {
|
|
"str:////Users/to/some/path/" => {
|
|
let cwd = "/Users/to/some/path/";
|
|
Ok(ReadResourceResult {
|
|
contents: vec![ResourceContents::text(cwd, uri)],
|
|
})
|
|
}
|
|
"memo://insights" => {
|
|
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
|
|
Ok(ReadResourceResult {
|
|
contents: vec![ResourceContents::text(memo, uri)],
|
|
})
|
|
}
|
|
_ => Err(McpError::resource_not_found(
|
|
"resource_not_found",
|
|
Some(json!({
|
|
"uri": uri
|
|
})),
|
|
)),
|
|
}
|
|
}
|
|
|
|
async fn list_prompts(
|
|
&self,
|
|
_request: Option<PaginatedRequestParam>,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ListPromptsResult, McpError> {
|
|
Ok(ListPromptsResult {
|
|
next_cursor: None,
|
|
prompts: vec![Prompt::new(
|
|
"example_prompt",
|
|
Some("This is an example prompt that takes one required argument, message"),
|
|
Some(vec![PromptArgument {
|
|
name: "message".to_string(),
|
|
description: Some("A message to put in the prompt".to_string()),
|
|
required: Some(true),
|
|
}]),
|
|
)],
|
|
})
|
|
}
|
|
|
|
async fn get_prompt(
|
|
&self,
|
|
GetPromptRequestParam { name, arguments }: GetPromptRequestParam,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<GetPromptResult, McpError> {
|
|
match name.as_str() {
|
|
"example_prompt" => {
|
|
let message = arguments
|
|
.and_then(|json| json.get("message")?.as_str().map(|s| s.to_string()))
|
|
.ok_or_else(|| {
|
|
McpError::invalid_params("No message provided to example_prompt", None)
|
|
})?;
|
|
|
|
let prompt =
|
|
format!("This is an example prompt with your message here: '{message}'");
|
|
Ok(GetPromptResult {
|
|
description: None,
|
|
messages: vec![PromptMessage {
|
|
role: PromptMessageRole::User,
|
|
content: PromptMessageContent::text(prompt),
|
|
}],
|
|
})
|
|
}
|
|
_ => Err(McpError::invalid_params("prompt not found", None)),
|
|
}
|
|
}
|
|
|
|
async fn list_resource_templates(
|
|
&self,
|
|
_request: Option<PaginatedRequestParam>,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ListResourceTemplatesResult, McpError> {
|
|
Ok(ListResourceTemplatesResult {
|
|
next_cursor: None,
|
|
resource_templates: Vec::new(),
|
|
})
|
|
}
|
|
|
|
async fn initialize(
|
|
&self,
|
|
_request: InitializeRequestParam,
|
|
context: RequestContext<RoleServer>,
|
|
) -> Result<InitializeResult, McpError> {
|
|
if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
|
|
let initialize_headers = &http_request_part.headers;
|
|
let initialize_uri = &http_request_part.uri;
|
|
tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
|
|
}
|
|
Ok(self.get_info())
|
|
}
|
|
} |