diff options
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.test.ts | 1096 |
1 files changed, 225 insertions, 871 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index df4d71ef..353b4f05 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,950 +4,304 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ +import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { - describe, - it, - expect, - vi, - beforeEach, - afterEach, - Mocked, -} from 'vitest'; -import { discoverMcpTools } from './mcp-client.js'; -import { sanitizeParameters } from './tool-registry.js'; -import { Schema, Type } from '@google/genai'; -import { Config, MCPServerConfig } from '../config/config.js'; -import { DiscoveredMCPTool } from './mcp-tool.js'; -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; + populateMcpServerCommand, + createTransport, + generateValidName, + isEnabled, + discoverTools, +} from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; -import { parse, ParseEntry } from 'shell-quote'; - -// Mock dependencies -vi.mock('shell-quote'); - -vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { - const MockedClient = vi.fn(); - MockedClient.prototype.connect = vi.fn(); - MockedClient.prototype.listTools = vi.fn(); - // Ensure instances have an onerror property that can be spied on or assigned to - MockedClient.mockImplementation(() => ({ - connect: MockedClient.prototype.connect, - listTools: MockedClient.prototype.listTools, - onerror: vi.fn(), // Each instance gets its own onerror mock - })); - return { Client: MockedClient }; -}); - -// Define a global mock for stderr.on that can be cleared and checked -const mockGlobalStdioStderrOn = vi.fn(); - -vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { - // This is the constructor for StdioClientTransport - const MockedStdioTransport = vi.fn().mockImplementation(function ( - this: any, - options: any, - ) { - // Always return a new object with a fresh reference to the global mock for .on - this.options = options; - this.stderr = { on: mockGlobalStdioStderrOn }; - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { StdioClientTransport: MockedStdioTransport }; -}); - -vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { - const MockedSSETransport = vi.fn().mockImplementation(function (this: any) { - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { SSEClientTransport: MockedSSETransport }; -}); - -vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => { - const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function ( - this: any, - ) { - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport }; -}); - -const mockToolRegistryInstance = { - registerTool: vi.fn(), - getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array - // Add other methods if they are called by the code under test, with default mocks - getTool: vi.fn(), - getAllTools: vi.fn().mockReturnValue([]), - getFunctionDeclarations: vi.fn().mockReturnValue([]), - discoverTools: vi.fn().mockResolvedValue(undefined), -}; -vi.mock('./tool-registry.js', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...(actual as any), - ToolRegistry: vi.fn(() => mockToolRegistryInstance), - sanitizeParameters: (actual as any).sanitizeParameters, - }; -}); - -describe('discoverMcpTools', () => { - let mockConfig: Mocked<Config>; - // Use the instance from the module mock - let mockToolRegistry: typeof mockToolRegistryInstance; - - beforeEach(() => { - // Assign the shared mock instance to the test-scoped variable - mockToolRegistry = mockToolRegistryInstance; - // Reset individual spies on the shared instance before each test - mockToolRegistry.registerTool.mockClear(); - mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default - mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool - mockToolRegistry.getAllTools.mockClear().mockReturnValue([]); - mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]); - mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined); - - mockConfig = { - getMcpServers: vi.fn().mockReturnValue({}), - getMcpServerCommand: vi.fn().mockReturnValue(undefined), - // getToolRegistry should now return the same shared mock instance - getToolRegistry: vi.fn(() => mockToolRegistry), - } as any; - - vi.mocked(parse).mockClear(); - vi.mocked(Client).mockClear(); - vi.mocked(Client.prototype.connect) - .mockClear() - .mockResolvedValue(undefined); - vi.mocked(Client.prototype.listTools) - .mockClear() - .mockResolvedValue({ tools: [] }); - - vi.mocked(StdioClientTransport).mockClear(); - // Ensure the StdioClientTransport mock constructor returns an object with a close method - vi.mocked(StdioClientTransport).mockImplementation(function ( - this: any, - options: any, - ) { - this.options = options; - this.stderr = { on: mockGlobalStdioStderrOn }; - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); - mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach - - vi.mocked(SSEClientTransport).mockClear(); - // Ensure the SSEClientTransport mock constructor returns an object with a close method - vi.mocked(SSEClientTransport).mockImplementation(function (this: any) { - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); +import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; +import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; +import * as GenAiLib from '@google/genai'; - vi.mocked(StreamableHTTPClientTransport).mockClear(); - // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method - vi.mocked(StreamableHTTPClientTransport).mockImplementation(function ( - this: any, - ) { - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); - }); +vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); +vi.mock('@modelcontextprotocol/sdk/client/index.js'); +vi.mock('@google/genai'); +describe('mcp-client', () => { afterEach(() => { vi.restoreAllMocks(); }); - it('should do nothing if no MCP servers or command are configured', async () => { - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1); - expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1); - expect(Client).not.toHaveBeenCalled(); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - }); - - it('should discover tools via mcpServerCommand', async () => { - const commandString = 'my-mcp-server --start'; - const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[]; - mockConfig.getMcpServerCommand.mockReturnValue(commandString); - vi.mocked(parse).mockReturnValue(parsedCommand); - - const mockTool = { - name: 'tool1', - description: 'desc1', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - - // PRE-MOCK getToolsByServer for the expected server name - // In this case, listTools fails, so no tools are registered. - // The default mock `mockReturnValue([])` from beforeEach should apply. + describe('discoverTools', () => { + it('should discover tools', async () => { + const mockedClient = {} as unknown as ClientLib.Client; + const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ + tool: () => ({ + functionDeclarations: [ + { + name: 'testFunction', + }, + ], + }), + } as unknown as GenAiLib.CallableTool); - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); + const tools = await discoverTools('test-server', {}, mockedClient); - expect(parse).toHaveBeenCalledWith(commandString, process.env); - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: parsedCommand[0], - args: parsedCommand.slice(1), - env: expect.any(Object), - cwd: undefined, - stderr: 'pipe', + expect(tools.length).toBe(1); + expect(mockedMcpToTool).toHaveBeenCalledOnce(); }); - expect(Client.prototype.connect).toHaveBeenCalledTimes(1); - expect(Client.prototype.listTools).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool1'); - expect(registeredTool.serverToolName).toBe('tool1'); }); - it('should discover tools via mcpServers config (stdio)', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-stdio', - args: ['arg1'], - }; - mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig }); - - const mockTool = { - name: 'tool-stdio', - description: 'desc-stdio', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], + describe('appendMcpServerCommand', () => { + it('should do nothing if no MCP servers or command are configured', () => { + const out = populateMcpServerCommand({}, undefined); + expect(out).toEqual({}); }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); + it('should discover tools via mcpServerCommand', () => { + const commandString = 'command --arg1 value1'; + const out = populateMcpServerCommand({}, commandString); + expect(out).toEqual({ + mcp: { + command: 'command', + args: ['--arg1', 'value1'], + }, + }); + }); - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: serverConfig.command, - args: serverConfig.args, - env: expect.any(Object), - cwd: undefined, - stderr: 'pipe', + it('should handle error if mcpServerCommand parsing fails', () => { + expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError(); }); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-stdio'); }); - it('should discover tools via mcpServers config (sse)', async () => { - const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' }; - mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig }); + describe('createTransport', () => { + const originalEnv = process.env; - const mockTool = { - name: 'tool-sse', - description: 'desc-sse', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], + beforeEach(() => { + vi.resetModules(); + process.env = {}; }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); + afterEach(() => { + process.env = originalEnv; + }); - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - {}, - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-sse'); - }); + describe('should connect via httpUrl', () => { + it('without headers', async () => { + const transport = createTransport( + 'test-server', + { + httpUrl: 'http://test-server', + }, + false, + ); - describe('SseClientTransport headers', () => { - const setupSseTest = async (headers?: Record<string, string>) => { - const serverConfig: MCPServerConfig = { - url: 'http://localhost:1234/sse', - ...(headers && { headers }), - }; - const serverName = headers - ? 'sse-server-with-headers' - : 'sse-server-no-headers'; - const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; + expect(transport).toEqual( + new StreamableHTTPClientTransport(new URL('http://test-server'), {}), + ); + }); - mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); + it('with headers', async () => { + const transport = createTransport( + 'test-server', + { + httpUrl: 'http://test-server', + headers: { Authorization: 'derp' }, + }, + false, + ); - const mockTool = { - name: toolName, - description: `desc-${toolName}`, - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], + expect(transport).toEqual( + new StreamableHTTPClientTransport(new URL('http://test-server'), { + requestInit: { + headers: { Authorization: 'derp' }, + }, + }), + ); }); - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); + }); - return { serverConfig }; - }; + describe('should connect via url', () => { + it('without headers', async () => { + const transport = createTransport( + 'test-server', + { + url: 'http://test-server', + }, + false, + ); + expect(transport).toEqual( + new SSEClientTransport(new URL('http://test-server'), {}), + ); + }); - it('should pass headers when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - 'X-Custom-Header': 'custom-value', - }; - const { serverConfig } = await setupSseTest(headers); + it('with headers', async () => { + const transport = createTransport( + 'test-server', + { + url: 'http://test-server', + headers: { Authorization: 'derp' }, + }, + false, + ); - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - { requestInit: { headers } }, - ); + expect(transport).toEqual( + new SSEClientTransport(new URL('http://test-server'), { + requestInit: { + headers: { Authorization: 'derp' }, + }, + }), + ); + }); }); - it('should work without headers (backwards compatibility)', async () => { - const { serverConfig } = await setupSseTest(); + it('should connect via command', () => { + const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - {}, + createTransport( + 'test-server', + { + command: 'test-command', + args: ['--foo', 'bar'], + env: { FOO: 'bar' }, + cwd: 'test/cwd', + }, + false, ); - }); - - it('should pass oauth token when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - }; - const { serverConfig } = await setupSseTest(headers); - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - { requestInit: { headers } }, - ); + expect(mockedTransport).toHaveBeenCalledWith({ + command: 'test-command', + args: ['--foo', 'bar'], + cwd: 'test/cwd', + env: { FOO: 'bar' }, + stderr: 'pipe', + }); }); }); - - it('should discover tools via mcpServers config (streamable http)', async () => { - const serverConfig: MCPServerConfig = { - httpUrl: 'http://localhost:3000/mcp', - }; - mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig }); - - const mockTool = { - name: 'tool-http', - description: 'desc-http', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], + describe('generateValidName', () => { + it('should return a valid name for a simple function', () => { + const funcDecl = { name: 'myFunction' }; + const serverName = 'myServer'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('myServer__myFunction'); }); - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - {}, - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-http'); - }); - - describe('StreamableHTTPClientTransport headers', () => { - const setupHttpTest = async (headers?: Record<string, string>) => { - const serverConfig: MCPServerConfig = { - httpUrl: 'http://localhost:3000/mcp', - ...(headers && { headers }), - }; - const serverName = headers - ? 'http-server-with-headers' - : 'http-server-no-headers'; - const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; - - mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); - - const mockTool = { - name: toolName, - description: `desc-${toolName}`, - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - return { serverConfig }; - }; - - it('should pass headers when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - 'X-Custom-Header': 'custom-value', - }; - const { serverConfig } = await setupHttpTest(headers); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - { requestInit: { headers } }, - ); + it('should prepend the server name', () => { + const funcDecl = { name: 'anotherFunction' }; + const serverName = 'production-server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('production-server__anotherFunction'); }); - it('should work without headers (backwards compatibility)', async () => { - const { serverConfig } = await setupHttpTest(); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - {}, - ); + it('should replace invalid characters with underscores', () => { + const funcDecl = { name: 'invalid-name with spaces' }; + const serverName = 'test_server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('test_server__invalid-name_with_spaces'); }); - it('should pass oauth token when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', + it('should truncate long names', () => { + const funcDecl = { + name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit', }; - const { serverConfig } = await setupHttpTest(headers); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - { requestInit: { headers } }, + const serverName = 'a_long_server_name'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'a_long_server_name__a_very_l___will_definitely_exceed_the_limit', ); }); - }); - it('should prefix tool names if multiple MCP servers are configured', async () => { - const serverConfig1: MCPServerConfig = { command: './mcp1' }; - const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' }; - mockConfig.getMcpServers.mockReturnValue({ - server1: serverConfig1, - server2: serverConfig2, + it('should handle names with only invalid characters', () => { + const funcDecl = { name: '!@#$%^&*()' }; + const serverName = 'special-chars'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('special-chars____________'); }); - const mockTool1 = { - name: 'toolA', // Same original name - description: 'd1', - inputSchema: { type: 'object' as const, properties: {} }, - }; - const mockTool2 = { - name: 'toolA', // Same original name - description: 'd2', - inputSchema: { type: 'object' as const, properties: {} }, - }; - const mockToolB = { - name: 'toolB', - description: 'dB', - inputSchema: { type: 'object' as const, properties: {} }, - }; - - vi.mocked(Client.prototype.listTools) - .mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1 - .mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA) - - const effectivelyRegisteredTools = new Map<string, any>(); - - mockToolRegistry.getTool.mockImplementation((toolName: string) => - effectivelyRegisteredTools.get(toolName), - ); - - // Store the original spy implementation if needed, or just let the new one be the behavior. - // The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance. - // We are setting its behavior for this test. - mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => { - // Simulate the actual registration name being stored for getTool to find - effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister); - // If it's the first time toolA is registered (from server1, not prefixed), - // also make it findable by its original name for the prefixing check of server2/toolA. - if ( - toolToRegister.serverName === 'server1' && - toolToRegister.serverToolName === 'toolA' && - toolToRegister.name === 'toolA' - ) { - effectivelyRegisteredTools.set('toolA', toolToRegister); - } - // The spy call count is inherently tracked by mockToolRegistry.registerTool itself. + it('should handle names that are already valid', () => { + const funcDecl = { name: 'already_valid' }; + const serverName = 'validator'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('validator__already_valid'); }); - // PRE-MOCK getToolsByServer for the expected server names - // This is for the final check in connectAndDiscover to see if any tools were registered *from that server* - mockToolRegistry.getToolsByServer.mockImplementation( - (serverName: string) => { - if (serverName === 'server1') - return [ - expect.objectContaining({ name: 'toolA' }), - expect.objectContaining({ name: 'toolB' }), - ]; - if (serverName === 'server2') - return [expect.objectContaining({ name: 'server2__toolA' })]; - return []; - }, - ); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3); - const registeredArgs = mockToolRegistry.registerTool.mock.calls.map( - (call) => call[0], - ) as DiscoveredMCPTool[]; - - // The order of server processing by Promise.all is not guaranteed. - // One 'toolA' will be unprefixed, the other will be prefixed. - const toolA_from_server1 = registeredArgs.find( - (t) => t.serverToolName === 'toolA' && t.serverName === 'server1', - ); - const toolA_from_server2 = registeredArgs.find( - (t) => t.serverToolName === 'toolA' && t.serverName === 'server2', - ); - const toolB_from_server1 = registeredArgs.find( - (t) => t.serverToolName === 'toolB' && t.serverName === 'server1', - ); - - expect(toolA_from_server1).toBeDefined(); - expect(toolA_from_server2).toBeDefined(); - expect(toolB_from_server1).toBeDefined(); - - expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique - - // Check that one of toolA is prefixed and the other is not, and the prefixed one is correct. - if (toolA_from_server1?.name === 'toolA') { - expect(toolA_from_server2?.name).toBe('server2__toolA'); - } else { - expect(toolA_from_server1?.name).toBe('server1__toolA'); - expect(toolA_from_server2?.name).toBe('toolA'); - } - }); - - it('should clean schema properties ($schema, additionalProperties)', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-clean' }; - mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig }); - - const rawSchema = { - type: 'object' as const, - $schema: 'http://json-schema.org/draft-07/schema#', - additionalProperties: true, - properties: { - prop1: { type: 'string', $schema: 'remove-this' }, - prop2: { - type: 'object' as const, - additionalProperties: false, - properties: { nested: { type: 'number' } }, - }, - }, - }; - const mockTool = { - name: 'cleanTool', - description: 'd', - inputSchema: JSON.parse(JSON.stringify(rawSchema)), - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], + it('should handle names with leading/trailing invalid characters', () => { + const funcDecl = { name: '-_invalid-_' }; + const serverName = 'trim-test'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('trim-test__-_invalid-_'); }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - const cleanedParams = registeredTool.schema.parameters as any; - - expect(cleanedParams).not.toHaveProperty('$schema'); - expect(cleanedParams).not.toHaveProperty('additionalProperties'); - expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema'); - expect(cleanedParams.properties.prop2).not.toHaveProperty( - 'additionalProperties', - ); - expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( - '$schema', - ); - expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( - 'additionalProperties', - ); - }); - it('should handle error if mcpServerCommand parsing fails', async () => { - const commandString = 'my-mcp-server "unterminated quote'; - mockConfig.getMcpServerCommand.mockReturnValue(commandString); - vi.mocked(parse).mockImplementation(() => { - throw new Error('Parsing failed'); + it('should handle names that are exactly 63 characters long', () => { + const longName = 'a'.repeat(45); + const funcDecl = { name: longName }; + const serverName = 'server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe(`server__${longName}`); + expect(result.length).toBe(53); }); - vi.spyOn(console, 'error').mockImplementation(() => {}); - await expect( - discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ), - ).rejects.toThrow('Parsing failed'); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - expect(console.error).not.toHaveBeenCalled(); - }); - - it('should log error and skip server if config is invalid (missing url and command)', async () => { - mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any }); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "MCP server 'bad-server' has invalid configuration", - ), - ); - // Client constructor should not be called if config is invalid before instantiation - expect(Client).not.toHaveBeenCalled(); - }); - - it('should log error and skip server if mcpClient.connect fails', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' }; - mockConfig.getMcpServers.mockReturnValue({ - 'fail-connect-server': serverConfig, + it('should handle names that are exactly 64 characters long', () => { + const longName = 'a'.repeat(55); + const funcDecl = { name: longName }; + const serverName = 'server'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + ); }); - vi.mocked(Client.prototype.connect).mockRejectedValue( - new Error('Connection refused'), - ); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "failed to start or connect to MCP server 'fail-connect-server'", - ), - ); - expect(Client.prototype.listTools).not.toHaveBeenCalled(); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - }); - it('should log error and skip server if mcpClient.listTools fails', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-fail-list' }; - mockConfig.getMcpServers.mockReturnValue({ - 'fail-list-server': serverConfig, + it('should handle names that are longer than 64 characters', () => { + const longName = 'a'.repeat(100); + const funcDecl = { name: longName }; + const serverName = 'long-server'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + ); }); - vi.mocked(Client.prototype.listTools).mockRejectedValue( - new Error('ListTools error'), - ); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "Failed to list or register tools for MCP server 'fail-list-server'", - ), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); }); + describe('isEnabled', () => { + const funcDecl = { name: 'myTool' }; + const serverName = 'myServer'; - it('should assign mcpClient.onerror handler', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-onerror' }; - mockConfig.getMcpServers.mockReturnValue({ - 'onerror-server': serverConfig, + it('should return true if no include or exclude lists are provided', () => { + const mcpServerConfig = {}; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - const clientInstances = vi.mocked(Client).mock.results; - expect(clientInstances.length).toBeGreaterThan(0); - const lastClientInstance = - clientInstances[clientInstances.length - 1]?.value; - expect(lastClientInstance?.onerror).toEqual(expect.any(Function)); - }); - describe('Tool Filtering', () => { - const mockTools = [ - { - name: 'toolA', - description: 'descA', - inputSchema: { type: 'object' as const, properties: {} }, - }, - { - name: 'toolB', - description: 'descB', - inputSchema: { type: 'object' as const, properties: {} }, - }, - { - name: 'toolC', - description: 'descC', - inputSchema: { type: 'object' as const, properties: {} }, - }, - ]; - - beforeEach(() => { - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: mockTools, - }); - mockToolRegistry.getToolsByServer.mockReturnValue([ - expect.any(DiscoveredMCPTool), - ]); + it('should return false if the tool is in the exclude list', () => { + const mcpServerConfig = { excludeTools: ['myTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); }); - it('should only include specified tools with includeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-include', - includeTools: ['toolA', 'toolC'], - }; - mockConfig.getMcpServers.mockReturnValue({ - 'include-server': serverConfig, - }); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); + it('should return true if the tool is in the include list', () => { + const mcpServerConfig = { includeTools: ['myTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); }); - it('should exclude specified tools with excludeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-exclude', - excludeTools: ['toolB'], - }; - mockConfig.getMcpServers.mockReturnValue({ - 'exclude-server': serverConfig, - }); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); + it('should return true if the tool is in the include list with parentheses', () => { + const mcpServerConfig = { includeTools: ['myTool()'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); + }); - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); + it('should return false if the include list exists but does not contain the tool', () => { + const mcpServerConfig = { includeTools: ['anotherTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); }); - it('should handle both includeTools and excludeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-both', - includeTools: ['toolA', 'toolB'], - excludeTools: ['toolB'], + it('should return false if the tool is in both the include and exclude lists', () => { + const mcpServerConfig = { + includeTools: ['myTool'], + excludeTools: ['myTool'], }; - mockConfig.getMcpServers.mockReturnValue({ 'both-server': serverConfig }); + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); + }); - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, + it('should return false if the function declaration has no name', () => { + const namelessFuncDecl = {}; + const mcpServerConfig = {}; + expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe( false, ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); }); }); }); - -describe('sanitizeParameters', () => { - it('should do nothing for an undefined schema', () => { - const schema = undefined; - sanitizeParameters(schema); - }); - - it('should remove default when anyOf is present', () => { - const schema: Schema = { - anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }], - default: 'hello', - }; - sanitizeParameters(schema); - expect(schema.default).toBeUndefined(); - }); - - it('should recursively sanitize items in anyOf', () => { - const schema: Schema = { - anyOf: [ - { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - { type: Type.NUMBER }, - ], - }; - sanitizeParameters(schema); - expect(schema.anyOf![0].default).toBeUndefined(); - }); - - it('should recursively sanitize items in items', () => { - const schema: Schema = { - items: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }; - sanitizeParameters(schema); - expect(schema.items!.default).toBeUndefined(); - }); - - it('should recursively sanitize items in properties', () => { - const schema: Schema = { - properties: { - prop1: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }, - }; - sanitizeParameters(schema); - expect(schema.properties!.prop1.default).toBeUndefined(); - }); - - it('should handle complex nested schemas', () => { - const schema: Schema = { - properties: { - prop1: { - items: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }, - prop2: { - anyOf: [ - { - properties: { - nestedProp: { - anyOf: [{ type: Type.NUMBER }], - default: 123, - }, - }, - }, - ], - }, - }, - }; - sanitizeParameters(schema); - expect(schema.properties!.prop1.items!.default).toBeUndefined(); - const nestedProp = - schema.properties!.prop2.anyOf![0].properties!.nestedProp; - expect(nestedProp?.default).toBeUndefined(); - }); -}); |
