summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts465
1 files changed, 82 insertions, 383 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index 3467ad95..b8f61856 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -4,16 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
+import { afterEach, describe, expect, it, vi } from 'vitest';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
populateMcpServerCommand,
createTransport,
isEnabled,
- discoverTools,
- discoverPrompts,
hasValidTypes,
- connectToMcpServer,
+ McpClient,
} from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
-
-import { DiscoveredMCPTool } from './mcp-tool.js';
+import { ToolRegistry } from './tool-registry.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
-import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
-vi.mock('./mcp-tool.js');
describe('mcp-client', () => {
afterEach(() => {
vi.restoreAllMocks();
});
- describe('discoverTools', () => {
+ describe('McpClient', () => {
it('should discover tools', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
+ const mockedClient = {
+ connect: vi.fn(),
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
@@ -51,62 +59,43 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(mockedMcpToTool).toHaveBeenCalledOnce();
- });
-
- it('should log an error if there is an error discovering a tool', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- const testError = new Error('Invalid tool name');
- vi.mocked(DiscoveredMCPTool).mockImplementation(
- (
- _mcpCallableTool: GenAiLib.CallableTool,
- _serverName: string,
- name: string,
- ) => {
- if (name === 'invalid tool name') {
- throw testError;
- }
- return { name: 'validTool' } as DiscoveredMCPTool;
+ const mockedToolRegistry = {
+ registerTool: vi.fn(),
+ } as unknown as ToolRegistry;
+ const client = new McpClient(
+ 'test-server',
+ {
+ command: 'test-command',
},
+ mockedToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
+ false,
);
-
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- },
- {
- name: 'invalid tool name', // this will fail validation
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(tools[0].name).toBe('validTool');
- expect(consoleErrorSpy).toHaveBeenCalledOnce();
- expect(consoleErrorSpy).toHaveBeenCalledWith(
- `Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
- );
+ await client.connect();
+ await client.discover();
+ expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
it('should skip tools if a parameter is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
+ const mockedClient = {
+ connect: vi.fn(),
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
+ registerCapabilities: vi.fn(),
+ setRequestHandler: vi.fn(),
+ tool: vi.fn(),
+ };
+ vi.mocked(ClientLib.Client).mockReturnValue(
+ mockedClient as unknown as ClientLib.Client,
+ );
+ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
+ {} as SdkClientStdioLib.StdioClientTransport,
+ );
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
@@ -132,89 +121,22 @@ describe('mcp-client', () => {
],
}),
} as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).toHaveBeenCalledOnce();
- expect(consoleWarnSpy).toHaveBeenCalledWith(
- `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
- `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
- );
- consoleWarnSpy.mockRestore();
- });
-
- it('should skip tools if a nested parameter is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'invalidTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {
- param1: {
- type: 'object',
- properties: {
- nestedParam: {
- description: 'a nested param with no type',
- },
- },
- },
- },
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(0);
- expect(consoleWarnSpy).toHaveBeenCalledOnce();
- expect(consoleWarnSpy).toHaveBeenCalledWith(
- `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
- `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
+ const mockedToolRegistry = {
+ registerTool: vi.fn(),
+ } as unknown as ToolRegistry;
+ const client = new McpClient(
+ 'test-server',
+ {
+ command: 'test-command',
+ },
+ mockedToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
+ false,
);
- consoleWarnSpy.mockRestore();
- });
-
- it('should skip tool if an array item is missing a type', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'invalidTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {
- param1: {
- type: 'array',
- items: {
- description: 'an array item with no type',
- },
- },
- },
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(0);
+ await client.connect();
+ await client.discover();
+ expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
@@ -223,109 +145,19 @@ describe('mcp-client', () => {
consoleWarnSpy.mockRestore();
});
- it('should discover tool with no properties in schema', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
- .mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- parametersJsonSchema: {
- type: 'object',
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).not.toHaveBeenCalled();
- consoleWarnSpy.mockRestore();
- });
-
- it('should discover tool with empty properties object in schema', async () => {
- const mockedClient = {} as unknown as ClientLib.Client;
- const consoleWarnSpy = vi
- .spyOn(console, 'warn')
+ it('should handle errors when discovering prompts', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
.mockImplementation(() => {});
- vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
- tool: () =>
- Promise.resolve({
- functionDeclarations: [
- {
- name: 'validTool',
- parametersJsonSchema: {
- type: 'object',
- properties: {},
- },
- },
- ],
- }),
- } as unknown as GenAiLib.CallableTool);
-
- const tools = await discoverTools('test-server', {}, mockedClient);
-
- expect(tools.length).toBe(1);
- expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
- expect(consoleWarnSpy).not.toHaveBeenCalled();
- consoleWarnSpy.mockRestore();
- });
- });
-
- describe('connectToMcpServer', () => {
- it('should send a notification when directories change', async () => {
const mockedClient = {
- registerCapabilities: vi.fn(),
- setRequestHandler: vi.fn(),
- notification: vi.fn(),
- callTool: vi.fn(),
connect: vi.fn(),
- };
- vi.mocked(ClientLib.Client).mockReturnValue(
- mockedClient as unknown as ClientLib.Client,
- );
- vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
- {} as SdkClientStdioLib.StdioClientTransport,
- );
- let onDirectoriesChangedCallback: () => void = () => {};
- const mockWorkspaceContext = {
- getDirectories: vi
- .fn()
- .mockReturnValue(['/test/dir', '/another/project']),
- onDirectoriesChanged: vi.fn().mockImplementation((callback) => {
- onDirectoriesChangedCallback = callback;
- }),
- } as unknown as WorkspaceContext;
-
- await connectToMcpServer(
- 'test-server',
- {
- command: 'test-command',
- },
- false,
- mockWorkspaceContext,
- );
-
- onDirectoriesChangedCallback();
-
- expect(mockedClient.notification).toHaveBeenCalledWith({
- method: 'notifications/roots/list_changed',
- });
- });
-
- it('should register a roots/list handler', async () => {
- const mockedClient = {
+ discover: vi.fn(),
+ disconnect: vi.fn(),
+ getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
- callTool: vi.fn(),
- connect: vi.fn(),
+ getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
+ request: vi.fn().mockRejectedValue(new Error('Test error')),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -333,151 +165,29 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
- const mockWorkspaceContext = {
- getDirectories: vi
- .fn()
- .mockReturnValue(['/test/dir', '/another/project']),
- onDirectoriesChanged: vi.fn(),
- } as unknown as WorkspaceContext;
-
- await connectToMcpServer(
+ vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
+ tool: () => Promise.resolve({ functionDeclarations: [] }),
+ } as unknown as GenAiLib.CallableTool);
+ const client = new McpClient(
'test-server',
{
command: 'test-command',
},
+ {} as ToolRegistry,
+ {} as PromptRegistry,
+ {} as WorkspaceContext,
false,
- mockWorkspaceContext,
);
-
- expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
- roots: {
- listChanged: true,
- },
- });
- expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
- const handler = mockedClient.setRequestHandler.mock.calls[0][1];
- const roots = await handler();
- expect(roots).toEqual({
- roots: [
- {
- uri: pathToFileURL('/test/dir').toString(),
- name: 'dir',
- },
- {
- uri: pathToFileURL('/another/project').toString(),
- name: 'project',
- },
- ],
- });
- });
- });
-
- describe('discoverPrompts', () => {
- const mockedPromptRegistry = {
- registerPrompt: vi.fn(),
- } as unknown as PromptRegistry;
-
- it('should discover and log prompts', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [
- { name: 'prompt1', description: 'desc1' },
- { name: 'prompt2' },
- ],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).toHaveBeenCalledWith(
- { method: 'prompts/list', params: {} },
- expect.anything(),
+ await client.connect();
+ await expect(client.discover()).rejects.toThrow(
+ 'No prompts or tools found on the server.',
);
- });
-
- it('should do nothing if no prompts are discovered', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
-
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleLogSpy = vi
- .spyOn(console, 'debug')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).toHaveBeenCalledOnce();
- expect(consoleLogSpy).not.toHaveBeenCalled();
-
- consoleLogSpy.mockRestore();
- });
-
- it('should do nothing if the server has no prompt support', async () => {
- const mockRequest = vi.fn().mockResolvedValue({
- prompts: [],
- });
- const mockGetServerCapabilities = vi.fn().mockReturnValue({});
-
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleLogSpy = vi
- .spyOn(console, 'debug')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
- expect(mockRequest).not.toHaveBeenCalled();
- expect(consoleLogSpy).not.toHaveBeenCalled();
-
- consoleLogSpy.mockRestore();
- });
-
- it('should log an error if discovery fails', async () => {
- const testError = new Error('test error');
- testError.message = 'test error';
- const mockRequest = vi.fn().mockRejectedValue(testError);
- const mockGetServerCapabilities = vi.fn().mockReturnValue({
- prompts: {},
- });
- const mockedClient = {
- getServerCapabilities: mockGetServerCapabilities,
- request: mockRequest,
- } as unknown as ClientLib.Client;
-
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
-
- expect(mockRequest).toHaveBeenCalledOnce();
expect(consoleErrorSpy).toHaveBeenCalledWith(
- `Error discovering prompts from test-server: ${testError.message}`,
+ `Error discovering prompts from test-server: Test error`,
);
-
consoleErrorSpy.mockRestore();
});
});
-
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
const out = populateMcpServerCommand({}, undefined);
@@ -501,17 +211,6 @@ describe('mcp-client', () => {
});
describe('createTransport', () => {
- const originalEnv = process.env;
-
- beforeEach(() => {
- vi.resetModules();
- process.env = {};
- });
-
- afterEach(() => {
- process.env = originalEnv;
- });
-
describe('should connect via httpUrl', () => {
it('without headers', async () => {
const transport = await createTransport(
@@ -601,7 +300,7 @@ describe('mcp-client', () => {
command: 'test-command',
args: ['--foo', 'bar'],
cwd: 'test/cwd',
- env: { FOO: 'bar' },
+ env: { ...process.env, FOO: 'bar' },
stderr: 'pipe',
});
});