summaryrefslogtreecommitdiff
path: root/packages/core/src/tools/mcp-client.test.ts
diff options
context:
space:
mode:
authorAdam Spiers <[email protected]>2025-06-30 01:09:08 +0100
committerGitHub <[email protected]>2025-06-30 00:09:08 +0000
commit0fd602eb43eea7abca980dc2ae3fd7bf2ba76a2a (patch)
treeb181996faa4e7ee66926ce68e2aeac2d823a91ed /packages/core/src/tools/mcp-client.test.ts
parentd1eb86581ce800778e5a093039ce237ec6da6118 (diff)
feat: add support to remote MCP servers for custom HTTP headers (#2477)
Diffstat (limited to 'packages/core/src/tools/mcp-client.test.ts')
-rw-r--r--packages/core/src/tools/mcp-client.test.ts114
1 files changed, 114 insertions, 0 deletions
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index f963a060..91524a2f 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -21,6 +21,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
+import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { parse, ParseEntry } from 'shell-quote';
// Mock dependencies
@@ -65,6 +66,16 @@ vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
return { SSEClientTransport: MockedSSETransport };
});
+vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => {
+ const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function (
+ this: any,
+ ) {
+ this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
+ return this;
+ });
+ return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport };
+});
+
const mockToolRegistryInstance = {
registerTool: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
@@ -129,6 +140,15 @@ describe('discoverMcpTools', () => {
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
+
+ vi.mocked(StreamableHTTPClientTransport).mockClear();
+ // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method
+ vi.mocked(StreamableHTTPClientTransport).mockImplementation(function (
+ this: any,
+ ) {
+ this.close = vi.fn().mockResolvedValue(undefined);
+ return this;
+ });
});
afterEach(() => {
@@ -267,6 +287,100 @@ describe('discoverMcpTools', () => {
expect(registeredTool.name).toBe('tool-sse');
});
+ it('should discover tools via mcpServers config (streamable http)', async () => {
+ const serverConfig: MCPServerConfig = {
+ httpUrl: 'http://localhost:3000/mcp',
+ };
+ mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig });
+
+ const mockTool = {
+ name: 'tool-http',
+ description: 'desc-http',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ {},
+ );
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
+ expect.any(DiscoveredMCPTool),
+ );
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ expect(registeredTool.name).toBe('tool-http');
+ });
+
+ describe('StreamableHTTPClientTransport headers', () => {
+ const setupHttpTest = async (headers?: Record<string, string>) => {
+ const serverConfig: MCPServerConfig = {
+ httpUrl: 'http://localhost:3000/mcp',
+ ...(headers && { headers }),
+ };
+ const serverName = headers
+ ? 'http-server-with-headers'
+ : 'http-server-no-headers';
+ const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
+
+ mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
+
+ const mockTool = {
+ name: toolName,
+ description: `desc-${toolName}`,
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
+
+ return { serverConfig };
+ };
+
+ it('should pass headers when provided', async () => {
+ const headers = {
+ Authorization: 'Bearer test-token',
+ 'X-Custom-Header': 'custom-value',
+ };
+ const { serverConfig } = await setupHttpTest(headers);
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ { requestInit: { headers } },
+ );
+ });
+
+ it('should work without headers (backwards compatibility)', async () => {
+ const { serverConfig } = await setupHttpTest();
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ {},
+ );
+ });
+ });
+
it('should prefix tool names if multiple MCP servers are configured', async () => {
const serverConfig1: MCPServerConfig = { command: './mcp1' };
const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };