diff options
Diffstat (limited to 'packages/core/src/tools/tool-registry.test.ts')
| -rw-r--r-- | packages/core/src/tools/tool-registry.test.ts | 50 |
1 files changed, 12 insertions, 38 deletions
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 13dff08c..cccf011f 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -23,15 +23,17 @@ import { spawn } from 'node:child_process'; import fs from 'node:fs'; import { MockTool } from '../test-utils/tools.js'; -vi.mock('node:fs'); +import { McpClientManager } from './mcp-client-manager.js'; -// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory -const mockDiscoverMcpTools = vi.hoisted(() => vi.fn()); +vi.mock('node:fs'); // Mock ./mcp-client.js to control its behavior within tool-registry tests -vi.mock('./mcp-client.js', () => ({ - discoverMcpTools: mockDiscoverMcpTools, -})); +vi.mock('./mcp-client.js', async () => { + const originalModule = await vi.importActual('./mcp-client.js'); + return { + ...originalModule, + }; +}); // Mock node:child_process vi.mock('node:child_process', async () => { @@ -143,7 +145,6 @@ describe('ToolRegistry', () => { clear: vi.fn(), removePromptsByServer: vi.fn(), } as any); - mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined); }); afterEach(() => { @@ -311,30 +312,10 @@ describe('ToolRegistry', () => { }); it('should discover tools using MCP servers defined in getMcpServers', async () => { - mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); - const mcpServerConfigVal = { - 'my-mcp-server': { - command: 'mcp-server-cmd', - args: ['--port', '1234'], - trust: true, - }, - }; - vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); - - await toolRegistry.discoverAllTools(); - - expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - mcpServerConfigVal, - undefined, - toolRegistry, - config.getPromptRegistry(), - false, - expect.any(Object), + const discoverSpy = vi.spyOn( + McpClientManager.prototype, + 'discoverAllMcpTools', ); - }); - - it('should discover tools using MCP servers defined in getMcpServers', async () => { mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); const mcpServerConfigVal = { @@ -348,14 +329,7 @@ describe('ToolRegistry', () => { await toolRegistry.discoverAllTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - mcpServerConfigVal, - undefined, - toolRegistry, - config.getPromptRegistry(), - false, - expect.any(Object), - ); + expect(discoverSpy).toHaveBeenCalled(); }); }); }); |
