diff options
| author | Olcan <[email protected]> | 2025-05-30 15:32:21 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-05-30 15:32:21 -0700 |
| commit | 2e57989aec569055a11f21762f72b961377281ab (patch) | |
| tree | 0fb4a97f195801472b07ed1ba679c21c1151e867 /packages/server/src | |
| parent | a60e51f44d84b1a13c335e6b75339b8ad8c544bb (diff) | |
confirm mcp tool executions from untrusted servers (per "trust" setting) (#631)
Diffstat (limited to 'packages/server/src')
| -rw-r--r-- | packages/server/src/config/config.ts | 1 | ||||
| -rw-r--r-- | packages/server/src/tools/mcp-client.ts | 2 | ||||
| -rw-r--r-- | packages/server/src/tools/mcp-tool.test.ts | 5 | ||||
| -rw-r--r-- | packages/server/src/tools/mcp-tool.ts | 47 | ||||
| -rw-r--r-- | packages/server/src/tools/tool-registry.test.ts | 2 | ||||
| -rw-r--r-- | packages/server/src/tools/tools.ts | 20 |
6 files changed, 72 insertions, 5 deletions
diff --git a/packages/server/src/config/config.ts b/packages/server/src/config/config.ts index 9c03a5c1..0cd7a4fa 100644 --- a/packages/server/src/config/config.ts +++ b/packages/server/src/config/config.ts @@ -33,6 +33,7 @@ export class MCPServerConfig { readonly url?: string, // Common readonly timeout?: number, + readonly trust?: boolean, ) {} } diff --git a/packages/server/src/tools/mcp-client.ts b/packages/server/src/tools/mcp-client.ts index 3b55f5e3..97a73289 100644 --- a/packages/server/src/tools/mcp-client.ts +++ b/packages/server/src/tools/mcp-client.ts @@ -134,11 +134,13 @@ async function connectAndDiscover( toolRegistry.registerTool( new DiscoveredMCPTool( mcpClient, + mcpServerName, toolNameForModel, tool.description ?? '', tool.inputSchema, tool.name, mcpServerConfig.timeout, + mcpServerConfig.trust, ), ); } diff --git a/packages/server/src/tools/mcp-tool.test.ts b/packages/server/src/tools/mcp-tool.test.ts index e28cf586..331696f7 100644 --- a/packages/server/src/tools/mcp-tool.test.ts +++ b/packages/server/src/tools/mcp-tool.test.ts @@ -55,6 +55,7 @@ describe('DiscoveredMCPTool', () => { it('should set properties correctly and augment description', () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -78,6 +79,7 @@ describe('DiscoveredMCPTool', () => { const customTimeout = 5000; const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -92,6 +94,7 @@ describe('DiscoveredMCPTool', () => { it('should call mcpClient.callTool with correct parameters and default timeout', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -122,6 +125,7 @@ describe('DiscoveredMCPTool', () => { const customTimeout = 15000; const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -146,6 +150,7 @@ describe('DiscoveredMCPTool', () => { it('should propagate rejection if mcpClient.callTool rejects', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, diff --git a/packages/server/src/tools/mcp-tool.ts b/packages/server/src/tools/mcp-tool.ts index 2a561179..80e6bbde 100644 --- a/packages/server/src/tools/mcp-tool.ts +++ b/packages/server/src/tools/mcp-tool.ts @@ -5,20 +5,30 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { BaseTool, ToolResult } from './tools.js'; +import { + BaseTool, + ToolResult, + ToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolMcpConfirmationDetails, +} from './tools.js'; type ToolParams = Record<string, unknown>; export const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { + private static readonly whitelist: Set<string> = new Set(); + constructor( private readonly mcpClient: Client, + private readonly serverName: string, // Added for server identification readonly name: string, readonly description: string, readonly parameterSchema: Record<string, unknown>, readonly serverToolName: string, readonly timeout?: number, + readonly trust?: boolean, ) { description += ` @@ -37,6 +47,41 @@ Returns the MCP server response as a json string. ); } + async shouldConfirmExecute( + _params: ToolParams, + _abortSignal: AbortSignal, + ): Promise<ToolCallConfirmationDetails | false> { + const serverWhitelistKey = this.serverName; + const toolWhitelistKey = `${this.serverName}.${this.serverToolName}`; + + if (this.trust) { + return false; // server is trusted, no confirmation needed + } + + if ( + DiscoveredMCPTool.whitelist.has(serverWhitelistKey) || + DiscoveredMCPTool.whitelist.has(toolWhitelistKey) + ) { + return false; // server and/or tool already whitelisted + } + + const confirmationDetails: ToolMcpConfirmationDetails = { + type: 'mcp', + title: 'Confirm MCP Tool Execution', + serverName: this.serverName, + toolName: this.serverToolName, + toolDisplayName: this.name, + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { + DiscoveredMCPTool.whitelist.add(serverWhitelistKey); + } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { + DiscoveredMCPTool.whitelist.add(toolWhitelistKey); + } + }, + }; + return confirmationDetails; + } + async execute(params: ToolParams): Promise<ToolResult> { const result = await this.mcpClient.callTool( { diff --git a/packages/server/src/tools/tool-registry.test.ts b/packages/server/src/tools/tool-registry.test.ts index 6a960a27..c93109ae 100644 --- a/packages/server/src/tools/tool-registry.test.ts +++ b/packages/server/src/tools/tool-registry.test.ts @@ -729,6 +729,7 @@ describe('DiscoveredMCPTool', () => { it('constructor should set up properties correctly and enhance description', () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, toolDescription, toolInputSchema, @@ -744,6 +745,7 @@ describe('DiscoveredMCPTool', () => { it('execute should call mcpClient.callTool with correct params and return serialized result', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, toolDescription, toolInputSchema, diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index e5d0c7cf..a2e7fa06 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -212,12 +212,24 @@ export interface ToolExecuteConfirmationDetails { rootCommand: string; } +export interface ToolMcpConfirmationDetails { + type: 'mcp'; + title: string; + serverName: string; + toolName: string; + toolDisplayName: string; + onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void> | void; +} + export type ToolCallConfirmationDetails = | ToolEditConfirmationDetails - | ToolExecuteConfirmationDetails; + | ToolExecuteConfirmationDetails + | ToolMcpConfirmationDetails; export enum ToolConfirmationOutcome { - ProceedOnce, - ProceedAlways, - Cancel, + ProceedOnce = 'proceed_once', + ProceedAlways = 'proceed_always', + ProceedAlwaysServer = 'proceed_always_server', + ProceedAlwaysTool = 'proceed_always_tool', + Cancel = 'cancel', } |
