diff options
Diffstat (limited to 'packages/core/src/tools/mcp-tool.test.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-tool.test.ts | 317 |
1 files changed, 238 insertions, 79 deletions
diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index 5c784c5d..86968b3d 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -14,37 +14,37 @@ import { afterEach, Mocked, } from 'vitest'; -import { - DiscoveredMCPTool, - MCP_TOOL_DEFAULT_TIMEOUT_MSEC, -} from './mcp-tool.js'; -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { ToolResult } from './tools.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; // Added getStringifiedResultForDisplay +import { ToolResult, ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome +import { CallableTool, Part } from '@google/genai'; -// Mock MCP SDK Client -vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { - const MockClient = vi.fn(); - MockClient.prototype.callTool = vi.fn(); - return { Client: MockClient }; -}); +// Mock @google/genai mcpToTool and CallableTool +// We only need to mock the parts of CallableTool that DiscoveredMCPTool uses. +const mockCallTool = vi.fn(); +const mockToolMethod = vi.fn(); + +const mockCallableToolInstance: Mocked<CallableTool> = { + tool: mockToolMethod as any, // Not directly used by DiscoveredMCPTool instance methods + callTool: mockCallTool as any, + // Add other methods if DiscoveredMCPTool starts using them +}; describe('DiscoveredMCPTool', () => { - let mockMcpClient: Mocked<Client>; - const toolName = 'test-mcp-tool'; + const serverName = 'mock-mcp-server'; + const toolNameForModel = 'test-mcp-tool-for-model'; const serverToolName = 'actual-server-tool-name'; const baseDescription = 'A test MCP tool.'; - const inputSchema = { + const inputSchema: Record<string, unknown> = { type: 'object' as const, properties: { param: { type: 'string' } }, + required: ['param'], }; beforeEach(() => { - // Create a new mock client for each test to reset call history - mockMcpClient = new (Client as any)({ - name: 'test-client', - version: '0.0.1', - }) as Mocked<Client>; - vi.mocked(mockMcpClient.callTool).mockClear(); + mockCallTool.mockClear(); + mockToolMethod.mockClear(); + // Clear allowlist before each relevant test, especially for shouldConfirmExecute + (DiscoveredMCPTool as any).allowlist.clear(); }); afterEach(() => { @@ -52,35 +52,45 @@ describe('DiscoveredMCPTool', () => { }); describe('constructor', () => { - it('should set properties correctly and augment description', () => { + it('should set properties correctly and augment description (non-generic server)', () => { const tool = new DiscoveredMCPTool( - mockMcpClient, - 'mock-mcp-server', - toolName, + mockCallableToolInstance, + serverName, // serverName is 'mock-mcp-server', not 'mcp' + toolNameForModel, baseDescription, inputSchema, serverToolName, ); - expect(tool.name).toBe(toolName); - expect(tool.schema.name).toBe(toolName); - expect(tool.schema.description).toContain(baseDescription); - expect(tool.schema.description).toContain('This MCP tool was discovered'); - // Corrected assertion for backticks and template literal - expect(tool.schema.description).toContain( - `tools/call\` method for tool name \`${toolName}\``, - ); + expect(tool.name).toBe(toolNameForModel); + expect(tool.schema.name).toBe(toolNameForModel); + const expectedDescription = `${baseDescription}\n\nThis MCP tool named '${serverToolName}' was discovered from an MCP server.`; + expect(tool.schema.description).toBe(expectedDescription); expect(tool.schema.parameters).toEqual(inputSchema); expect(tool.serverToolName).toBe(serverToolName); expect(tool.timeout).toBeUndefined(); }); + it('should set properties correctly and augment description (generic "mcp" server)', () => { + const genericServerName = 'mcp'; + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + genericServerName, // serverName is 'mcp' + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + const expectedDescription = `${baseDescription}\n\nThis MCP tool named '${serverToolName}' was discovered from '${genericServerName}' MCP server.`; + expect(tool.schema.description).toBe(expectedDescription); + }); + it('should accept and store a custom timeout', () => { const customTimeout = 5000; const tool = new DiscoveredMCPTool( - mockMcpClient, - 'mock-mcp-server', - toolName, + mockCallableToolInstance, + serverName, + toolNameForModel, baseDescription, inputSchema, serverToolName, @@ -91,77 +101,226 @@ describe('DiscoveredMCPTool', () => { }); describe('execute', () => { - it('should call mcpClient.callTool with correct parameters and default timeout', async () => { + it('should call mcpTool.callTool with correct parameters and format display output', async () => { const tool = new DiscoveredMCPTool( - mockMcpClient, - 'mock-mcp-server', - toolName, + mockCallableToolInstance, + serverName, + toolNameForModel, baseDescription, inputSchema, serverToolName, ); const params = { param: 'testValue' }; - const expectedMcpResult = { success: true, details: 'executed' }; - vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult); - - const result: ToolResult = await tool.execute(params); - - expect(mockMcpClient.callTool).toHaveBeenCalledWith( + const mockToolSuccessResultObject = { + success: true, + details: 'executed', + }; + const mockFunctionResponseContent: Part[] = [ + { text: JSON.stringify(mockToolSuccessResultObject) }, + ]; + const mockMcpToolResponseParts: Part[] = [ { - name: serverToolName, - arguments: params, - }, - undefined, - { - timeout: MCP_TOOL_DEFAULT_TIMEOUT_MSEC, + functionResponse: { + name: serverToolName, + response: { content: mockFunctionResponseContent }, + }, }, + ]; + mockCallTool.mockResolvedValue(mockMcpToolResponseParts); + + const toolResult: ToolResult = await tool.execute(params); + + expect(mockCallTool).toHaveBeenCalledWith([ + { name: serverToolName, args: params }, + ]); + expect(toolResult.llmContent).toEqual(mockMcpToolResponseParts); + + const stringifiedResponseContent = JSON.stringify( + mockToolSuccessResultObject, ); - const expectedOutput = - '```json\n' + JSON.stringify(expectedMcpResult, null, 2) + '\n```'; - expect(result.llmContent).toBe(expectedOutput); - expect(result.returnDisplay).toBe(expectedOutput); + // getStringifiedResultForDisplay joins text parts, then wraps the array of processed parts in JSON + const expectedDisplayOutput = + '```json\n' + + JSON.stringify([stringifiedResponseContent], null, 2) + + '\n```'; + expect(toolResult.returnDisplay).toBe(expectedDisplayOutput); }); - it('should call mcpClient.callTool with custom timeout if provided', async () => { - const customTimeout = 15000; + it('should handle empty result from getStringifiedResultForDisplay', async () => { const tool = new DiscoveredMCPTool( - mockMcpClient, - 'mock-mcp-server', - toolName, + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + const params = { param: 'testValue' }; + const mockMcpToolResponsePartsEmpty: Part[] = []; + mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty); + const toolResult: ToolResult = await tool.execute(params); + expect(toolResult.returnDisplay).toBe('```json\n[]\n```'); + }); + + it('should propagate rejection if mcpTool.callTool rejects', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, baseDescription, inputSchema, serverToolName, - customTimeout, ); - const params = { param: 'anotherValue' }; - const expectedMcpResult = { result: 'done' }; - vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult); + const params = { param: 'failCase' }; + const expectedError = new Error('MCP call failed'); + mockCallTool.mockRejectedValue(expectedError); - await tool.execute(params); + await expect(tool.execute(params)).rejects.toThrow(expectedError); + }); + }); - expect(mockMcpClient.callTool).toHaveBeenCalledWith( - expect.anything(), + describe('shouldConfirmExecute', () => { + // beforeEach is already clearing allowlist + + it('should return false if trust is true', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, undefined, - { - timeout: customTimeout, - }, + true, ); + expect( + await tool.shouldConfirmExecute({}, new AbortController().signal), + ).toBe(false); }); - it('should propagate rejection if mcpClient.callTool rejects', async () => { + it('should return false if server is allowlisted', async () => { + (DiscoveredMCPTool as any).allowlist.add(serverName); const tool = new DiscoveredMCPTool( - mockMcpClient, - 'mock-mcp-server', - toolName, + mockCallableToolInstance, + serverName, + toolNameForModel, baseDescription, inputSchema, serverToolName, ); - const params = { param: 'failCase' }; - const expectedError = new Error('MCP call failed'); - vi.mocked(mockMcpClient.callTool).mockRejectedValue(expectedError); + expect( + await tool.shouldConfirmExecute({}, new AbortController().signal), + ).toBe(false); + }); - await expect(tool.execute(params)).rejects.toThrow(expectedError); + it('should return false if tool is allowlisted', async () => { + const toolAllowlistKey = `${serverName}.${serverToolName}`; + (DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey); + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + expect( + await tool.shouldConfirmExecute({}, new AbortController().signal), + ).toBe(false); + }); + + it('should return confirmation details if not trusted and not allowlisted', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + const confirmation = await tool.shouldConfirmExecute( + {}, + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + if (confirmation && confirmation.type === 'mcp') { + // Type guard for ToolMcpConfirmationDetails + expect(confirmation.type).toBe('mcp'); + expect(confirmation.serverName).toBe(serverName); + expect(confirmation.toolName).toBe(serverToolName); + } else if (confirmation) { + // Handle other possible confirmation types if necessary, or strengthen test if only MCP is expected + throw new Error( + 'Confirmation was not of expected type MCP or was false', + ); + } else { + throw new Error( + 'Confirmation details not in expected format or was false', + ); + } + }); + + it('should add server to allowlist on ProceedAlwaysServer', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + const confirmation = await tool.shouldConfirmExecute( + {}, + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + if ( + confirmation && + typeof confirmation === 'object' && + 'onConfirm' in confirmation && + typeof confirmation.onConfirm === 'function' + ) { + await confirmation.onConfirm( + ToolConfirmationOutcome.ProceedAlwaysServer, + ); + expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true); + } else { + throw new Error( + 'Confirmation details or onConfirm not in expected format', + ); + } + }); + + it('should add tool to allowlist on ProceedAlwaysTool', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + toolNameForModel, + baseDescription, + inputSchema, + serverToolName, + ); + const toolAllowlistKey = `${serverName}.${serverToolName}`; + const confirmation = await tool.shouldConfirmExecute( + {}, + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + if ( + confirmation && + typeof confirmation === 'object' && + 'onConfirm' in confirmation && + typeof confirmation.onConfirm === 'function' + ) { + await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool); + expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe( + true, + ); + } else { + throw new Error( + 'Confirmation details or onConfirm not in expected format', + ); + } }); }); }); |
