summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorJerop Kipruto <[email protected]>2025-06-29 15:32:26 -0400
committerGitHub <[email protected]>2025-06-29 19:32:26 +0000
commitd8d78d73f9638d11ba8b6ba184b49d4dc7caa8f4 (patch)
treefd747168058eb730afc1766f5ad4712df335f6cf /packages/core/src
parent19a0276142b61208e5d4b723e422e37bf005845a (diff)
feat: allow command-specific restrictions for ShellTool (#2605)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/config/config.ts38
-rw-r--r--packages/core/src/tools/shell.test.ts171
-rw-r--r--packages/core/src/tools/shell.ts60
3 files changed, 254 insertions, 15 deletions
diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts
index 59c9c1bd..4ee2d23f 100644
--- a/packages/core/src/config/config.ts
+++ b/packages/core/src/config/config.ts
@@ -456,25 +456,33 @@ export class Config {
export function createToolRegistry(config: Config): Promise<ToolRegistry> {
const registry = new ToolRegistry(config);
const targetDir = config.getTargetDir();
- const tools = config.getCoreTools()
- ? new Set(config.getCoreTools())
- : undefined;
- const excludeTools = config.getExcludeTools()
- ? new Set(config.getExcludeTools())
- : undefined;
// helper to create & register core tools that are enabled
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const registerCoreTool = (ToolClass: any, ...args: unknown[]) => {
- // check both the tool name (.Name) and the class name (.name)
- if (
- // coreTools contain tool name
- (!tools || tools.has(ToolClass.Name) || tools.has(ToolClass.name)) &&
- // excludeTools don't contain tool name
- (!excludeTools ||
- (!excludeTools.has(ToolClass.Name) &&
- !excludeTools.has(ToolClass.name)))
- ) {
+ const className = ToolClass.name;
+ const toolName = ToolClass.Name || className;
+ const coreTools = config.getCoreTools();
+ const excludeTools = config.getExcludeTools();
+
+ let isEnabled = false;
+ if (coreTools === undefined) {
+ isEnabled = true;
+ } else {
+ isEnabled = coreTools.some(
+ (tool) =>
+ tool === className ||
+ tool === toolName ||
+ tool.startsWith(`${className}(`) ||
+ tool.startsWith(`${toolName}(`),
+ );
+ }
+
+ if (excludeTools?.includes(className) || excludeTools?.includes(toolName)) {
+ isEnabled = false;
+ }
+
+ if (isEnabled) {
registry.registerTool(new ToolClass(...args));
}
};
diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts
new file mode 100644
index 00000000..2cbd0ff4
--- /dev/null
+++ b/packages/core/src/tools/shell.test.ts
@@ -0,0 +1,171 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { expect, describe, it } from 'vitest';
+import { ShellTool } from './shell.js';
+import { Config } from '../config/config.js';
+
+describe('ShellTool', () => {
+ it('should allow a command if no restrictions are provided', async () => {
+ const config = {
+ getCoreTools: () => undefined,
+ getExcludeTools: () => undefined,
+ } as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('ls -l');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should allow a command if it is in the allowed list', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool(ls -l)'],
+ getExcludeTools: () => undefined,
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('ls -l');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should block a command if it is not in the allowed list', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool(ls -l)'],
+ getExcludeTools: () => undefined,
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('rm -rf /');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should block a command if it is in the blocked list', async () => {
+ const config = {
+ getCoreTools: () => undefined,
+ getExcludeTools: () => ['ShellTool(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('rm -rf /');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should allow a command if it is not in the blocked list', async () => {
+ const config = {
+ getCoreTools: () => undefined,
+ getExcludeTools: () => ['ShellTool(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('ls -l');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should block a command if it is in both the allowed and blocked lists', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool(rm -rf /)'],
+ getExcludeTools: () => ['ShellTool(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('rm -rf /');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should allow any command when ShellTool is in coreTools without specific commands', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool'],
+ getExcludeTools: () => [],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should block any command when ShellTool is in excludeTools without specific commands', async () => {
+ const config = {
+ getCoreTools: () => [],
+ getExcludeTools: () => ['ShellTool'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should allow a command if it is in the allowed list using the public-facing name', async () => {
+ const config = {
+ getCoreTools: () => ['run_shell_command(ls -l)'],
+ getExcludeTools: () => undefined,
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('ls -l');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should block a command if it is in the blocked list using the public-facing name', async () => {
+ const config = {
+ getCoreTools: () => undefined,
+ getExcludeTools: () => ['run_shell_command(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('rm -rf /');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should block any command when ShellTool is in excludeTools using the public-facing name', async () => {
+ const config = {
+ getCoreTools: () => [],
+ getExcludeTools: () => ['run_shell_command'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should block any command if coreTools contains an empty ShellTool command list using the public-facing name', async () => {
+ const config = {
+ getCoreTools: () => ['run_shell_command()'],
+ getExcludeTools: () => [],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should block any command if coreTools contains an empty ShellTool command list', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool()'],
+ getExcludeTools: () => [],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should block a command with extra whitespace if it is in the blocked list', async () => {
+ const config = {
+ getCoreTools: () => undefined,
+ getExcludeTools: () => ['ShellTool(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed(' rm -rf / ');
+ expect(isAllowed).toBe(false);
+ });
+
+ it('should allow any command when ShellTool is present with specific commands', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool', 'ShellTool(ls)'],
+ getExcludeTools: () => [],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('any command');
+ expect(isAllowed).toBe(true);
+ });
+
+ it('should block a command on the blocklist even with a wildcard allow', async () => {
+ const config = {
+ getCoreTools: () => ['ShellTool'],
+ getExcludeTools: () => ['ShellTool(rm -rf /)'],
+ } as unknown as Config;
+ const shellTool = new ShellTool(config);
+ const isAllowed = shellTool.isCommandAllowed('rm -rf /');
+ expect(isAllowed).toBe(false);
+ });
+});
diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts
index 1ca90768..a2fa5ce4 100644
--- a/packages/core/src/tools/shell.ts
+++ b/packages/core/src/tools/shell.ts
@@ -98,7 +98,67 @@ Process Group PGID: Process group started or \`(none)\``,
.pop(); // take last part and return command root (or undefined if previous line was empty)
}
+ isCommandAllowed(command: string): boolean {
+ const normalize = (cmd: string) => cmd.trim().replace(/\s+/g, ' ');
+
+ const extractCommands = (tools: string[]): string[] =>
+ tools.flatMap((tool) => {
+ if (tool.startsWith(`${ShellTool.name}(`) && tool.endsWith(')')) {
+ return [normalize(tool.slice(ShellTool.name.length + 1, -1))];
+ } else if (
+ tool.startsWith(`${ShellTool.Name}(`) &&
+ tool.endsWith(')')
+ ) {
+ return [normalize(tool.slice(ShellTool.Name.length + 1, -1))];
+ }
+ return [];
+ });
+
+ const coreTools = this.config.getCoreTools() || [];
+ const excludeTools = this.config.getExcludeTools() || [];
+
+ if (
+ excludeTools.includes(ShellTool.name) ||
+ excludeTools.includes(ShellTool.Name)
+ ) {
+ return false;
+ }
+
+ const blockedCommands = extractCommands(excludeTools);
+ const normalizedCommand = normalize(command);
+
+ if (blockedCommands.includes(normalizedCommand)) {
+ return false;
+ }
+
+ const hasSpecificCommands = coreTools.some(
+ (tool) =>
+ (tool.startsWith(`${ShellTool.name}(`) && tool.endsWith(')')) ||
+ (tool.startsWith(`${ShellTool.Name}(`) && tool.endsWith(')')),
+ );
+
+ if (hasSpecificCommands) {
+ // If the generic `ShellTool` is also present, it acts as a wildcard,
+ // allowing all commands (that are not explicitly blocked).
+ if (
+ coreTools.includes(ShellTool.name) ||
+ coreTools.includes(ShellTool.Name)
+ ) {
+ return true;
+ }
+
+ // Otherwise, we are in strict allow-list mode.
+ const allowedCommands = extractCommands(coreTools);
+ return allowedCommands.includes(normalizedCommand);
+ }
+
+ return true;
+ }
+
validateToolParams(params: ShellToolParams): string | null {
+ if (!this.isCommandAllowed(params.command)) {
+ return `Command is not allowed: ${params.command}`;
+ }
if (
!SchemaValidator.validate(
this.parameterSchema as Record<string, unknown>,