summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorJacob MacDonald <[email protected]>2025-08-08 16:29:06 -0700
committerGitHub <[email protected]>2025-08-08 23:29:06 +0000
commitf35921a77171d011d244cba1b2da0531f9749332 (patch)
tree8b14be742c8fcf94b56cbe6eedf6ebeb3febd4be /packages/core/src
parentc03ae4377729fb993426e8535cb041f8014f7b3b (diff)
Add MCP Roots support (#5856)
Co-authored-by: Jacob Richman <[email protected]>
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts57
-rw-r--r--packages/core/src/tools/mcp-client.ts26
-rw-r--r--packages/core/src/tools/tool-registry.test.ts2
-rw-r--r--packages/core/src/tools/tool-registry.ts3
4 files changed, 87 insertions, 1 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 1ccba76a..d37c6eae 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -13,6 +13,7 @@ import {
discoverTools,
discoverPrompts,
hasValidTypes,
+ connectToMcpServer,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
+import { WorkspaceContext } from '../utils/workspaceContext.js';
+import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
@@ -276,6 +279,56 @@ describe('mcp-client', () => {
});
});
+ describe('connectToMcpServer', () => {
+ it('should register a roots/list handler', async () => {
+ const mockedClient = {
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ callTool: vi.fn(),
+ connect: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
+ const mockWorkspaceContext = {
+ getDirectories: vi
+ .fn()
+ .mockReturnValue(['/test/dir', '/another/project']),
+ } as unknown as WorkspaceContext;
+
+ await connectToMcpServer(
+ 'test-server',
+ {
+ command: 'test-command',
+ },
+ false,
+ mockWorkspaceContext,
+ );
+
+ expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
+ roots: {},
+ });
+ expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
+ const handler = mockedClient.setRequestHandler.mock.calls[0][1];
+ const roots = await handler();
+ expect(roots).toEqual({
+ roots: [
+ {
+ uri: pathToFileURL('/test/dir').toString(),
+ name: 'dir',
+ },
+ {
+ uri: pathToFileURL('/another/project').toString(),
+ name: 'project',
+ },
+ ],
+ });
+ });
+ });
+
describe('discoverPrompts', () => {
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
@@ -486,7 +539,9 @@ describe('mcp-client', () => {
});
it('should connect via command', async () => {
- const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
+ const mockedTransport = vi
+ .spyOn(SdkClientStdioLib, 'StdioClientTransport')
+ .mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport(
'test-server',
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index 9a35b84e..83bc4024 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -20,6 +20,7 @@ import {
ListPromptsResultSchema,
GetPromptResult,
GetPromptResultSchema,
+ ListRootsRequestSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
@@ -33,6 +34,9 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
+import { basename } from 'node:path';
+import { pathToFileURL } from 'node:url';
+import { WorkspaceContext } from '../utils/workspaceContext.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@@ -306,6 +310,7 @@ export async function discoverMcpTools(
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
+ workspaceContext: WorkspaceContext,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
try {
@@ -319,6 +324,7 @@ export async function discoverMcpTools(
toolRegistry,
promptRegistry,
debugMode,
+ workspaceContext,
),
);
await Promise.all(discoveryPromises);
@@ -363,6 +369,7 @@ export async function connectAndDiscover(
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
debugMode: boolean,
+ workspaceContext: WorkspaceContext,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -372,6 +379,7 @@ export async function connectAndDiscover(
mcpServerName,
mcpServerConfig,
debugMode,
+ workspaceContext,
);
mcpClient.onerror = (error) => {
@@ -655,12 +663,30 @@ export async function connectToMcpServer(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
debugMode: boolean,
+ workspaceContext: WorkspaceContext,
): Promise<Client> {
const mcpClient = new Client({
name: 'gemini-cli-mcp-client',
version: '0.0.1',
});
+ mcpClient.registerCapabilities({
+ roots: {},
+ });
+
+ mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
+ const roots = [];
+ for (const dir of workspaceContext.getDirectories()) {
+ roots.push({
+ uri: pathToFileURL(dir).toString(),
+ name: basename(dir),
+ });
+ }
+ return {
+ roots,
+ };
+ });
+
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
// TODO: remove this hack once GenAI SDK does callTool with request options
if ('callTool' in mcpClient) {
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index e7c71e14..d8e536b7 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -336,6 +336,7 @@ describe('ToolRegistry', () => {
toolRegistry,
config.getPromptRegistry(),
false,
+ expect.any(Object),
);
});
@@ -359,6 +360,7 @@ describe('ToolRegistry', () => {
toolRegistry,
config.getPromptRegistry(),
false,
+ expect.any(Object),
);
});
});
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index c77fab8c..70226052 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -178,6 +178,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
+ this.config.getWorkspaceContext(),
);
}
@@ -199,6 +200,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
+ this.config.getWorkspaceContext(),
);
}
@@ -225,6 +227,7 @@ export class ToolRegistry {
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
+ this.config.getWorkspaceContext(),
);
}
}