summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorchristine betts <[email protected]>2025-07-25 20:56:33 +0000
committerGitHub <[email protected]>2025-07-25 20:56:33 +0000
commiteb65034117f7722554a717de034e891ba1996e93 (patch)
treef279bee5ca55b0e447eabc70a11e96de307d76f3 /packages/core/src
parentde968877895a8ae5f0edb83a43b37fa190cc8ec9 (diff)
Load and use MCP server prompts as slash commands in the CLI (#4828)
Co-authored-by: harold <[email protected]> Co-authored-by: N. Taylor Mullen <[email protected]>
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/config/config.ts7
-rw-r--r--packages/core/src/index.ts3
-rw-r--r--packages/core/src/prompts/mcp-prompts.ts19
-rw-r--r--packages/core/src/prompts/prompt-registry.ts56
-rw-r--r--packages/core/src/tools/mcp-client.test.ts73
-rw-r--r--packages/core/src/tools/mcp-client.ts119
-rw-r--r--packages/core/src/tools/tool-registry.test.ts2
-rw-r--r--packages/core/src/tools/tool-registry.ts3
8 files changed, 272 insertions, 10 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 96b6f2cb..7ccfdbc8 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -11,6 +11,7 @@ import {
ContentGeneratorConfig,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { LSTool } from '../tools/ls.js';
import { ReadFileTool } from '../tools/read-file.js';
@@ -186,6 +187,7 @@ export interface ConfigParameters {
export class Config {
private toolRegistry!: ToolRegistry;
+ private promptRegistry!: PromptRegistry;
private readonly sessionId: string;
private contentGeneratorConfig!: ContentGeneratorConfig;
private readonly embeddingModel: string;
@@ -314,6 +316,7 @@ export class Config {
if (this.getCheckpointingEnabled()) {
await this.getGitService();
}
+ this.promptRegistry = new PromptRegistry();
this.toolRegistry = await this.createToolRegistry();
}
@@ -396,6 +399,10 @@ export class Config {
return Promise.resolve(this.toolRegistry);
}
+ getPromptRegistry(): PromptRegistry {
+ return this.promptRegistry;
+ }
+
getDebugMode(): boolean {
return this.debugMode;
}
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index 9d87ce32..829de544 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -49,6 +49,9 @@ export * from './ide/ideContext.js';
export * from './tools/tools.js';
export * from './tools/tool-registry.js';
+// Export prompt logic
+export * from './prompts/mcp-prompts.js';
+
// Export specific tool logic
export * from './tools/read-file.js';
export * from './tools/ls.js';
diff --git a/packages/core/src/prompts/mcp-prompts.ts b/packages/core/src/prompts/mcp-prompts.ts
new file mode 100644
index 00000000..7265a023
--- /dev/null
+++ b/packages/core/src/prompts/mcp-prompts.ts
@@ -0,0 +1,19 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { Config } from '../config/config.js';
+import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
+
+export function getMCPServerPrompts(
+ config: Config,
+ serverName: string,
+): DiscoveredMCPPrompt[] {
+ const promptRegistry = config.getPromptRegistry();
+ if (!promptRegistry) {
+ return [];
+ }
+ return promptRegistry.getPromptsByServer(serverName);
+}
diff --git a/packages/core/src/prompts/prompt-registry.ts b/packages/core/src/prompts/prompt-registry.ts
new file mode 100644
index 00000000..56699130
--- /dev/null
+++ b/packages/core/src/prompts/prompt-registry.ts
@@ -0,0 +1,56 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
+
+export class PromptRegistry {
+ private prompts: Map<string, DiscoveredMCPPrompt> = new Map();
+
+ /**
+ * Registers a prompt definition.
+ * @param prompt - The prompt object containing schema and execution logic.
+ */
+ registerPrompt(prompt: DiscoveredMCPPrompt): void {
+ if (this.prompts.has(prompt.name)) {
+ const newName = `${prompt.serverName}_${prompt.name}`;
+ console.warn(
+ `Prompt with name "${prompt.name}" is already registered. Renaming to "${newName}".`,
+ );
+ this.prompts.set(newName, { ...prompt, name: newName });
+ } else {
+ this.prompts.set(prompt.name, prompt);
+ }
+ }
+
+ /**
+ * Returns an array of all registered and discovered prompt instances.
+ */
+ getAllPrompts(): DiscoveredMCPPrompt[] {
+ return Array.from(this.prompts.values()).sort((a, b) =>
+ a.name.localeCompare(b.name),
+ );
+ }
+
+ /**
+ * Get the definition of a specific prompt.
+ */
+ getPrompt(name: string): DiscoveredMCPPrompt | undefined {
+ return this.prompts.get(name);
+ }
+
+ /**
+ * Returns an array of prompts registered from a specific MCP server.
+ */
+ getPromptsByServer(serverName: string): DiscoveredMCPPrompt[] {
+ const serverPrompts: DiscoveredMCPPrompt[] = [];
+ for (const prompt of this.prompts.values()) {
+ if (prompt.serverName === serverName) {
+ serverPrompts.push(prompt);
+ }
+ }
+ return serverPrompts.sort((a, b) => a.name.localeCompare(b.name));
+ }
+}
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 428c9d2d..4560982c 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -11,6 +11,7 @@ import {
createTransport,
isEnabled,
discoverTools,
+ discoverPrompts,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
@@ -50,6 +52,77 @@ describe('mcp-client', () => {
});
});
+ describe('discoverPrompts', () => {
+ const mockedPromptRegistry = {
+ registerPrompt: vi.fn(),
+ } as unknown as PromptRegistry;
+
+ it('should discover and log prompts', async () => {
+ const mockRequest = vi.fn().mockResolvedValue({
+ prompts: [
+ { name: 'prompt1', description: 'desc1' },
+ { name: 'prompt2' },
+ ],
+ });
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledWith(
+ { method: 'prompts/list', params: {} },
+ expect.anything(),
+ );
+ });
+
+ it('should do nothing if no prompts are discovered', async () => {
+ const mockRequest = vi.fn().mockResolvedValue({
+ prompts: [],
+ });
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ const consoleLogSpy = vi
+ .spyOn(console, 'debug')
+ .mockImplementation(() => {
+ // no-op
+ });
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledOnce();
+ expect(consoleLogSpy).not.toHaveBeenCalled();
+
+ consoleLogSpy.mockRestore();
+ });
+
+ it('should log an error if discovery fails', async () => {
+ const testError = new Error('test error');
+ testError.message = 'test error';
+ const mockRequest = vi.fn().mockRejectedValue(testError);
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {
+ // no-op
+ });
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledOnce();
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ `Error discovering prompts from test-server: ${testError.message}`,
+ );
+
+ consoleErrorSpy.mockRestore();
+ });
+ });
+
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index c59b1592..d175af1f 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -15,12 +15,20 @@ import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
+import {
+ Prompt,
+ ListPromptsResultSchema,
+ GetPromptResult,
+ GetPromptResultSchema,
+} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
+
import { FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
+export type DiscoveredMCPPrompt = Prompt & {
+ serverName: string;
+ invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
+};
+
/**
* Enum representing the connection status of an MCP server
*/
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
/**
* Map to track the status of each MCP server within the core package
*/
-const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
+const serverStatuses: Map<string, MCPServerStatus> = new Map();
/**
* Track the overall MCP discovery state
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
- mcpServerStatusesInternal.set(serverName, status);
+ serverStatuses.set(serverName, status);
// Notify all listeners
for (const listener of statusChangeListeners) {
listener(serverName, status);
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
* Get the current status of an MCP server
*/
export function getMCPServerStatus(serverName: string): MCPServerStatus {
- return (
- mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
- );
+ return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
}
/**
* Get all MCP server statuses
*/
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
- return new Map(mcpServerStatusesInternal);
+ return new Map(serverStatuses);
}
/**
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
mcpServerName,
mcpServerConfig,
toolRegistry,
+ promptRegistry,
debugMode,
),
);
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
+ await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
const tools = await discoverTools(
mcpServerName,
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
}
} catch (error) {
console.error(
- `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
+ `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
+ error,
+ )}`,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
@@ -441,9 +458,6 @@ export async function discoverTools(
),
);
}
- if (discoveredTools.length === 0) {
- throw Error('No enabled tools found');
- }
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
@@ -451,6 +465,91 @@ export async function discoverTools(
}
/**
+ * Discovers and logs prompts from a connected MCP client.
+ * It retrieves prompt declarations from the client and logs their names.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ */
+export async function discoverPrompts(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptRegistry: PromptRegistry,
+): Promise<void> {
+ try {
+ const response = await mcpClient.request(
+ { method: 'prompts/list', params: {} },
+ ListPromptsResultSchema,
+ );
+
+ for (const prompt of response.prompts) {
+ promptRegistry.registerPrompt({
+ ...prompt,
+ serverName: mcpServerName,
+ invoke: (params: Record<string, unknown>) =>
+ invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
+ });
+ }
+ } catch (error) {
+ // It's okay if this fails, not all servers will have prompts.
+ // Don't log an error if the method is not found, which is a common case.
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ }
+}
+
+/**
+ * Invokes a prompt on a connected MCP client.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ * @param promptName The name of the prompt to invoke.
+ * @param promptParams The parameters to pass to the prompt.
+ * @returns A promise that resolves to the result of the prompt invocation.
+ */
+export async function invokeMcpPrompt(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptName: string,
+ promptParams: Record<string, unknown>,
+): Promise<GetPromptResult> {
+ try {
+ const response = await mcpClient.request(
+ {
+ method: 'prompts/get',
+ params: {
+ name: promptName,
+ arguments: promptParams,
+ },
+ },
+ GetPromptResultSchema,
+ );
+
+ return response;
+ } catch (error) {
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ throw error;
+ }
+}
+
+/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
* establishes a connection. It also applies a patch to handle request timeouts.
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index de355a98..b3fdd7a3 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -344,6 +344,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
+ undefined,
false,
);
});
@@ -366,6 +367,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
+ undefined,
false,
);
});
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index b72ed9a5..57627ee0 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -170,6 +170,7 @@ export class ToolRegistry {
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
@@ -192,6 +193,7 @@ export class ToolRegistry {
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
@@ -215,6 +217,7 @@ export class ToolRegistry {
{ [serverName]: serverConfig },
undefined,
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}