summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts1096
-rw-r--r--packages/core/src/tools/mcp-client.ts423
-rw-r--r--packages/core/src/tools/tool-registry.test.ts109
3 files changed, 537 insertions, 1091 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index df4d71ef..353b4f05 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -4,950 +4,304 @@
* SPDX-License-Identifier: Apache-2.0
*/
-/* eslint-disable @typescript-eslint/no-explicit-any */
+import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
+import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
- describe,
- it,
- expect,
- vi,
- beforeEach,
- afterEach,
- Mocked,
-} from 'vitest';
-import { discoverMcpTools } from './mcp-client.js';
-import { sanitizeParameters } from './tool-registry.js';
-import { Schema, Type } from '@google/genai';
-import { Config, MCPServerConfig } from '../config/config.js';
-import { DiscoveredMCPTool } from './mcp-tool.js';
-import { Client } from '@modelcontextprotocol/sdk/client/index.js';
-import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
+ populateMcpServerCommand,
+ createTransport,
+ generateValidName,
+ isEnabled,
+ discoverTools,
+} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
-import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
-import { parse, ParseEntry } from 'shell-quote';
-
-// Mock dependencies
-vi.mock('shell-quote');
-
-vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
- const MockedClient = vi.fn();
- MockedClient.prototype.connect = vi.fn();
- MockedClient.prototype.listTools = vi.fn();
- // Ensure instances have an onerror property that can be spied on or assigned to
- MockedClient.mockImplementation(() => ({
- connect: MockedClient.prototype.connect,
- listTools: MockedClient.prototype.listTools,
- onerror: vi.fn(), // Each instance gets its own onerror mock
- }));
- return { Client: MockedClient };
-});
-
-// Define a global mock for stderr.on that can be cleared and checked
-const mockGlobalStdioStderrOn = vi.fn();
-
-vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
- // This is the constructor for StdioClientTransport
- const MockedStdioTransport = vi.fn().mockImplementation(function (
- this: any,
- options: any,
- ) {
- // Always return a new object with a fresh reference to the global mock for .on
- this.options = options;
- this.stderr = { on: mockGlobalStdioStderrOn };
- this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
- return this;
- });
- return { StdioClientTransport: MockedStdioTransport };
-});
-
-vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
- const MockedSSETransport = vi.fn().mockImplementation(function (this: any) {
- this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
- return this;
- });
- return { SSEClientTransport: MockedSSETransport };
-});
-
-vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => {
- const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function (
- this: any,
- ) {
- this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
- return this;
- });
- return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport };
-});
-
-const mockToolRegistryInstance = {
- registerTool: vi.fn(),
- getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
- // Add other methods if they are called by the code under test, with default mocks
- getTool: vi.fn(),
- getAllTools: vi.fn().mockReturnValue([]),
- getFunctionDeclarations: vi.fn().mockReturnValue([]),
- discoverTools: vi.fn().mockResolvedValue(undefined),
-};
-vi.mock('./tool-registry.js', async (importOriginal) => {
- const actual = await importOriginal();
- return {
- ...(actual as any),
- ToolRegistry: vi.fn(() => mockToolRegistryInstance),
- sanitizeParameters: (actual as any).sanitizeParameters,
- };
-});
-
-describe('discoverMcpTools', () => {
- let mockConfig: Mocked<Config>;
- // Use the instance from the module mock
- let mockToolRegistry: typeof mockToolRegistryInstance;
-
- beforeEach(() => {
- // Assign the shared mock instance to the test-scoped variable
- mockToolRegistry = mockToolRegistryInstance;
- // Reset individual spies on the shared instance before each test
- mockToolRegistry.registerTool.mockClear();
- mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default
- mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool
- mockToolRegistry.getAllTools.mockClear().mockReturnValue([]);
- mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]);
- mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined);
-
- mockConfig = {
- getMcpServers: vi.fn().mockReturnValue({}),
- getMcpServerCommand: vi.fn().mockReturnValue(undefined),
- // getToolRegistry should now return the same shared mock instance
- getToolRegistry: vi.fn(() => mockToolRegistry),
- } as any;
-
- vi.mocked(parse).mockClear();
- vi.mocked(Client).mockClear();
- vi.mocked(Client.prototype.connect)
- .mockClear()
- .mockResolvedValue(undefined);
- vi.mocked(Client.prototype.listTools)
- .mockClear()
- .mockResolvedValue({ tools: [] });
-
- vi.mocked(StdioClientTransport).mockClear();
- // Ensure the StdioClientTransport mock constructor returns an object with a close method
- vi.mocked(StdioClientTransport).mockImplementation(function (
- this: any,
- options: any,
- ) {
- this.options = options;
- this.stderr = { on: mockGlobalStdioStderrOn };
- this.close = vi.fn().mockResolvedValue(undefined);
- return this;
- });
- mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
-
- vi.mocked(SSEClientTransport).mockClear();
- // Ensure the SSEClientTransport mock constructor returns an object with a close method
- vi.mocked(SSEClientTransport).mockImplementation(function (this: any) {
- this.close = vi.fn().mockResolvedValue(undefined);
- return this;
- });
+import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
+import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
+import * as GenAiLib from '@google/genai';
- vi.mocked(StreamableHTTPClientTransport).mockClear();
- // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method
- vi.mocked(StreamableHTTPClientTransport).mockImplementation(function (
- this: any,
- ) {
- this.close = vi.fn().mockResolvedValue(undefined);
- return this;
- });
- });
+vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
+vi.mock('@modelcontextprotocol/sdk/client/index.js');
+vi.mock('@google/genai');
+describe('mcp-client', () => {
afterEach(() => {
vi.restoreAllMocks();
});
- it('should do nothing if no MCP servers or command are configured', async () => {
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
- expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
- expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
- expect(Client).not.toHaveBeenCalled();
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
- });
-
- it('should discover tools via mcpServerCommand', async () => {
- const commandString = 'my-mcp-server --start';
- const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[];
- mockConfig.getMcpServerCommand.mockReturnValue(commandString);
- vi.mocked(parse).mockReturnValue(parsedCommand);
-
- const mockTool = {
- name: 'tool1',
- description: 'desc1',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
- });
-
- // PRE-MOCK getToolsByServer for the expected server name
- // In this case, listTools fails, so no tools are registered.
- // The default mock `mockReturnValue([])` from beforeEach should apply.
+ describe('discoverTools', () => {
+ it('should discover tools', async () => {
+ const mockedClient = {} as unknown as ClientLib.Client;
+ const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
+ tool: () => ({
+ functionDeclarations: [
+ {
+ name: 'testFunction',
+ },
+ ],
+ }),
+ } as unknown as GenAiLib.CallableTool);
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
+ const tools = await discoverTools('test-server', {}, mockedClient);
- expect(parse).toHaveBeenCalledWith(commandString, process.env);
- expect(StdioClientTransport).toHaveBeenCalledWith({
- command: parsedCommand[0],
- args: parsedCommand.slice(1),
- env: expect.any(Object),
- cwd: undefined,
- stderr: 'pipe',
+ expect(tools.length).toBe(1);
+ expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
- expect(Client.prototype.connect).toHaveBeenCalledTimes(1);
- expect(Client.prototype.listTools).toHaveBeenCalledTimes(1);
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.any(DiscoveredMCPTool),
- );
- const registeredTool = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- expect(registeredTool.name).toBe('tool1');
- expect(registeredTool.serverToolName).toBe('tool1');
});
- it('should discover tools via mcpServers config (stdio)', async () => {
- const serverConfig: MCPServerConfig = {
- command: './mcp-stdio',
- args: ['arg1'],
- };
- mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig });
-
- const mockTool = {
- name: 'tool-stdio',
- description: 'desc-stdio',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
+ describe('appendMcpServerCommand', () => {
+ it('should do nothing if no MCP servers or command are configured', () => {
+ const out = populateMcpServerCommand({}, undefined);
+ expect(out).toEqual({});
});
- // PRE-MOCK getToolsByServer for the expected server name
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
+ it('should discover tools via mcpServerCommand', () => {
+ const commandString = 'command --arg1 value1';
+ const out = populateMcpServerCommand({}, commandString);
+ expect(out).toEqual({
+ mcp: {
+ command: 'command',
+ args: ['--arg1', 'value1'],
+ },
+ });
+ });
- expect(StdioClientTransport).toHaveBeenCalledWith({
- command: serverConfig.command,
- args: serverConfig.args,
- env: expect.any(Object),
- cwd: undefined,
- stderr: 'pipe',
+ it('should handle error if mcpServerCommand parsing fails', () => {
+ expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError();
});
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.any(DiscoveredMCPTool),
- );
- const registeredTool = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- expect(registeredTool.name).toBe('tool-stdio');
});
- it('should discover tools via mcpServers config (sse)', async () => {
- const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' };
- mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig });
+ describe('createTransport', () => {
+ const originalEnv = process.env;
- const mockTool = {
- name: 'tool-sse',
- description: 'desc-sse',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
+ beforeEach(() => {
+ vi.resetModules();
+ process.env = {};
});
- // PRE-MOCK getToolsByServer for the expected server name
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
+ afterEach(() => {
+ process.env = originalEnv;
+ });
- expect(SSEClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.url!),
- {},
- );
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.any(DiscoveredMCPTool),
- );
- const registeredTool = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- expect(registeredTool.name).toBe('tool-sse');
- });
+ describe('should connect via httpUrl', () => {
+ it('without headers', async () => {
+ const transport = createTransport(
+ 'test-server',
+ {
+ httpUrl: 'http://test-server',
+ },
+ false,
+ );
- describe('SseClientTransport headers', () => {
- const setupSseTest = async (headers?: Record<string, string>) => {
- const serverConfig: MCPServerConfig = {
- url: 'http://localhost:1234/sse',
- ...(headers && { headers }),
- };
- const serverName = headers
- ? 'sse-server-with-headers'
- : 'sse-server-no-headers';
- const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
+ expect(transport).toEqual(
+ new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
+ );
+ });
- mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
+ it('with headers', async () => {
+ const transport = createTransport(
+ 'test-server',
+ {
+ httpUrl: 'http://test-server',
+ headers: { Authorization: 'derp' },
+ },
+ false,
+ );
- const mockTool = {
- name: toolName,
- description: `desc-${toolName}`,
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
+ expect(transport).toEqual(
+ new StreamableHTTPClientTransport(new URL('http://test-server'), {
+ requestInit: {
+ headers: { Authorization: 'derp' },
+ },
+ }),
+ );
});
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
+ });
- return { serverConfig };
- };
+ describe('should connect via url', () => {
+ it('without headers', async () => {
+ const transport = createTransport(
+ 'test-server',
+ {
+ url: 'http://test-server',
+ },
+ false,
+ );
+ expect(transport).toEqual(
+ new SSEClientTransport(new URL('http://test-server'), {}),
+ );
+ });
- it('should pass headers when provided', async () => {
- const headers = {
- Authorization: 'Bearer test-token',
- 'X-Custom-Header': 'custom-value',
- };
- const { serverConfig } = await setupSseTest(headers);
+ it('with headers', async () => {
+ const transport = createTransport(
+ 'test-server',
+ {
+ url: 'http://test-server',
+ headers: { Authorization: 'derp' },
+ },
+ false,
+ );
- expect(SSEClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.url!),
- { requestInit: { headers } },
- );
+ expect(transport).toEqual(
+ new SSEClientTransport(new URL('http://test-server'), {
+ requestInit: {
+ headers: { Authorization: 'derp' },
+ },
+ }),
+ );
+ });
});
- it('should work without headers (backwards compatibility)', async () => {
- const { serverConfig } = await setupSseTest();
+ it('should connect via command', () => {
+ const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
- expect(SSEClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.url!),
- {},
+ createTransport(
+ 'test-server',
+ {
+ command: 'test-command',
+ args: ['--foo', 'bar'],
+ env: { FOO: 'bar' },
+ cwd: 'test/cwd',
+ },
+ false,
);
- });
-
- it('should pass oauth token when provided', async () => {
- const headers = {
- Authorization: 'Bearer test-token',
- };
- const { serverConfig } = await setupSseTest(headers);
- expect(SSEClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.url!),
- { requestInit: { headers } },
- );
+ expect(mockedTransport).toHaveBeenCalledWith({
+ command: 'test-command',
+ args: ['--foo', 'bar'],
+ cwd: 'test/cwd',
+ env: { FOO: 'bar' },
+ stderr: 'pipe',
+ });
});
});
-
- it('should discover tools via mcpServers config (streamable http)', async () => {
- const serverConfig: MCPServerConfig = {
- httpUrl: 'http://localhost:3000/mcp',
- };
- mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig });
-
- const mockTool = {
- name: 'tool-http',
- description: 'desc-http',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
+ describe('generateValidName', () => {
+ it('should return a valid name for a simple function', () => {
+ const funcDecl = { name: 'myFunction' };
+ const serverName = 'myServer';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('myServer__myFunction');
});
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.httpUrl!),
- {},
- );
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.any(DiscoveredMCPTool),
- );
- const registeredTool = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- expect(registeredTool.name).toBe('tool-http');
- });
-
- describe('StreamableHTTPClientTransport headers', () => {
- const setupHttpTest = async (headers?: Record<string, string>) => {
- const serverConfig: MCPServerConfig = {
- httpUrl: 'http://localhost:3000/mcp',
- ...(headers && { headers }),
- };
- const serverName = headers
- ? 'http-server-with-headers'
- : 'http-server-no-headers';
- const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
-
- mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
-
- const mockTool = {
- name: toolName,
- description: `desc-${toolName}`,
- inputSchema: { type: 'object' as const, properties: {} },
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
- });
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- return { serverConfig };
- };
-
- it('should pass headers when provided', async () => {
- const headers = {
- Authorization: 'Bearer test-token',
- 'X-Custom-Header': 'custom-value',
- };
- const { serverConfig } = await setupHttpTest(headers);
-
- expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.httpUrl!),
- { requestInit: { headers } },
- );
+ it('should prepend the server name', () => {
+ const funcDecl = { name: 'anotherFunction' };
+ const serverName = 'production-server';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('production-server__anotherFunction');
});
- it('should work without headers (backwards compatibility)', async () => {
- const { serverConfig } = await setupHttpTest();
-
- expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.httpUrl!),
- {},
- );
+ it('should replace invalid characters with underscores', () => {
+ const funcDecl = { name: 'invalid-name with spaces' };
+ const serverName = 'test_server';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('test_server__invalid-name_with_spaces');
});
- it('should pass oauth token when provided', async () => {
- const headers = {
- Authorization: 'Bearer test-token',
+ it('should truncate long names', () => {
+ const funcDecl = {
+ name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit',
};
- const { serverConfig } = await setupHttpTest(headers);
-
- expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
- new URL(serverConfig.httpUrl!),
- { requestInit: { headers } },
+ const serverName = 'a_long_server_name';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result.length).toBe(63);
+ expect(result).toBe(
+ 'a_long_server_name__a_very_l___will_definitely_exceed_the_limit',
);
});
- });
- it('should prefix tool names if multiple MCP servers are configured', async () => {
- const serverConfig1: MCPServerConfig = { command: './mcp1' };
- const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };
- mockConfig.getMcpServers.mockReturnValue({
- server1: serverConfig1,
- server2: serverConfig2,
+ it('should handle names with only invalid characters', () => {
+ const funcDecl = { name: '!@#$%^&*()' };
+ const serverName = 'special-chars';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('special-chars____________');
});
- const mockTool1 = {
- name: 'toolA', // Same original name
- description: 'd1',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- const mockTool2 = {
- name: 'toolA', // Same original name
- description: 'd2',
- inputSchema: { type: 'object' as const, properties: {} },
- };
- const mockToolB = {
- name: 'toolB',
- description: 'dB',
- inputSchema: { type: 'object' as const, properties: {} },
- };
-
- vi.mocked(Client.prototype.listTools)
- .mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1
- .mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA)
-
- const effectivelyRegisteredTools = new Map<string, any>();
-
- mockToolRegistry.getTool.mockImplementation((toolName: string) =>
- effectivelyRegisteredTools.get(toolName),
- );
-
- // Store the original spy implementation if needed, or just let the new one be the behavior.
- // The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance.
- // We are setting its behavior for this test.
- mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => {
- // Simulate the actual registration name being stored for getTool to find
- effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister);
- // If it's the first time toolA is registered (from server1, not prefixed),
- // also make it findable by its original name for the prefixing check of server2/toolA.
- if (
- toolToRegister.serverName === 'server1' &&
- toolToRegister.serverToolName === 'toolA' &&
- toolToRegister.name === 'toolA'
- ) {
- effectivelyRegisteredTools.set('toolA', toolToRegister);
- }
- // The spy call count is inherently tracked by mockToolRegistry.registerTool itself.
+ it('should handle names that are already valid', () => {
+ const funcDecl = { name: 'already_valid' };
+ const serverName = 'validator';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('validator__already_valid');
});
- // PRE-MOCK getToolsByServer for the expected server names
- // This is for the final check in connectAndDiscover to see if any tools were registered *from that server*
- mockToolRegistry.getToolsByServer.mockImplementation(
- (serverName: string) => {
- if (serverName === 'server1')
- return [
- expect.objectContaining({ name: 'toolA' }),
- expect.objectContaining({ name: 'toolB' }),
- ];
- if (serverName === 'server2')
- return [expect.objectContaining({ name: 'server2__toolA' })];
- return [];
- },
- );
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
- const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
- (call) => call[0],
- ) as DiscoveredMCPTool[];
-
- // The order of server processing by Promise.all is not guaranteed.
- // One 'toolA' will be unprefixed, the other will be prefixed.
- const toolA_from_server1 = registeredArgs.find(
- (t) => t.serverToolName === 'toolA' && t.serverName === 'server1',
- );
- const toolA_from_server2 = registeredArgs.find(
- (t) => t.serverToolName === 'toolA' && t.serverName === 'server2',
- );
- const toolB_from_server1 = registeredArgs.find(
- (t) => t.serverToolName === 'toolB' && t.serverName === 'server1',
- );
-
- expect(toolA_from_server1).toBeDefined();
- expect(toolA_from_server2).toBeDefined();
- expect(toolB_from_server1).toBeDefined();
-
- expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique
-
- // Check that one of toolA is prefixed and the other is not, and the prefixed one is correct.
- if (toolA_from_server1?.name === 'toolA') {
- expect(toolA_from_server2?.name).toBe('server2__toolA');
- } else {
- expect(toolA_from_server1?.name).toBe('server1__toolA');
- expect(toolA_from_server2?.name).toBe('toolA');
- }
- });
-
- it('should clean schema properties ($schema, additionalProperties)', async () => {
- const serverConfig: MCPServerConfig = { command: './mcp-clean' };
- mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig });
-
- const rawSchema = {
- type: 'object' as const,
- $schema: 'http://json-schema.org/draft-07/schema#',
- additionalProperties: true,
- properties: {
- prop1: { type: 'string', $schema: 'remove-this' },
- prop2: {
- type: 'object' as const,
- additionalProperties: false,
- properties: { nested: { type: 'number' } },
- },
- },
- };
- const mockTool = {
- name: 'cleanTool',
- description: 'd',
- inputSchema: JSON.parse(JSON.stringify(rawSchema)),
- };
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: [mockTool],
+ it('should handle names with leading/trailing invalid characters', () => {
+ const funcDecl = { name: '-_invalid-_' };
+ const serverName = 'trim-test';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe('trim-test__-_invalid-_');
});
- // PRE-MOCK getToolsByServer for the expected server name
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
- const registeredTool = mockToolRegistry.registerTool.mock
- .calls[0][0] as DiscoveredMCPTool;
- const cleanedParams = registeredTool.schema.parameters as any;
-
- expect(cleanedParams).not.toHaveProperty('$schema');
- expect(cleanedParams).not.toHaveProperty('additionalProperties');
- expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema');
- expect(cleanedParams.properties.prop2).not.toHaveProperty(
- 'additionalProperties',
- );
- expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
- '$schema',
- );
- expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
- 'additionalProperties',
- );
- });
- it('should handle error if mcpServerCommand parsing fails', async () => {
- const commandString = 'my-mcp-server "unterminated quote';
- mockConfig.getMcpServerCommand.mockReturnValue(commandString);
- vi.mocked(parse).mockImplementation(() => {
- throw new Error('Parsing failed');
+ it('should handle names that are exactly 63 characters long', () => {
+ const longName = 'a'.repeat(45);
+ const funcDecl = { name: longName };
+ const serverName = 'server';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result).toBe(`server__${longName}`);
+ expect(result.length).toBe(53);
});
- vi.spyOn(console, 'error').mockImplementation(() => {});
- await expect(
- discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- ),
- ).rejects.toThrow('Parsing failed');
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
- expect(console.error).not.toHaveBeenCalled();
- });
-
- it('should log error and skip server if config is invalid (missing url and command)', async () => {
- mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
- vi.spyOn(console, 'error').mockImplementation(() => {});
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(console.error).toHaveBeenCalledWith(
- expect.stringContaining(
- "MCP server 'bad-server' has invalid configuration",
- ),
- );
- // Client constructor should not be called if config is invalid before instantiation
- expect(Client).not.toHaveBeenCalled();
- });
-
- it('should log error and skip server if mcpClient.connect fails', async () => {
- const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' };
- mockConfig.getMcpServers.mockReturnValue({
- 'fail-connect-server': serverConfig,
+ it('should handle names that are exactly 64 characters long', () => {
+ const longName = 'a'.repeat(55);
+ const funcDecl = { name: longName };
+ const serverName = 'server';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result.length).toBe(63);
+ expect(result).toBe(
+ 'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
+ );
});
- vi.mocked(Client.prototype.connect).mockRejectedValue(
- new Error('Connection refused'),
- );
- vi.spyOn(console, 'error').mockImplementation(() => {});
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(console.error).toHaveBeenCalledWith(
- expect.stringContaining(
- "failed to start or connect to MCP server 'fail-connect-server'",
- ),
- );
- expect(Client.prototype.listTools).not.toHaveBeenCalled();
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
- });
- it('should log error and skip server if mcpClient.listTools fails', async () => {
- const serverConfig: MCPServerConfig = { command: './mcp-fail-list' };
- mockConfig.getMcpServers.mockReturnValue({
- 'fail-list-server': serverConfig,
+ it('should handle names that are longer than 64 characters', () => {
+ const longName = 'a'.repeat(100);
+ const funcDecl = { name: longName };
+ const serverName = 'long-server';
+ const result = generateValidName(funcDecl, serverName);
+ expect(result.length).toBe(63);
+ expect(result).toBe(
+ 'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
+ );
});
- vi.mocked(Client.prototype.listTools).mockRejectedValue(
- new Error('ListTools error'),
- );
- vi.spyOn(console, 'error').mockImplementation(() => {});
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(console.error).toHaveBeenCalledWith(
- expect.stringContaining(
- "Failed to list or register tools for MCP server 'fail-list-server'",
- ),
- );
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
});
+ describe('isEnabled', () => {
+ const funcDecl = { name: 'myTool' };
+ const serverName = 'myServer';
- it('should assign mcpClient.onerror handler', async () => {
- const serverConfig: MCPServerConfig = { command: './mcp-onerror' };
- mockConfig.getMcpServers.mockReturnValue({
- 'onerror-server': serverConfig,
+ it('should return true if no include or exclude lists are provided', () => {
+ const mcpServerConfig = {};
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
});
- // PRE-MOCK getToolsByServer for the expected server name
- mockToolRegistry.getToolsByServer.mockReturnValueOnce([
- expect.any(DiscoveredMCPTool),
- ]);
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- const clientInstances = vi.mocked(Client).mock.results;
- expect(clientInstances.length).toBeGreaterThan(0);
- const lastClientInstance =
- clientInstances[clientInstances.length - 1]?.value;
- expect(lastClientInstance?.onerror).toEqual(expect.any(Function));
- });
- describe('Tool Filtering', () => {
- const mockTools = [
- {
- name: 'toolA',
- description: 'descA',
- inputSchema: { type: 'object' as const, properties: {} },
- },
- {
- name: 'toolB',
- description: 'descB',
- inputSchema: { type: 'object' as const, properties: {} },
- },
- {
- name: 'toolC',
- description: 'descC',
- inputSchema: { type: 'object' as const, properties: {} },
- },
- ];
-
- beforeEach(() => {
- vi.mocked(Client.prototype.listTools).mockResolvedValue({
- tools: mockTools,
- });
- mockToolRegistry.getToolsByServer.mockReturnValue([
- expect.any(DiscoveredMCPTool),
- ]);
+ it('should return false if the tool is in the exclude list', () => {
+ const mcpServerConfig = { excludeTools: ['myTool'] };
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
});
- it('should only include specified tools with includeTools', async () => {
- const serverConfig: MCPServerConfig = {
- command: './mcp-include',
- includeTools: ['toolA', 'toolC'],
- };
- mockConfig.getMcpServers.mockReturnValue({
- 'include-server': serverConfig,
- });
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
-
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolA' }),
- );
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolC' }),
- );
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolB' }),
- );
+ it('should return true if the tool is in the include list', () => {
+ const mcpServerConfig = { includeTools: ['myTool'] };
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
});
- it('should exclude specified tools with excludeTools', async () => {
- const serverConfig: MCPServerConfig = {
- command: './mcp-exclude',
- excludeTools: ['toolB'],
- };
- mockConfig.getMcpServers.mockReturnValue({
- 'exclude-server': serverConfig,
- });
-
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
- false,
- );
+ it('should return true if the tool is in the include list with parentheses', () => {
+ const mcpServerConfig = { includeTools: ['myTool()'] };
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
+ });
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolA' }),
- );
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolC' }),
- );
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolB' }),
- );
+ it('should return false if the include list exists but does not contain the tool', () => {
+ const mcpServerConfig = { includeTools: ['anotherTool'] };
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
});
- it('should handle both includeTools and excludeTools', async () => {
- const serverConfig: MCPServerConfig = {
- command: './mcp-both',
- includeTools: ['toolA', 'toolB'],
- excludeTools: ['toolB'],
+ it('should return false if the tool is in both the include and exclude lists', () => {
+ const mcpServerConfig = {
+ includeTools: ['myTool'],
+ excludeTools: ['myTool'],
};
- mockConfig.getMcpServers.mockReturnValue({ 'both-server': serverConfig });
+ expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
+ });
- await discoverMcpTools(
- mockConfig.getMcpServers() ?? {},
- mockConfig.getMcpServerCommand(),
- mockToolRegistry as any,
+ it('should return false if the function declaration has no name', () => {
+ const namelessFuncDecl = {};
+ const mcpServerConfig = {};
+ expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe(
false,
);
-
- expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
- expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolA' }),
- );
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolB' }),
- );
- expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith(
- expect.objectContaining({ serverToolName: 'toolC' }),
- );
});
});
});
-
-describe('sanitizeParameters', () => {
- it('should do nothing for an undefined schema', () => {
- const schema = undefined;
- sanitizeParameters(schema);
- });
-
- it('should remove default when anyOf is present', () => {
- const schema: Schema = {
- anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
- default: 'hello',
- };
- sanitizeParameters(schema);
- expect(schema.default).toBeUndefined();
- });
-
- it('should recursively sanitize items in anyOf', () => {
- const schema: Schema = {
- anyOf: [
- {
- anyOf: [{ type: Type.STRING }],
- default: 'world',
- },
- { type: Type.NUMBER },
- ],
- };
- sanitizeParameters(schema);
- expect(schema.anyOf![0].default).toBeUndefined();
- });
-
- it('should recursively sanitize items in items', () => {
- const schema: Schema = {
- items: {
- anyOf: [{ type: Type.STRING }],
- default: 'world',
- },
- };
- sanitizeParameters(schema);
- expect(schema.items!.default).toBeUndefined();
- });
-
- it('should recursively sanitize items in properties', () => {
- const schema: Schema = {
- properties: {
- prop1: {
- anyOf: [{ type: Type.STRING }],
- default: 'world',
- },
- },
- };
- sanitizeParameters(schema);
- expect(schema.properties!.prop1.default).toBeUndefined();
- });
-
- it('should handle complex nested schemas', () => {
- const schema: Schema = {
- properties: {
- prop1: {
- items: {
- anyOf: [{ type: Type.STRING }],
- default: 'world',
- },
- },
- prop2: {
- anyOf: [
- {
- properties: {
- nestedProp: {
- anyOf: [{ type: Type.NUMBER }],
- default: 123,
- },
- },
- },
- ],
- },
- },
- };
- sanitizeParameters(schema);
- expect(schema.properties!.prop1.items!.default).toBeUndefined();
- const nestedProp =
- schema.properties!.prop2.anyOf![0].properties!.nestedProp;
- expect(nestedProp?.default).toBeUndefined();
- });
-});
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index 6edfbac8..eb82190b 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -5,6 +5,7 @@
*/
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import {
SSEClientTransport,
@@ -17,7 +18,7 @@ import {
import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
-import { Type, mcpToTool } from '@google/genai';
+import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@@ -123,28 +124,25 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
return mcpDiscoveryState;
}
+/**
+ * Discovers tools from all configured MCP servers and registers them with the tool registry.
+ * It orchestrates the connection and discovery process for each server defined in the
+ * configuration, as well as any server specified via a command-line argument.
+ *
+ * @param mcpServers A record of named MCP server configurations.
+ * @param mcpServerCommand An optional command string for a dynamically specified MCP server.
+ * @param toolRegistry The central registry where discovered tools will be registered.
+ * @returns A promise that resolves when the discovery process has been attempted for all servers.
+ */
export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
debugMode: boolean,
): Promise<void> {
- // Set discovery state to in progress
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
-
try {
- if (mcpServerCommand) {
- const cmd = mcpServerCommand;
- const args = parse(cmd, process.env) as string[];
- if (args.some((arg) => typeof arg !== 'string')) {
- throw new Error('failed to parse mcpServerCommand: ' + cmd);
- }
- // use generic server name 'mcp'
- mcpServers['mcp'] = {
- command: args[0],
- args: args.slice(1),
- };
- }
+ mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand);
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
@@ -156,16 +154,31 @@ export async function discoverMcpTools(
),
);
await Promise.all(discoveryPromises);
-
- // Mark discovery as completed
- mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
- } catch (error) {
- // Still mark as completed even with errors
+ } finally {
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
- throw error;
}
}
+/** Visible for Testing */
+export function populateMcpServerCommand(
+ mcpServers: Record<string, MCPServerConfig>,
+ mcpServerCommand: string | undefined,
+): Record<string, MCPServerConfig> {
+ if (mcpServerCommand) {
+ const cmd = mcpServerCommand;
+ const args = parse(cmd, process.env) as string[];
+ if (args.some((arg) => typeof arg !== 'string')) {
+ throw new Error('failed to parse mcpServerCommand: ' + cmd);
+ }
+ // use generic server name 'mcp'
+ mcpServers['mcp'] = {
+ command: args[0],
+ args: args.slice(1),
+ };
+ }
+ return mcpServers;
+}
+
/**
* Connects to an MCP server and discovers available tools, registering them with the tool registry.
* This function handles the complete lifecycle of connecting to a server, discovering tools,
@@ -176,71 +189,117 @@ export async function discoverMcpTools(
* @param toolRegistry The registry to register discovered tools with
* @returns Promise that resolves when discovery is complete
*/
-async function connectAndDiscover(
+export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
debugMode: boolean,
): Promise<void> {
- // Initialize the server status as connecting
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
- let transport;
- if (mcpServerConfig.httpUrl) {
- const transportOptions: StreamableHTTPClientTransportOptions = {};
+ try {
+ const mcpClient = await connectToMcpServer(
+ mcpServerName,
+ mcpServerConfig,
+ debugMode,
+ );
+ try {
+ updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
- if (mcpServerConfig.headers) {
- transportOptions.requestInit = {
- headers: mcpServerConfig.headers,
+ mcpClient.onerror = (error) => {
+ console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
+ updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
- }
- transport = new StreamableHTTPClientTransport(
- new URL(mcpServerConfig.httpUrl),
- transportOptions,
- );
- } else if (mcpServerConfig.url) {
- const transportOptions: SSEClientTransportOptions = {};
- if (mcpServerConfig.headers) {
- transportOptions.requestInit = {
- headers: mcpServerConfig.headers,
- };
+ const tools = await discoverTools(
+ mcpServerName,
+ mcpServerConfig,
+ mcpClient,
+ );
+ for (const tool of tools) {
+ toolRegistry.registerTool(tool);
+ }
+ } catch (error) {
+ mcpClient.close();
+ throw error;
}
- transport = new SSEClientTransport(
- new URL(mcpServerConfig.url),
- transportOptions,
- );
- } else if (mcpServerConfig.command) {
- transport = new StdioClientTransport({
- command: mcpServerConfig.command,
- args: mcpServerConfig.args || [],
- env: {
- ...process.env,
- ...(mcpServerConfig.env || {}),
- } as Record<string, string>,
- cwd: mcpServerConfig.cwd,
- stderr: 'pipe',
- });
- } else {
- console.error(
- `MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`,
- );
- // Update status to disconnected
+ } catch (error) {
+ console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
}
+}
- if (
- debugMode &&
- transport instanceof StdioClientTransport &&
- transport.stderr
- ) {
- transport.stderr.on('data', (data) => {
- const stderrStr = data.toString().trim();
- console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
- });
+/**
+ * Discovers and sanitizes tools from a connected MCP client.
+ * It retrieves function declarations from the client, filters out disabled tools,
+ * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpServerConfig The configuration for the MCP server.
+ * @param mcpClient The active MCP client instance.
+ * @returns A promise that resolves to an array of discovered and enabled tools.
+ * @throws An error if no enabled tools are found or if the server provides invalid function declarations.
+ */
+export async function discoverTools(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ mcpClient: Client,
+): Promise<DiscoveredMCPTool[]> {
+ try {
+ const mcpCallableTool = mcpToTool(mcpClient);
+ const tool = await mcpCallableTool.tool();
+
+ if (!Array.isArray(tool.functionDeclarations)) {
+ throw new Error(`Server did not return valid function declarations.`);
+ }
+
+ const discoveredTools: DiscoveredMCPTool[] = [];
+ for (const funcDecl of tool.functionDeclarations) {
+ if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
+ continue;
+ }
+
+ const toolNameForModel = generateValidName(funcDecl, mcpServerName);
+
+ sanitizeParameters(funcDecl.parameters);
+
+ discoveredTools.push(
+ new DiscoveredMCPTool(
+ mcpCallableTool,
+ mcpServerName,
+ toolNameForModel,
+ funcDecl.description ?? '',
+ funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
+ funcDecl.name!,
+ mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
+ mcpServerConfig.trust,
+ ),
+ );
+ }
+ if (discoveredTools.length === 0) {
+ throw Error('No enabled tools found');
+ }
+ return discoveredTools;
+ } catch (error) {
+ throw new Error(`Error discovering tools: ${error}`);
}
+}
+/**
+ * Creates and connects an MCP client to a server based on the provided configuration.
+ * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
+ * establishes a connection. It also applies a patch to handle request timeouts.
+ *
+ * @param mcpServerName The name of the MCP server, used for logging and identification.
+ * @param mcpServerConfig The configuration specifying how to connect to the server.
+ * @returns A promise that resolves to a connected MCP `Client` instance.
+ * @throws An error if the connection fails or the configuration is invalid.
+ */
+export async function connectToMcpServer(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ debugMode: boolean,
+): Promise<Client> {
const mcpClient = new Client({
name: 'gemini-cli-mcp-client',
version: '0.0.1',
@@ -259,11 +318,20 @@ async function connectAndDiscover(
}
try {
- await mcpClient.connect(transport, {
- timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
- });
- // Connection successful
- updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
+ const transport = createTransport(
+ mcpServerName,
+ mcpServerConfig,
+ debugMode,
+ );
+ try {
+ await mcpClient.connect(transport, {
+ timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
+ });
+ return mcpClient;
+ } catch (error) {
+ await transport.close();
+ throw error;
+ }
} catch (error) {
// Create a safe config object that excludes sensitive information
const safeConfig = {
@@ -282,131 +350,110 @@ async function connectAndDiscover(
if (process.env.SANDBOX) {
errorString += `\nMake sure it is available in the sandbox`;
}
- console.error(errorString);
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
+ throw new Error(errorString);
}
+}
- mcpClient.onerror = (error) => {
- console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
- // Update status to disconnected on error
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- };
-
- try {
- const mcpCallableTool = mcpToTool(mcpClient);
- const tool = await mcpCallableTool.tool();
-
- if (!tool || !Array.isArray(tool.functionDeclarations)) {
- console.error(
- `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`,
- );
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- }
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
+/** Visible for Testing */
+export function createTransport(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ debugMode: boolean,
+): Transport {
+ if (mcpServerConfig.httpUrl) {
+ const transportOptions: StreamableHTTPClientTransportOptions = {};
+ if (mcpServerConfig.headers) {
+ transportOptions.requestInit = {
+ headers: mcpServerConfig.headers,
+ };
}
+ return new StreamableHTTPClientTransport(
+ new URL(mcpServerConfig.httpUrl),
+ transportOptions,
+ );
+ }
- for (const funcDecl of tool.functionDeclarations) {
- if (!funcDecl.name) {
- console.warn(
- `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
- );
- continue;
- }
-
- const { includeTools, excludeTools } = mcpServerConfig;
- const toolName = funcDecl.name;
-
- let isEnabled = false;
- if (includeTools === undefined) {
- isEnabled = true;
- } else {
- isEnabled = includeTools.some(
- (tool) => tool === toolName || tool.startsWith(`${toolName}(`),
- );
- }
-
- if (excludeTools?.includes(toolName)) {
- isEnabled = false;
- }
-
- if (!isEnabled) {
- continue;
- }
+ if (mcpServerConfig.url) {
+ const transportOptions: SSEClientTransportOptions = {};
+ if (mcpServerConfig.headers) {
+ transportOptions.requestInit = {
+ headers: mcpServerConfig.headers,
+ };
+ }
+ return new SSEClientTransport(
+ new URL(mcpServerConfig.url),
+ transportOptions,
+ );
+ }
- let toolNameForModel = funcDecl.name;
+ if (mcpServerConfig.command) {
+ const transport = new StdioClientTransport({
+ command: mcpServerConfig.command,
+ args: mcpServerConfig.args || [],
+ env: {
+ ...process.env,
+ ...(mcpServerConfig.env || {}),
+ } as Record<string, string>,
+ cwd: mcpServerConfig.cwd,
+ stderr: 'pipe',
+ });
+ if (debugMode) {
+ transport.stderr!.on('data', (data) => {
+ const stderrStr = data.toString().trim();
+ console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
+ });
+ }
+ return transport;
+ }
- // Replace invalid characters (based on 400 error message from Gemini API) with underscores
- toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_');
+ throw new Error(
+ `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
+ );
+}
- const existingTool = toolRegistry.getTool(toolNameForModel);
- if (existingTool) {
- toolNameForModel = mcpServerName + '__' + toolNameForModel;
- }
+/** Visible for testing */
+export function generateValidName(
+ funcDecl: FunctionDeclaration,
+ mcpServerName: string,
+) {
+ // Replace invalid characters (based on 400 error message from Gemini API) with underscores
+ let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
- // If longer than 63 characters, replace middle with '___'
- // (Gemini API says max length 64, but actual limit seems to be 63)
- if (toolNameForModel.length > 63) {
- toolNameForModel =
- toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32);
- }
+ // Prepend MCP server name to avoid conflicts with other tools
+ validToolname = mcpServerName + '__' + validToolname;
- sanitizeParameters(funcDecl.parameters);
+ // If longer than 63 characters, replace middle with '___'
+ // (Gemini API says max length 64, but actual limit seems to be 63)
+ if (validToolname.length > 63) {
+ validToolname =
+ validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
+ }
+ return validToolname;
+}
- toolRegistry.registerTool(
- new DiscoveredMCPTool(
- mcpCallableTool,
- mcpServerName,
- toolNameForModel,
- funcDecl.description ?? '',
- funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
- funcDecl.name,
- mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
- mcpServerConfig.trust,
- ),
- );
- }
- } catch (error) {
- console.error(
- `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`,
+/** Visible for testing */
+export function isEnabled(
+ funcDecl: FunctionDeclaration,
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+): boolean {
+ if (!funcDecl.name) {
+ console.warn(
+ `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
);
- // Ensure transport is cleaned up on error too
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- }
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
+ return false;
}
+ const { includeTools, excludeTools } = mcpServerConfig;
- // If no tools were registered from this MCP server, the following 'if' block
- // will close the connection. This is done to conserve resources and prevent
- // an orphaned connection to a server that isn't providing any usable
- // functionality. Connections to servers that did provide tools are kept
- // open, as those tools will require the connection to function.
- if (toolRegistry.getToolsByServer(mcpServerName).length === 0) {
- console.log(
- `No tools registered from MCP server '${mcpServerName}'. Closing connection.`,
- );
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- }
+ // excludeTools takes precedence over includeTools
+ if (excludeTools && excludeTools.includes(funcDecl.name)) {
+ return false;
}
+
+ return (
+ !includeTools ||
+ includeTools.some(
+ (tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`),
+ )
+ );
}
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index fba48c17..853f6458 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -326,6 +326,83 @@ describe('ToolRegistry', () => {
});
describe('sanitizeParameters', () => {
+ it('should remove default when anyOf is present', () => {
+ const schema: Schema = {
+ anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
+ default: 'hello',
+ };
+ sanitizeParameters(schema);
+ expect(schema.default).toBeUndefined();
+ });
+
+ it('should recursively sanitize items in anyOf', () => {
+ const schema: Schema = {
+ anyOf: [
+ {
+ anyOf: [{ type: Type.STRING }],
+ default: 'world',
+ },
+ { type: Type.NUMBER },
+ ],
+ };
+ sanitizeParameters(schema);
+ expect(schema.anyOf![0].default).toBeUndefined();
+ });
+
+ it('should recursively sanitize items in items', () => {
+ const schema: Schema = {
+ items: {
+ anyOf: [{ type: Type.STRING }],
+ default: 'world',
+ },
+ };
+ sanitizeParameters(schema);
+ expect(schema.items!.default).toBeUndefined();
+ });
+
+ it('should recursively sanitize items in properties', () => {
+ const schema: Schema = {
+ properties: {
+ prop1: {
+ anyOf: [{ type: Type.STRING }],
+ default: 'world',
+ },
+ },
+ };
+ sanitizeParameters(schema);
+ expect(schema.properties!.prop1.default).toBeUndefined();
+ });
+
+ it('should handle complex nested schemas', () => {
+ const schema: Schema = {
+ properties: {
+ prop1: {
+ items: {
+ anyOf: [{ type: Type.STRING }],
+ default: 'world',
+ },
+ },
+ prop2: {
+ anyOf: [
+ {
+ properties: {
+ nestedProp: {
+ anyOf: [{ type: Type.NUMBER }],
+ default: 123,
+ },
+ },
+ },
+ ],
+ },
+ },
+ };
+ sanitizeParameters(schema);
+ expect(schema.properties!.prop1.items!.default).toBeUndefined();
+ const nestedProp =
+ schema.properties!.prop2.anyOf![0].properties!.nestedProp;
+ expect(nestedProp?.default).toBeUndefined();
+ });
+
it('should remove unsupported format from a simple string property', () => {
const schema: Schema = {
type: Type.OBJECT,
@@ -356,25 +433,6 @@ describe('sanitizeParameters', () => {
expect(schema).toEqual(originalSchema);
});
- it('should handle nested objects recursively', () => {
- const schema: Schema = {
- type: Type.OBJECT,
- properties: {
- user: {
- type: Type.OBJECT,
- properties: {
- email: { type: Type.STRING, format: 'email' },
- },
- },
- },
- };
- sanitizeParameters(schema);
- expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
- 'format',
- undefined,
- );
- });
-
it('should handle arrays of objects', () => {
const schema: Schema = {
type: Type.OBJECT,
@@ -414,19 +472,6 @@ describe('sanitizeParameters', () => {
expect(() => sanitizeParameters(undefined)).not.toThrow();
});
- it('should handle cyclic schemas without crashing', () => {
- const schema: any = {
- type: Type.OBJECT,
- properties: {
- name: { type: Type.STRING, format: 'hostname' },
- },
- };
- schema.properties.self = schema;
-
- expect(() => sanitizeParameters(schema)).not.toThrow();
- expect(schema.properties.name).toHaveProperty('format', undefined);
- });
-
it('should handle complex nested schemas with cycles', () => {
const userNode: any = {
type: Type.OBJECT,