diff options
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.test.ts | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 1ccba76a..d37c6eae 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -13,6 +13,7 @@ import { discoverTools, discoverPrompts, hasValidTypes, + connectToMcpServer, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; +import { pathToFileURL } from 'node:url'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); @@ -276,6 +279,56 @@ describe('mcp-client', () => { }); }); + describe('connectToMcpServer', () => { + it('should register a roots/list handler', async () => { + const mockedClient = { + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + callTool: vi.fn(), + connect: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + const mockWorkspaceContext = { + getDirectories: vi + .fn() + .mockReturnValue(['/test/dir', '/another/project']), + } as unknown as WorkspaceContext; + + await connectToMcpServer( + 'test-server', + { + command: 'test-command', + }, + false, + mockWorkspaceContext, + ); + + expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({ + roots: {}, + }); + expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce(); + const handler = mockedClient.setRequestHandler.mock.calls[0][1]; + const roots = await handler(); + expect(roots).toEqual({ + roots: [ + { + uri: pathToFileURL('/test/dir').toString(), + name: 'dir', + }, + { + uri: pathToFileURL('/another/project').toString(), + name: 'project', + }, + ], + }); + }); + }); + describe('discoverPrompts', () => { const mockedPromptRegistry = { registerPrompt: vi.fn(), @@ -486,7 +539,9 @@ describe('mcp-client', () => { }); it('should connect via command', async () => { - const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); + const mockedTransport = vi + .spyOn(SdkClientStdioLib, 'StdioClientTransport') + .mockReturnValue({} as SdkClientStdioLib.StdioClientTransport); await createTransport( 'test-server', |
