diff options
| author | Ramón Medrano Llamas <[email protected]> | 2025-08-19 21:03:19 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-08-19 19:03:19 +0000 |
| commit | b24c5887c45edde8690b4d73d8961e63eee13a34 (patch) | |
| tree | 6136f1f6bcc61801edb9f6d6411966b3b6678984 /packages/core/src/tools/mcp-client.test.ts | |
| parent | 4828e4daf198a675ce118cec08dcfbd0bfbb28a6 (diff) | |
feat: restart MCP servers on /mcp refresh (#5479)
Co-authored-by: Brian Ray <[email protected]>
Co-authored-by: N. Taylor Mullen <[email protected]>
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
| -rw-r--r-- | packages/core/src/tools/mcp-client.test.ts | 465 |
1 files changed, 82 insertions, 383 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 3467ad95..b8f61856 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,16 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest'; +import { afterEach, describe, expect, it, vi } from 'vitest'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { populateMcpServerCommand, createTransport, isEnabled, - discoverTools, - discoverPrompts, hasValidTypes, - connectToMcpServer, + McpClient, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { AuthProviderType } from '../config/config.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; - -import { DiscoveredMCPTool } from './mcp-tool.js'; +import { ToolRegistry } from './tool-registry.js'; import { WorkspaceContext } from '../utils/workspaceContext.js'; -import { pathToFileURL } from 'node:url'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@google/genai'); vi.mock('../mcp/oauth-provider.js'); vi.mock('../mcp/oauth-token-storage.js'); -vi.mock('./mcp-tool.js'); describe('mcp-client', () => { afterEach(() => { vi.restoreAllMocks(); }); - describe('discoverTools', () => { + describe('McpClient', () => { it('should discover tools', async () => { - const mockedClient = {} as unknown as ClientLib.Client; + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => ({ functionDeclarations: [ @@ -51,62 +59,43 @@ describe('mcp-client', () => { ], }), } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(mockedMcpToTool).toHaveBeenCalledOnce(); - }); - - it('should log an error if there is an error discovering a tool', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - const testError = new Error('Invalid tool name'); - vi.mocked(DiscoveredMCPTool).mockImplementation( - ( - _mcpCallableTool: GenAiLib.CallableTool, - _serverName: string, - name: string, - ) => { - if (name === 'invalid tool name') { - throw testError; - } - return { name: 'validTool' } as DiscoveredMCPTool; + const mockedToolRegistry = { + registerTool: vi.fn(), + } as unknown as ToolRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', }, + mockedToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, + false, ); - - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - }, - { - name: 'invalid tool name', // this will fail validation - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(tools[0].name).toBe('validTool'); - expect(consoleErrorSpy).toHaveBeenCalledOnce(); - expect(consoleErrorSpy).toHaveBeenCalledWith( - `Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`, - ); + await client.connect(); + await client.discover(); + expect(mockedMcpToTool).toHaveBeenCalledOnce(); }); it('should skip tools if a parameter is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; const consoleWarnSpy = vi .spyOn(console, 'warn') .mockImplementation(() => {}); + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + tool: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => Promise.resolve({ @@ -132,89 +121,22 @@ describe('mcp-client', () => { ], }), } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).toHaveBeenCalledOnce(); - expect(consoleWarnSpy).toHaveBeenCalledWith( - `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + - `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, - ); - consoleWarnSpy.mockRestore(); - }); - - it('should skip tools if a nested parameter is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'invalidTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { - type: 'object', - properties: { - nestedParam: { - description: 'a nested param with no type', - }, - }, - }, - }, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(0); - expect(consoleWarnSpy).toHaveBeenCalledOnce(); - expect(consoleWarnSpy).toHaveBeenCalledWith( - `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + - `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, + const mockedToolRegistry = { + registerTool: vi.fn(), + } as unknown as ToolRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + mockedToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, + false, ); - consoleWarnSpy.mockRestore(); - }); - - it('should skip tool if an array item is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'invalidTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { - type: 'array', - items: { - description: 'an array item with no type', - }, - }, - }, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(0); + await client.connect(); + await client.discover(); + expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(consoleWarnSpy).toHaveBeenCalledOnce(); expect(consoleWarnSpy).toHaveBeenCalledWith( `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + @@ -223,109 +145,19 @@ describe('mcp-client', () => { consoleWarnSpy.mockRestore(); }); - it('should discover tool with no properties in schema', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - parametersJsonSchema: { - type: 'object', - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).not.toHaveBeenCalled(); - consoleWarnSpy.mockRestore(); - }); - - it('should discover tool with empty properties object in schema', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') + it('should handle errors when discovering prompts', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - parametersJsonSchema: { - type: 'object', - properties: {}, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).not.toHaveBeenCalled(); - consoleWarnSpy.mockRestore(); - }); - }); - - describe('connectToMcpServer', () => { - it('should send a notification when directories change', async () => { const mockedClient = { - registerCapabilities: vi.fn(), - setRequestHandler: vi.fn(), - notification: vi.fn(), - callTool: vi.fn(), connect: vi.fn(), - }; - vi.mocked(ClientLib.Client).mockReturnValue( - mockedClient as unknown as ClientLib.Client, - ); - vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( - {} as SdkClientStdioLib.StdioClientTransport, - ); - let onDirectoriesChangedCallback: () => void = () => {}; - const mockWorkspaceContext = { - getDirectories: vi - .fn() - .mockReturnValue(['/test/dir', '/another/project']), - onDirectoriesChanged: vi.fn().mockImplementation((callback) => { - onDirectoriesChangedCallback = callback; - }), - } as unknown as WorkspaceContext; - - await connectToMcpServer( - 'test-server', - { - command: 'test-command', - }, - false, - mockWorkspaceContext, - ); - - onDirectoriesChangedCallback(); - - expect(mockedClient.notification).toHaveBeenCalledWith({ - method: 'notifications/roots/list_changed', - }); - }); - - it('should register a roots/list handler', async () => { - const mockedClient = { + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), - callTool: vi.fn(), - connect: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), + request: vi.fn().mockRejectedValue(new Error('Test error')), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -333,151 +165,29 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - const mockWorkspaceContext = { - getDirectories: vi - .fn() - .mockReturnValue(['/test/dir', '/another/project']), - onDirectoriesChanged: vi.fn(), - } as unknown as WorkspaceContext; - - await connectToMcpServer( + vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ + tool: () => Promise.resolve({ functionDeclarations: [] }), + } as unknown as GenAiLib.CallableTool); + const client = new McpClient( 'test-server', { command: 'test-command', }, + {} as ToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, false, - mockWorkspaceContext, ); - - expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({ - roots: { - listChanged: true, - }, - }); - expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce(); - const handler = mockedClient.setRequestHandler.mock.calls[0][1]; - const roots = await handler(); - expect(roots).toEqual({ - roots: [ - { - uri: pathToFileURL('/test/dir').toString(), - name: 'dir', - }, - { - uri: pathToFileURL('/another/project').toString(), - name: 'project', - }, - ], - }); - }); - }); - - describe('discoverPrompts', () => { - const mockedPromptRegistry = { - registerPrompt: vi.fn(), - } as unknown as PromptRegistry; - - it('should discover and log prompts', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [ - { name: 'prompt1', description: 'desc1' }, - { name: 'prompt2' }, - ], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).toHaveBeenCalledWith( - { method: 'prompts/list', params: {} }, - expect.anything(), + await client.connect(); + await expect(client.discover()).rejects.toThrow( + 'No prompts or tools found on the server.', ); - }); - - it('should do nothing if no prompts are discovered', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - const consoleLogSpy = vi - .spyOn(console, 'debug') - .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).toHaveBeenCalledOnce(); - expect(consoleLogSpy).not.toHaveBeenCalled(); - - consoleLogSpy.mockRestore(); - }); - - it('should do nothing if the server has no prompt support', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({}); - - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - const consoleLogSpy = vi - .spyOn(console, 'debug') - .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).not.toHaveBeenCalled(); - expect(consoleLogSpy).not.toHaveBeenCalled(); - - consoleLogSpy.mockRestore(); - }); - - it('should log an error if discovery fails', async () => { - const testError = new Error('test error'); - testError.message = 'test error'; - const mockRequest = vi.fn().mockRejectedValue(testError); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockRequest).toHaveBeenCalledOnce(); expect(consoleErrorSpy).toHaveBeenCalledWith( - `Error discovering prompts from test-server: ${testError.message}`, + `Error discovering prompts from test-server: Test error`, ); - consoleErrorSpy.mockRestore(); }); }); - describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); @@ -501,17 +211,6 @@ describe('mcp-client', () => { }); describe('createTransport', () => { - const originalEnv = process.env; - - beforeEach(() => { - vi.resetModules(); - process.env = {}; - }); - - afterEach(() => { - process.env = originalEnv; - }); - describe('should connect via httpUrl', () => { it('without headers', async () => { const transport = await createTransport( @@ -601,7 +300,7 @@ describe('mcp-client', () => { command: 'test-command', args: ['--foo', 'bar'], cwd: 'test/cwd', - env: { FOO: 'bar' }, + env: { ...process.env, FOO: 'bar' }, stderr: 'pipe', }); }); |
