summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts57
1 files changed, 56 insertions, 1 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 1ccba76a..d37c6eae 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -13,6 +13,7 @@ import {
discoverTools,
discoverPrompts,
hasValidTypes,
+ connectToMcpServer,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
+import { WorkspaceContext } from '../utils/workspaceContext.js';
+import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
@@ -276,6 +279,56 @@ describe('mcp-client', () => {
});
});
+ describe('connectToMcpServer', () => {
+ it('should register a roots/list handler', async () => {
+ const mockedClient = {
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ callTool: vi.fn(),
+ connect: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
+ const mockWorkspaceContext = {
+ getDirectories: vi
+ .fn()
+ .mockReturnValue(['/test/dir', '/another/project']),
+ } as unknown as WorkspaceContext;
+
+ await connectToMcpServer(
+ 'test-server',
+ {
+ command: 'test-command',
+ },
+ false,
+ mockWorkspaceContext,
+ );
+
+ expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
+ roots: {},
+ });
+ expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
+ const handler = mockedClient.setRequestHandler.mock.calls[0][1];
+ const roots = await handler();
+ expect(roots).toEqual({
+ roots: [
+ {
+ uri: pathToFileURL('/test/dir').toString(),
+ name: 'dir',
+ },
+ {
+ uri: pathToFileURL('/another/project').toString(),
+ name: 'project',
+ },
+ ],
+ });
+ });
+ });
+
describe('discoverPrompts', () => {
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
@@ -486,7 +539,9 @@ describe('mcp-client', () => {
});
it('should connect via command', async () => {
- const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
+ const mockedTransport = vi
+ .spyOn(SdkClientStdioLib, 'StdioClientTransport')
+ .mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport(
'test-server',