diff options
| author | N. Taylor Mullen <[email protected]> | 2025-06-02 13:39:25 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-02 20:39:25 +0000 |
| commit | 58597c29d30eb0d95e1792f02eb7f1e7edc4218a (patch) | |
| tree | 2dfb528ab008e454422fc27c941aa7aa925ec5d7 /packages/core/src/tools/mcp-client.ts | |
| parent | 0795e55f0e7d2f2822bcd83eaf066eb99c67f858 (diff) | |
refactor: Update MCP tool discovery to use @google/genai - Also fixes JSON schema issues. (#682)
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 127 |
1 files changed, 80 insertions, 47 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 97a73289..87835219 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -10,12 +10,9 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { parse } from 'shell-quote'; import { Config, MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import { ToolRegistry } from './tool-registry.js'; +import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai'; -export async function discoverMcpTools( - config: Config, - toolRegistry: ToolRegistry, -): Promise<void> { +export async function discoverMcpTools(config: Config): Promise<void> { const mcpServers = config.getMcpServers() || {}; if (config.getMcpServerCommand()) { @@ -33,12 +30,7 @@ export async function discoverMcpTools( const discoveryPromises = Object.entries(mcpServers).map( ([mcpServerName, mcpServerConfig]) => - connectAndDiscover( - mcpServerName, - mcpServerConfig, - toolRegistry, - mcpServers, - ), + connectAndDiscover(mcpServerName, mcpServerConfig, config), ); await Promise.all(discoveryPromises); } @@ -46,8 +38,7 @@ export async function discoverMcpTools( async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, - toolRegistry: ToolRegistry, - mcpServers: Record<string, MCPServerConfig>, + config: Config, ): Promise<void> { let transport; if (mcpServerConfig.url) { @@ -67,7 +58,7 @@ async function connectAndDiscover( console.error( `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`, ); - return; // Return a resolved promise as this path doesn't throw. + return; } const mcpClient = new Client({ @@ -82,63 +73,82 @@ async function connectAndDiscover( `failed to start or connect to MCP server '${mcpServerName}' ` + `${JSON.stringify(mcpServerConfig)}; \n${error}`, ); - return; // Return a resolved promise, let other MCP servers be discovered. + return; } mcpClient.onerror = (error) => { - console.error('MCP ERROR', error.toString()); + console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); }; if (transport instanceof StdioClientTransport && transport.stderr) { transport.stderr.on('data', (data) => { - if (!data.toString().includes('] INFO')) { - console.debug('MCP STDERR', data.toString()); + const stderrStr = data.toString(); + // Filter out verbose INFO logs from some MCP servers + if (!stderrStr.includes('] INFO')) { + console.debug(`MCP STDERR (${mcpServerName}):`, stderrStr); } }); } + const toolRegistry = await config.getToolRegistry(); try { - const result = await mcpClient.listTools(); - for (const tool of result.tools) { - // Recursively remove additionalProperties and $schema from the inputSchema - // eslint-disable-next-line @typescript-eslint/no-explicit-any -- This function recursively navigates a deeply nested and potentially heterogeneous JSON schema object. Using 'any' is a pragmatic choice here to avoid overly complex type definitions for all possible schema variations. - const removeSchemaProps = (obj: any) => { - if (typeof obj !== 'object' || obj === null) { - return; - } - if (Array.isArray(obj)) { - obj.forEach(removeSchemaProps); - } else { - delete obj.additionalProperties; - delete obj.$schema; - Object.values(obj).forEach(removeSchemaProps); - } - }; - removeSchemaProps(tool.inputSchema); + const mcpCallableTool: CallableTool = mcpToTool(mcpClient); + const discoveredToolFunctions = await mcpCallableTool.tool(); - // if there are multiple MCP servers, prefix tool name with mcpServerName to avoid collisions - let toolNameForModel = tool.name; - if (Object.keys(mcpServers).length > 1) { - toolNameForModel = mcpServerName + '__' + toolNameForModel; + if ( + !discoveredToolFunctions || + !Array.isArray(discoveredToolFunctions.functionDeclarations) + ) { + console.error( + `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`, + ); + if (transport instanceof StdioClientTransport) { + await transport.close(); + } else if (transport instanceof SSEClientTransport) { + await transport.close(); + } + return; + } + + for (const funcDecl of discoveredToolFunctions.functionDeclarations) { + if (!funcDecl.name) { + console.warn( + `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, + ); + continue; } - // replace invalid characters (based on 400 error message) with underscores + let toolNameForModel = funcDecl.name; + + // Replace invalid characters (based on 400 error message from Gemini API) with underscores toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_'); - // if longer than 63 characters, replace middle with '___' - // note 400 error message says max length is 64, but actual limit seems to be 63 + const existingTool = toolRegistry.getTool(toolNameForModel); + if (existingTool) { + toolNameForModel = mcpServerName + '__' + toolNameForModel; + } + + // If longer than 63 characters, replace middle with '___' + // (Gemini API says max length 64, but actual limit seems to be 63) if (toolNameForModel.length > 63) { toolNameForModel = toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32); } + + // Ensure parameters is a valid JSON schema object, default to empty if not. + const parameterSchema: Record<string, unknown> = + funcDecl.parameters && typeof funcDecl.parameters === 'object' + ? { ...(funcDecl.parameters as FunctionDeclaration) } + : { type: 'object', properties: {} }; + toolRegistry.registerTool( new DiscoveredMCPTool( - mcpClient, + mcpCallableTool, mcpServerName, toolNameForModel, - tool.description ?? '', - tool.inputSchema, - tool.name, + funcDecl.description ?? '', + parameterSchema, + funcDecl.name, mcpServerConfig.timeout, mcpServerConfig.trust, ), @@ -148,6 +158,29 @@ async function connectAndDiscover( console.error( `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`, ); - // Do not re-throw, allow other servers to proceed. + // Ensure transport is cleaned up on error too + if ( + transport instanceof StdioClientTransport || + transport instanceof SSEClientTransport + ) { + await transport.close(); + } + } + + // If no tools were registered from this MCP server, the following 'if' block + // will close the connection. This is done to conserve resources and prevent + // an orphaned connection to a server that isn't providing any usable + // functionality. Connections to servers that did provide tools are kept + // open, as those tools will require the connection to function. + if (toolRegistry.getToolsByServer(mcpServerName).length === 0) { + console.log( + `No tools registered from MCP server '${mcpServerName}'. Closing connection.`, + ); + if ( + transport instanceof StdioClientTransport || + transport instanceof SSEClientTransport + ) { + await transport.close(); + } } } |
