summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorN. Taylor Mullen <[email protected]>2025-06-08 15:42:49 -0700
committerGitHub <[email protected]>2025-06-08 22:42:49 +0000
commitf2ea78d0e4e5d25ab3cc25dc9f1492135630c9be (patch)
treecdc80f281095a279c1c1746a5b4c1fbfa008dc20
parent7868ef82299ae1da5a09334f67d57eb3b472563a (diff)
fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.test.tsx91
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts57
-rw-r--r--packages/cli/src/ui/hooks/useReactToolScheduler.ts20
-rw-r--r--packages/cli/src/ui/hooks/useToolScheduler.test.ts118
-rw-r--r--packages/core/src/core/coreToolScheduler.test.ts105
-rw-r--r--packages/core/src/core/coreToolScheduler.ts46
-rw-r--r--packages/core/src/tools/shell.ts7
7 files changed, 235 insertions, 209 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index 1335eb8e..f41f7f9c 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -18,6 +18,7 @@ import {
import { Config } from '@gemini-cli/core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
+import { Dispatch, SetStateAction } from 'react';
// --- MOCKS ---
const mockSendMessageStream = vi
@@ -309,16 +310,41 @@ describe('useGeminiStream', () => {
const client = geminiClient || mockConfig.getGeminiClient();
- const { result, rerender } = renderHook(() =>
- useGeminiStream(
- client,
- mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
- mockSetShowHelp,
- mockConfig,
- mockOnDebugMessage,
- mockHandleSlashCommand,
- false, // shellModeActive
- ),
+ const { result, rerender } = renderHook(
+ (props: {
+ client: any;
+ addItem: UseHistoryManagerReturn['addItem'];
+ setShowHelp: Dispatch<SetStateAction<boolean>>;
+ config: Config;
+ onDebugMessage: (message: string) => void;
+ handleSlashCommand: (
+ command: PartListUnion,
+ ) =>
+ | import('./slashCommandProcessor.js').SlashCommandActionReturn
+ | boolean;
+ shellModeActive: boolean;
+ }) =>
+ useGeminiStream(
+ props.client,
+ props.addItem,
+ props.setShowHelp,
+ props.config,
+ props.onDebugMessage,
+ props.handleSlashCommand,
+ props.shellModeActive,
+ ),
+ {
+ initialProps: {
+ client,
+ addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
+ setShowHelp: mockSetShowHelp,
+ config: mockConfig,
+ onDebugMessage: mockOnDebugMessage,
+ handleSlashCommand:
+ mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
+ shellModeActive: false,
+ },
+ },
);
return {
result,
@@ -326,7 +352,6 @@ describe('useGeminiStream', () => {
mockMarkToolsAsSubmitted,
mockSendMessageStream,
client,
- // mockFilter removed
};
};
@@ -423,24 +448,29 @@ describe('useGeminiStream', () => {
} as TrackedCancelledToolCall,
];
- const hookResult = await act(async () =>
- renderTestHook(simplifiedToolCalls),
- );
-
const {
+ rerender,
mockMarkToolsAsSubmitted,
mockSendMessageStream: localMockSendMessageStream,
- } = hookResult!;
-
- // It seems the initial render + effect run should be enough.
- // If rerender was for a specific state change, it might still be needed.
- // For now, let's test if the initial effect run (covered by the first act) is sufficient.
- // If not, we can add back: await act(async () => { rerender({}); });
+ client,
+ } = renderTestHook(simplifiedToolCalls);
- expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
+ act(() => {
+ rerender({
+ client,
+ addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
+ setShowHelp: mockSetShowHelp,
+ config: mockConfig,
+ onDebugMessage: mockOnDebugMessage,
+ handleSlashCommand:
+ mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
+ shellModeActive: false,
+ });
+ });
await waitFor(() => {
- expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
+ expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
});
const expectedMergedResponse = mergePartListUnions([
@@ -479,12 +509,21 @@ describe('useGeminiStream', () => {
client,
);
- await act(async () => {
- rerender({} as any);
+ act(() => {
+ rerender({
+ client,
+ addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
+ setShowHelp: mockSetShowHelp,
+ config: mockConfig,
+ onDebugMessage: mockOnDebugMessage,
+ handleSlashCommand:
+ mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
+ shellModeActive: false,
+ });
});
await waitFor(() => {
- expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
expect(client.addHistory).toHaveBeenCalledTimes(2);
expect(client.addHistory).toHaveBeenCalledWith({
role: 'user',
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 3b3d01e0..2b47ae6f 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -83,28 +83,24 @@ export const useGeminiStream = (
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
- const [
- toolCalls,
- scheduleToolCalls,
- cancelAllToolCalls,
- markToolsAsSubmitted,
- ] = useReactToolScheduler(
- (completedToolCallsFromScheduler) => {
- // This onComplete is called when ALL scheduled tools for a given batch are done.
- if (completedToolCallsFromScheduler.length > 0) {
- // Add the final state of these tools to the history for display.
- // The new useEffect will handle submitting their responses.
- addItem(
- mapTrackedToolCallsToDisplay(
- completedToolCallsFromScheduler as TrackedToolCall[],
- ),
- Date.now(),
- );
- }
- },
- config,
- setPendingHistoryItem,
- );
+ const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
+ useReactToolScheduler(
+ (completedToolCallsFromScheduler) => {
+ // This onComplete is called when ALL scheduled tools for a given batch are done.
+ if (completedToolCallsFromScheduler.length > 0) {
+ // Add the final state of these tools to the history for display.
+ // The new useEffect will handle submitting their responses.
+ addItem(
+ mapTrackedToolCallsToDisplay(
+ completedToolCallsFromScheduler as TrackedToolCall[],
+ ),
+ Date.now(),
+ );
+ }
+ },
+ config,
+ setPendingHistoryItem,
+ );
const pendingToolCallGroupDisplay = useMemo(
() =>
@@ -143,10 +139,15 @@ export const useGeminiStream = (
return StreamingState.Idle;
}, [isResponding, toolCalls]);
+ useEffect(() => {
+ if (streamingState === StreamingState.Idle) {
+ abortControllerRef.current = null;
+ }
+ }, [streamingState]);
+
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
- cancelAllToolCalls(); // Also cancel any pending/executing tool calls
}
});
@@ -191,7 +192,7 @@ export const useGeminiStream = (
name: toolName,
args: toolArgs,
};
- scheduleToolCalls([toolCallRequest]);
+ scheduleToolCalls([toolCallRequest], abortSignal);
}
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
}
@@ -330,9 +331,8 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
- cancelAllToolCalls();
},
- [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls],
+ [addItem, pendingHistoryItemRef, setPendingHistoryItem],
);
const handleErrorEvent = useCallback(
@@ -365,6 +365,7 @@ export const useGeminiStream = (
async (
stream: AsyncIterable<GeminiEvent>,
userMessageTimestamp: number,
+ signal: AbortSignal,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
const toolCallRequests: ToolCallRequestInfo[] = [];
@@ -401,7 +402,7 @@ export const useGeminiStream = (
}
}
if (toolCallRequests.length > 0) {
- scheduleToolCalls(toolCallRequests);
+ scheduleToolCalls(toolCallRequests, signal);
}
return StreamProcessingStatus.Completed;
},
@@ -453,6 +454,7 @@ export const useGeminiStream = (
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
+ abortSignal,
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
@@ -476,7 +478,6 @@ export const useGeminiStream = (
);
}
} finally {
- abortControllerRef.current = null; // Always reset
setIsResponding(false);
}
},
diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
index 32a3e77f..ae58ed38 100644
--- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts
+++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
@@ -32,8 +32,8 @@ import {
export type ScheduleFn = (
request: ToolCallRequestInfo | ToolCallRequestInfo[],
+ signal: AbortSignal,
) => void;
-export type CancelFn = (reason?: string) => void;
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
export type TrackedScheduledToolCall = ScheduledToolCall & {
@@ -69,7 +69,7 @@ export function useReactToolScheduler(
setPendingHistoryItem: React.Dispatch<
React.SetStateAction<HistoryItemWithoutId | null>
>,
-): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
+): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
TrackedToolCall[]
>([]);
@@ -172,15 +172,11 @@ export function useReactToolScheduler(
);
const schedule: ScheduleFn = useCallback(
- async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
- scheduler.schedule(request);
- },
- [scheduler],
- );
-
- const cancel: CancelFn = useCallback(
- (reason: string = 'unspecified') => {
- scheduler.cancelAll(reason);
+ async (
+ request: ToolCallRequestInfo | ToolCallRequestInfo[],
+ signal: AbortSignal,
+ ) => {
+ scheduler.schedule(request, signal);
},
[scheduler],
);
@@ -198,7 +194,7 @@ export function useReactToolScheduler(
[],
);
- return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
+ return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
}
/**
diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts
index 1959b031..f5a3529c 100644
--- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts
+++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts
@@ -137,7 +137,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
@@ -290,7 +290,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -337,7 +337,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -374,7 +374,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -410,7 +410,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -451,7 +451,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -507,7 +507,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -579,7 +579,7 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request);
+ schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -634,102 +634,6 @@ describe('useReactToolScheduler', () => {
expect(result.current[0]).toEqual([]);
});
- it.skip('should cancel tool calls before execution (e.g. when status is scheduled)', async () => {
- mockToolRegistry.getTool.mockReturnValue(mockTool);
- (mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
- (mockTool.execute as Mock).mockReturnValue(new Promise(() => {}));
-
- const { result } = renderScheduler();
- const schedule = result.current[1];
- const cancel = result.current[2];
- const request: ToolCallRequestInfo = {
- callId: 'cancelCall',
- name: 'mockTool',
- args: {},
- };
-
- act(() => {
- schedule(request);
- });
- await act(async () => {
- await vi.runAllTimersAsync();
- });
-
- act(() => {
- cancel();
- });
- await act(async () => {
- await vi.runAllTimersAsync();
- });
-
- expect(onComplete).toHaveBeenCalledWith([
- expect.objectContaining({
- status: 'cancelled',
- request,
- response: expect.objectContaining({
- responseParts: expect.arrayContaining([
- expect.objectContaining({
- functionResponse: expect.objectContaining({
- response: expect.objectContaining({
- error:
- '[Operation Cancelled] Reason: User cancelled before execution',
- }),
- }),
- }),
- ]),
- }),
- }),
- ]);
- expect(mockTool.execute).not.toHaveBeenCalled();
- expect(result.current[0]).toEqual([]);
- });
-
- it.skip('should cancel tool calls that are awaiting approval', async () => {
- mockToolRegistry.getTool.mockReturnValue(mockToolRequiresConfirmation);
- const { result } = renderScheduler();
- const schedule = result.current[1];
- const cancelFn = result.current[2];
- const request: ToolCallRequestInfo = {
- callId: 'cancelApprovalCall',
- name: 'mockToolRequiresConfirmation',
- args: {},
- };
-
- act(() => {
- schedule(request);
- });
- await act(async () => {
- await vi.runAllTimersAsync();
- });
-
- act(() => {
- cancelFn();
- });
- await act(async () => {
- await vi.runAllTimersAsync();
- });
-
- expect(onComplete).toHaveBeenCalledWith([
- expect.objectContaining({
- status: 'cancelled',
- request,
- response: expect.objectContaining({
- responseParts: expect.arrayContaining([
- expect.objectContaining({
- functionResponse: expect.objectContaining({
- response: expect.objectContaining({
- error:
- '[Operation Cancelled] Reason: User cancelled during approval',
- }),
- }),
- }),
- ]),
- }),
- }),
- ]);
- expect(result.current[0]).toEqual([]);
- });
-
it('should schedule and execute multiple tool calls', async () => {
const tool1 = {
...mockTool,
@@ -766,7 +670,7 @@ describe('useReactToolScheduler', () => {
];
act(() => {
- schedule(requests);
+ schedule(requests, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
@@ -848,13 +752,13 @@ describe('useReactToolScheduler', () => {
};
act(() => {
- schedule(request1);
+ schedule(request1, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
});
- expect(() => schedule(request2)).toThrow(
+ expect(() => schedule(request2, new AbortController().signal)).toThrow(
'Cannot schedule tool calls while other tool calls are running',
);
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts
index be42bb24..1e8b2b2a 100644
--- a/packages/core/src/core/coreToolScheduler.test.ts
+++ b/packages/core/src/core/coreToolScheduler.test.ts
@@ -4,9 +4,110 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { describe, it, expect } from 'vitest';
-import { convertToFunctionResponse } from './coreToolScheduler.js';
+/* eslint-disable @typescript-eslint/no-explicit-any */
+import { describe, it, expect, vi } from 'vitest';
+import {
+ CoreToolScheduler,
+ ToolCall,
+ ValidatingToolCall,
+} from './coreToolScheduler.js';
+import {
+ BaseTool,
+ ToolCallConfirmationDetails,
+ ToolConfirmationOutcome,
+ ToolResult,
+} from '../index.js';
import { Part, PartListUnion } from '@google/genai';
+import { convertToFunctionResponse } from './coreToolScheduler.js';
+
+class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
+ shouldConfirm = false;
+ executeFn = vi.fn();
+
+ constructor(name = 'mockTool') {
+ super(name, name, 'A mock tool', {});
+ }
+
+ async shouldConfirmExecute(
+ _params: Record<string, unknown>,
+ _abortSignal: AbortSignal,
+ ): Promise<ToolCallConfirmationDetails | false> {
+ if (this.shouldConfirm) {
+ return {
+ type: 'exec',
+ title: 'Confirm Mock Tool',
+ command: 'do_thing',
+ rootCommand: 'do_thing',
+ onConfirm: async () => {},
+ };
+ }
+ return false;
+ }
+
+ async execute(
+ params: Record<string, unknown>,
+ _abortSignal: AbortSignal,
+ ): Promise<ToolResult> {
+ this.executeFn(params);
+ return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
+ }
+}
+
+describe('CoreToolScheduler', () => {
+ it('should cancel a tool call if the signal is aborted before confirmation', async () => {
+ const mockTool = new MockTool();
+ mockTool.shouldConfirm = true;
+ const toolRegistry = {
+ getTool: () => mockTool,
+ getFunctionDeclarations: () => [],
+ tools: new Map(),
+ discovery: {} as any,
+ config: {} as any,
+ registerTool: () => {},
+ getToolByName: () => mockTool,
+ getToolByDisplayName: () => mockTool,
+ getTools: () => [],
+ discoverTools: async () => {},
+ getAllTools: () => [],
+ getToolsByServer: () => [],
+ };
+
+ const onAllToolCallsComplete = vi.fn();
+ const onToolCallsUpdate = vi.fn();
+
+ const scheduler = new CoreToolScheduler({
+ toolRegistry: Promise.resolve(toolRegistry as any),
+ onAllToolCallsComplete,
+ onToolCallsUpdate,
+ });
+
+ const abortController = new AbortController();
+ const request = { callId: '1', name: 'mockTool', args: {} };
+
+ abortController.abort();
+ await scheduler.schedule([request], abortController.signal);
+
+ const _waitingCall = onToolCallsUpdate.mock
+ .calls[1][0][0] as ValidatingToolCall;
+ const confirmationDetails = await mockTool.shouldConfirmExecute(
+ {},
+ abortController.signal,
+ );
+ if (confirmationDetails) {
+ await scheduler.handleConfirmationResponse(
+ '1',
+ confirmationDetails.onConfirm,
+ ToolConfirmationOutcome.ProceedOnce,
+ abortController.signal,
+ );
+ }
+
+ expect(onAllToolCallsComplete).toHaveBeenCalled();
+ const completedCalls = onAllToolCallsComplete.mock
+ .calls[0][0] as ToolCall[];
+ expect(completedCalls[0].status).toBe('cancelled');
+ });
+});
describe('convertToFunctionResponse', () => {
const toolName = 'testTool';
diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts
index 5278ae76..679d86aa 100644
--- a/packages/core/src/core/coreToolScheduler.ts
+++ b/packages/core/src/core/coreToolScheduler.ts
@@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions {
export class CoreToolScheduler {
private toolRegistry: Promise<ToolRegistry>;
private toolCalls: ToolCall[] = [];
- private abortController: AbortController;
private outputUpdateHandler?: OutputUpdateHandler;
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
private onToolCallsUpdate?: ToolCallsUpdateHandler;
@@ -220,7 +219,6 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
- this.abortController = new AbortController();
}
private setStatusInternal(
@@ -379,6 +377,7 @@ export class CoreToolScheduler {
async schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
+ signal: AbortSignal,
): Promise<void> {
if (this.isRunning()) {
throw new Error(
@@ -426,7 +425,7 @@ export class CoreToolScheduler {
} else {
const confirmationDetails = await toolInstance.shouldConfirmExecute(
reqInfo.args,
- this.abortController.signal,
+ signal,
);
if (confirmationDetails) {
@@ -438,6 +437,7 @@ export class CoreToolScheduler {
reqInfo.callId,
originalOnConfirm,
outcome,
+ signal,
),
};
this.setStatusInternal(
@@ -460,7 +460,7 @@ export class CoreToolScheduler {
);
}
}
- this.attemptExecutionOfScheduledCalls();
+ this.attemptExecutionOfScheduledCalls(signal);
this.checkAndNotifyCompletion();
}
@@ -468,6 +468,7 @@ export class CoreToolScheduler {
callId: string,
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
outcome: ToolConfirmationOutcome,
+ signal: AbortSignal,
): Promise<void> {
const toolCall = this.toolCalls.find(
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
@@ -477,7 +478,7 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
- if (outcome === ToolConfirmationOutcome.Cancel) {
+ if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
@@ -497,7 +498,7 @@ export class CoreToolScheduler {
const modifyResults = await editTool.onModify(
waitingToolCall.request.args as unknown as EditToolParams,
- this.abortController.signal,
+ signal,
outcome,
);
@@ -513,10 +514,10 @@ export class CoreToolScheduler {
} else {
this.setStatusInternal(callId, 'scheduled');
}
- this.attemptExecutionOfScheduledCalls();
+ this.attemptExecutionOfScheduledCalls(signal);
}
- private attemptExecutionOfScheduledCalls(): void {
+ private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
const allCallsFinalOrScheduled = this.toolCalls.every(
(call) =>
call.status === 'scheduled' ||
@@ -553,17 +554,13 @@ export class CoreToolScheduler {
: undefined;
scheduledCall.tool
- .execute(
- scheduledCall.request.args,
- this.abortController.signal,
- liveOutputCallback,
- )
+ .execute(scheduledCall.request.args, signal, liveOutputCallback)
.then((toolResult: ToolResult) => {
- if (this.abortController.signal.aborted) {
+ if (signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
- this.abortController.signal.reason || 'Execution aborted.',
+ 'User cancelled tool execution.',
);
return;
}
@@ -613,29 +610,10 @@ export class CoreToolScheduler {
if (this.onAllToolCallsComplete) {
this.onAllToolCallsComplete(completedCalls);
}
- this.abortController = new AbortController();
this.notifyToolCallsUpdate();
}
}
- cancelAll(reason: string = 'User initiated cancellation.'): void {
- if (!this.abortController.signal.aborted) {
- this.abortController.abort(reason);
- }
- this.abortController = new AbortController();
-
- const callsToCancel = [...this.toolCalls];
- callsToCancel.forEach((call) => {
- if (
- call.status !== 'error' &&
- call.status !== 'success' &&
- call.status !== 'cancelled'
- ) {
- this.setStatusInternal(call.request.callId, 'cancelled', reason);
- }
- });
- }
-
private notifyToolCallsUpdate(): void {
if (this.onToolCallsUpdate) {
this.onToolCallsUpdate([...this.toolCalls]);
diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts
index 9ced00a4..caef67b9 100644
--- a/packages/core/src/tools/shell.ts
+++ b/packages/core/src/tools/shell.ts
@@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
};
}
+ if (abortSignal.aborted) {
+ return {
+ llmContent: 'Command was cancelled by user before it could start.',
+ returnDisplay: 'Command cancelled by user.',
+ };
+ }
+
// wrap command to append subprocess pids (via pgrep) to temporary file
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
const tempFilePath = path.join(os.tmpdir(), tempFileName);