summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useGeminiStream.ts
diff options
context:
space:
mode:
authorTaylor Mullen <[email protected]>2025-05-09 23:29:02 -0700
committerN. Taylor Mullen <[email protected]>2025-05-10 00:21:09 -0700
commit6b518dc9e4c601c0108768932dc1450c036075fd (patch)
treeaac19953db5a8cc2d1a68f46b51f1e5bef570e0e /packages/cli/src/ui/hooks/useGeminiStream.ts
parent090198a7d644f24c617bd35db6a287b930729280 (diff)
Enable tools to cancel active execution.
- Plumbed abort signals through to tools - Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled. Fixes https://b.corp.google.com/issues/416829935
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts115
1 files changed, 82 insertions, 33 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 3f8cee40..e86ae0b9 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -89,7 +89,7 @@ export const useGeminiStream = (
}, [config, addItem]);
useInput((_input, key) => {
- if (streamingState === StreamingState.Responding && key.escape) {
+ if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
}
});
@@ -104,6 +104,9 @@ export const useGeminiStream = (
setShowHelp(false);
+ abortControllerRef.current ??= new AbortController();
+ const signal = abortControllerRef.current.signal;
+
if (typeof query === 'string') {
const trimmedQuery = query.trim();
setDebugMessage(`User query: '${trimmedQuery}'`);
@@ -120,6 +123,7 @@ export const useGeminiStream = (
addItem,
setDebugMessage,
messageId: userMessageTimestamp,
+ signal,
});
if (!atCommandResult.shouldProceed) return;
queryToSendToGemini = atCommandResult.processedQuery;
@@ -165,9 +169,6 @@ export const useGeminiStream = (
const chat = chatSessionRef.current;
try {
- abortControllerRef.current = new AbortController();
- const signal = abortControllerRef.current.signal;
-
const stream = client.sendMessageStream(
chat,
queryToSendToGemini,
@@ -294,7 +295,26 @@ export const useGeminiStream = (
} else if (event.type === ServerGeminiEventType.UserCancelled) {
// Flush out existing pending history item.
if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, userMessageTimestamp);
+ // If the pending item is a tool_group, update statuses to Canceled
+ if (pendingHistoryItemRef.current.type === 'tool_group') {
+ const updatedTools = pendingHistoryItemRef.current.tools.map(
+ (tool) => {
+ if (
+ tool.status === ToolCallStatus.Pending ||
+ tool.status === ToolCallStatus.Confirming ||
+ tool.status === ToolCallStatus.Executing
+ ) {
+ return { ...tool, status: ToolCallStatus.Canceled };
+ }
+ return tool;
+ },
+ );
+ const pendingHistoryItem = pendingHistoryItemRef.current;
+ pendingHistoryItem.tools = updatedTools;
+ addItem(pendingHistoryItem, userMessageTimestamp);
+ } else {
+ addItem(pendingHistoryItemRef.current, userMessageTimestamp);
+ }
setPendingHistoryItem(null);
}
addItem(
@@ -412,6 +432,59 @@ export const useGeminiStream = (
}
if (outcome === ToolConfirmationOutcome.Cancel) {
+ declineToolExecution(
+ 'User rejected function call.',
+ ToolCallStatus.Error,
+ );
+ } else {
+ const tool = toolRegistry.getTool(request.name);
+ if (!tool) {
+ throw new Error(
+ `Tool "${request.name}" not found or is not registered.`,
+ );
+ }
+
+ try {
+ abortControllerRef.current = new AbortController();
+ const result = await tool.execute(
+ request.args,
+ abortControllerRef.current.signal,
+ );
+
+ if (abortControllerRef.current.signal.aborted) {
+ declineToolExecution(
+ result.llmContent,
+ ToolCallStatus.Canceled,
+ );
+ return;
+ }
+
+ const functionResponse: Part = {
+ functionResponse: {
+ name: request.name,
+ id: request.callId,
+ response: { output: result.llmContent },
+ },
+ };
+
+ const responseInfo: ToolCallResponseInfo = {
+ callId: request.callId,
+ responsePart: functionResponse,
+ resultDisplay: result.returnDisplay,
+ error: undefined,
+ };
+ updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
+ setStreamingState(StreamingState.Idle);
+ await submitQuery(functionResponse);
+ } finally {
+ abortControllerRef.current = null;
+ }
+ }
+
+ function declineToolExecution(
+ declineMessage: string,
+ status: ToolCallStatus,
+ ) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = {
@@ -426,43 +499,19 @@ export const useGeminiStream = (
functionResponse: {
id: request.callId,
name: request.name,
- response: { error: 'User rejected function call.' },
+ response: { error: declineMessage },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
- error: new Error('User rejected function call.'),
- };
- // Update UI to show cancellation/error
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
- setStreamingState(StreamingState.Idle);
- } else {
- const tool = toolRegistry.getTool(request.name);
- if (!tool) {
- throw new Error(
- `Tool "${request.name}" not found or is not registered.`,
- );
- }
- const result = await tool.execute(request.args);
- const functionResponse: Part = {
- functionResponse: {
- name: request.name,
- id: request.callId,
- response: { output: result.llmContent },
- },
+ error: new Error(declineMessage),
};
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay: result.returnDisplay,
- error: undefined,
- };
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
+ // Update UI to show cancellation/error
+ updateFunctionResponseUI(responseInfo, status);
setStreamingState(StreamingState.Idle);
- await submitQuery(functionResponse);
}
};