diff options
| author | BigUncle <[email protected]> | 2025-07-06 05:58:51 +0800 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-05 21:58:51 +0000 |
| commit | b564d4a088d11ae5a90291e642b104761d72ee61 (patch) | |
| tree | 2423d004716aa77f05e55b019fe652a18d539f1d /packages/core/src/tools/tool-registry.ts | |
| parent | 5c9372372c73afcff893499e538cf5522a4400e2 (diff) | |
fix(core): Sanitize tool parameters to fix 400 API errors (#3300)
Diffstat (limited to 'packages/core/src/tools/tool-registry.ts')
| -rw-r--r-- | packages/core/src/tools/tool-registry.ts | 200 |
1 files changed, 180 insertions, 20 deletions
diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index f3162ac0..62ae2a51 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -4,12 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { FunctionDeclaration } from '@google/genai'; +import { FunctionDeclaration, Schema, Type } from '@google/genai'; import { Tool, ToolResult, BaseTool } from './tools.js'; import { Config } from '../config/config.js'; -import { spawn, execSync } from 'node:child_process'; +import { spawn } from 'node:child_process'; +import { StringDecoder } from 'node:string_decoder'; import { discoverMcpTools } from './mcp-client.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; +import { parse } from 'shell-quote'; type ToolParams = Record<string, unknown>; @@ -157,38 +159,137 @@ export class ToolRegistry { // Keep manually registered tools } } - // discover tools using discovery command, if configured + + await this.discoverAndRegisterToolsFromCommand(); + + // discover tools using MCP servers, if configured + await discoverMcpTools( + this.config.getMcpServers() ?? {}, + this.config.getMcpServerCommand(), + this, + ); + } + + private async discoverAndRegisterToolsFromCommand(): Promise<void> { const discoveryCmd = this.config.getToolDiscoveryCommand(); - if (discoveryCmd) { + if (!discoveryCmd) { + return; + } + + try { + const cmdParts = parse(discoveryCmd); + if (cmdParts.length === 0) { + throw new Error( + 'Tool discovery command is empty or contains only whitespace.', + ); + } + const proc = spawn(cmdParts[0] as string, cmdParts.slice(1) as string[]); + let stdout = ''; + const stdoutDecoder = new StringDecoder('utf8'); + let stderr = ''; + const stderrDecoder = new StringDecoder('utf8'); + let sizeLimitExceeded = false; + const MAX_STDOUT_SIZE = 10 * 1024 * 1024; // 10MB limit + const MAX_STDERR_SIZE = 10 * 1024 * 1024; // 10MB limit + + let stdoutByteLength = 0; + let stderrByteLength = 0; + + proc.stdout.on('data', (data) => { + if (sizeLimitExceeded) return; + if (stdoutByteLength + data.length > MAX_STDOUT_SIZE) { + sizeLimitExceeded = true; + proc.kill(); + return; + } + stdoutByteLength += data.length; + stdout += stdoutDecoder.write(data); + }); + + proc.stderr.on('data', (data) => { + if (sizeLimitExceeded) return; + if (stderrByteLength + data.length > MAX_STDERR_SIZE) { + sizeLimitExceeded = true; + proc.kill(); + return; + } + stderrByteLength += data.length; + stderr += stderrDecoder.write(data); + }); + + await new Promise<void>((resolve, reject) => { + proc.on('error', reject); + proc.on('close', (code) => { + stdout += stdoutDecoder.end(); + stderr += stderrDecoder.end(); + + if (sizeLimitExceeded) { + return reject( + new Error( + `Tool discovery command output exceeded size limit of ${MAX_STDOUT_SIZE} bytes.`, + ), + ); + } + + if (code !== 0) { + console.error(`Command failed with code ${code}`); + console.error(stderr); + return reject( + new Error(`Tool discovery command failed with exit code ${code}`), + ); + } + resolve(); + }); + }); + // execute discovery command and extract function declarations (w/ or w/o "tool" wrappers) const functions: FunctionDeclaration[] = []; - for (const tool of JSON.parse(execSync(discoveryCmd).toString().trim())) { - if (tool['function_declarations']) { - functions.push(...tool['function_declarations']); - } else if (tool['functionDeclarations']) { - functions.push(...tool['functionDeclarations']); - } else if (tool['name']) { - functions.push(tool); + const discoveredItems = JSON.parse(stdout.trim()); + + if (!discoveredItems || !Array.isArray(discoveredItems)) { + throw new Error( + 'Tool discovery command did not return a JSON array of tools.', + ); + } + + for (const tool of discoveredItems) { + if (tool && typeof tool === 'object') { + if (Array.isArray(tool['function_declarations'])) { + functions.push(...tool['function_declarations']); + } else if (Array.isArray(tool['functionDeclarations'])) { + functions.push(...tool['functionDeclarations']); + } else if (tool['name']) { + functions.push(tool as FunctionDeclaration); + } } } // register each function as a tool for (const func of functions) { + if (!func.name) { + console.warn('Discovered a tool with no name. Skipping.'); + continue; + } + // Sanitize the parameters before registering the tool. + const parameters = + func.parameters && + typeof func.parameters === 'object' && + !Array.isArray(func.parameters) + ? (func.parameters as Schema) + : {}; + sanitizeParameters(parameters); this.registerTool( new DiscoveredTool( this.config, - func.name!, - func.description!, - func.parameters! as Record<string, unknown>, + func.name, + func.description ?? '', + parameters as Record<string, unknown>, ), ); } + } catch (e) { + console.error(`Tool discovery command "${discoveryCmd}" failed:`, e); + throw e; } - // discover tools using MCP servers, if configured - await discoverMcpTools( - this.config.getMcpServers() ?? {}, - this.config.getMcpServerCommand(), - this, - ); } /** @@ -232,3 +333,62 @@ export class ToolRegistry { return this.tools.get(name); } } + +/** + * Sanitizes a schema object in-place to ensure compatibility with the Gemini API. + * + * NOTE: This function mutates the passed schema object. + * + * It performs the following actions: + * - Removes the `default` property when `anyOf` is present. + * - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'. + * - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`. + * - Handles circular references within the schema to prevent infinite loops. + * + * @param schema The schema object to sanitize. It will be modified directly. + */ +export function sanitizeParameters(schema?: Schema) { + _sanitizeParameters(schema, new Set<Schema>()); +} + +/** + * Internal recursive implementation for sanitizeParameters. + * @param schema The schema object to sanitize. + * @param visited A set used to track visited schema objects during recursion. + */ +function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) { + if (!schema || visited.has(schema)) { + return; + } + visited.add(schema); + + if (schema.anyOf) { + // Vertex AI gets confused if both anyOf and default are set. + schema.default = undefined; + for (const item of schema.anyOf) { + if (typeof item !== 'boolean') { + _sanitizeParameters(item, visited); + } + } + } + if (schema.items && typeof schema.items !== 'boolean') { + _sanitizeParameters(schema.items, visited); + } + if (schema.properties) { + for (const item of Object.values(schema.properties)) { + if (typeof item !== 'boolean') { + _sanitizeParameters(item, visited); + } + } + } + // Vertex AI only supports 'enum' and 'date-time' for STRING format. + if (schema.type === Type.STRING) { + if ( + schema.format && + schema.format !== 'enum' && + schema.format !== 'date-time' + ) { + schema.format = undefined; + } + } +} |
