diff options
| author | christine betts <[email protected]> | 2025-07-25 20:56:33 +0000 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-25 20:56:33 +0000 |
| commit | eb65034117f7722554a717de034e891ba1996e93 (patch) | |
| tree | f279bee5ca55b0e447eabc70a11e96de307d76f3 /packages/core/src | |
| parent | de968877895a8ae5f0edb83a43b37fa190cc8ec9 (diff) | |
Load and use MCP server prompts as slash commands in the CLI (#4828)
Co-authored-by: harold <[email protected]>
Co-authored-by: N. Taylor Mullen <[email protected]>
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/config/config.ts | 7 | ||||
| -rw-r--r-- | packages/core/src/index.ts | 3 | ||||
| -rw-r--r-- | packages/core/src/prompts/mcp-prompts.ts | 19 | ||||
| -rw-r--r-- | packages/core/src/prompts/prompt-registry.ts | 56 | ||||
| -rw-r--r-- | packages/core/src/tools/mcp-client.test.ts | 73 | ||||
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 119 | ||||
| -rw-r--r-- | packages/core/src/tools/tool-registry.test.ts | 2 | ||||
| -rw-r--r-- | packages/core/src/tools/tool-registry.ts | 3 |
8 files changed, 272 insertions, 10 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 96b6f2cb..7ccfdbc8 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -11,6 +11,7 @@ import { ContentGeneratorConfig, createContentGeneratorConfig, } from '../core/contentGenerator.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; @@ -186,6 +187,7 @@ export interface ConfigParameters { export class Config { private toolRegistry!: ToolRegistry; + private promptRegistry!: PromptRegistry; private readonly sessionId: string; private contentGeneratorConfig!: ContentGeneratorConfig; private readonly embeddingModel: string; @@ -314,6 +316,7 @@ export class Config { if (this.getCheckpointingEnabled()) { await this.getGitService(); } + this.promptRegistry = new PromptRegistry(); this.toolRegistry = await this.createToolRegistry(); } @@ -396,6 +399,10 @@ export class Config { return Promise.resolve(this.toolRegistry); } + getPromptRegistry(): PromptRegistry { + return this.promptRegistry; + } + getDebugMode(): boolean { return this.debugMode; } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 9d87ce32..829de544 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -49,6 +49,9 @@ export * from './ide/ideContext.js'; export * from './tools/tools.js'; export * from './tools/tool-registry.js'; +// Export prompt logic +export * from './prompts/mcp-prompts.js'; + // Export specific tool logic export * from './tools/read-file.js'; export * from './tools/ls.js'; diff --git a/packages/core/src/prompts/mcp-prompts.ts b/packages/core/src/prompts/mcp-prompts.ts new file mode 100644 index 00000000..7265a023 --- /dev/null +++ b/packages/core/src/prompts/mcp-prompts.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Config } from '../config/config.js'; +import { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; + +export function getMCPServerPrompts( + config: Config, + serverName: string, +): DiscoveredMCPPrompt[] { + const promptRegistry = config.getPromptRegistry(); + if (!promptRegistry) { + return []; + } + return promptRegistry.getPromptsByServer(serverName); +} diff --git a/packages/core/src/prompts/prompt-registry.ts b/packages/core/src/prompts/prompt-registry.ts new file mode 100644 index 00000000..56699130 --- /dev/null +++ b/packages/core/src/prompts/prompt-registry.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; + +export class PromptRegistry { + private prompts: Map<string, DiscoveredMCPPrompt> = new Map(); + + /** + * Registers a prompt definition. + * @param prompt - The prompt object containing schema and execution logic. + */ + registerPrompt(prompt: DiscoveredMCPPrompt): void { + if (this.prompts.has(prompt.name)) { + const newName = `${prompt.serverName}_${prompt.name}`; + console.warn( + `Prompt with name "${prompt.name}" is already registered. Renaming to "${newName}".`, + ); + this.prompts.set(newName, { ...prompt, name: newName }); + } else { + this.prompts.set(prompt.name, prompt); + } + } + + /** + * Returns an array of all registered and discovered prompt instances. + */ + getAllPrompts(): DiscoveredMCPPrompt[] { + return Array.from(this.prompts.values()).sort((a, b) => + a.name.localeCompare(b.name), + ); + } + + /** + * Get the definition of a specific prompt. + */ + getPrompt(name: string): DiscoveredMCPPrompt | undefined { + return this.prompts.get(name); + } + + /** + * Returns an array of prompts registered from a specific MCP server. + */ + getPromptsByServer(serverName: string): DiscoveredMCPPrompt[] { + const serverPrompts: DiscoveredMCPPrompt[] = []; + for (const prompt of this.prompts.values()) { + if (prompt.serverName === serverName) { + serverPrompts.push(prompt); + } + } + return serverPrompts.sort((a, b) => a.name.localeCompare(b.name)); + } +} diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 428c9d2d..4560982c 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -11,6 +11,7 @@ import { createTransport, isEnabled, discoverTools, + discoverPrompts, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import * as GenAiLib from '@google/genai'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { AuthProviderType } from '../config/config.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); @@ -50,6 +52,77 @@ describe('mcp-client', () => { }); }); + describe('discoverPrompts', () => { + const mockedPromptRegistry = { + registerPrompt: vi.fn(), + } as unknown as PromptRegistry; + + it('should discover and log prompts', async () => { + const mockRequest = vi.fn().mockResolvedValue({ + prompts: [ + { name: 'prompt1', description: 'desc1' }, + { name: 'prompt2' }, + ], + }); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledWith( + { method: 'prompts/list', params: {} }, + expect.anything(), + ); + }); + + it('should do nothing if no prompts are discovered', async () => { + const mockRequest = vi.fn().mockResolvedValue({ + prompts: [], + }); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + const consoleLogSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => { + // no-op + }); + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledOnce(); + expect(consoleLogSpy).not.toHaveBeenCalled(); + + consoleLogSpy.mockRestore(); + }); + + it('should log an error if discovery fails', async () => { + const testError = new Error('test error'); + testError.message = 'test error'; + const mockRequest = vi.fn().mockRejectedValue(testError); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => { + // no-op + }); + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledOnce(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `Error discovering prompts from test-server: ${testError.message}`, + ); + + consoleErrorSpy.mockRestore(); + }); + }); + describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index c59b1592..d175af1f 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -15,12 +15,20 @@ import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { + Prompt, + ListPromptsResultSchema, + GetPromptResult, + GetPromptResultSchema, +} from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import { AuthProviderType, MCPServerConfig } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; + import { FunctionDeclaration, mcpToTool } from '@google/genai'; import { ToolRegistry } from './tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; @@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes +export type DiscoveredMCPPrompt = Prompt & { + serverName: string; + invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>; +}; + /** * Enum representing the connection status of an MCP server */ @@ -55,7 +68,7 @@ export enum MCPDiscoveryState { /** * Map to track the status of each MCP server within the core package */ -const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map(); +const serverStatuses: Map<string, MCPServerStatus> = new Map(); /** * Track the overall MCP discovery state @@ -104,7 +117,7 @@ function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { - mcpServerStatusesInternal.set(serverName, status); + serverStatuses.set(serverName, status); // Notify all listeners for (const listener of statusChangeListeners) { listener(serverName, status); @@ -115,16 +128,14 @@ function updateMCPServerStatus( * Get the current status of an MCP server */ export function getMCPServerStatus(serverName: string): MCPServerStatus { - return ( - mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED - ); + return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED; } /** * Get all MCP server statuses */ export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> { - return new Map(mcpServerStatusesInternal); + return new Map(serverStatuses); } /** @@ -307,6 +318,7 @@ export async function discoverMcpTools( mcpServers: Record<string, MCPServerConfig>, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise<void> { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; @@ -319,6 +331,7 @@ export async function discoverMcpTools( mcpServerName, mcpServerConfig, toolRegistry, + promptRegistry, debugMode, ), ); @@ -362,6 +375,7 @@ export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise<void> { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -378,6 +392,7 @@ export async function connectAndDiscover( console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; + await discoverPrompts(mcpServerName, mcpClient, promptRegistry); const tools = await discoverTools( mcpServerName, @@ -393,7 +408,9 @@ export async function connectAndDiscover( } } catch (error) { console.error( - `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`, + `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage( + error, + )}`, ); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); } @@ -441,9 +458,6 @@ export async function discoverTools( ), ); } - if (discoveredTools.length === 0) { - throw Error('No enabled tools found'); - } return discoveredTools; } catch (error) { throw new Error(`Error discovering tools: ${error}`); @@ -451,6 +465,91 @@ export async function discoverTools( } /** + * Discovers and logs prompts from a connected MCP client. + * It retrieves prompt declarations from the client and logs their names. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + */ +export async function discoverPrompts( + mcpServerName: string, + mcpClient: Client, + promptRegistry: PromptRegistry, +): Promise<void> { + try { + const response = await mcpClient.request( + { method: 'prompts/list', params: {} }, + ListPromptsResultSchema, + ); + + for (const prompt of response.prompts) { + promptRegistry.registerPrompt({ + ...prompt, + serverName: mcpServerName, + invoke: (params: Record<string, unknown>) => + invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), + }); + } + } catch (error) { + // It's okay if this fails, not all servers will have prompts. + // Don't log an error if the method is not found, which is a common case. + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error discovering prompts from ${mcpServerName}: ${getErrorMessage( + error, + )}`, + ); + } + } +} + +/** + * Invokes a prompt on a connected MCP client. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + * @param promptName The name of the prompt to invoke. + * @param promptParams The parameters to pass to the prompt. + * @returns A promise that resolves to the result of the prompt invocation. + */ +export async function invokeMcpPrompt( + mcpServerName: string, + mcpClient: Client, + promptName: string, + promptParams: Record<string, unknown>, +): Promise<GetPromptResult> { + try { + const response = await mcpClient.request( + { + method: 'prompts/get', + params: { + name: promptName, + arguments: promptParams, + }, + }, + GetPromptResultSchema, + ); + + return response; + } catch (error) { + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage( + error, + )}`, + ); + } + throw error; + } +} + +/** * Creates and connects an MCP client to a server based on the provided configuration. * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and * establishes a connection. It also applies a patch to handle request timeouts. diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index de355a98..b3fdd7a3 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -344,6 +344,7 @@ describe('ToolRegistry', () => { mcpServerConfigVal, undefined, toolRegistry, + undefined, false, ); }); @@ -366,6 +367,7 @@ describe('ToolRegistry', () => { mcpServerConfigVal, undefined, toolRegistry, + undefined, false, ); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index b72ed9a5..57627ee0 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -170,6 +170,7 @@ export class ToolRegistry { this.config.getMcpServers() ?? {}, this.config.getMcpServerCommand(), this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); } @@ -192,6 +193,7 @@ export class ToolRegistry { this.config.getMcpServers() ?? {}, this.config.getMcpServerCommand(), this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); } @@ -215,6 +217,7 @@ export class ToolRegistry { { [serverName]: serverConfig }, undefined, this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); } |
