summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
-rw-r--r--packages/core/src/tools/mcp-client.ts423
1 files changed, 235 insertions, 188 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index 6edfbac8..eb82190b 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -5,6 +5,7 @@
*/
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
+import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import {
SSEClientTransport,
@@ -17,7 +18,7 @@ import {
import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
-import { Type, mcpToTool } from '@google/genai';
+import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@@ -123,28 +124,25 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
return mcpDiscoveryState;
}
+/**
+ * Discovers tools from all configured MCP servers and registers them with the tool registry.
+ * It orchestrates the connection and discovery process for each server defined in the
+ * configuration, as well as any server specified via a command-line argument.
+ *
+ * @param mcpServers A record of named MCP server configurations.
+ * @param mcpServerCommand An optional command string for a dynamically specified MCP server.
+ * @param toolRegistry The central registry where discovered tools will be registered.
+ * @returns A promise that resolves when the discovery process has been attempted for all servers.
+ */
export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
debugMode: boolean,
): Promise<void> {
- // Set discovery state to in progress
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
-
try {
- if (mcpServerCommand) {
- const cmd = mcpServerCommand;
- const args = parse(cmd, process.env) as string[];
- if (args.some((arg) => typeof arg !== 'string')) {
- throw new Error('failed to parse mcpServerCommand: ' + cmd);
- }
- // use generic server name 'mcp'
- mcpServers['mcp'] = {
- command: args[0],
- args: args.slice(1),
- };
- }
+ mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand);
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
@@ -156,16 +154,31 @@ export async function discoverMcpTools(
),
);
await Promise.all(discoveryPromises);
-
- // Mark discovery as completed
- mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
- } catch (error) {
- // Still mark as completed even with errors
+ } finally {
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
- throw error;
}
}
+/** Visible for Testing */
+export function populateMcpServerCommand(
+ mcpServers: Record<string, MCPServerConfig>,
+ mcpServerCommand: string | undefined,
+): Record<string, MCPServerConfig> {
+ if (mcpServerCommand) {
+ const cmd = mcpServerCommand;
+ const args = parse(cmd, process.env) as string[];
+ if (args.some((arg) => typeof arg !== 'string')) {
+ throw new Error('failed to parse mcpServerCommand: ' + cmd);
+ }
+ // use generic server name 'mcp'
+ mcpServers['mcp'] = {
+ command: args[0],
+ args: args.slice(1),
+ };
+ }
+ return mcpServers;
+}
+
/**
* Connects to an MCP server and discovers available tools, registering them with the tool registry.
* This function handles the complete lifecycle of connecting to a server, discovering tools,
@@ -176,71 +189,117 @@ export async function discoverMcpTools(
* @param toolRegistry The registry to register discovered tools with
* @returns Promise that resolves when discovery is complete
*/
-async function connectAndDiscover(
+export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
debugMode: boolean,
): Promise<void> {
- // Initialize the server status as connecting
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
- let transport;
- if (mcpServerConfig.httpUrl) {
- const transportOptions: StreamableHTTPClientTransportOptions = {};
+ try {
+ const mcpClient = await connectToMcpServer(
+ mcpServerName,
+ mcpServerConfig,
+ debugMode,
+ );
+ try {
+ updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
- if (mcpServerConfig.headers) {
- transportOptions.requestInit = {
- headers: mcpServerConfig.headers,
+ mcpClient.onerror = (error) => {
+ console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
+ updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
- }
- transport = new StreamableHTTPClientTransport(
- new URL(mcpServerConfig.httpUrl),
- transportOptions,
- );
- } else if (mcpServerConfig.url) {
- const transportOptions: SSEClientTransportOptions = {};
- if (mcpServerConfig.headers) {
- transportOptions.requestInit = {
- headers: mcpServerConfig.headers,
- };
+ const tools = await discoverTools(
+ mcpServerName,
+ mcpServerConfig,
+ mcpClient,
+ );
+ for (const tool of tools) {
+ toolRegistry.registerTool(tool);
+ }
+ } catch (error) {
+ mcpClient.close();
+ throw error;
}
- transport = new SSEClientTransport(
- new URL(mcpServerConfig.url),
- transportOptions,
- );
- } else if (mcpServerConfig.command) {
- transport = new StdioClientTransport({
- command: mcpServerConfig.command,
- args: mcpServerConfig.args || [],
- env: {
- ...process.env,
- ...(mcpServerConfig.env || {}),
- } as Record<string, string>,
- cwd: mcpServerConfig.cwd,
- stderr: 'pipe',
- });
- } else {
- console.error(
- `MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`,
- );
- // Update status to disconnected
+ } catch (error) {
+ console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
}
+}
- if (
- debugMode &&
- transport instanceof StdioClientTransport &&
- transport.stderr
- ) {
- transport.stderr.on('data', (data) => {
- const stderrStr = data.toString().trim();
- console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
- });
+/**
+ * Discovers and sanitizes tools from a connected MCP client.
+ * It retrieves function declarations from the client, filters out disabled tools,
+ * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpServerConfig The configuration for the MCP server.
+ * @param mcpClient The active MCP client instance.
+ * @returns A promise that resolves to an array of discovered and enabled tools.
+ * @throws An error if no enabled tools are found or if the server provides invalid function declarations.
+ */
+export async function discoverTools(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ mcpClient: Client,
+): Promise<DiscoveredMCPTool[]> {
+ try {
+ const mcpCallableTool = mcpToTool(mcpClient);
+ const tool = await mcpCallableTool.tool();
+
+ if (!Array.isArray(tool.functionDeclarations)) {
+ throw new Error(`Server did not return valid function declarations.`);
+ }
+
+ const discoveredTools: DiscoveredMCPTool[] = [];
+ for (const funcDecl of tool.functionDeclarations) {
+ if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
+ continue;
+ }
+
+ const toolNameForModel = generateValidName(funcDecl, mcpServerName);
+
+ sanitizeParameters(funcDecl.parameters);
+
+ discoveredTools.push(
+ new DiscoveredMCPTool(
+ mcpCallableTool,
+ mcpServerName,
+ toolNameForModel,
+ funcDecl.description ?? '',
+ funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
+ funcDecl.name!,
+ mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
+ mcpServerConfig.trust,
+ ),
+ );
+ }
+ if (discoveredTools.length === 0) {
+ throw Error('No enabled tools found');
+ }
+ return discoveredTools;
+ } catch (error) {
+ throw new Error(`Error discovering tools: ${error}`);
}
+}
+/**
+ * Creates and connects an MCP client to a server based on the provided configuration.
+ * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
+ * establishes a connection. It also applies a patch to handle request timeouts.
+ *
+ * @param mcpServerName The name of the MCP server, used for logging and identification.
+ * @param mcpServerConfig The configuration specifying how to connect to the server.
+ * @returns A promise that resolves to a connected MCP `Client` instance.
+ * @throws An error if the connection fails or the configuration is invalid.
+ */
+export async function connectToMcpServer(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ debugMode: boolean,
+): Promise<Client> {
const mcpClient = new Client({
name: 'gemini-cli-mcp-client',
version: '0.0.1',
@@ -259,11 +318,20 @@ async function connectAndDiscover(
}
try {
- await mcpClient.connect(transport, {
- timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
- });
- // Connection successful
- updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
+ const transport = createTransport(
+ mcpServerName,
+ mcpServerConfig,
+ debugMode,
+ );
+ try {
+ await mcpClient.connect(transport, {
+ timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
+ });
+ return mcpClient;
+ } catch (error) {
+ await transport.close();
+ throw error;
+ }
} catch (error) {
// Create a safe config object that excludes sensitive information
const safeConfig = {
@@ -282,131 +350,110 @@ async function connectAndDiscover(
if (process.env.SANDBOX) {
errorString += `\nMake sure it is available in the sandbox`;
}
- console.error(errorString);
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
+ throw new Error(errorString);
}
+}
- mcpClient.onerror = (error) => {
- console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
- // Update status to disconnected on error
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- };
-
- try {
- const mcpCallableTool = mcpToTool(mcpClient);
- const tool = await mcpCallableTool.tool();
-
- if (!tool || !Array.isArray(tool.functionDeclarations)) {
- console.error(
- `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`,
- );
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- }
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- return;
+/** Visible for Testing */
+export function createTransport(
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+ debugMode: boolean,
+): Transport {
+ if (mcpServerConfig.httpUrl) {
+ const transportOptions: StreamableHTTPClientTransportOptions = {};
+ if (mcpServerConfig.headers) {
+ transportOptions.requestInit = {
+ headers: mcpServerConfig.headers,
+ };
}
+ return new StreamableHTTPClientTransport(
+ new URL(mcpServerConfig.httpUrl),
+ transportOptions,
+ );
+ }
- for (const funcDecl of tool.functionDeclarations) {
- if (!funcDecl.name) {
- console.warn(
- `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
- );
- continue;
- }
-
- const { includeTools, excludeTools } = mcpServerConfig;
- const toolName = funcDecl.name;
-
- let isEnabled = false;
- if (includeTools === undefined) {
- isEnabled = true;
- } else {
- isEnabled = includeTools.some(
- (tool) => tool === toolName || tool.startsWith(`${toolName}(`),
- );
- }
-
- if (excludeTools?.includes(toolName)) {
- isEnabled = false;
- }
-
- if (!isEnabled) {
- continue;
- }
+ if (mcpServerConfig.url) {
+ const transportOptions: SSEClientTransportOptions = {};
+ if (mcpServerConfig.headers) {
+ transportOptions.requestInit = {
+ headers: mcpServerConfig.headers,
+ };
+ }
+ return new SSEClientTransport(
+ new URL(mcpServerConfig.url),
+ transportOptions,
+ );
+ }
- let toolNameForModel = funcDecl.name;
+ if (mcpServerConfig.command) {
+ const transport = new StdioClientTransport({
+ command: mcpServerConfig.command,
+ args: mcpServerConfig.args || [],
+ env: {
+ ...process.env,
+ ...(mcpServerConfig.env || {}),
+ } as Record<string, string>,
+ cwd: mcpServerConfig.cwd,
+ stderr: 'pipe',
+ });
+ if (debugMode) {
+ transport.stderr!.on('data', (data) => {
+ const stderrStr = data.toString().trim();
+ console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
+ });
+ }
+ return transport;
+ }
- // Replace invalid characters (based on 400 error message from Gemini API) with underscores
- toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_');
+ throw new Error(
+ `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
+ );
+}
- const existingTool = toolRegistry.getTool(toolNameForModel);
- if (existingTool) {
- toolNameForModel = mcpServerName + '__' + toolNameForModel;
- }
+/** Visible for testing */
+export function generateValidName(
+ funcDecl: FunctionDeclaration,
+ mcpServerName: string,
+) {
+ // Replace invalid characters (based on 400 error message from Gemini API) with underscores
+ let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
- // If longer than 63 characters, replace middle with '___'
- // (Gemini API says max length 64, but actual limit seems to be 63)
- if (toolNameForModel.length > 63) {
- toolNameForModel =
- toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32);
- }
+ // Prepend MCP server name to avoid conflicts with other tools
+ validToolname = mcpServerName + '__' + validToolname;
- sanitizeParameters(funcDecl.parameters);
+ // If longer than 63 characters, replace middle with '___'
+ // (Gemini API says max length 64, but actual limit seems to be 63)
+ if (validToolname.length > 63) {
+ validToolname =
+ validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
+ }
+ return validToolname;
+}
- toolRegistry.registerTool(
- new DiscoveredMCPTool(
- mcpCallableTool,
- mcpServerName,
- toolNameForModel,
- funcDecl.description ?? '',
- funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
- funcDecl.name,
- mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
- mcpServerConfig.trust,
- ),
- );
- }
- } catch (error) {
- console.error(
- `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`,
+/** Visible for testing */
+export function isEnabled(
+ funcDecl: FunctionDeclaration,
+ mcpServerName: string,
+ mcpServerConfig: MCPServerConfig,
+): boolean {
+ if (!funcDecl.name) {
+ console.warn(
+ `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
);
- // Ensure transport is cleaned up on error too
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- }
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
+ return false;
}
+ const { includeTools, excludeTools } = mcpServerConfig;
- // If no tools were registered from this MCP server, the following 'if' block
- // will close the connection. This is done to conserve resources and prevent
- // an orphaned connection to a server that isn't providing any usable
- // functionality. Connections to servers that did provide tools are kept
- // open, as those tools will require the connection to function.
- if (toolRegistry.getToolsByServer(mcpServerName).length === 0) {
- console.log(
- `No tools registered from MCP server '${mcpServerName}'. Closing connection.`,
- );
- if (
- transport instanceof StdioClientTransport ||
- transport instanceof SSEClientTransport ||
- transport instanceof StreamableHTTPClientTransport
- ) {
- await transport.close();
- // Update status to disconnected
- updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
- }
+ // excludeTools takes precedence over includeTools
+ if (excludeTools && excludeTools.includes(funcDecl.name)) {
+ return false;
}
+
+ return (
+ !includeTools ||
+ includeTools.some(
+ (tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`),
+ )
+ );
}