diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/ui/hooks/atCommandProcessor.ts | 4 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 115 |
2 files changed, 85 insertions, 34 deletions
diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 5ffa5383..a13a7d36 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -26,6 +26,7 @@ interface HandleAtCommandParams { addItem: UseHistoryManagerReturn['addItem']; setDebugMessage: React.Dispatch<React.SetStateAction<string>>; messageId: number; + signal: AbortSignal; } interface HandleAtCommandResult { @@ -90,6 +91,7 @@ export async function handleAtCommand({ addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }: HandleAtCommandParams): Promise<HandleAtCommandResult> { const trimmedQuery = query.trim(); const parsedCommand = parseAtCommand(trimmedQuery); @@ -163,7 +165,7 @@ export async function handleAtCommand({ let toolCallDisplay: IndividualToolCallDisplay; try { - const result = await readManyFilesTool.execute(toolArgs); + const result = await readManyFilesTool.execute(toolArgs, signal); const fileContent = result.llmContent || ''; toolCallDisplay = { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 3f8cee40..e86ae0b9 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,7 +89,7 @@ export const useGeminiStream = ( }, [config, addItem]); useInput((_input, key) => { - if (streamingState === StreamingState.Responding && key.escape) { + if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); } }); @@ -104,6 +104,9 @@ export const useGeminiStream = ( setShowHelp(false); + abortControllerRef.current ??= new AbortController(); + const signal = abortControllerRef.current.signal; + if (typeof query === 'string') { const trimmedQuery = query.trim(); setDebugMessage(`User query: '${trimmedQuery}'`); @@ -120,6 +123,7 @@ export const useGeminiStream = ( addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }); if (!atCommandResult.shouldProceed) return; queryToSendToGemini = atCommandResult.processedQuery; @@ -165,9 +169,6 @@ export const useGeminiStream = ( const chat = chatSessionRef.current; try { - abortControllerRef.current = new AbortController(); - const signal = abortControllerRef.current.signal; - const stream = client.sendMessageStream( chat, queryToSendToGemini, @@ -294,7 +295,26 @@ export const useGeminiStream = ( } else if (event.type === ServerGeminiEventType.UserCancelled) { // Flush out existing pending history item. if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); + // If the pending item is a tool_group, update statuses to Canceled + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map( + (tool) => { + if ( + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ) { + return { ...tool, status: ToolCallStatus.Canceled }; + } + return tool; + }, + ); + const pendingHistoryItem = pendingHistoryItemRef.current; + pendingHistoryItem.tools = updatedTools; + addItem(pendingHistoryItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } setPendingHistoryItem(null); } addItem( @@ -412,6 +432,59 @@ export const useGeminiStream = ( } if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + + try { + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + ); + return; + } + + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + setStreamingState(StreamingState.Idle); + await submitQuery(functionResponse); + } finally { + abortControllerRef.current = null; + } + } + + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + ) { let resultDisplay: ToolResultDisplay | undefined; if ('fileDiff' in originalConfirmationDetails) { resultDisplay = { @@ -426,43 +499,19 @@ export const useGeminiStream = ( functionResponse: { id: request.callId, name: request.name, - response: { error: 'User rejected function call.' }, + response: { error: declineMessage }, }, }; const responseInfo: ToolCallResponseInfo = { callId: request.callId, responsePart: functionResponse, resultDisplay, - error: new Error('User rejected function call.'), - }; - // Update UI to show cancellation/error - updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); - setStreamingState(StreamingState.Idle); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - const result = await tool.execute(request.args); - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, + error: new Error(declineMessage), }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + // Update UI to show cancellation/error + updateFunctionResponseUI(responseInfo, status); setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); } }; |
