summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRamón Medrano Llamas <[email protected]>2025-08-19 21:03:19 +0200
committerGitHub <[email protected]>2025-08-19 19:03:19 +0000
commitb24c5887c45edde8690b4d73d8961e63eee13a34 (patch)
tree6136f1f6bcc61801edb9f6d6411966b3b6678984
parent4828e4daf198a675ce118cec08dcfbd0bfbb28a6 (diff)
feat: restart MCP servers on /mcp refresh (#5479)
Co-authored-by: Brian Ray <[email protected]> Co-authored-by: N. Taylor Mullen <[email protected]>
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.test.ts5
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.ts6
-rw-r--r--packages/cli/src/ui/hooks/atCommandProcessor.test.ts6
-rw-r--r--packages/core/src/tools/mcp-client-manager.test.ts54
-rw-r--r--packages/core/src/tools/mcp-client-manager.ts115
-rw-r--r--packages/core/src/tools/mcp-client.test.ts465
-rw-r--r--packages/core/src/tools/mcp-client.ts130
-rw-r--r--packages/core/src/tools/tool-registry.test.ts50
-rw-r--r--packages/core/src/tools/tool-registry.ts43
9 files changed, 427 insertions, 447 deletions
diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts
index 6e48c2f9..09b97bb0 100644
--- a/packages/cli/src/ui/commands/mcpCommand.test.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.test.ts
@@ -972,6 +972,7 @@ describe('mcpCommand', () => {
it('should refresh the list of tools and display the status', async () => {
const mockToolRegistry = {
discoverMcpTools: vi.fn(),
+ restartMcpServers: vi.fn(),
getAllTools: vi.fn().mockReturnValue([]),
};
const mockGeminiClient = {
@@ -1004,11 +1005,11 @@ describe('mcpCommand', () => {
expect(context.ui.addItem).toHaveBeenCalledWith(
{
type: 'info',
- text: 'Refreshing MCP servers and tools...',
+ text: 'Restarting MCP servers...',
},
expect.any(Number),
);
- expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled();
+ expect(mockToolRegistry.restartMcpServers).toHaveBeenCalled();
expect(mockGeminiClient.setTools).toHaveBeenCalled();
expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1);
diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts
index 686102be..9e321937 100644
--- a/packages/cli/src/ui/commands/mcpCommand.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.ts
@@ -471,7 +471,7 @@ const listCommand: SlashCommand = {
const refreshCommand: SlashCommand = {
name: 'refresh',
- description: 'Refresh the list of MCP servers and tools',
+ description: 'Restarts MCP servers.',
kind: CommandKind.BUILT_IN,
action: async (
context: CommandContext,
@@ -497,12 +497,12 @@ const refreshCommand: SlashCommand = {
context.ui.addItem(
{
type: 'info',
- text: 'Refreshing MCP servers and tools...',
+ text: 'Restarting MCP servers...',
},
Date.now(),
);
- await toolRegistry.discoverMcpTools();
+ await toolRegistry.restartMcpServers();
// Update the client with the new tools
const geminiClient = config.getGeminiClient();
diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts
index 5509d9ff..7403f788 100644
--- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts
+++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts
@@ -63,6 +63,12 @@ describe('handleAtCommand', () => {
isPathWithinWorkspace: () => true,
getDirectories: () => [testRootDir],
}),
+ getMcpServers: () => ({}),
+ getMcpServerCommand: () => undefined,
+ getPromptRegistry: () => ({
+ getPromptsByServer: () => [],
+ }),
+ getDebugMode: () => false,
} as unknown as Config;
const registry = new ToolRegistry(mockConfig);
diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts
new file mode 100644
index 00000000..3dba197f
--- /dev/null
+++ b/packages/core/src/tools/mcp-client-manager.test.ts
@@ -0,0 +1,54 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { afterEach, describe, expect, it, vi } from 'vitest';
+import { McpClientManager } from './mcp-client-manager.js';
+import { McpClient } from './mcp-client.js';
+import { ToolRegistry } from './tool-registry.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
+import { WorkspaceContext } from '../utils/workspaceContext.js';
+
+vi.mock('./mcp-client.js', async () => {
+ const originalModule = await vi.importActual('./mcp-client.js');
+ return {
+ ...originalModule,
+ McpClient: vi.fn(),
+ populateMcpServerCommand: vi.fn(() => ({
+ 'test-server': {},
+ })),
+ };
+});
+
+describe('McpClientManager', () => {
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ it('should discover tools from all servers', async () => {
+ const mockedMcpClient = {
+ connect: vi.fn(),
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
+ };
+ vi.mocked(McpClient).mockReturnValue(
+ mockedMcpClient as unknown as McpClient,
+ );
+ const manager = new McpClientManager(
+ {
+ 'test-server': {},
+ },
+ '',
+ {} as ToolRegistry,
+ {} as PromptRegistry,
+ false,
+ {} as WorkspaceContext,
+ );
+ await manager.discoverAllMcpTools();
+ expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
+ expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
+ });
+});
diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts
new file mode 100644
index 00000000..c22afb8f
--- /dev/null
+++ b/packages/core/src/tools/mcp-client-manager.ts
@@ -0,0 +1,115 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { MCPServerConfig } from '../config/config.js';
+import { ToolRegistry } from './tool-registry.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
+import {
+ McpClient,
+ MCPDiscoveryState,
+ populateMcpServerCommand,
+} from './mcp-client.js';
+import { getErrorMessage } from '../utils/errors.js';
+import { WorkspaceContext } from '../utils/workspaceContext.js';
+
+/**
+ * Manages the lifecycle of multiple MCP clients, including local child processes.
+ * This class is responsible for starting, stopping, and discovering tools from
+ * a collection of MCP servers defined in the configuration.
+ */
+export class McpClientManager {
+ private clients: Map<string, McpClient> = new Map();
+ private readonly mcpServers: Record<string, MCPServerConfig>;
+ private readonly mcpServerCommand: string | undefined;
+ private readonly toolRegistry: ToolRegistry;
+ private readonly promptRegistry: PromptRegistry;
+ private readonly debugMode: boolean;
+ private readonly workspaceContext: WorkspaceContext;
+ private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
+
+ constructor(
+ mcpServers: Record<string, MCPServerConfig>,
+ mcpServerCommand: string | undefined,
+ toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
+ debugMode: boolean,
+ workspaceContext: WorkspaceContext,
+ ) {
+ this.mcpServers = mcpServers;
+ this.mcpServerCommand = mcpServerCommand;
+ this.toolRegistry = toolRegistry;
+ this.promptRegistry = promptRegistry;
+ this.debugMode = debugMode;
+ this.workspaceContext = workspaceContext;
+ }
+
+ /**
+ * Initiates the tool discovery process for all configured MCP servers.
+ * It connects to each server, discovers its available tools, and registers
+ * them with the `ToolRegistry`.
+ */
+ async discoverAllMcpTools(): Promise<void> {
+ await this.stop();
+ this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
+ const servers = populateMcpServerCommand(
+ this.mcpServers,
+ this.mcpServerCommand,
+ );
+
+ const discoveryPromises = Object.entries(servers).map(
+ async ([name, config]) => {
+ const client = new McpClient(
+ name,
+ config,
+ this.toolRegistry,
+ this.promptRegistry,
+ this.workspaceContext,
+ this.debugMode,
+ );
+ this.clients.set(name, client);
+ try {
+ await client.connect();
+ await client.discover();
+ } catch (error) {
+ // Log the error but don't let a single failed server stop the others
+ console.error(
+ `Error during discovery for server '${name}': ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ },
+ );
+
+ await Promise.all(discoveryPromises);
+ this.discoveryState = MCPDiscoveryState.COMPLETED;
+ }
+
+ /**
+ * Stops all running local MCP servers and closes all client connections.
+ * This is the cleanup method to be called on application exit.
+ */
+ async stop(): Promise<void> {
+ const disconnectionPromises = Array.from(this.clients.entries()).map(
+ async ([name, client]) => {
+ try {
+ await client.disconnect();
+ } catch (error) {
+ console.error(
+ `Error stopping client '${name}': ${getErrorMessage(error)}`,
+ );
+ }
+ },
+ );
+
+ await Promise.all(disconnectionPromises);
+ this.clients.clear();
+ }
+
+ getDiscoveryState(): MCPDiscoveryState {
+ return this.discoveryState;
+ }
+}
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 3467ad95..b8f61856 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -4,16 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
+import { afterEach, describe, expect, it, vi } from 'vitest';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
populateMcpServerCommand,
createTransport,
isEnabled,
- discoverTools,
- discoverPrompts,
hasValidTypes,
- connectToMcpServer,
+ McpClient,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
-
-import { DiscoveredMCPTool } from './mcp-tool.js';
+import { ToolRegistry } from './tool-registry.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
-import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
-vi.mock('./mcp-tool.js');
describe('mcp-client', () => {
afterEach(() => {
vi.restoreAllMocks();
});
- describe('discoverTools', () => {
+ describe('McpClient', () => {
it('should discover tools', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
+ const mockedClient = {
+ connect: vi.fn(),
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
@@ -51,62 +59,43 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(mockedMcpToTool).toHaveBeenCalledOnce();
- });
-
- it('should log an error if there is an error discovering a tool', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- const testError = new Error('Invalid tool name');
- vi.mocked(DiscoveredMCPTool).mockImplementation(
- (
- _mcpCallableTool: GenAiLib.CallableTool,
- _serverName: string,
- name: string,
- ) => {
- if (name === 'invalid tool name') {
- throw testError;
- }
- return { name: 'validTool' } as DiscoveredMCPTool;
+ const mockedToolRegistry = {
+ registerTool: vi.fn(),
+ } as unknown as ToolRegistry;
+ const client = new McpClient(
+ 'test-server',
+ {
+ command: 'test-command',
},
+ mockedToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
+ false,
);
-
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- },
- {
- name: 'invalid tool name', // this will fail validation
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(tools[0].name).toBe('validTool');
- expect(consoleErrorSpy).toHaveBeenCalledOnce();
- expect(consoleErrorSpy).toHaveBeenCalledWith(
- `Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
- );
+ await client.connect();
+ await client.discover();
+ expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
it('should skip tools if a parameter is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
+ const mockedClient = {
+ connect: vi.fn(),
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ tool: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
@@ -132,89 +121,22 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).toHaveBeenCalledOnce();
- expect(consoleWarnSpy).toHaveBeenCalledWith(
- `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
- `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
- );
- consoleWarnSpy.mockRestore();
- });
-
- it('should skip tools if a nested parameter is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'invalidTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {
- param1: {
- type: 'object',
- properties: {
- nestedParam: {
- description: 'a nested param with no type',
- },
- },
- },
- },
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(0);
- expect(consoleWarnSpy).toHaveBeenCalledOnce();
- expect(consoleWarnSpy).toHaveBeenCalledWith(
- `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
- `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
+ const mockedToolRegistry = {
+ registerTool: vi.fn(),
+ } as unknown as ToolRegistry;
+ const client = new McpClient(
+ 'test-server',
+ {
+ command: 'test-command',
+ },
+ mockedToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
+ false,
);
- consoleWarnSpy.mockRestore();
- });
-
- it('should skip tool if an array item is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'invalidTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {
- param1: {
- type: 'array',
- items: {
- description: 'an array item with no type',
- },
- },
- },
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(0);
+ await client.connect();
+ await client.discover();
+ expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
@@ -223,109 +145,19 @@ describe('mcp-client', () => {
consoleWarnSpy.mockRestore();
});
- it('should discover tool with no properties in schema', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- parametersJsonSchema: {
- type: 'object',
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).not.toHaveBeenCalled();
- consoleWarnSpy.mockRestore();
- });
-
- it('should discover tool with empty properties object in schema', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
+ it('should handle errors when discovering prompts', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
.mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {},
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).not.toHaveBeenCalled();
- consoleWarnSpy.mockRestore();
- });
- });
-
- describe('connectToMcpServer', () => {
- it('should send a notification when directories change', async () => {
const mockedClient = {
- registerCapabilities: vi.fn(),
- setRequestHandler: vi.fn(),
- notification: vi.fn(),
- callTool: vi.fn(),
connect: vi.fn(),
- };
- vi.mocked(ClientLib.Client).mockReturnValue(
- mockedClient as unknown as ClientLib.Client,
- );
- vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
- {} as SdkClientStdioLib.StdioClientTransport,
- );
- let onDirectoriesChangedCallback: () => void = () => {};
- const mockWorkspaceContext = {
- getDirectories: vi
- .fn()
- .mockReturnValue(['/test/dir', '/another/project']),
- onDirectoriesChanged: vi.fn().mockImplementation((callback) => {
- onDirectoriesChangedCallback = callback;
- }),
- } as unknown as WorkspaceContext;
-
- await connectToMcpServer(
- 'test-server',
- {
- command: 'test-command',
- },
- false,
- mockWorkspaceContext,
- );
-
- onDirectoriesChangedCallback();
-
- expect(mockedClient.notification).toHaveBeenCalledWith({
- method: 'notifications/roots/list_changed',
- });
- });
-
- it('should register a roots/list handler', async () => {
- const mockedClient = {
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
- callTool: vi.fn(),
- connect: vi.fn(),
+ getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
+ request: vi.fn().mockRejectedValue(new Error('Test error')),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -333,151 +165,29 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
- const mockWorkspaceContext = {
- getDirectories: vi
- .fn()
- .mockReturnValue(['/test/dir', '/another/project']),
- onDirectoriesChanged: vi.fn(),
- } as unknown as WorkspaceContext;
-
- await connectToMcpServer(
+ vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
+ tool: () => Promise.resolve({ functionDeclarations: [] }),
+ } as unknown as GenAiLib.CallableTool);
+ const client = new McpClient(
'test-server',
{
command: 'test-command',
},
+ {} as ToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
false,
- mockWorkspaceContext,
);
-
- expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
- roots: {
- listChanged: true,
- },
- });
- expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
- const handler = mockedClient.setRequestHandler.mock.calls[0][1];
- const roots = await handler();
- expect(roots).toEqual({
- roots: [
- {
- uri: pathToFileURL('/test/dir').toString(),
- name: 'dir',
- },
- {
- uri: pathToFileURL('/another/project').toString(),
- name: 'project',
- },
- ],
- });
- });
- });
-
- describe('discoverPrompts', () => {
- const mockedPromptRegistry = {
- registerPrompt: vi.fn(),
- } as unknown as PromptRegistry;
-
- it('should discover and log prompts', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [
- { name: 'prompt1', description: 'desc1' },
- { name: 'prompt2' },
- ],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).toHaveBeenCalledWith(
- { method: 'prompts/list', params: {} },
- expect.anything(),
+ await client.connect();
+ await expect(client.discover()).rejects.toThrow(
+ 'No prompts or tools found on the server.',
);
- });
-
- it('should do nothing if no prompts are discovered', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
-
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleLogSpy = vi
- .spyOn(console, 'debug')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).toHaveBeenCalledOnce();
- expect(consoleLogSpy).not.toHaveBeenCalled();
-
- consoleLogSpy.mockRestore();
- });
-
- it('should do nothing if the server has no prompt support', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({});
-
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleLogSpy = vi
- .spyOn(console, 'debug')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).not.toHaveBeenCalled();
- expect(consoleLogSpy).not.toHaveBeenCalled();
-
- consoleLogSpy.mockRestore();
- });
-
- it('should log an error if discovery fails', async () => {
- const testError = new Error('test error');
- testError.message = 'test error';
- const mockRequest = vi.fn().mockRejectedValue(testError);
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
- `Error discovering prompts from test-server: ${testError.message}`,
+ `Error discovering prompts from test-server: Test error`,
);
-
consoleErrorSpy.mockRestore();
});
});
-
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
@@ -501,17 +211,6 @@ describe('mcp-client', () => {
});
describe('createTransport', () => {
- const originalEnv = process.env;
-
- beforeEach(() => {
- vi.resetModules();
- process.env = {};
- });
-
- afterEach(() => {
- process.env = originalEnv;
- });
-
describe('should connect via httpUrl', () => {
it('without headers', async () => {
const transport = await createTransport(
@@ -601,7 +300,7 @@ describe('mcp-client', () => {
command: 'test-command',
args: ['--foo', 'bar'],
cwd: 'test/cwd',
- env: { FOO: 'bar' },
+ env: { ...process.env, FOO: 'bar' },
stderr: 'pipe',
});
});
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index e9001466..ede0d036 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -70,6 +70,134 @@ export enum MCPDiscoveryState {
}
/**
+ * A client for a single MCP server.
+ *
+ * This class is responsible for connecting to, discovering tools from, and
+ * managing the state of a single MCP server.
+ */
+export class McpClient {
+ private client: Client;
+ private transport: Transport | undefined;
+ private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
+ private isDisconnecting = false;
+
+ constructor(
+ private readonly serverName: string,
+ private readonly serverConfig: MCPServerConfig,
+ private readonly toolRegistry: ToolRegistry,
+ private readonly promptRegistry: PromptRegistry,
+ private readonly workspaceContext: WorkspaceContext,
+ private readonly debugMode: boolean,
+ ) {
+ this.client = new Client({
+ name: `gemini-cli-mcp-client-${this.serverName}`,
+ version: '0.0.1',
+ });
+ }
+
+ /**
+ * Connects to the MCP server.
+ */
+ async connect(): Promise<void> {
+ this.isDisconnecting = false;
+ this.updateStatus(MCPServerStatus.CONNECTING);
+ try {
+ this.transport = await this.createTransport();
+
+ this.client.onerror = (error) => {
+ if (this.isDisconnecting) {
+ return;
+ }
+ console.error(`MCP ERROR (${this.serverName}):`, error.toString());
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ };
+
+ this.client.registerCapabilities({
+ roots: {},
+ });
+
+ this.client.setRequestHandler(ListRootsRequestSchema, async () => {
+ const roots = [];
+ for (const dir of this.workspaceContext.getDirectories()) {
+ roots.push({
+ uri: pathToFileURL(dir).toString(),
+ name: basename(dir),
+ });
+ }
+ return {
+ roots,
+ };
+ });
+
+ await this.client.connect(this.transport, {
+ timeout: this.serverConfig.timeout,
+ });
+
+ this.updateStatus(MCPServerStatus.CONNECTED);
+ } catch (error) {
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ throw error;
+ }
+ }
+
+ /**
+ * Discovers tools and prompts from the MCP server.
+ */
+ async discover(): Promise<void> {
+ if (this.status !== MCPServerStatus.CONNECTED) {
+ throw new Error('Client is not connected.');
+ }
+
+ const prompts = await this.discoverPrompts();
+ const tools = await this.discoverTools();
+
+ if (prompts.length === 0 && tools.length === 0) {
+ throw new Error('No prompts or tools found on the server.');
+ }
+
+ for (const tool of tools) {
+ this.toolRegistry.registerTool(tool);
+ }
+ }
+
+ /**
+ * Disconnects from the MCP server.
+ */
+ async disconnect(): Promise<void> {
+ this.isDisconnecting = true;
+ if (this.transport) {
+ await this.transport.close();
+ }
+ this.client.close();
+ this.updateStatus(MCPServerStatus.DISCONNECTED);
+ }
+
+ /**
+ * Returns the current status of the client.
+ */
+ getStatus(): MCPServerStatus {
+ return this.status;
+ }
+
+ private updateStatus(status: MCPServerStatus): void {
+ this.status = status;
+ updateMCPServerStatus(this.serverName, status);
+ }
+
+ private async createTransport(): Promise<Transport> {
+ return createTransport(this.serverName, this.serverConfig, this.debugMode);
+ }
+
+ private async discoverTools(): Promise<DiscoveredMCPTool[]> {
+ return discoverTools(this.serverName, this.serverConfig, this.client);
+ }
+
+ private async discoverPrompts(): Promise<Prompt[]> {
+ return discoverPrompts(this.serverName, this.client, this.promptRegistry);
+ }
+}
+
+/**
* Map to track the status of each MCP server within the core package
*/
const serverStatuses: Map<string, MCPServerStatus> = new Map();
@@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener(
/**
* Update the status of an MCP server
*/
-function updateMCPServerStatus(
+export function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index 13dff08c..cccf011f 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -23,15 +23,17 @@ import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/tools.js';
-vi.mock('node:fs');
+import { McpClientManager } from './mcp-client-manager.js';
-// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
-const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
+vi.mock('node:fs');
// Mock ./mcp-client.js to control its behavior within tool-registry tests
-vi.mock('./mcp-client.js', () => ({
- discoverMcpTools: mockDiscoverMcpTools,
-}));
+vi.mock('./mcp-client.js', async () => {
+ const originalModule = await vi.importActual('./mcp-client.js');
+ return {
+ ...originalModule,
+ };
+});
// Mock node:child_process
vi.mock('node:child_process', async () => {
@@ -143,7 +145,6 @@ describe('ToolRegistry', () => {
clear: vi.fn(),
removePromptsByServer: vi.fn(),
} as any);
- mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
});
afterEach(() => {
@@ -311,30 +312,10 @@ describe('ToolRegistry', () => {
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
- mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
- vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
- const mcpServerConfigVal = {
- 'my-mcp-server': {
- command: 'mcp-server-cmd',
- args: ['--port', '1234'],
- trust: true,
- },
- };
- vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
-
- await toolRegistry.discoverAllTools();
-
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
- mcpServerConfigVal,
- undefined,
- toolRegistry,
- config.getPromptRegistry(),
- false,
- expect.any(Object),
+ const discoverSpy = vi.spyOn(
+ McpClientManager.prototype,
+ 'discoverAllMcpTools',
);
- });
-
- it('should discover tools using MCP servers defined in getMcpServers', async () => {
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
@@ -348,14 +329,7 @@ describe('ToolRegistry', () => {
await toolRegistry.discoverAllTools();
- expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
- mcpServerConfigVal,
- undefined,
- toolRegistry,
- config.getPromptRegistry(),
- false,
- expect.any(Object),
- );
+ expect(discoverSpy).toHaveBeenCalled();
});
});
});
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index ff155679..90531742 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -16,7 +16,8 @@ import {
import { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
-import { discoverMcpTools } from './mcp-client.js';
+import { connectAndDiscover } from './mcp-client.js';
+import { McpClientManager } from './mcp-client-manager.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { parse } from 'shell-quote';
@@ -163,9 +164,18 @@ Signal: Signal number or \`(none)\` if no signal was received.
export class ToolRegistry {
private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
+ private mcpClientManager: McpClientManager;
constructor(config: Config) {
this.config = config;
+ this.mcpClientManager = new McpClientManager(
+ this.config.getMcpServers() ?? {},
+ this.config.getMcpServerCommand(),
+ this,
+ this.config.getPromptRegistry(),
+ this.config.getDebugMode(),
+ this.config.getWorkspaceContext(),
+ );
}
/**
@@ -220,14 +230,7 @@ export class ToolRegistry {
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
- await discoverMcpTools(
- this.config.getMcpServers() ?? {},
- this.config.getMcpServerCommand(),
- this,
- this.config.getPromptRegistry(),
- this.config.getDebugMode(),
- this.config.getWorkspaceContext(),
- );
+ await this.mcpClientManager.discoverAllMcpTools();
}
/**
@@ -242,14 +245,14 @@ export class ToolRegistry {
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
- await discoverMcpTools(
- this.config.getMcpServers() ?? {},
- this.config.getMcpServerCommand(),
- this,
- this.config.getPromptRegistry(),
- this.config.getDebugMode(),
- this.config.getWorkspaceContext(),
- );
+ await this.mcpClientManager.discoverAllMcpTools();
+ }
+
+ /**
+ * Restarts all MCP servers and re-discovers tools.
+ */
+ async restartMcpServers(): Promise<void> {
+ await this.discoverMcpTools();
}
/**
@@ -269,9 +272,9 @@ export class ToolRegistry {
const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {
- await discoverMcpTools(
- { [serverName]: serverConfig },
- undefined,
+ await connectAndDiscover(
+ serverName,
+ serverConfig,
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),