diff options
| author | N. Taylor Mullen <[email protected]> | 2025-06-02 13:39:25 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-02 20:39:25 +0000 |
| commit | 58597c29d30eb0d95e1792f02eb7f1e7edc4218a (patch) | |
| tree | 2dfb528ab008e454422fc27c941aa7aa925ec5d7 /packages/core/src/tools/mcp-client.test.ts | |
| parent | 0795e55f0e7d2f2822bcd83eaf066eb99c67f858 (diff) | |
refactor: Update MCP tool discovery to use @google/genai - Also fixes JSON schema issues. (#682)
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.test.ts | 193 |
1 files changed, 158 insertions, 35 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 4664669d..121cd1d8 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -16,7 +16,6 @@ import { } from 'vitest'; import { discoverMcpTools } from './mcp-client.js'; import { Config, MCPServerConfig } from '../config/config.js'; -import { ToolRegistry } from './tool-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -51,33 +50,56 @@ vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { // Always return a new object with a fresh reference to the global mock for .on this.options = options; this.stderr = { on: mockGlobalStdioStderrOn }; + this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method return this; }); return { StdioClientTransport: MockedStdioTransport }; }); vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { - const MockedSSETransport = vi.fn(); + const MockedSSETransport = vi.fn().mockImplementation(function (this: any) { + this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method + return this; + }); return { SSEClientTransport: MockedSSETransport }; }); -vi.mock('./tool-registry.js'); +const mockToolRegistryInstance = { + registerTool: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array + // Add other methods if they are called by the code under test, with default mocks + getTool: vi.fn(), + getAllTools: vi.fn().mockReturnValue([]), + getFunctionDeclarations: vi.fn().mockReturnValue([]), + discoverTools: vi.fn().mockResolvedValue(undefined), +}; +vi.mock('./tool-registry.js', () => ({ + ToolRegistry: vi.fn(() => mockToolRegistryInstance), +})); describe('discoverMcpTools', () => { let mockConfig: Mocked<Config>; - let mockToolRegistry: Mocked<ToolRegistry>; + // Use the instance from the module mock + let mockToolRegistry: typeof mockToolRegistryInstance; beforeEach(() => { + // Assign the shared mock instance to the test-scoped variable + mockToolRegistry = mockToolRegistryInstance; + // Reset individual spies on the shared instance before each test + mockToolRegistry.registerTool.mockClear(); + mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default + mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool + mockToolRegistry.getAllTools.mockClear().mockReturnValue([]); + mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]); + mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined); + mockConfig = { getMcpServers: vi.fn().mockReturnValue({}), getMcpServerCommand: vi.fn().mockReturnValue(undefined), + // getToolRegistry should now return the same shared mock instance + getToolRegistry: vi.fn(() => mockToolRegistry), } as any; - mockToolRegistry = new (ToolRegistry as any)( - mockConfig, - ) as Mocked<ToolRegistry>; - mockToolRegistry.registerTool = vi.fn(); - vi.mocked(parse).mockClear(); vi.mocked(Client).mockClear(); vi.mocked(Client.prototype.connect) @@ -88,9 +110,24 @@ describe('discoverMcpTools', () => { .mockResolvedValue({ tools: [] }); vi.mocked(StdioClientTransport).mockClear(); + // Ensure the StdioClientTransport mock constructor returns an object with a close method + vi.mocked(StdioClientTransport).mockImplementation(function ( + this: any, + options: any, + ) { + this.options = options; + this.stderr = { on: mockGlobalStdioStderrOn }; + this.close = vi.fn().mockResolvedValue(undefined); + return this; + }); mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach vi.mocked(SSEClientTransport).mockClear(); + // Ensure the SSEClientTransport mock constructor returns an object with a close method + vi.mocked(SSEClientTransport).mockImplementation(function (this: any) { + this.close = vi.fn().mockResolvedValue(undefined); + return this; + }); }); afterEach(() => { @@ -98,7 +135,7 @@ describe('discoverMcpTools', () => { }); it('should do nothing if no MCP servers or command are configured', async () => { - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1); expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1); expect(Client).not.toHaveBeenCalled(); @@ -120,7 +157,11 @@ describe('discoverMcpTools', () => { tools: [mockTool], }); - await discoverMcpTools(mockConfig, mockToolRegistry); + // PRE-MOCK getToolsByServer for the expected server name + // In this case, listTools fails, so no tools are registered. + // The default mock `mockReturnValue([])` from beforeEach should apply. + + await discoverMcpTools(mockConfig); expect(parse).toHaveBeenCalledWith(commandString, process.env); expect(StdioClientTransport).toHaveBeenCalledWith({ @@ -158,7 +199,12 @@ describe('discoverMcpTools', () => { tools: [mockTool], }); - await discoverMcpTools(mockConfig, mockToolRegistry); + // PRE-MOCK getToolsByServer for the expected server name + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); + + await discoverMcpTools(mockConfig); expect(StdioClientTransport).toHaveBeenCalledWith({ command: serverConfig.command, @@ -188,7 +234,12 @@ describe('discoverMcpTools', () => { tools: [mockTool], }); - await discoverMcpTools(mockConfig, mockToolRegistry); + // PRE-MOCK getToolsByServer for the expected server name + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); + + await discoverMcpTools(mockConfig); expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!)); expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( @@ -208,32 +259,96 @@ describe('discoverMcpTools', () => { }); const mockTool1 = { - name: 'toolA', + name: 'toolA', // Same original name description: 'd1', inputSchema: { type: 'object' as const, properties: {} }, }; const mockTool2 = { - name: 'toolB', + name: 'toolA', // Same original name description: 'd2', inputSchema: { type: 'object' as const, properties: {} }, }; + const mockToolB = { + name: 'toolB', + description: 'dB', + inputSchema: { type: 'object' as const, properties: {} }, + }; vi.mocked(Client.prototype.listTools) - .mockResolvedValueOnce({ tools: [mockTool1] }) - .mockResolvedValueOnce({ tools: [mockTool2] }); + .mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1 + .mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA) - await discoverMcpTools(mockConfig, mockToolRegistry); + const effectivelyRegisteredTools = new Map<string, any>(); - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); - const registeredTool1 = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - const registeredTool2 = mockToolRegistry.registerTool.mock - .calls[1][0] as DiscoveredMCPTool; + mockToolRegistry.getTool.mockImplementation((toolName: string) => + effectivelyRegisteredTools.get(toolName), + ); - expect(registeredTool1.name).toBe('server1__toolA'); - expect(registeredTool1.serverToolName).toBe('toolA'); - expect(registeredTool2.name).toBe('server2__toolB'); - expect(registeredTool2.serverToolName).toBe('toolB'); + // Store the original spy implementation if needed, or just let the new one be the behavior. + // The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance. + // We are setting its behavior for this test. + mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => { + // Simulate the actual registration name being stored for getTool to find + effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister); + // If it's the first time toolA is registered (from server1, not prefixed), + // also make it findable by its original name for the prefixing check of server2/toolA. + if ( + toolToRegister.serverName === 'server1' && + toolToRegister.serverToolName === 'toolA' && + toolToRegister.name === 'toolA' + ) { + effectivelyRegisteredTools.set('toolA', toolToRegister); + } + // The spy call count is inherently tracked by mockToolRegistry.registerTool itself. + }); + + // PRE-MOCK getToolsByServer for the expected server names + // This is for the final check in connectAndDiscover to see if any tools were registered *from that server* + mockToolRegistry.getToolsByServer.mockImplementation( + (serverName: string) => { + if (serverName === 'server1') + return [ + expect.objectContaining({ name: 'toolA' }), + expect.objectContaining({ name: 'toolB' }), + ]; + if (serverName === 'server2') + return [expect.objectContaining({ name: 'server2__toolA' })]; + return []; + }, + ); + + await discoverMcpTools(mockConfig); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3); + const registeredArgs = mockToolRegistry.registerTool.mock.calls.map( + (call) => call[0], + ) as DiscoveredMCPTool[]; + + // The order of server processing by Promise.all is not guaranteed. + // One 'toolA' will be unprefixed, the other will be prefixed. + const toolA_from_server1 = registeredArgs.find( + (t) => t.serverToolName === 'toolA' && t.serverName === 'server1', + ); + const toolA_from_server2 = registeredArgs.find( + (t) => t.serverToolName === 'toolA' && t.serverName === 'server2', + ); + const toolB_from_server1 = registeredArgs.find( + (t) => t.serverToolName === 'toolB' && t.serverName === 'server1', + ); + + expect(toolA_from_server1).toBeDefined(); + expect(toolA_from_server2).toBeDefined(); + expect(toolB_from_server1).toBeDefined(); + + expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique + + // Check that one of toolA is prefixed and the other is not, and the prefixed one is correct. + if (toolA_from_server1?.name === 'toolA') { + expect(toolA_from_server2?.name).toBe('server2__toolA'); + } else { + expect(toolA_from_server1?.name).toBe('server1__toolA'); + expect(toolA_from_server2?.name).toBe('toolA'); + } }); it('should clean schema properties ($schema, additionalProperties)', async () => { @@ -261,8 +376,12 @@ describe('discoverMcpTools', () => { vi.mocked(Client.prototype.listTools).mockResolvedValue({ tools: [mockTool], }); + // PRE-MOCK getToolsByServer for the expected server name + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); const registeredTool = mockToolRegistry.registerTool.mock @@ -291,9 +410,9 @@ describe('discoverMcpTools', () => { }); vi.spyOn(console, 'error').mockImplementation(() => {}); - await expect( - discoverMcpTools(mockConfig, mockToolRegistry), - ).rejects.toThrow('Parsing failed'); + await expect(discoverMcpTools(mockConfig)).rejects.toThrow( + 'Parsing failed', + ); expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); expect(console.error).not.toHaveBeenCalled(); }); @@ -302,7 +421,7 @@ describe('discoverMcpTools', () => { mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any }); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -323,7 +442,7 @@ describe('discoverMcpTools', () => { ); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -344,7 +463,7 @@ describe('discoverMcpTools', () => { ); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -359,8 +478,12 @@ describe('discoverMcpTools', () => { mockConfig.getMcpServers.mockReturnValue({ 'onerror-server': serverConfig, }); + // PRE-MOCK getToolsByServer for the expected server name + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); - await discoverMcpTools(mockConfig, mockToolRegistry); + await discoverMcpTools(mockConfig); const clientInstances = vi.mocked(Client).mock.results; expect(clientInstances.length).toBeGreaterThan(0); |
