diff options
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 119 |
1 files changed, 109 insertions, 10 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index c59b1592..d175af1f 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -15,12 +15,20 @@ import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { + Prompt, + ListPromptsResultSchema, + GetPromptResult, + GetPromptResultSchema, +} from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import { AuthProviderType, MCPServerConfig } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; + import { FunctionDeclaration, mcpToTool } from '@google/genai'; import { ToolRegistry } from './tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; @@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes +export type DiscoveredMCPPrompt = Prompt & { + serverName: string; + invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>; +}; + /** * Enum representing the connection status of an MCP server */ @@ -55,7 +68,7 @@ export enum MCPDiscoveryState { /** * Map to track the status of each MCP server within the core package */ -const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map(); +const serverStatuses: Map<string, MCPServerStatus> = new Map(); /** * Track the overall MCP discovery state @@ -104,7 +117,7 @@ function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { - mcpServerStatusesInternal.set(serverName, status); + serverStatuses.set(serverName, status); // Notify all listeners for (const listener of statusChangeListeners) { listener(serverName, status); @@ -115,16 +128,14 @@ function updateMCPServerStatus( * Get the current status of an MCP server */ export function getMCPServerStatus(serverName: string): MCPServerStatus { - return ( - mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED - ); + return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED; } /** * Get all MCP server statuses */ export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> { - return new Map(mcpServerStatusesInternal); + return new Map(serverStatuses); } /** @@ -307,6 +318,7 @@ export async function discoverMcpTools( mcpServers: Record<string, MCPServerConfig>, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise<void> { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; @@ -319,6 +331,7 @@ export async function discoverMcpTools( mcpServerName, mcpServerConfig, toolRegistry, + promptRegistry, debugMode, ), ); @@ -362,6 +375,7 @@ export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise<void> { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -378,6 +392,7 @@ export async function connectAndDiscover( console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; + await discoverPrompts(mcpServerName, mcpClient, promptRegistry); const tools = await discoverTools( mcpServerName, @@ -393,7 +408,9 @@ export async function connectAndDiscover( } } catch (error) { console.error( - `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`, + `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage( + error, + )}`, ); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); } @@ -441,9 +458,6 @@ export async function discoverTools( ), ); } - if (discoveredTools.length === 0) { - throw Error('No enabled tools found'); - } return discoveredTools; } catch (error) { throw new Error(`Error discovering tools: ${error}`); @@ -451,6 +465,91 @@ export async function discoverTools( } /** + * Discovers and logs prompts from a connected MCP client. + * It retrieves prompt declarations from the client and logs their names. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + */ +export async function discoverPrompts( + mcpServerName: string, + mcpClient: Client, + promptRegistry: PromptRegistry, +): Promise<void> { + try { + const response = await mcpClient.request( + { method: 'prompts/list', params: {} }, + ListPromptsResultSchema, + ); + + for (const prompt of response.prompts) { + promptRegistry.registerPrompt({ + ...prompt, + serverName: mcpServerName, + invoke: (params: Record<string, unknown>) => + invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), + }); + } + } catch (error) { + // It's okay if this fails, not all servers will have prompts. + // Don't log an error if the method is not found, which is a common case. + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error discovering prompts from ${mcpServerName}: ${getErrorMessage( + error, + )}`, + ); + } + } +} + +/** + * Invokes a prompt on a connected MCP client. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + * @param promptName The name of the prompt to invoke. + * @param promptParams The parameters to pass to the prompt. + * @returns A promise that resolves to the result of the prompt invocation. + */ +export async function invokeMcpPrompt( + mcpServerName: string, + mcpClient: Client, + promptName: string, + promptParams: Record<string, unknown>, +): Promise<GetPromptResult> { + try { + const response = await mcpClient.request( + { + method: 'prompts/get', + params: { + name: promptName, + arguments: promptParams, + }, + }, + GetPromptResultSchema, + ); + + return response; + } catch (error) { + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage( + error, + )}`, + ); + } + throw error; + } +} + +/** * Creates and connects an MCP client to a server based on the provided configuration. * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and * establishes a connection. It also applies a patch to handle request timeouts. |
