diff options
Diffstat (limited to 'packages/core/src/tools/mcp-tool.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-tool.ts | 115 |
1 files changed, 77 insertions, 38 deletions
diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 59f83db3..01a8d75c 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -5,14 +5,16 @@ */ import { - BaseTool, - ToolResult, + BaseDeclarativeTool, + BaseToolInvocation, + Kind, ToolCallConfirmationDetails, ToolConfirmationOutcome, + ToolInvocation, ToolMcpConfirmationDetails, - Kind, + ToolResult, } from './tools.js'; -import { CallableTool, Part, FunctionCall } from '@google/genai'; +import { CallableTool, FunctionCall, Part } from '@google/genai'; type ToolParams = Record<string, unknown>; @@ -50,45 +52,25 @@ type McpContentBlock = | McpResourceBlock | McpResourceLinkBlock; -export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { +class DiscoveredMCPToolInvocation extends BaseToolInvocation< + ToolParams, + ToolResult +> { private static readonly allowlist: Set<string> = new Set(); constructor( private readonly mcpTool: CallableTool, readonly serverName: string, readonly serverToolName: string, - description: string, - readonly parameterSchema: unknown, + readonly displayName: string, readonly timeout?: number, readonly trust?: boolean, - nameOverride?: string, + params: ToolParams = {}, ) { - super( - nameOverride ?? generateValidName(serverToolName), - `${serverToolName} (${serverName} MCP Server)`, - description, - Kind.Other, - parameterSchema, - true, // isOutputMarkdown - false, // canUpdateOutput - ); - } - - asFullyQualifiedTool(): DiscoveredMCPTool { - return new DiscoveredMCPTool( - this.mcpTool, - this.serverName, - this.serverToolName, - this.description, - this.parameterSchema, - this.timeout, - this.trust, - `${this.serverName}__${this.serverToolName}`, - ); + super(params); } async shouldConfirmExecute( - _params: ToolParams, _abortSignal: AbortSignal, ): Promise<ToolCallConfirmationDetails | false> { const serverAllowListKey = this.serverName; @@ -99,8 +81,8 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { } if ( - DiscoveredMCPTool.allowlist.has(serverAllowListKey) || - DiscoveredMCPTool.allowlist.has(toolAllowListKey) + DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) || + DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey) ) { return false; // server and/or tool already allowlisted } @@ -110,23 +92,23 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { title: 'Confirm MCP Tool Execution', serverName: this.serverName, toolName: this.serverToolName, // Display original tool name in confirmation - toolDisplayName: this.name, // Display global registry name exposed to model and user + toolDisplayName: this.displayName, // Display global registry name exposed to model and user onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { - DiscoveredMCPTool.allowlist.add(serverAllowListKey); + DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey); } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { - DiscoveredMCPTool.allowlist.add(toolAllowListKey); + DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey); } }, }; return confirmationDetails; } - async execute(params: ToolParams): Promise<ToolResult> { + async execute(): Promise<ToolResult> { const functionCalls: FunctionCall[] = [ { name: this.serverToolName, - args: params, + args: this.params, }, ]; @@ -138,6 +120,63 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { returnDisplay: getStringifiedResultForDisplay(rawResponseParts), }; } + + getDescription(): string { + return this.displayName; + } +} + +export class DiscoveredMCPTool extends BaseDeclarativeTool< + ToolParams, + ToolResult +> { + constructor( + private readonly mcpTool: CallableTool, + readonly serverName: string, + readonly serverToolName: string, + description: string, + readonly parameterSchema: unknown, + readonly timeout?: number, + readonly trust?: boolean, + nameOverride?: string, + ) { + super( + nameOverride ?? generateValidName(serverToolName), + `${serverToolName} (${serverName} MCP Server)`, + description, + Kind.Other, + parameterSchema, + true, // isOutputMarkdown + false, // canUpdateOutput + ); + } + + asFullyQualifiedTool(): DiscoveredMCPTool { + return new DiscoveredMCPTool( + this.mcpTool, + this.serverName, + this.serverToolName, + this.description, + this.parameterSchema, + this.timeout, + this.trust, + `${this.serverName}__${this.serverToolName}`, + ); + } + + protected createInvocation( + params: ToolParams, + ): ToolInvocation<ToolParams, ToolResult> { + return new DiscoveredMCPToolInvocation( + this.mcpTool, + this.serverName, + this.serverToolName, + this.displayName, + this.timeout, + this.trust, + params, + ); + } } function transformTextBlock(block: McpTextBlock): Part { |
