diff options
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 423 |
1 files changed, 235 insertions, 188 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 6edfbac8..eb82190b 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -5,6 +5,7 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport, @@ -17,7 +18,7 @@ import { import { parse } from 'shell-quote'; import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import { Type, mcpToTool } from '@google/genai'; +import { FunctionDeclaration, Type, mcpToTool } from '@google/genai'; import { sanitizeParameters, ToolRegistry } from './tool-registry.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes @@ -123,28 +124,25 @@ export function getMCPDiscoveryState(): MCPDiscoveryState { return mcpDiscoveryState; } +/** + * Discovers tools from all configured MCP servers and registers them with the tool registry. + * It orchestrates the connection and discovery process for each server defined in the + * configuration, as well as any server specified via a command-line argument. + * + * @param mcpServers A record of named MCP server configurations. + * @param mcpServerCommand An optional command string for a dynamically specified MCP server. + * @param toolRegistry The central registry where discovered tools will be registered. + * @returns A promise that resolves when the discovery process has been attempted for all servers. + */ export async function discoverMcpTools( mcpServers: Record<string, MCPServerConfig>, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise<void> { - // Set discovery state to in progress mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; - try { - if (mcpServerCommand) { - const cmd = mcpServerCommand; - const args = parse(cmd, process.env) as string[]; - if (args.some((arg) => typeof arg !== 'string')) { - throw new Error('failed to parse mcpServerCommand: ' + cmd); - } - // use generic server name 'mcp' - mcpServers['mcp'] = { - command: args[0], - args: args.slice(1), - }; - } + mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand); const discoveryPromises = Object.entries(mcpServers).map( ([mcpServerName, mcpServerConfig]) => @@ -156,16 +154,31 @@ export async function discoverMcpTools( ), ); await Promise.all(discoveryPromises); - - // Mark discovery as completed - mcpDiscoveryState = MCPDiscoveryState.COMPLETED; - } catch (error) { - // Still mark as completed even with errors + } finally { mcpDiscoveryState = MCPDiscoveryState.COMPLETED; - throw error; } } +/** Visible for Testing */ +export function populateMcpServerCommand( + mcpServers: Record<string, MCPServerConfig>, + mcpServerCommand: string | undefined, +): Record<string, MCPServerConfig> { + if (mcpServerCommand) { + const cmd = mcpServerCommand; + const args = parse(cmd, process.env) as string[]; + if (args.some((arg) => typeof arg !== 'string')) { + throw new Error('failed to parse mcpServerCommand: ' + cmd); + } + // use generic server name 'mcp' + mcpServers['mcp'] = { + command: args[0], + args: args.slice(1), + }; + } + return mcpServers; +} + /** * Connects to an MCP server and discovers available tools, registering them with the tool registry. * This function handles the complete lifecycle of connecting to a server, discovering tools, @@ -176,71 +189,117 @@ export async function discoverMcpTools( * @param toolRegistry The registry to register discovered tools with * @returns Promise that resolves when discovery is complete */ -async function connectAndDiscover( +export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise<void> { - // Initialize the server status as connecting updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); - let transport; - if (mcpServerConfig.httpUrl) { - const transportOptions: StreamableHTTPClientTransportOptions = {}; + try { + const mcpClient = await connectToMcpServer( + mcpServerName, + mcpServerConfig, + debugMode, + ); + try { + updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); - if (mcpServerConfig.headers) { - transportOptions.requestInit = { - headers: mcpServerConfig.headers, + mcpClient.onerror = (error) => { + console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); + updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; - } - transport = new StreamableHTTPClientTransport( - new URL(mcpServerConfig.httpUrl), - transportOptions, - ); - } else if (mcpServerConfig.url) { - const transportOptions: SSEClientTransportOptions = {}; - if (mcpServerConfig.headers) { - transportOptions.requestInit = { - headers: mcpServerConfig.headers, - }; + const tools = await discoverTools( + mcpServerName, + mcpServerConfig, + mcpClient, + ); + for (const tool of tools) { + toolRegistry.registerTool(tool); + } + } catch (error) { + mcpClient.close(); + throw error; } - transport = new SSEClientTransport( - new URL(mcpServerConfig.url), - transportOptions, - ); - } else if (mcpServerConfig.command) { - transport = new StdioClientTransport({ - command: mcpServerConfig.command, - args: mcpServerConfig.args || [], - env: { - ...process.env, - ...(mcpServerConfig.env || {}), - } as Record<string, string>, - cwd: mcpServerConfig.cwd, - stderr: 'pipe', - }); - } else { - console.error( - `MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`, - ); - // Update status to disconnected + } catch (error) { + console.error(`Error connecting to MCP server '${mcpServerName}':`, error); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; } +} - if ( - debugMode && - transport instanceof StdioClientTransport && - transport.stderr - ) { - transport.stderr.on('data', (data) => { - const stderrStr = data.toString().trim(); - console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); - }); +/** + * Discovers and sanitizes tools from a connected MCP client. + * It retrieves function declarations from the client, filters out disabled tools, + * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances. + * + * @param mcpServerName The name of the MCP server. + * @param mcpServerConfig The configuration for the MCP server. + * @param mcpClient The active MCP client instance. + * @returns A promise that resolves to an array of discovered and enabled tools. + * @throws An error if no enabled tools are found or if the server provides invalid function declarations. + */ +export async function discoverTools( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + mcpClient: Client, +): Promise<DiscoveredMCPTool[]> { + try { + const mcpCallableTool = mcpToTool(mcpClient); + const tool = await mcpCallableTool.tool(); + + if (!Array.isArray(tool.functionDeclarations)) { + throw new Error(`Server did not return valid function declarations.`); + } + + const discoveredTools: DiscoveredMCPTool[] = []; + for (const funcDecl of tool.functionDeclarations) { + if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) { + continue; + } + + const toolNameForModel = generateValidName(funcDecl, mcpServerName); + + sanitizeParameters(funcDecl.parameters); + + discoveredTools.push( + new DiscoveredMCPTool( + mcpCallableTool, + mcpServerName, + toolNameForModel, + funcDecl.description ?? '', + funcDecl.parameters ?? { type: Type.OBJECT, properties: {} }, + funcDecl.name!, + mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + mcpServerConfig.trust, + ), + ); + } + if (discoveredTools.length === 0) { + throw Error('No enabled tools found'); + } + return discoveredTools; + } catch (error) { + throw new Error(`Error discovering tools: ${error}`); } +} +/** + * Creates and connects an MCP client to a server based on the provided configuration. + * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and + * establishes a connection. It also applies a patch to handle request timeouts. + * + * @param mcpServerName The name of the MCP server, used for logging and identification. + * @param mcpServerConfig The configuration specifying how to connect to the server. + * @returns A promise that resolves to a connected MCP `Client` instance. + * @throws An error if the connection fails or the configuration is invalid. + */ +export async function connectToMcpServer( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + debugMode: boolean, +): Promise<Client> { const mcpClient = new Client({ name: 'gemini-cli-mcp-client', version: '0.0.1', @@ -259,11 +318,20 @@ async function connectAndDiscover( } try { - await mcpClient.connect(transport, { - timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - }); - // Connection successful - updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); + const transport = createTransport( + mcpServerName, + mcpServerConfig, + debugMode, + ); + try { + await mcpClient.connect(transport, { + timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + return mcpClient; + } catch (error) { + await transport.close(); + throw error; + } } catch (error) { // Create a safe config object that excludes sensitive information const safeConfig = { @@ -282,131 +350,110 @@ async function connectAndDiscover( if (process.env.SANDBOX) { errorString += `\nMake sure it is available in the sandbox`; } - console.error(errorString); - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; + throw new Error(errorString); } +} - mcpClient.onerror = (error) => { - console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); - // Update status to disconnected on error - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - }; - - try { - const mcpCallableTool = mcpToTool(mcpClient); - const tool = await mcpCallableTool.tool(); - - if (!tool || !Array.isArray(tool.functionDeclarations)) { - console.error( - `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`, - ); - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - } - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; +/** Visible for Testing */ +export function createTransport( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + debugMode: boolean, +): Transport { + if (mcpServerConfig.httpUrl) { + const transportOptions: StreamableHTTPClientTransportOptions = {}; + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; } + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.httpUrl), + transportOptions, + ); + } - for (const funcDecl of tool.functionDeclarations) { - if (!funcDecl.name) { - console.warn( - `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, - ); - continue; - } - - const { includeTools, excludeTools } = mcpServerConfig; - const toolName = funcDecl.name; - - let isEnabled = false; - if (includeTools === undefined) { - isEnabled = true; - } else { - isEnabled = includeTools.some( - (tool) => tool === toolName || tool.startsWith(`${toolName}(`), - ); - } - - if (excludeTools?.includes(toolName)) { - isEnabled = false; - } - - if (!isEnabled) { - continue; - } + if (mcpServerConfig.url) { + const transportOptions: SSEClientTransportOptions = {}; + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; + } + return new SSEClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } - let toolNameForModel = funcDecl.name; + if (mcpServerConfig.command) { + const transport = new StdioClientTransport({ + command: mcpServerConfig.command, + args: mcpServerConfig.args || [], + env: { + ...process.env, + ...(mcpServerConfig.env || {}), + } as Record<string, string>, + cwd: mcpServerConfig.cwd, + stderr: 'pipe', + }); + if (debugMode) { + transport.stderr!.on('data', (data) => { + const stderrStr = data.toString().trim(); + console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); + }); + } + return transport; + } - // Replace invalid characters (based on 400 error message from Gemini API) with underscores - toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_'); + throw new Error( + `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`, + ); +} - const existingTool = toolRegistry.getTool(toolNameForModel); - if (existingTool) { - toolNameForModel = mcpServerName + '__' + toolNameForModel; - } +/** Visible for testing */ +export function generateValidName( + funcDecl: FunctionDeclaration, + mcpServerName: string, +) { + // Replace invalid characters (based on 400 error message from Gemini API) with underscores + let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_'); - // If longer than 63 characters, replace middle with '___' - // (Gemini API says max length 64, but actual limit seems to be 63) - if (toolNameForModel.length > 63) { - toolNameForModel = - toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32); - } + // Prepend MCP server name to avoid conflicts with other tools + validToolname = mcpServerName + '__' + validToolname; - sanitizeParameters(funcDecl.parameters); + // If longer than 63 characters, replace middle with '___' + // (Gemini API says max length 64, but actual limit seems to be 63) + if (validToolname.length > 63) { + validToolname = + validToolname.slice(0, 28) + '___' + validToolname.slice(-32); + } + return validToolname; +} - toolRegistry.registerTool( - new DiscoveredMCPTool( - mcpCallableTool, - mcpServerName, - toolNameForModel, - funcDecl.description ?? '', - funcDecl.parameters ?? { type: Type.OBJECT, properties: {} }, - funcDecl.name, - mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - mcpServerConfig.trust, - ), - ); - } - } catch (error) { - console.error( - `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`, +/** Visible for testing */ +export function isEnabled( + funcDecl: FunctionDeclaration, + mcpServerName: string, + mcpServerConfig: MCPServerConfig, +): boolean { + if (!funcDecl.name) { + console.warn( + `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, ); - // Ensure transport is cleaned up on error too - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - } - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); + return false; } + const { includeTools, excludeTools } = mcpServerConfig; - // If no tools were registered from this MCP server, the following 'if' block - // will close the connection. This is done to conserve resources and prevent - // an orphaned connection to a server that isn't providing any usable - // functionality. Connections to servers that did provide tools are kept - // open, as those tools will require the connection to function. - if (toolRegistry.getToolsByServer(mcpServerName).length === 0) { - console.log( - `No tools registered from MCP server '${mcpServerName}'. Closing connection.`, - ); - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - } + // excludeTools takes precedence over includeTools + if (excludeTools && excludeTools.includes(funcDecl.name)) { + return false; } + + return ( + !includeTools || + includeTools.some( + (tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`), + ) + ); } |
