summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/cli/src/services/BuiltinCommandLoader.test.ts11
-rw-r--r--packages/cli/src/services/BuiltinCommandLoader.ts2
-rw-r--r--packages/cli/src/services/McpPromptLoader.ts231
-rw-r--r--packages/cli/src/ui/App.test.tsx1
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.test.ts18
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.ts65
-rw-r--r--packages/cli/src/ui/commands/types.ts1
-rw-r--r--packages/cli/src/ui/hooks/slashCommandProcessor.test.ts78
-rw-r--r--packages/cli/src/ui/hooks/slashCommandProcessor.ts155
-rw-r--r--packages/cli/src/ui/hooks/useCompletion.test.ts6
-rw-r--r--packages/cli/src/ui/hooks/useCompletion.ts11
-rw-r--r--packages/core/src/config/config.ts7
-rw-r--r--packages/core/src/index.ts3
-rw-r--r--packages/core/src/prompts/mcp-prompts.ts19
-rw-r--r--packages/core/src/prompts/prompt-registry.ts56
-rw-r--r--packages/core/src/tools/mcp-client.test.ts73
-rw-r--r--packages/core/src/tools/mcp-client.ts119
-rw-r--r--packages/core/src/tools/tool-registry.test.ts2
-rw-r--r--packages/core/src/tools/tool-registry.ts3
19 files changed, 761 insertions, 100 deletions
diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts
index 0e64b1ac..cd449dd8 100644
--- a/packages/cli/src/services/BuiltinCommandLoader.test.ts
+++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts
@@ -40,13 +40,19 @@ vi.mock('../ui/commands/extensionsCommand.js', () => ({
extensionsCommand: {},
}));
vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} }));
-vi.mock('../ui/commands/mcpCommand.js', () => ({ mcpCommand: {} }));
vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} }));
vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} }));
vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} }));
vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} }));
vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} }));
vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} }));
+vi.mock('../ui/commands/mcpCommand.js', () => ({
+ mcpCommand: {
+ name: 'mcp',
+ description: 'MCP command',
+ kind: 'BUILT_IN',
+ },
+}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -114,5 +120,8 @@ describe('BuiltinCommandLoader', () => {
const ideCmd = commands.find((c) => c.name === 'ide');
expect(ideCmd).toBeDefined();
+
+ const mcpCmd = commands.find((c) => c.name === 'mcp');
+ expect(mcpCmd).toBeDefined();
});
});
diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts
index 259c6013..58adf5cb 100644
--- a/packages/cli/src/services/BuiltinCommandLoader.ts
+++ b/packages/cli/src/services/BuiltinCommandLoader.ts
@@ -58,9 +58,9 @@ export class BuiltinCommandLoader implements ICommandLoader {
extensionsCommand,
helpCommand,
ideCommand(this.config),
- mcpCommand,
memoryCommand,
privacyCommand,
+ mcpCommand,
quitCommand,
restoreCommand(this.config),
statsCommand,
diff --git a/packages/cli/src/services/McpPromptLoader.ts b/packages/cli/src/services/McpPromptLoader.ts
new file mode 100644
index 00000000..e912fb3e
--- /dev/null
+++ b/packages/cli/src/services/McpPromptLoader.ts
@@ -0,0 +1,231 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import {
+ Config,
+ getErrorMessage,
+ getMCPServerPrompts,
+} from '@google/gemini-cli-core';
+import {
+ CommandContext,
+ CommandKind,
+ SlashCommand,
+ SlashCommandActionReturn,
+} from '../ui/commands/types.js';
+import { ICommandLoader } from './types.js';
+import { PromptArgument } from '@modelcontextprotocol/sdk/types.js';
+
+/**
+ * Discovers and loads executable slash commands from prompts exposed by
+ * Model-Context-Protocol (MCP) servers.
+ */
+export class McpPromptLoader implements ICommandLoader {
+ constructor(private readonly config: Config | null) {}
+
+ /**
+ * Loads all available prompts from all configured MCP servers and adapts
+ * them into executable SlashCommand objects.
+ *
+ * @param _signal An AbortSignal (unused for this synchronous loader).
+ * @returns A promise that resolves to an array of loaded SlashCommands.
+ */
+ loadCommands(_signal: AbortSignal): Promise<SlashCommand[]> {
+ const promptCommands: SlashCommand[] = [];
+ if (!this.config) {
+ return Promise.resolve([]);
+ }
+ const mcpServers = this.config.getMcpServers() || {};
+ for (const serverName in mcpServers) {
+ const prompts = getMCPServerPrompts(this.config, serverName) || [];
+ for (const prompt of prompts) {
+ const commandName = `${prompt.name}`;
+ const newPromptCommand: SlashCommand = {
+ name: commandName,
+ description: prompt.description || `Invoke prompt ${prompt.name}`,
+ kind: CommandKind.MCP_PROMPT,
+ subCommands: [
+ {
+ name: 'help',
+ description: 'Show help for this prompt',
+ kind: CommandKind.MCP_PROMPT,
+ action: async (): Promise<SlashCommandActionReturn> => {
+ if (!prompt.arguments || prompt.arguments.length === 0) {
+ return {
+ type: 'message',
+ messageType: 'info',
+ content: `Prompt "${prompt.name}" has no arguments.`,
+ };
+ }
+
+ let helpMessage = `Arguments for "${prompt.name}":\n\n`;
+ if (prompt.arguments && prompt.arguments.length > 0) {
+ helpMessage += `You can provide arguments by name (e.g., --argName="value") or by position.\n\n`;
+ helpMessage += `e.g., ${prompt.name} ${prompt.arguments?.map((_) => `"foo"`)} is equivalent to ${prompt.name} ${prompt.arguments?.map((arg) => `--${arg.name}="foo"`)}\n\n`;
+ }
+ for (const arg of prompt.arguments) {
+ helpMessage += ` --${arg.name}\n`;
+ if (arg.description) {
+ helpMessage += ` ${arg.description}\n`;
+ }
+ helpMessage += ` (required: ${
+ arg.required ? 'yes' : 'no'
+ })\n\n`;
+ }
+ return {
+ type: 'message',
+ messageType: 'info',
+ content: helpMessage,
+ };
+ },
+ },
+ ],
+ action: async (
+ context: CommandContext,
+ args: string,
+ ): Promise<SlashCommandActionReturn> => {
+ if (!this.config) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Config not loaded.',
+ };
+ }
+
+ const promptInputs = this.parseArgs(args, prompt.arguments);
+ if (promptInputs instanceof Error) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: promptInputs.message,
+ };
+ }
+
+ try {
+ const mcpServers = this.config.getMcpServers() || {};
+ const mcpServerConfig = mcpServers[serverName];
+ if (!mcpServerConfig) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `MCP server config not found for '${serverName}'.`,
+ };
+ }
+ const result = await prompt.invoke(promptInputs);
+
+ if (result.error) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `Error invoking prompt: ${result.error}`,
+ };
+ }
+
+ if (!result.messages?.[0]?.content?.text) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content:
+ 'Received an empty or invalid prompt response from the server.',
+ };
+ }
+
+ return {
+ type: 'submit_prompt',
+ content: JSON.stringify(result.messages[0].content.text),
+ };
+ } catch (error) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: `Error: ${getErrorMessage(error)}`,
+ };
+ }
+ },
+ completion: async (_: CommandContext, partialArg: string) => {
+ if (!prompt || !prompt.arguments) {
+ return [];
+ }
+
+ const suggestions: string[] = [];
+ const usedArgNames = new Set(
+ (partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)),
+ );
+
+ for (const arg of prompt.arguments) {
+ if (!usedArgNames.has(arg.name)) {
+ suggestions.push(`--${arg.name}=""`);
+ }
+ }
+
+ return suggestions;
+ },
+ };
+ promptCommands.push(newPromptCommand);
+ }
+ }
+ return Promise.resolve(promptCommands);
+ }
+
+ private parseArgs(
+ userArgs: string,
+ promptArgs: PromptArgument[] | undefined,
+ ): Record<string, unknown> | Error {
+ const argValues: { [key: string]: string } = {};
+ const promptInputs: Record<string, unknown> = {};
+
+ // arg parsing: --key="value" or --key=value
+ const namedArgRegex = /--([^=]+)=(?:"((?:\\.|[^"\\])*)"|([^ ]*))/g;
+ let match;
+ const remainingArgs: string[] = [];
+ let lastIndex = 0;
+
+ while ((match = namedArgRegex.exec(userArgs)) !== null) {
+ const key = match[1];
+ const value = match[2] ?? match[3]; // Quoted or unquoted value
+ argValues[key] = value;
+ // Capture text between matches as potential positional args
+ if (match.index > lastIndex) {
+ remainingArgs.push(userArgs.substring(lastIndex, match.index).trim());
+ }
+ lastIndex = namedArgRegex.lastIndex;
+ }
+
+ // Capture any remaining text after the last named arg
+ if (lastIndex < userArgs.length) {
+ remainingArgs.push(userArgs.substring(lastIndex).trim());
+ }
+
+ const positionalArgs = remainingArgs.join(' ').split(/ +/);
+
+ if (!promptArgs) {
+ return promptInputs;
+ }
+ for (const arg of promptArgs) {
+ if (argValues[arg.name]) {
+ promptInputs[arg.name] = argValues[arg.name];
+ }
+ }
+
+ const unfilledArgs = promptArgs.filter(
+ (arg) => arg.required && !promptInputs[arg.name],
+ );
+
+ const missingArgs: string[] = [];
+ for (let i = 0; i < unfilledArgs.length; i++) {
+ if (positionalArgs.length > i && positionalArgs[i]) {
+ promptInputs[unfilledArgs[i].name] = positionalArgs[i];
+ } else {
+ missingArgs.push(unfilledArgs[i].name);
+ }
+ }
+
+ if (missingArgs.length > 0) {
+ const missingArgNames = missingArgs.map((name) => `--${name}`).join(', ');
+ return new Error(`Missing required argument(s): ${missingArgNames}`);
+ }
+ return promptInputs;
+ }
+}
diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx
index 56093562..903f4b66 100644
--- a/packages/cli/src/ui/App.test.tsx
+++ b/packages/cli/src/ui/App.test.tsx
@@ -125,6 +125,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
getToolCallCommand: vi.fn(() => opts.toolCallCommand),
getMcpServerCommand: vi.fn(() => opts.mcpServerCommand),
getMcpServers: vi.fn(() => opts.mcpServers),
+ getPromptRegistry: vi.fn(),
getExtensions: vi.fn(() => []),
getBlockedMcpServers: vi.fn(() => []),
getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'),
diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts
index 2b8753a0..afa71ba5 100644
--- a/packages/cli/src/ui/commands/mcpCommand.test.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.test.ts
@@ -71,6 +71,7 @@ describe('mcpCommand', () => {
getToolRegistry: ReturnType<typeof vi.fn>;
getMcpServers: ReturnType<typeof vi.fn>;
getBlockedMcpServers: ReturnType<typeof vi.fn>;
+ getPromptRegistry: ReturnType<typeof vi.fn>;
};
beforeEach(() => {
@@ -92,6 +93,10 @@ describe('mcpCommand', () => {
}),
getMcpServers: vi.fn().mockReturnValue({}),
getBlockedMcpServers: vi.fn().mockReturnValue([]),
+ getPromptRegistry: vi.fn().mockResolvedValue({
+ getAllPrompts: vi.fn().mockReturnValue([]),
+ getPromptsByServer: vi.fn().mockReturnValue([]),
+ }),
};
mockContext = createMockCommandContext({
@@ -223,7 +228,7 @@ describe('mcpCommand', () => {
// Server 2 - Connected
expect(message).toContain(
- '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tools)',
+ '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tool)',
);
expect(message).toContain('server2_tool1');
@@ -365,13 +370,13 @@ describe('mcpCommand', () => {
if (isMessageAction(result)) {
const message = result.content;
expect(message).toContain(
- '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
+ '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
);
expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m');
expect(message).toContain(
'🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)',
);
- expect(message).toContain('No tools available');
+ expect(message).toContain('No tools or prompts available');
}
});
@@ -421,10 +426,10 @@ describe('mcpCommand', () => {
// Check server statuses
expect(message).toContain(
- '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
+ '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
);
expect(message).toContain(
- '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools will appear when ready)',
+ '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools and prompts will appear when ready)',
);
}
});
@@ -994,6 +999,9 @@ describe('mcpCommand', () => {
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
+ getPromptRegistry: vi.fn().mockResolvedValue({
+ getPromptsByServer: vi.fn().mockReturnValue([]),
+ }),
},
},
});
diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts
index 5467b994..709053b6 100644
--- a/packages/cli/src/ui/commands/mcpCommand.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.ts
@@ -12,6 +12,7 @@ import {
MessageActionReturn,
} from './types.js';
import {
+ DiscoveredMCPPrompt,
DiscoveredMCPTool,
getMCPDiscoveryState,
getMCPServerStatus,
@@ -101,6 +102,8 @@ const getMcpStatus = async (
(tool) =>
tool instanceof DiscoveredMCPTool && tool.serverName === serverName,
) as DiscoveredMCPTool[];
+ const promptRegistry = await config.getPromptRegistry();
+ const serverPrompts = promptRegistry.getPromptsByServer(serverName) || [];
const status = getMCPServerStatus(serverName);
@@ -160,9 +163,26 @@ const getMcpStatus = async (
// Add tool count with conditional messaging
if (status === MCPServerStatus.CONNECTED) {
- message += ` (${serverTools.length} tools)`;
+ const parts = [];
+ if (serverTools.length > 0) {
+ parts.push(
+ `${serverTools.length} ${serverTools.length === 1 ? 'tool' : 'tools'}`,
+ );
+ }
+ if (serverPrompts.length > 0) {
+ parts.push(
+ `${serverPrompts.length} ${
+ serverPrompts.length === 1 ? 'prompt' : 'prompts'
+ }`,
+ );
+ }
+ if (parts.length > 0) {
+ message += ` (${parts.join(', ')})`;
+ } else {
+ message += ` (0 tools)`;
+ }
} else if (status === MCPServerStatus.CONNECTING) {
- message += ` (tools will appear when ready)`;
+ message += ` (tools and prompts will appear when ready)`;
} else {
message += ` (${serverTools.length} tools cached)`;
}
@@ -186,6 +206,7 @@ const getMcpStatus = async (
message += RESET_COLOR;
if (serverTools.length > 0) {
+ message += ` ${COLOR_CYAN}Tools:${RESET_COLOR}\n`;
serverTools.forEach((tool) => {
if (showDescriptions && tool.description) {
// Format tool name in cyan using simple ANSI cyan color
@@ -222,12 +243,41 @@ const getMcpStatus = async (
}
}
});
- } else {
+ }
+ if (serverPrompts.length > 0) {
+ if (serverTools.length > 0) {
+ message += '\n';
+ }
+ message += ` ${COLOR_CYAN}Prompts:${RESET_COLOR}\n`;
+ serverPrompts.forEach((prompt: DiscoveredMCPPrompt) => {
+ if (showDescriptions && prompt.description) {
+ message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}`;
+ const descLines = prompt.description.trim().split('\n');
+ if (descLines) {
+ message += ':\n';
+ for (const descLine of descLines) {
+ message += ` ${COLOR_GREEN}${descLine}${RESET_COLOR}\n`;
+ }
+ } else {
+ message += '\n';
+ }
+ } else {
+ message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}\n`;
+ }
+ });
+ }
+
+ if (serverTools.length === 0 && serverPrompts.length === 0) {
+ message += ' No tools or prompts available\n';
+ } else if (serverTools.length === 0) {
message += ' No tools available';
if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`;
}
message += '\n';
+ } else if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
+ // This case is for when serverTools.length > 0
+ message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}\n`;
}
message += '\n';
}
@@ -328,11 +378,10 @@ const authCommand: SlashCommand = {
// Import dynamically to avoid circular dependencies
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
- // Create OAuth config for authentication (will be discovered automatically)
- const oauthConfig = server.oauth || {
- authorizationUrl: '', // Will be discovered automatically
- tokenUrl: '', // Will be discovered automatically
- };
+ let oauthConfig = server.oauth;
+ if (!oauthConfig) {
+ oauthConfig = { enabled: false };
+ }
// Pass the MCP server URL for OAuth discovery
const mcpServerUrl = server.httpUrl || server.url;
diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts
index 9a1088fd..1684677c 100644
--- a/packages/cli/src/ui/commands/types.ts
+++ b/packages/cli/src/ui/commands/types.ts
@@ -128,6 +128,7 @@ export type SlashCommandActionReturn =
export enum CommandKind {
BUILT_IN = 'built-in',
FILE = 'file',
+ MCP_PROMPT = 'mcp-prompt',
}
// The standardized contract for any command in the system.
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
index 84eeb033..d308af46 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
@@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({
})),
}));
+const mockMcpLoadCommands = vi.fn();
+vi.mock('../../services/McpPromptLoader.js', () => ({
+ McpPromptLoader: vi.fn().mockImplementation(() => ({
+ loadCommands: mockMcpLoadCommands,
+ })),
+}));
+
vi.mock('../contexts/SessionContext.js', () => ({
useSessionStats: vi.fn(() => ({ stats: {} })),
}));
@@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js';
import { MessageType } from '../types.js';
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
+import { McpPromptLoader } from '../../services/McpPromptLoader.js';
const createTestCommand = (
overrides: Partial<SlashCommand>,
@@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => {
(vi.mocked(BuiltinCommandLoader) as Mock).mockClear();
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
+ mockMcpLoadCommands.mockResolvedValue([]);
});
const setupProcessorHook = (
builtinCommands: SlashCommand[] = [],
fileCommands: SlashCommand[] = [],
+ mcpCommands: SlashCommand[] = [],
) => {
mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands));
mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands));
+ mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands));
const { result } = renderHook(() =>
useSlashCommandProcessor(
@@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => {
setupProcessorHook();
expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig);
expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig);
+ expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig);
});
it('should call loadCommands and populate state after mounting', async () => {
@@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => {
expect(result.current.slashCommands[0]?.name).toBe('test');
expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1);
expect(mockFileLoadCommands).toHaveBeenCalledTimes(1);
+ expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1);
});
it('should provide an immutable array of commands to consumers', async () => {
@@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => {
expect.any(Number),
);
});
+
+ it('should handle "submit_prompt" action returned from a mcp-based command', async () => {
+ const mcpCommand = createTestCommand(
+ {
+ name: 'mcpcmd',
+ description: 'A command from mcp',
+ action: async () => ({
+ type: 'submit_prompt',
+ content: 'The actual prompt from the mcp command.',
+ }),
+ },
+ CommandKind.MCP_PROMPT,
+ );
+
+ const result = setupProcessorHook([], [], [mcpCommand]);
+ await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
+
+ let actionResult;
+ await act(async () => {
+ actionResult = await result.current.handleSlashCommand('/mcpcmd');
+ });
+
+ expect(actionResult).toEqual({
+ type: 'submit_prompt',
+ content: 'The actual prompt from the mcp command.',
+ });
+
+ expect(mockAddItem).toHaveBeenCalledWith(
+ { type: MessageType.USER, text: '/mcpcmd' },
+ expect.any(Number),
+ );
+ });
});
describe('Command Parsing and Matching', () => {
@@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => {
});
describe('Command Precedence', () => {
+ it('should override mcp-based commands with file-based commands of the same name', async () => {
+ const mcpAction = vi.fn();
+ const fileAction = vi.fn();
+
+ const mcpCommand = createTestCommand(
+ {
+ name: 'override',
+ description: 'mcp',
+ action: mcpAction,
+ },
+ CommandKind.MCP_PROMPT,
+ );
+ const fileCommand = createTestCommand(
+ { name: 'override', description: 'file', action: fileAction },
+ CommandKind.FILE,
+ );
+
+ const result = setupProcessorHook([], [fileCommand], [mcpCommand]);
+
+ await waitFor(() => {
+ // The service should only return one command with the name 'override'
+ expect(result.current.slashCommands).toHaveLength(1);
+ });
+
+ await act(async () => {
+ await result.current.handleSlashCommand('/override');
+ });
+
+ // Only the file-based command's action should be called.
+ expect(fileAction).toHaveBeenCalledTimes(1);
+ expect(mcpAction).not.toHaveBeenCalled();
+ });
+
it('should prioritize a command with a primary name over a command with a matching alias', async () => {
const quitAction = vi.fn();
const exitAction = vi.fn();
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
index fa2b0b12..9e9dc21c 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
@@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js';
import { CommandService } from '../../services/CommandService.js';
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
+import { McpPromptLoader } from '../../services/McpPromptLoader.js';
/**
* Hook to define and process slash commands (e.g., /help, /clear).
@@ -164,6 +165,7 @@ export const useSlashCommandProcessor = (
const controller = new AbortController();
const load = async () => {
const loaders = [
+ new McpPromptLoader(config),
new BuiltinCommandLoader(config),
new FileCommandLoader(config),
];
@@ -246,82 +248,95 @@ export const useSlashCommandProcessor = (
args,
},
};
- const result = await commandToExecute.action(
- fullCommandContext,
- args,
- );
+ try {
+ const result = await commandToExecute.action(
+ fullCommandContext,
+ args,
+ );
- if (result) {
- switch (result.type) {
- case 'tool':
- return {
- type: 'schedule_tool',
- toolName: result.toolName,
- toolArgs: result.toolArgs,
- };
- case 'message':
- addItem(
- {
- type:
- result.messageType === 'error'
- ? MessageType.ERROR
- : MessageType.INFO,
- text: result.content,
- },
- Date.now(),
- );
- return { type: 'handled' };
- case 'dialog':
- switch (result.dialog) {
- case 'help':
- setShowHelp(true);
- return { type: 'handled' };
- case 'auth':
- openAuthDialog();
- return { type: 'handled' };
- case 'theme':
- openThemeDialog();
- return { type: 'handled' };
- case 'editor':
- openEditorDialog();
- return { type: 'handled' };
- case 'privacy':
- openPrivacyNotice();
- return { type: 'handled' };
- default: {
- const unhandled: never = result.dialog;
- throw new Error(
- `Unhandled slash command result: ${unhandled}`,
- );
+ if (result) {
+ switch (result.type) {
+ case 'tool':
+ return {
+ type: 'schedule_tool',
+ toolName: result.toolName,
+ toolArgs: result.toolArgs,
+ };
+ case 'message':
+ addItem(
+ {
+ type:
+ result.messageType === 'error'
+ ? MessageType.ERROR
+ : MessageType.INFO,
+ text: result.content,
+ },
+ Date.now(),
+ );
+ return { type: 'handled' };
+ case 'dialog':
+ switch (result.dialog) {
+ case 'help':
+ setShowHelp(true);
+ return { type: 'handled' };
+ case 'auth':
+ openAuthDialog();
+ return { type: 'handled' };
+ case 'theme':
+ openThemeDialog();
+ return { type: 'handled' };
+ case 'editor':
+ openEditorDialog();
+ return { type: 'handled' };
+ case 'privacy':
+ openPrivacyNotice();
+ return { type: 'handled' };
+ default: {
+ const unhandled: never = result.dialog;
+ throw new Error(
+ `Unhandled slash command result: ${unhandled}`,
+ );
+ }
}
+ case 'load_history': {
+ await config
+ ?.getGeminiClient()
+ ?.setHistory(result.clientHistory);
+ fullCommandContext.ui.clear();
+ result.history.forEach((item, index) => {
+ fullCommandContext.ui.addItem(item, index);
+ });
+ return { type: 'handled' };
}
- case 'load_history': {
- await config
- ?.getGeminiClient()
- ?.setHistory(result.clientHistory);
- fullCommandContext.ui.clear();
- result.history.forEach((item, index) => {
- fullCommandContext.ui.addItem(item, index);
- });
- return { type: 'handled' };
- }
- case 'quit':
- setQuittingMessages(result.messages);
- setTimeout(() => {
- process.exit(0);
- }, 100);
- return { type: 'handled' };
+ case 'quit':
+ setQuittingMessages(result.messages);
+ setTimeout(() => {
+ process.exit(0);
+ }, 100);
+ return { type: 'handled' };
- case 'submit_prompt':
- return {
- type: 'submit_prompt',
- content: result.content,
- };
- default: {
- const unhandled: never = result;
- throw new Error(`Unhandled slash command result: ${unhandled}`);
+ case 'submit_prompt':
+ return {
+ type: 'submit_prompt',
+ content: result.content,
+ };
+ default: {
+ const unhandled: never = result;
+ throw new Error(
+ `Unhandled slash command result: ${unhandled}`,
+ );
+ }
}
}
+ } catch (e) {
+ addItem(
+ {
+ type: MessageType.ERROR,
+ text: e instanceof Error ? e.message : String(e),
+ },
+ Date.now(),
+ );
+ return { type: 'handled' };
}
return { type: 'handled' };
diff --git a/packages/cli/src/ui/hooks/useCompletion.test.ts b/packages/cli/src/ui/hooks/useCompletion.test.ts
index cd525435..da6a7ab3 100644
--- a/packages/cli/src/ui/hooks/useCompletion.test.ts
+++ b/packages/cli/src/ui/hooks/useCompletion.test.ts
@@ -1100,7 +1100,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(0);
});
- expect(mockBuffer.setText).toHaveBeenCalledWith('/memory');
+ expect(mockBuffer.setText).toHaveBeenCalledWith('/memory ');
});
it('should append a sub-command when the parent is complete', () => {
@@ -1145,7 +1145,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(1); // index 1 is 'add'
});
- expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add');
+ expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add ');
});
it('should complete a command with an alternative name', () => {
@@ -1190,7 +1190,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(0);
});
- expect(mockBuffer.setText).toHaveBeenCalledWith('/help');
+ expect(mockBuffer.setText).toHaveBeenCalledWith('/help ');
});
it('should complete a file path', async () => {
diff --git a/packages/cli/src/ui/hooks/useCompletion.ts b/packages/cli/src/ui/hooks/useCompletion.ts
index dc45222d..10724c21 100644
--- a/packages/cli/src/ui/hooks/useCompletion.ts
+++ b/packages/cli/src/ui/hooks/useCompletion.ts
@@ -638,10 +638,17 @@ export function useCompletion(
// Determine the base path of the command.
// - If there's a trailing space, the whole command is the base.
// - If it's a known parent path, the whole command is the base.
+ // - If the last part is a complete argument, the whole command is the base.
// - Otherwise, the base is everything EXCEPT the last partial part.
+ const lastPart = parts.length > 0 ? parts[parts.length - 1] : '';
+ const isLastPartACompleteArg =
+ lastPart.startsWith('--') && lastPart.includes('=');
+
const basePath =
- hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1);
- const newValue = `/${[...basePath, suggestion].join(' ')}`;
+ hasTrailingSpace || isParentPath || isLastPartACompleteArg
+ ? parts
+ : parts.slice(0, -1);
+ const newValue = `/${[...basePath, suggestion].join(' ')} `;
buffer.setText(newValue);
} else {
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 96b6f2cb..7ccfdbc8 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -11,6 +11,7 @@ import {
ContentGeneratorConfig,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { LSTool } from '../tools/ls.js';
import { ReadFileTool } from '../tools/read-file.js';
@@ -186,6 +187,7 @@ export interface ConfigParameters {
export class Config {
private toolRegistry!: ToolRegistry;
+ private promptRegistry!: PromptRegistry;
private readonly sessionId: string;
private contentGeneratorConfig!: ContentGeneratorConfig;
private readonly embeddingModel: string;
@@ -314,6 +316,7 @@ export class Config {
if (this.getCheckpointingEnabled()) {
await this.getGitService();
}
+ this.promptRegistry = new PromptRegistry();
this.toolRegistry = await this.createToolRegistry();
}
@@ -396,6 +399,10 @@ export class Config {
return Promise.resolve(this.toolRegistry);
}
+ getPromptRegistry(): PromptRegistry {
+ return this.promptRegistry;
+ }
+
getDebugMode(): boolean {
return this.debugMode;
}
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index 9d87ce32..829de544 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -49,6 +49,9 @@ export * from './ide/ideContext.js';
export * from './tools/tools.js';
export * from './tools/tool-registry.js';
+// Export prompt logic
+export * from './prompts/mcp-prompts.js';
+
// Export specific tool logic
export * from './tools/read-file.js';
export * from './tools/ls.js';
diff --git a/packages/core/src/prompts/mcp-prompts.ts b/packages/core/src/prompts/mcp-prompts.ts
new file mode 100644
index 00000000..7265a023
--- /dev/null
+++ b/packages/core/src/prompts/mcp-prompts.ts
@@ -0,0 +1,19 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { Config } from '../config/config.js';
+import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
+
+export function getMCPServerPrompts(
+ config: Config,
+ serverName: string,
+): DiscoveredMCPPrompt[] {
+ const promptRegistry = config.getPromptRegistry();
+ if (!promptRegistry) {
+ return [];
+ }
+ return promptRegistry.getPromptsByServer(serverName);
+}
diff --git a/packages/core/src/prompts/prompt-registry.ts b/packages/core/src/prompts/prompt-registry.ts
new file mode 100644
index 00000000..56699130
--- /dev/null
+++ b/packages/core/src/prompts/prompt-registry.ts
@@ -0,0 +1,56 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
+
+export class PromptRegistry {
+ private prompts: Map<string, DiscoveredMCPPrompt> = new Map();
+
+ /**
+ * Registers a prompt definition.
+ * @param prompt - The prompt object containing schema and execution logic.
+ */
+ registerPrompt(prompt: DiscoveredMCPPrompt): void {
+ if (this.prompts.has(prompt.name)) {
+ const newName = `${prompt.serverName}_${prompt.name}`;
+ console.warn(
+ `Prompt with name "${prompt.name}" is already registered. Renaming to "${newName}".`,
+ );
+ this.prompts.set(newName, { ...prompt, name: newName });
+ } else {
+ this.prompts.set(prompt.name, prompt);
+ }
+ }
+
+ /**
+ * Returns an array of all registered and discovered prompt instances.
+ */
+ getAllPrompts(): DiscoveredMCPPrompt[] {
+ return Array.from(this.prompts.values()).sort((a, b) =>
+ a.name.localeCompare(b.name),
+ );
+ }
+
+ /**
+ * Get the definition of a specific prompt.
+ */
+ getPrompt(name: string): DiscoveredMCPPrompt | undefined {
+ return this.prompts.get(name);
+ }
+
+ /**
+ * Returns an array of prompts registered from a specific MCP server.
+ */
+ getPromptsByServer(serverName: string): DiscoveredMCPPrompt[] {
+ const serverPrompts: DiscoveredMCPPrompt[] = [];
+ for (const prompt of this.prompts.values()) {
+ if (prompt.serverName === serverName) {
+ serverPrompts.push(prompt);
+ }
+ }
+ return serverPrompts.sort((a, b) => a.name.localeCompare(b.name));
+ }
+}
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 428c9d2d..4560982c 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -11,6 +11,7 @@ import {
createTransport,
isEnabled,
discoverTools,
+ discoverPrompts,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
@@ -50,6 +52,77 @@ describe('mcp-client', () => {
});
});
+ describe('discoverPrompts', () => {
+ const mockedPromptRegistry = {
+ registerPrompt: vi.fn(),
+ } as unknown as PromptRegistry;
+
+ it('should discover and log prompts', async () => {
+ const mockRequest = vi.fn().mockResolvedValue({
+ prompts: [
+ { name: 'prompt1', description: 'desc1' },
+ { name: 'prompt2' },
+ ],
+ });
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledWith(
+ { method: 'prompts/list', params: {} },
+ expect.anything(),
+ );
+ });
+
+ it('should do nothing if no prompts are discovered', async () => {
+ const mockRequest = vi.fn().mockResolvedValue({
+ prompts: [],
+ });
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ const consoleLogSpy = vi
+ .spyOn(console, 'debug')
+ .mockImplementation(() => {
+ // no-op
+ });
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledOnce();
+ expect(consoleLogSpy).not.toHaveBeenCalled();
+
+ consoleLogSpy.mockRestore();
+ });
+
+ it('should log an error if discovery fails', async () => {
+ const testError = new Error('test error');
+ testError.message = 'test error';
+ const mockRequest = vi.fn().mockRejectedValue(testError);
+ const mockedClient = {
+ request: mockRequest,
+ } as unknown as ClientLib.Client;
+
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {
+ // no-op
+ });
+
+ await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
+
+ expect(mockRequest).toHaveBeenCalledOnce();
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ `Error discovering prompts from test-server: ${testError.message}`,
+ );
+
+ consoleErrorSpy.mockRestore();
+ });
+ });
+
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index c59b1592..d175af1f 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -15,12 +15,20 @@ import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
+import {
+ Prompt,
+ ListPromptsResultSchema,
+ GetPromptResult,
+ GetPromptResultSchema,
+} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
+
import { FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
+import { PromptRegistry } from '../prompts/prompt-registry.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
@@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
+export type DiscoveredMCPPrompt = Prompt & {
+ serverName: string;
+ invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
+};
+
/**
* Enum representing the connection status of an MCP server
*/
@@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
/**
* Map to track the status of each MCP server within the core package
*/
-const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
+const serverStatuses: Map<string, MCPServerStatus> = new Map();
/**
* Track the overall MCP discovery state
@@ -104,7 +117,7 @@ function updateMCPServerStatus(
serverName: string,
status: MCPServerStatus,
): void {
- mcpServerStatusesInternal.set(serverName, status);
+ serverStatuses.set(serverName, status);
// Notify all listeners
for (const listener of statusChangeListeners) {
listener(serverName, status);
@@ -115,16 +128,14 @@ function updateMCPServerStatus(
* Get the current status of an MCP server
*/
export function getMCPServerStatus(serverName: string): MCPServerStatus {
- return (
- mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
- );
+ return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
}
/**
* Get all MCP server statuses
*/
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
- return new Map(mcpServerStatusesInternal);
+ return new Map(serverStatuses);
}
/**
@@ -307,6 +318,7 @@ export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
@@ -319,6 +331,7 @@ export async function discoverMcpTools(
mcpServerName,
mcpServerConfig,
toolRegistry,
+ promptRegistry,
debugMode,
),
);
@@ -362,6 +375,7 @@ export async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry,
+ promptRegistry: PromptRegistry,
debugMode: boolean,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -378,6 +392,7 @@ export async function connectAndDiscover(
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
};
+ await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
const tools = await discoverTools(
mcpServerName,
@@ -393,7 +408,9 @@ export async function connectAndDiscover(
}
} catch (error) {
console.error(
- `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
+ `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
+ error,
+ )}`,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
@@ -441,9 +458,6 @@ export async function discoverTools(
),
);
}
- if (discoveredTools.length === 0) {
- throw Error('No enabled tools found');
- }
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
@@ -451,6 +465,91 @@ export async function discoverTools(
}
/**
+ * Discovers and logs prompts from a connected MCP client.
+ * It retrieves prompt declarations from the client and logs their names.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ */
+export async function discoverPrompts(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptRegistry: PromptRegistry,
+): Promise<void> {
+ try {
+ const response = await mcpClient.request(
+ { method: 'prompts/list', params: {} },
+ ListPromptsResultSchema,
+ );
+
+ for (const prompt of response.prompts) {
+ promptRegistry.registerPrompt({
+ ...prompt,
+ serverName: mcpServerName,
+ invoke: (params: Record<string, unknown>) =>
+ invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
+ });
+ }
+ } catch (error) {
+ // It's okay if this fails, not all servers will have prompts.
+ // Don't log an error if the method is not found, which is a common case.
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ }
+}
+
+/**
+ * Invokes a prompt on a connected MCP client.
+ *
+ * @param mcpServerName The name of the MCP server.
+ * @param mcpClient The active MCP client instance.
+ * @param promptName The name of the prompt to invoke.
+ * @param promptParams The parameters to pass to the prompt.
+ * @returns A promise that resolves to the result of the prompt invocation.
+ */
+export async function invokeMcpPrompt(
+ mcpServerName: string,
+ mcpClient: Client,
+ promptName: string,
+ promptParams: Record<string, unknown>,
+): Promise<GetPromptResult> {
+ try {
+ const response = await mcpClient.request(
+ {
+ method: 'prompts/get',
+ params: {
+ name: promptName,
+ arguments: promptParams,
+ },
+ },
+ GetPromptResultSchema,
+ );
+
+ return response;
+ } catch (error) {
+ if (
+ error instanceof Error &&
+ !error.message?.includes('Method not found')
+ ) {
+ console.error(
+ `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
+ error,
+ )}`,
+ );
+ }
+ throw error;
+ }
+}
+
+/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
* establishes a connection. It also applies a patch to handle request timeouts.
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index de355a98..b3fdd7a3 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -344,6 +344,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
+ undefined,
false,
);
});
@@ -366,6 +367,7 @@ describe('ToolRegistry', () => {
mcpServerConfigVal,
undefined,
toolRegistry,
+ undefined,
false,
);
});
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index b72ed9a5..57627ee0 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -170,6 +170,7 @@ export class ToolRegistry {
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
@@ -192,6 +193,7 @@ export class ToolRegistry {
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}
@@ -215,6 +217,7 @@ export class ToolRegistry {
{ [serverName]: serverConfig },
undefined,
this,
+ this.config.getPromptRegistry(),
this.config.getDebugMode(),
);
}