summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.test.ts80
-rw-r--r--packages/cli/src/ui/commands/mcpCommand.ts47
-rw-r--r--packages/core/src/config/config.test.ts2
-rw-r--r--packages/core/src/config/config.ts2
-rw-r--r--packages/core/src/tools/tool-registry.test.ts6
-rw-r--r--packages/core/src/tools/tool-registry.ts25
6 files changed, 155 insertions, 7 deletions
diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts
index e52cb9df..2b8753a0 100644
--- a/packages/cli/src/ui/commands/mcpCommand.test.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.test.ts
@@ -976,4 +976,84 @@ describe('mcpCommand', () => {
}
});
});
+
+ describe('refresh subcommand', () => {
+ it('should refresh the list of tools and display the status', async () => {
+ const mockToolRegistry = {
+ discoverMcpTools: vi.fn(),
+ getAllTools: vi.fn().mockReturnValue([]),
+ };
+ const mockGeminiClient = {
+ setTools: vi.fn(),
+ };
+
+ const context = createMockCommandContext({
+ services: {
+ config: {
+ getMcpServers: vi.fn().mockReturnValue({ server1: {} }),
+ getBlockedMcpServers: vi.fn().mockReturnValue([]),
+ getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
+ getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
+ },
+ },
+ });
+
+ const refreshCommand = mcpCommand.subCommands?.find(
+ (cmd) => cmd.name === 'refresh',
+ );
+ expect(refreshCommand).toBeDefined();
+
+ const result = await refreshCommand!.action!(context, '');
+
+ expect(context.ui.addItem).toHaveBeenCalledWith(
+ {
+ type: 'info',
+ text: 'Refreshing MCP servers and tools...',
+ },
+ expect.any(Number),
+ );
+ expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled();
+ expect(mockGeminiClient.setTools).toHaveBeenCalled();
+
+ expect(isMessageAction(result)).toBe(true);
+ if (isMessageAction(result)) {
+ expect(result.messageType).toBe('info');
+ expect(result.content).toContain('Configured MCP servers:');
+ }
+ });
+
+ it('should show an error if config is not available', async () => {
+ const contextWithoutConfig = createMockCommandContext({
+ services: {
+ config: null,
+ },
+ });
+
+ const refreshCommand = mcpCommand.subCommands?.find(
+ (cmd) => cmd.name === 'refresh',
+ );
+ const result = await refreshCommand!.action!(contextWithoutConfig, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Config not loaded.',
+ });
+ });
+
+ it('should show an error if tool registry is not available', async () => {
+ mockConfig.getToolRegistry = vi.fn().mockResolvedValue(undefined);
+
+ const refreshCommand = mcpCommand.subCommands?.find(
+ (cmd) => cmd.name === 'refresh',
+ );
+ const result = await refreshCommand!.action!(mockContext, '');
+
+ expect(result).toEqual({
+ type: 'message',
+ messageType: 'error',
+ content: 'Could not retrieve tool registry.',
+ });
+ });
+ });
});
diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts
index c33a25d1..5467b994 100644
--- a/packages/cli/src/ui/commands/mcpCommand.ts
+++ b/packages/cli/src/ui/commands/mcpCommand.ts
@@ -417,12 +417,57 @@ const listCommand: SlashCommand = {
},
};
+const refreshCommand: SlashCommand = {
+ name: 'refresh',
+ description: 'Refresh the list of MCP servers and tools',
+ kind: CommandKind.BUILT_IN,
+ action: async (
+ context: CommandContext,
+ ): Promise<SlashCommandActionReturn> => {
+ const { config } = context.services;
+ if (!config) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Config not loaded.',
+ };
+ }
+
+ const toolRegistry = await config.getToolRegistry();
+ if (!toolRegistry) {
+ return {
+ type: 'message',
+ messageType: 'error',
+ content: 'Could not retrieve tool registry.',
+ };
+ }
+
+ context.ui.addItem(
+ {
+ type: 'info',
+ text: 'Refreshing MCP servers and tools...',
+ },
+ Date.now(),
+ );
+
+ await toolRegistry.discoverMcpTools();
+
+ // Update the client with the new tools
+ const geminiClient = config.getGeminiClient();
+ if (geminiClient) {
+ await geminiClient.setTools();
+ }
+
+ return getMcpStatus(context, false, false, false);
+ },
+};
+
export const mcpCommand: SlashCommand = {
name: 'mcp',
description:
'list configured MCP servers and tools, or authenticate with OAuth-enabled servers',
kind: CommandKind.BUILT_IN,
- subCommands: [listCommand, authCommand],
+ subCommands: [listCommand, authCommand, refreshCommand],
// Default action when no subcommand is provided
action: async (context: CommandContext, args: string) =>
// If no subcommand, run the list command
diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts
index 9fec505f..3f0b3db5 100644
--- a/packages/core/src/config/config.test.ts
+++ b/packages/core/src/config/config.test.ts
@@ -23,7 +23,7 @@ import { GitService } from '../services/gitService.js';
vi.mock('../tools/tool-registry', () => {
const ToolRegistryMock = vi.fn();
ToolRegistryMock.prototype.registerTool = vi.fn();
- ToolRegistryMock.prototype.discoverTools = vi.fn();
+ ToolRegistryMock.prototype.discoverAllTools = vi.fn();
ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed
ToolRegistryMock.prototype.getTool = vi.fn();
ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []);
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 231bbcd5..485a56c4 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -630,7 +630,7 @@ export class Config {
registerCoreTool(MemoryTool);
registerCoreTool(WebSearchTool, this);
- await registry.discoverTools();
+ await registry.discoverAllTools();
return registry;
}
}
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts
index ab337252..de355a98 100644
--- a/packages/core/src/tools/tool-registry.test.ts
+++ b/packages/core/src/tools/tool-registry.test.ts
@@ -312,7 +312,7 @@ describe('ToolRegistry', () => {
return mockChildProcess as any;
});
- await toolRegistry.discoverTools();
+ await toolRegistry.discoverAllTools();
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
expect(discoveredTool).toBeDefined();
@@ -338,7 +338,7 @@ describe('ToolRegistry', () => {
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
- await toolRegistry.discoverTools();
+ await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
@@ -360,7 +360,7 @@ describe('ToolRegistry', () => {
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
- await toolRegistry.discoverTools();
+ await toolRegistry.discoverAllTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts
index a6742c06..b72ed9a5 100644
--- a/packages/core/src/tools/tool-registry.ts
+++ b/packages/core/src/tools/tool-registry.ts
@@ -153,8 +153,9 @@ export class ToolRegistry {
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.
+ * This will discover tools from the command line and from MCP servers.
*/
- async discoverTools(): Promise<void> {
+ async discoverAllTools(): Promise<void> {
// remove any previously discovered tools
for (const tool of this.tools.values()) {
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
@@ -174,6 +175,28 @@ export class ToolRegistry {
}
/**
+ * Discovers tools from project (if available and configured).
+ * Can be called multiple times to update discovered tools.
+ * This will NOT discover tools from the command line, only from MCP servers.
+ */
+ async discoverMcpTools(): Promise<void> {
+ // remove any previously discovered tools
+ for (const tool of this.tools.values()) {
+ if (tool instanceof DiscoveredMCPTool) {
+ this.tools.delete(tool.name);
+ }
+ }
+
+ // discover tools using MCP servers, if configured
+ await discoverMcpTools(
+ this.config.getMcpServers() ?? {},
+ this.config.getMcpServerCommand(),
+ this,
+ this.config.getDebugMode(),
+ );
+ }
+
+ /**
* Discover or re-discover tools for a single MCP server.
* @param serverName - The name of the server to discover tools from.
*/