diff options
Diffstat (limited to 'packages/core/src/tools/tool-registry.test.ts')
| -rw-r--r-- | packages/core/src/tools/tool-registry.test.ts | 328 |
1 files changed, 227 insertions, 101 deletions
diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index b39ec7b9..4d586d62 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -14,22 +14,22 @@ import { afterEach, Mocked, } from 'vitest'; -import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; -import { DiscoveredMCPTool } from './mcp-tool.js'; import { - Config, - ConfigParameters, - MCPServerConfig, - ApprovalMode, -} from '../config/config.js'; + ToolRegistry, + DiscoveredTool, + sanitizeParameters, +} from './tool-registry.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; +import { Config, ConfigParameters, ApprovalMode } from '../config/config.js'; import { BaseTool, ToolResult } from './tools.js'; import { FunctionDeclaration, CallableTool, mcpToTool, Type, + Schema, } from '@google/genai'; -import { execSync } from 'node:child_process'; +import { spawn } from 'node:child_process'; // Use vi.hoisted to define the mock function so it can be used in the vi.mock factory const mockDiscoverMcpTools = vi.hoisted(() => vi.fn()); @@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { set onerror(handler: any) { mockMcpClientOnError(handler); }, - // listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools })); return { Client: MockClient }; }); @@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => { return { ...actualGenai, mcpToTool: vi.fn().mockImplementation(() => ({ - // Default mock implementation tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }), callTool: vi.fn(), })), @@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = { describe('ToolRegistry', () => { let config: Config; let toolRegistry: ToolRegistry; + let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>; beforeEach(() => { config = new Config(baseConfigParams); @@ -148,13 +147,19 @@ describe('ToolRegistry', () => { vi.spyOn(console, 'debug').mockImplementation(() => {}); vi.spyOn(console, 'log').mockImplementation(() => {}); - // Reset mocks for MCP parts - mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success + mockMcpClientConnect.mockReset().mockResolvedValue(undefined); mockStdioTransportClose.mockReset(); mockSseTransportClose.mockReset(); vi.mocked(mcpToTool).mockClear(); - // Default mcpToTool to return a callable tool that returns no functions vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([])); + + mockConfigGetToolDiscoveryCommand = vi.spyOn( + config, + 'getToolDiscoveryCommand', + ); + vi.spyOn(config, 'getMcpServers'); + vi.spyOn(config, 'getMcpServerCommand'); + mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined); }); afterEach(() => { @@ -167,21 +172,18 @@ describe('ToolRegistry', () => { toolRegistry.registerTool(tool); expect(toolRegistry.getTool('mock-tool')).toBe(tool); }); - // ... other registerTool tests }); describe('getToolsByServer', () => { it('should return an empty array if no tools match the server name', () => { - toolRegistry.registerTool(new MockTool()); // A non-MCP tool + toolRegistry.registerTool(new MockTool()); expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]); }); it('should return only tools matching the server name', async () => { const server1Name = 'mcp-server-uno'; const server2Name = 'mcp-server-dos'; - - // Manually register mock MCP tools for this test - const mockCallable = {} as CallableTool; // Minimal mock callable + const mockCallable = {} as CallableTool; const mcpTool1 = new DiscoveredMCPTool( mockCallable, server1Name, @@ -207,73 +209,87 @@ describe('ToolRegistry', () => { const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name); expect(toolsFromServer1).toHaveLength(1); expect(toolsFromServer1[0].name).toBe(mcpTool1.name); - expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe( - server1Name, - ); const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name); expect(toolsFromServer2).toHaveLength(1); expect(toolsFromServer2[0].name).toBe(mcpTool2.name); - expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe( - server2Name, - ); - - expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]); }); }); describe('discoverTools', () => { - let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>; - let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>; - let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>; - let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>; - - beforeEach(() => { - mockConfigGetToolDiscoveryCommand = vi.spyOn( - config, - 'getToolDiscoveryCommand', - ); - mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers'); - mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand'); - mockExecSync = vi.mocked(execSync); - toolRegistry = new ToolRegistry(config); // Reset registry - // Reset the mock for discoverMcpTools before each test in this suite - mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined); - }); - - it('should discover tools using discovery command', async () => { - // ... this test remains largely the same + it('should sanitize tool parameters during discovery from command', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); - const mockToolDeclarations: FunctionDeclaration[] = [ - { - name: 'discovered-tool-1', - description: 'A discovered tool', - parameters: { type: Type.OBJECT, properties: {} }, + + const unsanitizedToolDeclaration: FunctionDeclaration = { + name: 'tool-with-bad-format', + description: 'A tool with an invalid format property', + parameters: { + type: Type.OBJECT, + properties: { + some_string: { + type: Type.STRING, + format: 'uuid', // This is an unsupported format + }, + }, }, - ]; - mockExecSync.mockReturnValue( - Buffer.from( - JSON.stringify([{ function_declarations: mockToolDeclarations }]), - ), - ); + }; + + const mockSpawn = vi.mocked(spawn); + const mockChildProcess = { + stdout: { on: vi.fn() }, + stderr: { on: vi.fn() }, + on: vi.fn(), + }; + mockSpawn.mockReturnValue(mockChildProcess as any); + + // Simulate stdout data + mockChildProcess.stdout.on.mockImplementation((event, callback) => { + if (event === 'data') { + callback( + Buffer.from( + JSON.stringify([ + { function_declarations: [unsanitizedToolDeclaration] }, + ]), + ), + ); + } + return mockChildProcess as any; + }); + + // Simulate process close + mockChildProcess.on.mockImplementation((event, callback) => { + if (event === 'close') { + callback(0); + } + return mockChildProcess as any; + }); + await toolRegistry.discoverTools(); - expect(execSync).toHaveBeenCalledWith(discoveryCommand); - const discoveredTool = toolRegistry.getTool('discovered-tool-1'); - expect(discoveredTool).toBeInstanceOf(DiscoveredTool); + + const discoveredTool = toolRegistry.getTool('tool-with-bad-format'); + expect(discoveredTool).toBeDefined(); + + const registeredParams = (discoveredTool as DiscoveredTool).schema + .parameters as Schema; + expect(registeredParams.properties?.['some_string']).toBeDefined(); + expect(registeredParams.properties?.['some_string']).toHaveProperty( + 'format', + undefined, + ); }); it('should discover tools using MCP servers defined in getMcpServers', async () => { mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - mockConfigGetMcpServerCommand.mockReturnValue(undefined); + vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); const mcpServerConfigVal = { 'my-mcp-server': { command: 'mcp-server-cmd', args: ['--port', '1234'], trust: true, - } as MCPServerConfig, + }, }; - mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal); + vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); await toolRegistry.discoverTools(); @@ -282,56 +298,166 @@ describe('ToolRegistry', () => { undefined, toolRegistry, ); - // We no longer check these as discoverMcpTools is mocked - // expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1); - // expect(Client).toHaveBeenCalledTimes(1); - // expect(StdioClientTransport).toHaveBeenCalledWith({ - // command: 'mcp-server-cmd', - // args: ['--port', '1234'], - // env: expect.any(Object), - // stderr: 'pipe', - // }); - // expect(mockMcpClientConnect).toHaveBeenCalled(); - - // To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools - // to call toolRegistry.registerTool, or we test that separately. - // For now, we just check that the delegation happened. }); - it('should discover tools using MCP server command from getMcpServerCommand', async () => { + it('should discover tools using MCP servers defined in getMcpServers', async () => { mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - mockConfigGetMcpServers.mockReturnValue({}); - mockConfigGetMcpServerCommand.mockReturnValue( - 'mcp-server-start-command --param', - ); + vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); + const mcpServerConfigVal = { + 'my-mcp-server': { + command: 'mcp-server-cmd', + args: ['--port', '1234'], + trust: true, + }, + }; + vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); await toolRegistry.discoverTools(); + expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - {}, - 'mcp-server-start-command --param', + mcpServerConfigVal, + undefined, toolRegistry, ); }); + }); +}); - it('should handle errors during MCP client connection gracefully and close transport', async () => { - mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - mockConfigGetMcpServers.mockReturnValue({ - 'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig, - }); +describe('sanitizeParameters', () => { + it('should remove unsupported format from a simple string property', () => { + const schema: Schema = { + type: Type.OBJECT, + properties: { + name: { type: Type.STRING }, + id: { type: Type.STRING, format: 'uuid' }, + }, + }; + sanitizeParameters(schema); + expect(schema.properties?.['id']).toHaveProperty('format', undefined); + expect(schema.properties?.['name']).not.toHaveProperty('format'); + }); - mockMcpClientConnect.mockRejectedValue(new Error('Connection failed')); + it('should NOT remove supported format values', () => { + const schema: Schema = { + type: Type.OBJECT, + properties: { + date: { type: Type.STRING, format: 'date-time' }, + role: { + type: Type.STRING, + format: 'enum', + enum: ['admin', 'user'], + }, + }, + }; + const originalSchema = JSON.parse(JSON.stringify(schema)); + sanitizeParameters(schema); + expect(schema).toEqual(originalSchema); + }); - await toolRegistry.discoverTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - { - 'failing-mcp': { command: 'fail-cmd' }, + it('should handle nested objects recursively', () => { + const schema: Schema = { + type: Type.OBJECT, + properties: { + user: { + type: Type.OBJECT, + properties: { + email: { type: Type.STRING, format: 'email' }, + }, }, - undefined, - toolRegistry, - ); - expect(toolRegistry.getAllTools()).toHaveLength(0); - }); + }, + }; + sanitizeParameters(schema); + expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty( + 'format', + undefined, + ); + }); + + it('should handle arrays of objects', () => { + const schema: Schema = { + type: Type.OBJECT, + properties: { + items: { + type: Type.ARRAY, + items: { + type: Type.OBJECT, + properties: { + itemId: { type: Type.STRING, format: 'uuid' }, + }, + }, + }, + }, + }; + sanitizeParameters(schema); + expect( + (schema.properties?.['items']?.items as Schema)?.properties?.['itemId'], + ).toHaveProperty('format', undefined); + }); + + it('should handle schemas with no properties to sanitize', () => { + const schema: Schema = { + type: Type.OBJECT, + properties: { + count: { type: Type.NUMBER }, + isActive: { type: Type.BOOLEAN }, + }, + }; + const originalSchema = JSON.parse(JSON.stringify(schema)); + sanitizeParameters(schema); + expect(schema).toEqual(originalSchema); + }); + + it('should not crash on an empty or undefined schema', () => { + expect(() => sanitizeParameters({})).not.toThrow(); + expect(() => sanitizeParameters(undefined)).not.toThrow(); + }); + + it('should handle cyclic schemas without crashing', () => { + const schema: any = { + type: Type.OBJECT, + properties: { + name: { type: Type.STRING, format: 'hostname' }, + }, + }; + schema.properties.self = schema; + + expect(() => sanitizeParameters(schema)).not.toThrow(); + expect(schema.properties.name).toHaveProperty('format', undefined); + }); + + it('should handle complex nested schemas with cycles', () => { + const userNode: any = { + type: Type.OBJECT, + properties: { + id: { type: Type.STRING, format: 'uuid' }, + name: { type: Type.STRING }, + manager: { + type: Type.OBJECT, + properties: { + id: { type: Type.STRING, format: 'uuid' }, + }, + }, + }, + }; + userNode.properties.reports = { + type: Type.ARRAY, + items: userNode, + }; + + const schema: Schema = { + type: Type.OBJECT, + properties: { + ceo: userNode, + }, + }; + + expect(() => sanitizeParameters(schema)).not.toThrow(); + expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty( + 'format', + undefined, + ); + expect( + schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'], + ).toHaveProperty('format', undefined); }); - // Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed - // if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts) }); |
