summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/server/src/tools/mcp-client.test.ts371
-rw-r--r--packages/server/src/tools/mcp-client.ts138
-rw-r--r--packages/server/src/tools/mcp-tool.test.ts161
-rw-r--r--packages/server/src/tools/mcp-tool.ts49
-rw-r--r--packages/server/src/tools/tool-registry.test.ts26
-rw-r--r--packages/server/src/tools/tool-registry.ts154
6 files changed, 737 insertions, 162 deletions
diff --git a/packages/server/src/tools/mcp-client.test.ts b/packages/server/src/tools/mcp-client.test.ts
new file mode 100644
index 00000000..4664669d
--- /dev/null
+++ b/packages/server/src/tools/mcp-client.test.ts
@@ -0,0 +1,371 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/* eslint-disable @typescript-eslint/no-explicit-any */
+import {
+ describe,
+ it,
+ expect,
+ vi,
+ beforeEach,
+ afterEach,
+ Mocked,
+} from 'vitest';
+import { discoverMcpTools } from './mcp-client.js';
+import { Config, MCPServerConfig } from '../config/config.js';
+import { ToolRegistry } from './tool-registry.js';
+import { DiscoveredMCPTool } from './mcp-tool.js';
+import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
+import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
+import { parse, ParseEntry } from 'shell-quote';
+
+// Mock dependencies
+vi.mock('shell-quote');
+
+vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
+ const MockedClient = vi.fn();
+ MockedClient.prototype.connect = vi.fn();
+ MockedClient.prototype.listTools = vi.fn();
+ // Ensure instances have an onerror property that can be spied on or assigned to
+ MockedClient.mockImplementation(() => ({
+ connect: MockedClient.prototype.connect,
+ listTools: MockedClient.prototype.listTools,
+ onerror: vi.fn(), // Each instance gets its own onerror mock
+ }));
+ return { Client: MockedClient };
+});
+
+// Define a global mock for stderr.on that can be cleared and checked
+const mockGlobalStdioStderrOn = vi.fn();
+
+vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
+ // This is the constructor for StdioClientTransport
+ const MockedStdioTransport = vi.fn().mockImplementation(function (
+ this: any,
+ options: any,
+ ) {
+ // Always return a new object with a fresh reference to the global mock for .on
+ this.options = options;
+ this.stderr = { on: mockGlobalStdioStderrOn };
+ return this;
+ });
+ return { StdioClientTransport: MockedStdioTransport };
+});
+
+vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
+ const MockedSSETransport = vi.fn();
+ return { SSEClientTransport: MockedSSETransport };
+});
+
+vi.mock('./tool-registry.js');
+
+describe('discoverMcpTools', () => {
+ let mockConfig: Mocked<Config>;
+ let mockToolRegistry: Mocked<ToolRegistry>;
+
+ beforeEach(() => {
+ mockConfig = {
+ getMcpServers: vi.fn().mockReturnValue({}),
+ getMcpServerCommand: vi.fn().mockReturnValue(undefined),
+ } as any;
+
+ mockToolRegistry = new (ToolRegistry as any)(
+ mockConfig,
+ ) as Mocked<ToolRegistry>;
+ mockToolRegistry.registerTool = vi.fn();
+
+ vi.mocked(parse).mockClear();
+ vi.mocked(Client).mockClear();
+ vi.mocked(Client.prototype.connect)
+ .mockClear()
+ .mockResolvedValue(undefined);
+ vi.mocked(Client.prototype.listTools)
+ .mockClear()
+ .mockResolvedValue({ tools: [] });
+
+ vi.mocked(StdioClientTransport).mockClear();
+ mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
+
+ vi.mocked(SSEClientTransport).mockClear();
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ it('should do nothing if no MCP servers or command are configured', async () => {
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+ expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
+ expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
+ expect(Client).not.toHaveBeenCalled();
+ expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
+ });
+
+ it('should discover tools via mcpServerCommand', async () => {
+ const commandString = 'my-mcp-server --start';
+ const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[];
+ mockConfig.getMcpServerCommand.mockReturnValue(commandString);
+ vi.mocked(parse).mockReturnValue(parsedCommand);
+
+ const mockTool = {
+ name: 'tool1',
+ description: 'desc1',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(parse).toHaveBeenCalledWith(commandString, process.env);
+ expect(StdioClientTransport).toHaveBeenCalledWith({
+ command: parsedCommand[0],
+ args: parsedCommand.slice(1),
+ env: expect.any(Object),
+ cwd: undefined,
+ stderr: 'pipe',
+ });
+ expect(Client.prototype.connect).toHaveBeenCalledTimes(1);
+ expect(Client.prototype.listTools).toHaveBeenCalledTimes(1);
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
+ expect.any(DiscoveredMCPTool),
+ );
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ expect(registeredTool.name).toBe('tool1');
+ expect(registeredTool.serverToolName).toBe('tool1');
+ });
+
+ it('should discover tools via mcpServers config (stdio)', async () => {
+ const serverConfig: MCPServerConfig = {
+ command: './mcp-stdio',
+ args: ['arg1'],
+ };
+ mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig });
+
+ const mockTool = {
+ name: 'tool-stdio',
+ description: 'desc-stdio',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(StdioClientTransport).toHaveBeenCalledWith({
+ command: serverConfig.command,
+ args: serverConfig.args,
+ env: expect.any(Object),
+ cwd: undefined,
+ stderr: 'pipe',
+ });
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
+ expect.any(DiscoveredMCPTool),
+ );
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ expect(registeredTool.name).toBe('tool-stdio');
+ });
+
+ it('should discover tools via mcpServers config (sse)', async () => {
+ const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' };
+ mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig });
+
+ const mockTool = {
+ name: 'tool-sse',
+ description: 'desc-sse',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
+ expect.any(DiscoveredMCPTool),
+ );
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ expect(registeredTool.name).toBe('tool-sse');
+ });
+
+ it('should prefix tool names if multiple MCP servers are configured', async () => {
+ const serverConfig1: MCPServerConfig = { command: './mcp1' };
+ const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };
+ mockConfig.getMcpServers.mockReturnValue({
+ server1: serverConfig1,
+ server2: serverConfig2,
+ });
+
+ const mockTool1 = {
+ name: 'toolA',
+ description: 'd1',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ const mockTool2 = {
+ name: 'toolB',
+ description: 'd2',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+
+ vi.mocked(Client.prototype.listTools)
+ .mockResolvedValueOnce({ tools: [mockTool1] })
+ .mockResolvedValueOnce({ tools: [mockTool2] });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2);
+ const registeredTool1 = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ const registeredTool2 = mockToolRegistry.registerTool.mock
+ .calls[1][0] as DiscoveredMCPTool;
+
+ expect(registeredTool1.name).toBe('server1__toolA');
+ expect(registeredTool1.serverToolName).toBe('toolA');
+ expect(registeredTool2.name).toBe('server2__toolB');
+ expect(registeredTool2.serverToolName).toBe('toolB');
+ });
+
+ it('should clean schema properties ($schema, additionalProperties)', async () => {
+ const serverConfig: MCPServerConfig = { command: './mcp-clean' };
+ mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig });
+
+ const rawSchema = {
+ type: 'object' as const,
+ $schema: 'http://json-schema.org/draft-07/schema#',
+ additionalProperties: true,
+ properties: {
+ prop1: { type: 'string', $schema: 'remove-this' },
+ prop2: {
+ type: 'object' as const,
+ additionalProperties: false,
+ properties: { nested: { type: 'number' } },
+ },
+ },
+ };
+ const mockTool = {
+ name: 'cleanTool',
+ description: 'd',
+ inputSchema: JSON.parse(JSON.stringify(rawSchema)),
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ const cleanedParams = registeredTool.schema.parameters as any;
+
+ expect(cleanedParams).not.toHaveProperty('$schema');
+ expect(cleanedParams).not.toHaveProperty('additionalProperties');
+ expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema');
+ expect(cleanedParams.properties.prop2).not.toHaveProperty(
+ 'additionalProperties',
+ );
+ expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
+ '$schema',
+ );
+ expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
+ 'additionalProperties',
+ );
+ });
+
+ it('should handle error if mcpServerCommand parsing fails', async () => {
+ const commandString = 'my-mcp-server "unterminated quote';
+ mockConfig.getMcpServerCommand.mockReturnValue(commandString);
+ vi.mocked(parse).mockImplementation(() => {
+ throw new Error('Parsing failed');
+ });
+ vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ await expect(
+ discoverMcpTools(mockConfig, mockToolRegistry),
+ ).rejects.toThrow('Parsing failed');
+ expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
+ expect(console.error).not.toHaveBeenCalled();
+ });
+
+ it('should log error and skip server if config is invalid (missing url and command)', async () => {
+ mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
+ vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(console.error).toHaveBeenCalledWith(
+ expect.stringContaining(
+ "MCP server 'bad-server' has invalid configuration",
+ ),
+ );
+ // Client constructor should not be called if config is invalid before instantiation
+ expect(Client).not.toHaveBeenCalled();
+ });
+
+ it('should log error and skip server if mcpClient.connect fails', async () => {
+ const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' };
+ mockConfig.getMcpServers.mockReturnValue({
+ 'fail-connect-server': serverConfig,
+ });
+ vi.mocked(Client.prototype.connect).mockRejectedValue(
+ new Error('Connection refused'),
+ );
+ vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(console.error).toHaveBeenCalledWith(
+ expect.stringContaining(
+ "failed to start or connect to MCP server 'fail-connect-server'",
+ ),
+ );
+ expect(Client.prototype.listTools).not.toHaveBeenCalled();
+ expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
+ });
+
+ it('should log error and skip server if mcpClient.listTools fails', async () => {
+ const serverConfig: MCPServerConfig = { command: './mcp-fail-list' };
+ mockConfig.getMcpServers.mockReturnValue({
+ 'fail-list-server': serverConfig,
+ });
+ vi.mocked(Client.prototype.listTools).mockRejectedValue(
+ new Error('ListTools error'),
+ );
+ vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ expect(console.error).toHaveBeenCalledWith(
+ expect.stringContaining(
+ "Failed to list or register tools for MCP server 'fail-list-server'",
+ ),
+ );
+ expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
+ });
+
+ it('should assign mcpClient.onerror handler', async () => {
+ const serverConfig: MCPServerConfig = { command: './mcp-onerror' };
+ mockConfig.getMcpServers.mockReturnValue({
+ 'onerror-server': serverConfig,
+ });
+
+ await discoverMcpTools(mockConfig, mockToolRegistry);
+
+ const clientInstances = vi.mocked(Client).mock.results;
+ expect(clientInstances.length).toBeGreaterThan(0);
+ const lastClientInstance =
+ clientInstances[clientInstances.length - 1]?.value;
+ expect(lastClientInstance?.onerror).toEqual(expect.any(Function));
+ });
+});
diff --git a/packages/server/src/tools/mcp-client.ts b/packages/server/src/tools/mcp-client.ts
new file mode 100644
index 00000000..8c2b4879
--- /dev/null
+++ b/packages/server/src/tools/mcp-client.ts
@@ -0,0 +1,138 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
+import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
+import { parse } from 'shell-quote';
+import { Config, MCPServerConfig } from '../config/config.js';
+import { DiscoveredMCPTool } from './mcp-tool.js';
+import { ToolRegistry } from './tool-registry.js';
+
+export async function discoverMcpTools(
+ config: Config,
+ toolRegistry: ToolRegistry,
+): Promise<void> {
+ const mcpServers = config.getMcpServers() || {};
+
+ if (config.getMcpServerCommand()) {
+ const cmd = config.getMcpServerCommand()!;
+ const args = parse(cmd, process.env) as string[];
+ if (args.some((arg) => typeof arg !== 'string')) {
+ throw new Error('failed to parse mcpServerCommand: ' + cmd);
+ }
+ // use generic server name 'mcp'
+ mcpServers['mcp'] = {
+ command: args[0],
+ args: args.slice(1),
+ };
+ }
+
+ const discoveryPromises = Object.entries(mcpServers).map(
+ ([mcpServerName, mcpServerConfig]) =>
+ connectAndDiscover(
+ mcpServerName,
+ mcpServerConfig,
+ toolRegistry,
+ mcpServers,
+ ),
+ );
+ await Promise.all(discoveryPromises);
+}
+
+async function connectAndDiscover(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ toolRegistry: ToolRegistry,
+ mcpServers: Record<string, MCPServerConfig>,
+): Promise<void> {
+ let transport;
+ if (mcpServerConfig.url) {
+ transport = new SSEClientTransport(new URL(mcpServerConfig.url));
+ } else if (mcpServerConfig.command) {
+ transport = new StdioClientTransport({
+ command: mcpServerConfig.command,
+ args: mcpServerConfig.args || [],
+ env: {
+ ...process.env,
+ ...(mcpServerConfig.env || {}),
+ } as Record<string, string>,
+ cwd: mcpServerConfig.cwd,
+ stderr: 'pipe',
+ });
+ } else {
+ console.error(
+ `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`,
+ );
+ return; // Return a resolved promise as this path doesn't throw.
+ }
+
+ const mcpClient = new Client({
+ name: 'gemini-cli-mcp-client',
+ version: '0.0.1',
+ });
+
+ try {
+ await mcpClient.connect(transport);
+ } catch (error) {
+ console.error(
+ `failed to start or connect to MCP server '${mcpServerName}' ` +
+ `${JSON.stringify(mcpServerConfig)}; \n${error}`,
+ );
+ return; // Return a resolved promise, let other MCP servers be discovered.
+ }
+
+ mcpClient.onerror = (error) => {
+ console.error('MCP ERROR', error.toString());
+ };
+
+ if (transport instanceof StdioClientTransport && transport.stderr) {
+ transport.stderr.on('data', (data) => {
+ if (!data.toString().includes('] INFO')) {
+ console.debug('MCP STDERR', data.toString());
+ }
+ });
+ }
+
+ try {
+ const result = await mcpClient.listTools();
+ for (const tool of result.tools) {
+ // Recursively remove additionalProperties and $schema from the inputSchema
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any -- This function recursively navigates a deeply nested and potentially heterogeneous JSON schema object. Using 'any' is a pragmatic choice here to avoid overly complex type definitions for all possible schema variations.
+ const removeSchemaProps = (obj: any) => {
+ if (typeof obj !== 'object' || obj === null) {
+ return;
+ }
+ if (Array.isArray(obj)) {
+ obj.forEach(removeSchemaProps);
+ } else {
+ delete obj.additionalProperties;
+ delete obj.$schema;
+ Object.values(obj).forEach(removeSchemaProps);
+ }
+ };
+ removeSchemaProps(tool.inputSchema);
+
+ toolRegistry.registerTool(
+ new DiscoveredMCPTool(
+ mcpClient,
+ Object.keys(mcpServers).length > 1
+ ? mcpServerName + '__' + tool.name
+ : tool.name,
+ tool.description ?? '',
+ tool.inputSchema,
+ tool.name,
+ mcpServerConfig.timeout,
+ ),
+ );
+ }
+ } catch (error) {
+ console.error(
+ `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`,
+ );
+ // Do not re-throw, allow other servers to proceed.
+ }
+}
diff --git a/packages/server/src/tools/mcp-tool.test.ts b/packages/server/src/tools/mcp-tool.test.ts
new file mode 100644
index 00000000..e28cf586
--- /dev/null
+++ b/packages/server/src/tools/mcp-tool.test.ts
@@ -0,0 +1,161 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/* eslint-disable @typescript-eslint/no-explicit-any */
+import {
+ describe,
+ it,
+ expect,
+ vi,
+ beforeEach,
+ afterEach,
+ Mocked,
+} from 'vitest';
+import {
+ DiscoveredMCPTool,
+ MCP_TOOL_DEFAULT_TIMEOUT_MSEC,
+} from './mcp-tool.js';
+import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { ToolResult } from './tools.js';
+
+// Mock MCP SDK Client
+vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
+ const MockClient = vi.fn();
+ MockClient.prototype.callTool = vi.fn();
+ return { Client: MockClient };
+});
+
+describe('DiscoveredMCPTool', () => {
+ let mockMcpClient: Mocked<Client>;
+ const toolName = 'test-mcp-tool';
+ const serverToolName = 'actual-server-tool-name';
+ const baseDescription = 'A test MCP tool.';
+ const inputSchema = {
+ type: 'object' as const,
+ properties: { param: { type: 'string' } },
+ };
+
+ beforeEach(() => {
+ // Create a new mock client for each test to reset call history
+ mockMcpClient = new (Client as any)({
+ name: 'test-client',
+ version: '0.0.1',
+ }) as Mocked<Client>;
+ vi.mocked(mockMcpClient.callTool).mockClear();
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ describe('constructor', () => {
+ it('should set properties correctly and augment description', () => {
+ const tool = new DiscoveredMCPTool(
+ mockMcpClient,
+ toolName,
+ baseDescription,
+ inputSchema,
+ serverToolName,
+ );
+
+ expect(tool.name).toBe(toolName);
+ expect(tool.schema.name).toBe(toolName);
+ expect(tool.schema.description).toContain(baseDescription);
+ expect(tool.schema.description).toContain('This MCP tool was discovered');
+ // Corrected assertion for backticks and template literal
+ expect(tool.schema.description).toContain(
+ `tools/call\` method for tool name \`${toolName}\``,
+ );
+ expect(tool.schema.parameters).toEqual(inputSchema);
+ expect(tool.serverToolName).toBe(serverToolName);
+ expect(tool.timeout).toBeUndefined();
+ });
+
+ it('should accept and store a custom timeout', () => {
+ const customTimeout = 5000;
+ const tool = new DiscoveredMCPTool(
+ mockMcpClient,
+ toolName,
+ baseDescription,
+ inputSchema,
+ serverToolName,
+ customTimeout,
+ );
+ expect(tool.timeout).toBe(customTimeout);
+ });
+ });
+
+ describe('execute', () => {
+ it('should call mcpClient.callTool with correct parameters and default timeout', async () => {
+ const tool = new DiscoveredMCPTool(
+ mockMcpClient,
+ toolName,
+ baseDescription,
+ inputSchema,
+ serverToolName,
+ );
+ const params = { param: 'testValue' };
+ const expectedMcpResult = { success: true, details: 'executed' };
+ vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult);
+
+ const result: ToolResult = await tool.execute(params);
+
+ expect(mockMcpClient.callTool).toHaveBeenCalledWith(
+ {
+ name: serverToolName,
+ arguments: params,
+ },
+ undefined,
+ {
+ timeout: MCP_TOOL_DEFAULT_TIMEOUT_MSEC,
+ },
+ );
+ const expectedOutput = JSON.stringify(expectedMcpResult, null, 2);
+ expect(result.llmContent).toBe(expectedOutput);
+ expect(result.returnDisplay).toBe(expectedOutput);
+ });
+
+ it('should call mcpClient.callTool with custom timeout if provided', async () => {
+ const customTimeout = 15000;
+ const tool = new DiscoveredMCPTool(
+ mockMcpClient,
+ toolName,
+ baseDescription,
+ inputSchema,
+ serverToolName,
+ customTimeout,
+ );
+ const params = { param: 'anotherValue' };
+ const expectedMcpResult = { result: 'done' };
+ vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult);
+
+ await tool.execute(params);
+
+ expect(mockMcpClient.callTool).toHaveBeenCalledWith(
+ expect.anything(),
+ undefined,
+ {
+ timeout: customTimeout,
+ },
+ );
+ });
+
+ it('should propagate rejection if mcpClient.callTool rejects', async () => {
+ const tool = new DiscoveredMCPTool(
+ mockMcpClient,
+ toolName,
+ baseDescription,
+ inputSchema,
+ serverToolName,
+ );
+ const params = { param: 'failCase' };
+ const expectedError = new Error('MCP call failed');
+ vi.mocked(mockMcpClient.callTool).mockRejectedValue(expectedError);
+
+ await expect(tool.execute(params)).rejects.toThrow(expectedError);
+ });
+ });
+});
diff --git a/packages/server/src/tools/mcp-tool.ts b/packages/server/src/tools/mcp-tool.ts
new file mode 100644
index 00000000..05ad750c
--- /dev/null
+++ b/packages/server/src/tools/mcp-tool.ts
@@ -0,0 +1,49 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { BaseTool, ToolResult } from './tools.js';
+
+type ToolParams = Record<string, unknown>;
+
+export const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
+
+export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
+ constructor(
+ private readonly mcpClient: Client,
+ readonly name: string,
+ readonly description: string,
+ readonly parameterSchema: Record<string, unknown>,
+ readonly serverToolName: string,
+ readonly timeout?: number,
+ ) {
+ description += `
+
+This MCP tool was discovered from a local MCP server using JSON RPC 2.0 over stdio transport protocol.
+When called, this tool will invoke the \`tools/call\` method for tool name \`${name}\`.
+MCP servers can be configured in project or user settings.
+Returns the MCP server response as a json string.
+`;
+ super(name, name, description, parameterSchema);
+ }
+
+ async execute(params: ToolParams): Promise<ToolResult> {
+ const result = await this.mcpClient.callTool(
+ {
+ name: this.serverToolName,
+ arguments: params,
+ },
+ undefined, // skip resultSchema to specify options (RequestOptions)
+ {
+ timeout: this.timeout ?? MCP_TOOL_DEFAULT_TIMEOUT_MSEC,
+ },
+ );
+ return {
+ llmContent: JSON.stringify(result, null, 2),
+ returnDisplay: JSON.stringify(result, null, 2),
+ };
+ }
+}
diff --git a/packages/server/src/tools/tool-registry.test.ts b/packages/server/src/tools/tool-registry.test.ts
index 4c2bff38..bb41b35c 100644
--- a/packages/server/src/tools/tool-registry.test.ts
+++ b/packages/server/src/tools/tool-registry.test.ts
@@ -14,11 +14,8 @@ import {
afterEach,
Mocked,
} from 'vitest';
-import {
- ToolRegistry,
- DiscoveredTool,
- DiscoveredMCPTool,
-} from './tool-registry.js';
+import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
+import { DiscoveredMCPTool } from './mcp-tool.js';
import { Config } from '../config/config.js';
import { BaseTool, ToolResult } from './tools.js';
import { FunctionDeclaration } from '@google/genai';
@@ -347,7 +344,7 @@ describe('ToolRegistry', () => {
toolRegistry = new ToolRegistry(config);
});
- it('should discover tools using discovery command', () => {
+ it('should discover tools using discovery command', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
const mockToolDeclarations: FunctionDeclaration[] = [
@@ -366,7 +363,7 @@ describe('ToolRegistry', () => {
),
);
- toolRegistry.discoverTools();
+ await toolRegistry.discoverTools();
expect(execSync).toHaveBeenCalledWith(discoveryCommand);
const discoveredTool = toolRegistry.getTool('discovered-tool-1');
@@ -376,7 +373,7 @@ describe('ToolRegistry', () => {
expect(discoveredTool?.description).toContain(discoveryCommand);
});
- it('should remove previously discovered tools before discovering new ones', () => {
+ it('should remove previously discovered tools before discovering new ones', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
mockExecSync.mockReturnValueOnce(
@@ -394,7 +391,7 @@ describe('ToolRegistry', () => {
]),
),
);
- toolRegistry.discoverTools();
+ await toolRegistry.discoverTools();
expect(toolRegistry.getTool('old-discovered-tool')).toBeInstanceOf(
DiscoveredTool,
);
@@ -414,7 +411,7 @@ describe('ToolRegistry', () => {
]),
),
);
- toolRegistry.discoverTools();
+ await toolRegistry.discoverTools();
expect(toolRegistry.getTool('old-discovered-tool')).toBeUndefined();
expect(toolRegistry.getTool('new-discovered-tool')).toBeInstanceOf(
DiscoveredTool,
@@ -457,8 +454,7 @@ describe('ToolRegistry', () => {
});
mockMcpClientInstance.connect.mockResolvedValue(undefined);
- toolRegistry.discoverTools();
- await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for async operations
+ await toolRegistry.discoverTools();
expect(Client).toHaveBeenCalledTimes(1);
expect(StdioClientTransport).toHaveBeenCalledWith({
@@ -511,8 +507,7 @@ describe('ToolRegistry', () => {
});
mockMcpClientInstance.connect.mockResolvedValue(undefined);
- toolRegistry.discoverTools();
- await new Promise((resolve) => setTimeout(resolve, 100));
+ await toolRegistry.discoverTools();
expect(Client).toHaveBeenCalledTimes(1);
expect(StdioClientTransport).toHaveBeenCalledWith({
@@ -544,8 +539,7 @@ describe('ToolRegistry', () => {
// Need to await the async IIFE within discoverTools.
// Since discoverTools itself isn't async, we can't directly await it.
// We'll check the console.error mock.
- toolRegistry.discoverTools();
- await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for async operations
+ await toolRegistry.discoverTools();
expect(console.error).toHaveBeenCalledWith(
`failed to start or connect to MCP server 'failing-mcp' ${JSON.stringify({ command: 'fail-cmd' })}; \nError: Connection failed`,
diff --git a/packages/server/src/tools/tool-registry.ts b/packages/server/src/tools/tool-registry.ts
index 7b75e0f2..a2677e63 100644
--- a/packages/server/src/tools/tool-registry.ts
+++ b/packages/server/src/tools/tool-registry.ts
@@ -7,15 +7,11 @@
import { FunctionDeclaration } from '@google/genai';
import { Tool, ToolResult, BaseTool } from './tools.js';
import { Config } from '../config/config.js';
-import { parse } from 'shell-quote';
import { spawn, execSync } from 'node:child_process';
-// TODO: remove this dependency once MCP support is built into genai SDK
-import { Client } from '@modelcontextprotocol/sdk/client/index.js';
-import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
-import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
-type ToolParams = Record<string, unknown>;
+import { discoverMcpTools } from './mcp-client.js';
+import { DiscoveredMCPTool } from './mcp-tool.js';
-const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
+type ToolParams = Record<string, unknown>;
export class DiscoveredTool extends BaseTool<ToolParams, ToolResult> {
constructor(
@@ -95,43 +91,6 @@ Signal: Signal number or \`(none)\` if no signal was received.
}
}
-export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
- constructor(
- private readonly mcpClient: Client,
- readonly name: string,
- readonly description: string,
- readonly parameterSchema: Record<string, unknown>,
- readonly serverToolName: string,
- readonly timeout?: number,
- ) {
- description += `
-
-This MCP tool was discovered from a local MCP server using JSON RPC 2.0 over stdio transport protocol.
-When called, this tool will invoke the \`tools/call\` method for tool name \`${name}\`.
-MCP servers can be configured in project or user settings.
-Returns the MCP server response as a json string.
-`;
- super(name, name, description, parameterSchema);
- }
-
- async execute(params: ToolParams): Promise<ToolResult> {
- const result = await this.mcpClient.callTool(
- {
- name: this.serverToolName,
- arguments: params,
- },
- undefined, // skip resultSchema to specify options (RequestOptions)
- {
- timeout: this.timeout ?? MCP_TOOL_DEFAULT_TIMEOUT_MSEC,
- },
- );
- return {
- llmContent: JSON.stringify(result, null, 2),
- returnDisplay: JSON.stringify(result, null, 2),
- };
- }
-}
-
export class ToolRegistry {
private tools: Map<string, Tool> = new Map();
private config: Config;
@@ -158,11 +117,13 @@ export class ToolRegistry {
* Discovers tools from project, if a discovery command is configured.
* Can be called multiple times to update discovered tools.
*/
- discoverTools(): void {
+ async discoverTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
- if (tool instanceof DiscoveredTool) {
+ if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
this.tools.delete(tool.name);
+ } else {
+ // Keep manually registered tools
}
}
// discover tools using discovery command, if configured
@@ -186,106 +147,7 @@ export class ToolRegistry {
}
}
// discover tools using MCP servers, if configured
- // convert mcpServerCommand (if any) to StdioServerParameters
- const mcpServers = this.config.getMcpServers() || {};
-
- if (this.config.getMcpServerCommand()) {
- const cmd = this.config.getMcpServerCommand()!;
- const args = parse(cmd, process.env) as string[];
- if (args.some((arg) => typeof arg !== 'string')) {
- throw new Error('failed to parse mcpServerCommand: ' + cmd);
- }
- // use generic server name 'mcp'
- mcpServers['mcp'] = {
- command: args[0],
- args: args.slice(1),
- };
- }
- for (const [mcpServerName, mcpServerConfig] of Object.entries(mcpServers)) {
- (async () => {
- const mcpClient = new Client({
- name: 'mcp-client',
- version: '0.0.1',
- });
- let transport;
- if (mcpServerConfig.url) {
- // SSE transport if URL is provided
- transport = new SSEClientTransport(new URL(mcpServerConfig.url));
- } else if (mcpServerConfig.command) {
- // Stdio transport if command is provided
- transport = new StdioClientTransport({
- command: mcpServerConfig.command,
- args: mcpServerConfig.args || [],
- env: {
- ...process.env,
- ...(mcpServerConfig.env || {}),
- } as Record<string, string>,
- cwd: mcpServerConfig.cwd,
- stderr: 'pipe',
- });
- } else {
- console.error(
- `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`,
- );
- return;
- }
- try {
- await mcpClient.connect(transport);
- } catch (error) {
- console.error(
- `failed to start or connect to MCP server '${mcpServerName}' ` +
- `${JSON.stringify(mcpServerConfig)}; \n${error}`,
- );
- // Do not re-throw, let other MCP servers be discovered.
- return; // Exit this async IIFE if connection failed
- }
- mcpClient.onerror = (error) => {
- console.error('MCP ERROR', error.toString());
- };
- if (transport instanceof StdioClientTransport && !transport.stderr) {
- throw new Error('transport missing stderr stream');
- }
- if (transport instanceof StdioClientTransport) {
- transport.stderr!.on('data', (data) => {
- // filter out INFO messages logged for each request received
- if (!data.toString().includes('] INFO')) {
- console.debug('MCP STDERR', data.toString());
- }
- });
- }
- const result = await mcpClient.listTools();
- for (const tool of result.tools) {
- // Recursively remove additionalProperties and $schema from the inputSchema
- // eslint-disable-next-line @typescript-eslint/no-explicit-any -- This function recursively navigates a deeply nested and potentially heterogeneous JSON schema object. Using 'any' is a pragmatic choice here to avoid overly complex type definitions for all possible schema variations.
- const removeSchemaProps = (obj: any) => {
- if (typeof obj !== 'object' || obj === null) {
- return;
- }
- if (Array.isArray(obj)) {
- obj.forEach(removeSchemaProps);
- } else {
- delete obj.additionalProperties;
- delete obj.$schema;
- Object.values(obj).forEach(removeSchemaProps);
- }
- };
- removeSchemaProps(tool.inputSchema);
-
- this.registerTool(
- new DiscoveredMCPTool(
- mcpClient,
- Object.keys(mcpServers).length > 1
- ? mcpServerName + '__' + tool.name
- : tool.name,
- tool.description ?? '',
- tool.inputSchema,
- tool.name,
- mcpServerConfig.timeout,
- ),
- );
- }
- })();
- }
+ await discoverMcpTools(this.config, this);
}
/**