summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.test.ts
diff options
context:
space:
mode:
authorN. Taylor Mullen <[email protected]>2025-06-02 13:39:25 -0700
committerGitHub <[email protected]>2025-06-02 20:39:25 +0000
commit58597c29d30eb0d95e1792f02eb7f1e7edc4218a (patch)
tree2dfb528ab008e454422fc27c941aa7aa925ec5d7 /packages/core/src/tools/mcp-client.test.ts
parent0795e55f0e7d2f2822bcd83eaf066eb99c67f858 (diff)
refactor: Update MCP tool discovery to use @google/genai - Also fixes JSON schema issues. (#682)
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts193
1 files changed, 158 insertions, 35 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 4664669d..121cd1d8 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -16,7 +16,6 @@ import {
} from 'vitest';
import { discoverMcpTools } from './mcp-client.js';
import { Config, MCPServerConfig } from '../config/config.js';
-import { ToolRegistry } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -51,33 +50,56 @@ vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
// Always return a new object with a fresh reference to the global mock for .on
this.options = options;
this.stderr = { on: mockGlobalStdioStderrOn };
+ this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
return this;
});
return { StdioClientTransport: MockedStdioTransport };
});
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
- const MockedSSETransport = vi.fn();
+ const MockedSSETransport = vi.fn().mockImplementation(function (this: any) {
+ this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
+ return this;
+ });
return { SSEClientTransport: MockedSSETransport };
});
-vi.mock('./tool-registry.js');
+const mockToolRegistryInstance = {
+ registerTool: vi.fn(),
+ getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
+ // Add other methods if they are called by the code under test, with default mocks
+ getTool: vi.fn(),
+ getAllTools: vi.fn().mockReturnValue([]),
+ getFunctionDeclarations: vi.fn().mockReturnValue([]),
+ discoverTools: vi.fn().mockResolvedValue(undefined),
+};
+vi.mock('./tool-registry.js', () => ({
+ ToolRegistry: vi.fn(() => mockToolRegistryInstance),
+}));
describe('discoverMcpTools', () => {
let mockConfig: Mocked<Config>;
- let mockToolRegistry: Mocked<ToolRegistry>;
+ // Use the instance from the module mock
+ let mockToolRegistry: typeof mockToolRegistryInstance;
beforeEach(() => {
+ // Assign the shared mock instance to the test-scoped variable
+ mockToolRegistry = mockToolRegistryInstance;
+ // Reset individual spies on the shared instance before each test
+ mockToolRegistry.registerTool.mockClear();
+ mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default
+ mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool
+ mockToolRegistry.getAllTools.mockClear().mockReturnValue([]);
+ mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]);
+ mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined);
+
mockConfig = {
getMcpServers: vi.fn().mockReturnValue({}),
getMcpServerCommand: vi.fn().mockReturnValue(undefined),
+ // getToolRegistry should now return the same shared mock instance
+ getToolRegistry: vi.fn(() => mockToolRegistry),
} as any;
- mockToolRegistry = new (ToolRegistry as any)(
- mockConfig,
- ) as Mocked<ToolRegistry>;
- mockToolRegistry.registerTool = vi.fn();
-
vi.mocked(parse).mockClear();
vi.mocked(Client).mockClear();
vi.mocked(Client.prototype.connect)
@@ -88,9 +110,24 @@ describe('discoverMcpTools', () => {
.mockResolvedValue({ tools: [] });
vi.mocked(StdioClientTransport).mockClear();
+ // Ensure the StdioClientTransport mock constructor returns an object with a close method
+ vi.mocked(StdioClientTransport).mockImplementation(function (
+ this: any,
+ options: any,
+ ) {
+ this.options = options;
+ this.stderr = { on: mockGlobalStdioStderrOn };
+ this.close = vi.fn().mockResolvedValue(undefined);
+ return this;
+ });
mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
vi.mocked(SSEClientTransport).mockClear();
+ // Ensure the SSEClientTransport mock constructor returns an object with a close method
+ vi.mocked(SSEClientTransport).mockImplementation(function (this: any) {
+ this.close = vi.fn().mockResolvedValue(undefined);
+ return this;
+ });
});
afterEach(() => {
@@ -98,7 +135,7 @@ describe('discoverMcpTools', () => {
});
it('should do nothing if no MCP servers or command are configured', async () => {
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
expect(Client).not.toHaveBeenCalled();
@@ -120,7 +157,11 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ // PRE-MOCK getToolsByServer for the expected server name
+ // In this case, listTools fails, so no tools are registered.
+ // The default mock `mockReturnValue([])` from beforeEach should apply.
+
+ await discoverMcpTools(mockConfig);
expect(parse).toHaveBeenCalledWith(commandString, process.env);
expect(StdioClientTransport).toHaveBeenCalledWith({
@@ -158,7 +199,12 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ // PRE-MOCK getToolsByServer for the expected server name
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(mockConfig);
expect(StdioClientTransport).toHaveBeenCalledWith({
command: serverConfig.command,
@@ -188,7 +234,12 @@ describe('discoverMcpTools', () => {
tools: [mockTool],
});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ // PRE-MOCK getToolsByServer for the expected server name
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(mockConfig);
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
@@ -208,32 +259,96 @@ describe('discoverMcpTools', () => {
});
const mockTool1 = {
- name: 'toolA',
+ name: 'toolA', // Same original name
description: 'd1',
inputSchema: { type: 'object' as const, properties: {} },
};
const mockTool2 = {
- name: 'toolB',
+ name: 'toolA', // Same original name
description: 'd2',
inputSchema: { type: 'object' as const, properties: {} },
};
+ const mockToolB = {
+ name: 'toolB',
+ description: 'dB',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
vi.mocked(Client.prototype.listTools)
- .mockResolvedValueOnce({ tools: [mockTool1] })
- .mockResolvedValueOnce({ tools: [mockTool2] });
+ .mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1
+ .mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA)
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ const effectivelyRegisteredTools = new Map<string, any>();
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
- const registeredTool1 = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- const registeredTool2 = mockToolRegistry.registerTool.mock
- .calls[1][0] as DiscoveredMCPTool;
+ mockToolRegistry.getTool.mockImplementation((toolName: string) =>
+ effectivelyRegisteredTools.get(toolName),
+ );
- expect(registeredTool1.name).toBe('server1__toolA');
- expect(registeredTool1.serverToolName).toBe('toolA');
- expect(registeredTool2.name).toBe('server2__toolB');
- expect(registeredTool2.serverToolName).toBe('toolB');
+ // Store the original spy implementation if needed, or just let the new one be the behavior.
+ // The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance.
+ // We are setting its behavior for this test.
+ mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => {
+ // Simulate the actual registration name being stored for getTool to find
+ effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister);
+ // If it's the first time toolA is registered (from server1, not prefixed),
+ // also make it findable by its original name for the prefixing check of server2/toolA.
+ if (
+ toolToRegister.serverName === 'server1' &&
+ toolToRegister.serverToolName === 'toolA' &&
+ toolToRegister.name === 'toolA'
+ ) {
+ effectivelyRegisteredTools.set('toolA', toolToRegister);
+ }
+ // The spy call count is inherently tracked by mockToolRegistry.registerTool itself.
+ });
+
+ // PRE-MOCK getToolsByServer for the expected server names
+ // This is for the final check in connectAndDiscover to see if any tools were registered *from that server*
+ mockToolRegistry.getToolsByServer.mockImplementation(
+ (serverName: string) => {
+ if (serverName === 'server1')
+ return [
+ expect.objectContaining({ name: 'toolA' }),
+ expect.objectContaining({ name: 'toolB' }),
+ ];
+ if (serverName === 'server2')
+ return [expect.objectContaining({ name: 'server2__toolA' })];
+ return [];
+ },
+ );
+
+ await discoverMcpTools(mockConfig);
+
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
+ const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
+ (call) => call[0],
+ ) as DiscoveredMCPTool[];
+
+ // The order of server processing by Promise.all is not guaranteed.
+ // One 'toolA' will be unprefixed, the other will be prefixed.
+ const toolA_from_server1 = registeredArgs.find(
+ (t) => t.serverToolName === 'toolA' && t.serverName === 'server1',
+ );
+ const toolA_from_server2 = registeredArgs.find(
+ (t) => t.serverToolName === 'toolA' && t.serverName === 'server2',
+ );
+ const toolB_from_server1 = registeredArgs.find(
+ (t) => t.serverToolName === 'toolB' && t.serverName === 'server1',
+ );
+
+ expect(toolA_from_server1).toBeDefined();
+ expect(toolA_from_server2).toBeDefined();
+ expect(toolB_from_server1).toBeDefined();
+
+ expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique
+
+ // Check that one of toolA is prefixed and the other is not, and the prefixed one is correct.
+ if (toolA_from_server1?.name === 'toolA') {
+ expect(toolA_from_server2?.name).toBe('server2__toolA');
+ } else {
+ expect(toolA_from_server1?.name).toBe('server1__toolA');
+ expect(toolA_from_server2?.name).toBe('toolA');
+ }
});
it('should clean schema properties ($schema, additionalProperties)', async () => {
@@ -261,8 +376,12 @@ describe('discoverMcpTools', () => {
vi.mocked(Client.prototype.listTools).mockResolvedValue({
tools: [mockTool],
});
+ // PRE-MOCK getToolsByServer for the expected server name
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
const registeredTool = mockToolRegistry.registerTool.mock
@@ -291,9 +410,9 @@ describe('discoverMcpTools', () => {
});
vi.spyOn(console, 'error').mockImplementation(() => {});
- await expect(
- discoverMcpTools(mockConfig, mockToolRegistry),
- ).rejects.toThrow('Parsing failed');
+ await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
+ 'Parsing failed',
+ );
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
expect(console.error).not.toHaveBeenCalled();
});
@@ -302,7 +421,7 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -323,7 +442,7 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -344,7 +463,7 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@@ -359,8 +478,12 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({
'onerror-server': serverConfig,
});
+ // PRE-MOCK getToolsByServer for the expected server name
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
- await discoverMcpTools(mockConfig, mockToolRegistry);
+ await discoverMcpTools(mockConfig);
const clientInstances = vi.mocked(Client).mock.results;
expect(clientInstances.length).toBeGreaterThan(0);