summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/mcp-client.ts')
-rw-r--r--packages/core/src/tools/mcp-client.ts119
1 files changed, 109 insertions, 10 deletions
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index c59b1592..d175af1f 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -15,12 +15,20 @@ import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
+import {
+ Prompt,
+ ListPromptsResultSchema,
+ GetPromptResult,
+ GetPromptResultSchema,
+} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
+
import { FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
+export type DiscoveredMCPPrompt = Prompt & {
+ serverName: string;
+ invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
+};
+
/**
* Enum representing the connection status of an MCP server
*/
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
/**
* Map to track the status of each MCP server within the core package
*/
-const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
+const serverStatuses: Map<string, MCPServerStatus> = new Map();
/**
* Track the overall MCP discovery state
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
- mcpServerStatusesInternal.set(serverName, status);
+ serverStatuses.set(serverName, status);
// Notify all listeners
for (const listener of statusChangeListeners) {
listener(serverName, status);
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
* Get the current status of an MCP server
*/
export function getMCPServerStatus(serverName: string): MCPServerStatus {
- return (
- mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
- );
+ return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
}
/**
* Get all MCP server statuses
*/
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
- return new Map(mcpServerStatusesInternal);
+ return new Map(serverStatuses);
}
/**
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
mcpServerName,
mcpServerConfig,
toolRegistry,
+ promptRegistry,
debugMode,
),
);
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
+ await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
const tools = await discoverTools(
mcpServerName,
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
}
} catch (error) {
console.error(
- `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
+ `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
+ error,
+ )}`,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
@@ -441,9 +458,6 @@ export async function discoverTools(
),
);
}
- if (discoveredTools.length === 0) {
- throw Error('No enabled tools found');
- }
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
@@ -451,6 +465,91 @@ export async function discoverTools(
}
/**
+ * Discovers and logs prompts from a connected MCP client.
+ * It retrieves prompt declarations from the client and logs their names.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ */
+export async function discoverPrompts(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptRegistry: PromptRegistry,
+): Promise<void> {
+ try {
+ const response = await mcpClient.request(
+ { method: 'prompts/list', params: {} },
+ ListPromptsResultSchema,
+ );
+
+ for (const prompt of response.prompts) {
+ promptRegistry.registerPrompt({
+ ...prompt,
+ serverName: mcpServerName,
+ invoke: (params: Record<string, unknown>) =>
+ invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
+ });
+ }
+ } catch (error) {
+ // It's okay if this fails, not all servers will have prompts.
+ // Don't log an error if the method is not found, which is a common case.
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ }
+}
+
+/**
+ * Invokes a prompt on a connected MCP client.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ * @param promptName The name of the prompt to invoke.
+ * @param promptParams The parameters to pass to the prompt.
+ * @returns A promise that resolves to the result of the prompt invocation.
+ */
+export async function invokeMcpPrompt(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptName: string,
+ promptParams: Record<string, unknown>,
+): Promise<GetPromptResult> {
+ try {
+ const response = await mcpClient.request(
+ {
+ method: 'prompts/get',
+ params: {
+ name: promptName,
+ arguments: promptParams,
+ },
+ },
+ GetPromptResultSchema,
+ );
+
+ return response;
+ } catch (error) {
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ throw error;
+ }
+}
+
+/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
* establishes a connection. It also applies a patch to handle request timeouts.