diff options
| -rw-r--r-- | packages/cli/src/ui/commands/mcpCommand.test.ts | 9 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/mcpCommand.ts | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/types.ts | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 9 | ||||
| -rw-r--r-- | packages/core/src/tools/tool-registry.ts | 12 |
5 files changed, 36 insertions, 1 deletions
diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index ad04cb69..0f339665 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -881,9 +881,14 @@ describe('mcpCommand', () => { }), getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), + getPromptRegistry: vi.fn().mockResolvedValue({ + removePromptsByServer: vi.fn(), + }), }, }, }); + // Mock the reloadCommands function + context.ui.reloadCommands = vi.fn(); const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); @@ -901,6 +906,7 @@ describe('mcpCommand', () => { 'test-server', ); expect(mockGeminiClient.setTools).toHaveBeenCalled(); + expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1); expect(isMessageAction(result)).toBe(true); if (isMessageAction(result)) { @@ -985,6 +991,8 @@ describe('mcpCommand', () => { }, }, }); + // Mock the reloadCommands function, which is new logic. + context.ui.reloadCommands = vi.fn(); const refreshCommand = mcpCommand.subCommands?.find( (cmd) => cmd.name === 'refresh', @@ -1002,6 +1010,7 @@ describe('mcpCommand', () => { ); expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled(); expect(mockGeminiClient.setTools).toHaveBeenCalled(); + expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1); expect(isMessageAction(result)).toBe(true); if (isMessageAction(result)) { diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 11c71f1a..686102be 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -417,6 +417,9 @@ const authCommand: SlashCommand = { await geminiClient.setTools(); } + // Reload the slash commands to reflect the changes. + context.ui.reloadCommands(); + return { type: 'message', messageType: 'info', @@ -507,6 +510,9 @@ const refreshCommand: SlashCommand = { await geminiClient.setTools(); } + // Reload the slash commands to reflect the changes. + context.ui.reloadCommands(); + return getMcpStatus(context, false, false, false); }, }; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index d4f0b454..876409d0 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -61,6 +61,7 @@ export interface CommandContext { toggleCorgiMode: () => void; toggleVimEnabled: () => Promise<boolean>; setGeminiMdFileCount: (count: number) => void; + reloadCommands: () => void; }; // Session-specific data session: { diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index b4ce0d4d..32f55de2 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -57,6 +57,11 @@ export const useSlashCommandProcessor = ( ) => { const session = useSessionStats(); const [commands, setCommands] = useState<readonly SlashCommand[]>([]); + const [reloadTrigger, setReloadTrigger] = useState(0); + + const reloadCommands = useCallback(() => { + setReloadTrigger((v) => v + 1); + }, []); const [shellConfirmationRequest, setShellConfirmationRequest] = useState<null | { commands: string[]; @@ -172,6 +177,7 @@ export const useSlashCommandProcessor = ( toggleCorgiMode, toggleVimEnabled, setGeminiMdFileCount, + reloadCommands, }, session: { stats: session.stats, @@ -195,6 +201,7 @@ export const useSlashCommandProcessor = ( toggleVimEnabled, sessionShellAllowlist, setGeminiMdFileCount, + reloadCommands, ], ); @@ -220,7 +227,7 @@ export const useSlashCommandProcessor = ( return () => { controller.abort(); }; - }, [config, ideMode]); + }, [config, ideMode, reloadTrigger]); const handleSlashCommand = useCallback( async ( diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 02f77727..b3625285 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -159,6 +159,18 @@ export class ToolRegistry { } /** + * Removes all tools from a specific MCP server. + * @param serverName The name of the server to remove tools from. + */ + removeMcpToolsByServer(serverName: string): void { + for (const [name, tool] of this.tools.entries()) { + if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) { + this.tools.delete(name); + } + } + } + + /** * Discovers tools from project (if available and configured). * Can be called multiple times to update discovered tools. * This will discover tools from the command line and from MCP servers. |
