summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Spiers <[email protected]>2025-06-30 01:09:08 +0100
committerGitHub <[email protected]>2025-06-30 00:09:08 +0000
commit0fd602eb43eea7abca980dc2ae3fd7bf2ba76a2a (patch)
treeb181996faa4e7ee66926ce68e2aeac2d823a91ed
parentd1eb86581ce800778e5a093039ce237ec6da6118 (diff)
feat: add support to remote MCP servers for custom HTTP headers (#2477)
-rw-r--r--docs/tools/mcp-server.md19
-rw-r--r--packages/core/src/config/config.ts1
-rw-r--r--packages/core/src/tools/mcp-client.test.ts114
-rw-r--r--packages/core/src/tools/mcp-client.ts17
4 files changed, 149 insertions, 2 deletions
diff --git a/docs/tools/mcp-server.md b/docs/tools/mcp-server.md
index ebce6160..0be9a34b 100644
--- a/docs/tools/mcp-server.md
+++ b/docs/tools/mcp-server.md
@@ -87,6 +87,7 @@ Each server configuration supports the following properties:
#### Optional
- **`args`** (string[]): Command-line arguments for Stdio transport
+- **`headers`** (object): Custom HTTP headers when using `httpUrl`
- **`env`** (object): Environment variables for the server process. Values can reference environment variables using `$VAR_NAME` or `${VAR_NAME}` syntax
- **`cwd`** (string): Working directory for Stdio transport
- **`timeout`** (number): Request timeout in milliseconds (default: 600,000ms = 10 minutes)
@@ -166,6 +167,24 @@ Each server configuration supports the following properties:
}
```
+#### HTTP-based MCP Server with Custom Headers
+
+```json
+{
+ "mcpServers": {
+ "httpServerWithAuth": {
+ "httpUrl": "http://localhost:3000/mcp",
+ "headers": {
+ "Authorization": "Bearer your-api-token",
+ "X-Custom-Header": "custom-value",
+ "Content-Type": "application/json"
+ },
+ "timeout": 5000
+ }
+ }
+}
+```
+
## Discovery Process Deep Dive
When the Gemini CLI starts, it performs MCP server discovery through the following detailed process:
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 4ee2d23f..3bb5b85e 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -76,6 +76,7 @@ export class MCPServerConfig {
readonly url?: string,
// For streamable http transport
readonly httpUrl?: string,
+ readonly headers?: Record<string, string>,
// For websocket transport
readonly tcp?: string,
// Common
diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts
index f963a060..91524a2f 100644
--- a/packages/core/src/tools/mcp-client.test.ts
+++ b/packages/core/src/tools/mcp-client.test.ts
@@ -21,6 +21,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
+import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { parse, ParseEntry } from 'shell-quote';
// Mock dependencies
@@ -65,6 +66,16 @@ vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
return { SSEClientTransport: MockedSSETransport };
});
+vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => {
+ const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function (
+ this: any,
+ ) {
+ this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
+ return this;
+ });
+ return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport };
+});
+
const mockToolRegistryInstance = {
registerTool: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
@@ -129,6 +140,15 @@ describe('discoverMcpTools', () => {
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
+
+ vi.mocked(StreamableHTTPClientTransport).mockClear();
+ // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method
+ vi.mocked(StreamableHTTPClientTransport).mockImplementation(function (
+ this: any,
+ ) {
+ this.close = vi.fn().mockResolvedValue(undefined);
+ return this;
+ });
});
afterEach(() => {
@@ -267,6 +287,100 @@ describe('discoverMcpTools', () => {
expect(registeredTool.name).toBe('tool-sse');
});
+ it('should discover tools via mcpServers config (streamable http)', async () => {
+ const serverConfig: MCPServerConfig = {
+ httpUrl: 'http://localhost:3000/mcp',
+ };
+ mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig });
+
+ const mockTool = {
+ name: 'tool-http',
+ description: 'desc-http',
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ {},
+ );
+ expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
+ expect.any(DiscoveredMCPTool),
+ );
+ const registeredTool = mockToolRegistry.registerTool.mock
+ .calls[0][0] as DiscoveredMCPTool;
+ expect(registeredTool.name).toBe('tool-http');
+ });
+
+ describe('StreamableHTTPClientTransport headers', () => {
+ const setupHttpTest = async (headers?: Record<string, string>) => {
+ const serverConfig: MCPServerConfig = {
+ httpUrl: 'http://localhost:3000/mcp',
+ ...(headers && { headers }),
+ };
+ const serverName = headers
+ ? 'http-server-with-headers'
+ : 'http-server-no-headers';
+ const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
+
+ mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
+
+ const mockTool = {
+ name: toolName,
+ description: `desc-${toolName}`,
+ inputSchema: { type: 'object' as const, properties: {} },
+ };
+ vi.mocked(Client.prototype.listTools).mockResolvedValue({
+ tools: [mockTool],
+ });
+ mockToolRegistry.getToolsByServer.mockReturnValueOnce([
+ expect.any(DiscoveredMCPTool),
+ ]);
+
+ await discoverMcpTools(
+ mockConfig.getMcpServers() ?? {},
+ mockConfig.getMcpServerCommand(),
+ mockToolRegistry as any,
+ );
+
+ return { serverConfig };
+ };
+
+ it('should pass headers when provided', async () => {
+ const headers = {
+ Authorization: 'Bearer test-token',
+ 'X-Custom-Header': 'custom-value',
+ };
+ const { serverConfig } = await setupHttpTest(headers);
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ { requestInit: { headers } },
+ );
+ });
+
+ it('should work without headers (backwards compatibility)', async () => {
+ const { serverConfig } = await setupHttpTest();
+
+ expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
+ new URL(serverConfig.httpUrl!),
+ {},
+ );
+ });
+ });
+
it('should prefix tool names if multiple MCP servers are configured', async () => {
const serverConfig1: MCPServerConfig = { command: './mcp1' };
const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };
diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts
index 72382ac1..52196b80 100644
--- a/packages/core/src/tools/mcp-client.ts
+++ b/packages/core/src/tools/mcp-client.ts
@@ -7,7 +7,10 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
-import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
+import {
+ StreamableHTTPClientTransport,
+ StreamableHTTPClientTransportOptions,
+} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
@@ -169,8 +172,17 @@ async function connectAndDiscover(
let transport;
if (mcpServerConfig.httpUrl) {
+ const transportOptions: StreamableHTTPClientTransportOptions = {};
+
+ if (mcpServerConfig.headers) {
+ transportOptions.requestInit = {
+ headers: mcpServerConfig.headers,
+ };
+ }
+
transport = new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
+ transportOptions,
);
} else if (mcpServerConfig.url) {
transport = new SSEClientTransport(new URL(mcpServerConfig.url));
@@ -222,10 +234,11 @@ async function connectAndDiscover(
const safeConfig = {
command: mcpServerConfig.command,
url: mcpServerConfig.url,
+ httpUrl: mcpServerConfig.httpUrl,
cwd: mcpServerConfig.cwd,
timeout: mcpServerConfig.timeout,
trust: mcpServerConfig.trust,
- // Exclude args and env which may contain sensitive data
+ // Exclude args, env, and headers which may contain sensitive data
};
let errorString =