diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/config/config.ts | 2 | ||||
| -rw-r--r-- | packages/core/src/tools/mcp-client.ts | 19 | ||||
| -rw-r--r-- | packages/core/src/tools/websocket-client-transport.ts | 97 |
3 files changed, 112 insertions, 6 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index abc2240b..5e13241d 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -45,6 +45,8 @@ export class MCPServerConfig { readonly cwd?: string, // For sse transport readonly url?: string, + // For websocket transport + readonly tcp?: string, // Common readonly timeout?: number, readonly trust?: boolean, diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index a7d6e00c..6f498730 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -12,6 +12,7 @@ import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai'; import { ToolRegistry } from './tool-registry.js'; +import { WebSocketClientTransport } from './websocket-client-transport.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes @@ -164,6 +165,8 @@ async function connectAndDiscover( let transport; if (mcpServerConfig.url) { transport = new SSEClientTransport(new URL(mcpServerConfig.url)); + } else if (mcpServerConfig.tcp) { + transport = new WebSocketClientTransport(new URL(mcpServerConfig.tcp)); } else if (mcpServerConfig.command) { transport = new StdioClientTransport({ command: mcpServerConfig.command, @@ -177,7 +180,7 @@ async function connectAndDiscover( }); } else { console.error( - `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`, + `MCP server '${mcpServerName}' has invalid configuration: missing url (for SSE), tcp (for websocket), and command (for stdio). Skipping.`, ); // Update status to disconnected updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); @@ -254,9 +257,11 @@ async function connectAndDiscover( console.error( `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`, ); - if (transport instanceof StdioClientTransport) { - await transport.close(); - } else if (transport instanceof SSEClientTransport) { + if ( + transport instanceof StdioClientTransport || + transport instanceof SSEClientTransport || + transport instanceof WebSocketClientTransport + ) { await transport.close(); } // Update status to disconnected @@ -315,7 +320,8 @@ async function connectAndDiscover( // Ensure transport is cleaned up on error too if ( transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport + transport instanceof SSEClientTransport || + transport instanceof WebSocketClientTransport ) { await transport.close(); } @@ -334,7 +340,8 @@ async function connectAndDiscover( ); if ( transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport + transport instanceof SSEClientTransport || + transport instanceof WebSocketClientTransport ) { await transport.close(); // Update status to disconnected diff --git a/packages/core/src/tools/websocket-client-transport.ts b/packages/core/src/tools/websocket-client-transport.ts new file mode 100644 index 00000000..ff754c0a --- /dev/null +++ b/packages/core/src/tools/websocket-client-transport.ts @@ -0,0 +1,97 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import WebSocket from 'ws'; +import { + Transport, + TransportSendOptions, +} from '@modelcontextprotocol/sdk/shared/transport.js'; +import { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { AuthInfo } from '@modelcontextprotocol/sdk/server/auth/types.js'; + +export class WebSocketClientTransport implements Transport { + private socket: WebSocket | null = null; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: ( + message: JSONRPCMessage, + extra?: { authInfo?: AuthInfo }, + ) => void; + + constructor(private readonly url: URL) {} + + async start(): Promise<void> { + return new Promise((resolve, reject) => { + const handshakeTimeoutDuration = 10000; + let connectionTimeout: NodeJS.Timeout | null = null; + + try { + this.socket = new WebSocket(this.url.toString(), { + handshakeTimeout: handshakeTimeoutDuration, + }); + + connectionTimeout = setTimeout(() => { + this.socket?.close(); + reject( + new Error( + `WebSocket connection timed out after ${handshakeTimeoutDuration}ms`, + ), + ); + }, handshakeTimeoutDuration); + + this.socket.on('open', () => { + clearTimeout(connectionTimeout!); + resolve(); + }); + + this.socket.on('message', (data) => { + try { + const parsedMessage: JSONRPCMessage = JSON.parse(data.toString()); + this.onmessage?.(parsedMessage, { authInfo: undefined }); // Auth unsupported currently + } catch (error: unknown) { + this.onerror?.( + error instanceof Error ? error : new Error(String(error)), + ); + } + }); + + this.socket.on('error', (error) => { + clearTimeout(connectionTimeout!); + this.onerror?.(error); + reject(error); + }); + + this.socket.on('close', () => { + clearTimeout(connectionTimeout!); + this.onclose?.(); + this.socket = null; + }); + } catch (error: unknown) { + clearTimeout(connectionTimeout!); + reject(error instanceof Error ? error : new Error(String(error))); + } + }); + } + + async close(): Promise<void> { + if (this.socket) { + this.socket.close(); + this.socket = null; + } + } + + async send( + message: JSONRPCMessage, + _options?: TransportSendOptions, + ): Promise<void> { + if (!this.socket || this.socket.readyState !== WebSocket.OPEN) { + throw new Error( + 'WebSocket is not connected or not open. Cannot send message.', + ); + } + this.socket.send(JSON.stringify(message)); + } +} |
