summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useGeminiStream.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts311
1 files changed, 80 insertions, 231 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index de7980d5..324a4ffa 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -4,34 +4,28 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { useState, useRef, useCallback, useEffect } from 'react';
+import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
ServerGeminiContentEvent as ContentEvent,
- ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
- ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
- ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
Config,
MessageSenderType,
ServerToolCallConfirmationDetails,
- ToolCallConfirmationDetails,
ToolCallResponseInfo,
- ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolResultDisplay,
- partListUnionToString,
+ ToolCallRequestInfo,
} from '@gemini-code/server';
import { type Chat, type PartListUnion, type Part } from '@google/genai';
import {
StreamingState,
- IndividualToolCallDisplay,
ToolCallStatus,
HistoryItemWithoutId,
HistoryItemToolGroup,
@@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
+import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
enum StreamProcessingStatus {
Completed,
@@ -65,7 +60,6 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean,
shellModeActive: boolean,
) => {
- const toolRegistry = config.getToolRegistry();
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null);
@@ -74,6 +68,25 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
+ const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
+ if (tools.length) {
+ addItem(mapToDisplay(tools), Date.now());
+ submitQuery(
+ tools
+ .filter(
+ (t) =>
+ t.status === 'error' ||
+ t.status === 'cancelled' ||
+ t.status === 'success',
+ )
+ .map((t) => t.response.responsePart),
+ );
+ }
+ }, config);
+ const pendingToolCalls = useMemo(
+ () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
+ [toolCalls],
+ );
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
@@ -104,6 +117,7 @@ export const useGeminiStream = (
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
+ cancel();
}
});
@@ -215,157 +229,48 @@ export const useGeminiStream = (
);
};
- const updateConfirmingFunctionStatusUI = (
- callId: string,
- confirmationDetails: ToolCallConfirmationDetails | undefined,
- ) => {
- setPendingHistoryItem((item) =>
- item?.type === 'tool_group'
- ? {
- ...item,
- tools: item.tools.map((tool) =>
- tool.callId === callId
- ? {
- ...tool,
- status: ToolCallStatus.Confirming,
- confirmationDetails,
- }
- : tool,
- ),
- }
- : item,
- );
- };
-
- const wireConfirmationSubmission = (
- confirmationDetails: ServerToolCallConfirmationDetails,
- ): ToolCallConfirmationDetails => {
- const originalConfirmationDetails = confirmationDetails.details;
- const request = confirmationDetails.request;
- const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
- originalConfirmationDetails.onConfirm(outcome);
- if (pendingHistoryItemRef?.current?.type === 'tool_group') {
- setPendingHistoryItem((item) =>
- item?.type === 'tool_group'
- ? {
- ...item,
- tools: item.tools.map((tool) =>
- tool.callId === request.callId
- ? {
- ...tool,
- confirmationDetails: undefined,
- status: ToolCallStatus.Executing,
- }
- : tool,
- ),
- }
- : item,
- );
- refreshStatic();
- }
-
- if (outcome === ToolConfirmationOutcome.Cancel) {
- declineToolExecution(
- 'User rejected function call.',
- ToolCallStatus.Error,
- request,
- originalConfirmationDetails,
- );
- } else {
- const tool = toolRegistry.getTool(request.name);
- if (!tool) {
- throw new Error(
- `Tool "${request.name}" not found or is not registered.`,
- );
- }
- try {
- abortControllerRef.current = new AbortController();
- const result = await tool.execute(
- request.args,
- abortControllerRef.current.signal,
- );
- if (abortControllerRef.current.signal.aborted) {
- declineToolExecution(
- partListUnionToString(result.llmContent),
- ToolCallStatus.Canceled,
- request,
- originalConfirmationDetails,
- );
- return;
- }
-
- const functionResponse: Part = {
- functionResponse: {
- name: request.name,
- id: request.callId,
- response: { output: result.llmContent },
- },
- };
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay: result.returnDisplay,
- error: undefined,
- };
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, Date.now());
- setPendingHistoryItem(null);
- }
- setIsResponding(false);
- await submitQuery(functionResponse); // Recursive call
- } finally {
- if (streamingState !== StreamingState.WaitingForConfirmation) {
- abortControllerRef.current = null;
- }
- }
- }
- };
-
- // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
- // or could be a standalone helper if more params are passed.
- function declineToolExecution(
- declineMessage: string,
- status: ToolCallStatus,
- request: ServerToolCallConfirmationDetails['request'],
- originalDetails: ServerToolCallConfirmationDetails['details'],
- ) {
- let resultDisplay: ToolResultDisplay | undefined;
- if ('fileDiff' in originalDetails) {
- resultDisplay = {
- fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
- fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
- };
- } else {
- resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
- }
- const functionResponse: Part = {
- functionResponse: {
- id: request.callId,
- name: request.name,
- response: { error: declineMessage },
- },
- };
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay,
- error: new Error(declineMessage),
+ // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
+ // or could be a standalone helper if more params are passed.
+ // TODO: handle file diff result display stuff
+ function _declineToolExecution(
+ declineMessage: string,
+ status: ToolCallStatus,
+ request: ServerToolCallConfirmationDetails['request'],
+ originalDetails: ServerToolCallConfirmationDetails['details'],
+ ) {
+ let resultDisplay: ToolResultDisplay | undefined;
+ if ('fileDiff' in originalDetails) {
+ resultDisplay = {
+ fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
+ fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
- const history = chatSessionRef.current?.getHistory();
- if (history) {
- history.push({ role: 'model', parts: [functionResponse] });
- }
- updateFunctionResponseUI(responseInfo, status);
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, Date.now());
- setPendingHistoryItem(null);
- }
- setIsResponding(false);
+ } else {
+ resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
-
- return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
- };
+ const functionResponse: Part = {
+ functionResponse: {
+ id: request.callId,
+ name: request.name,
+ response: { error: declineMessage },
+ },
+ };
+ const responseInfo: ToolCallResponseInfo = {
+ callId: request.callId,
+ responsePart: functionResponse,
+ resultDisplay,
+ error: new Error(declineMessage),
+ };
+ const history = chatSessionRef.current?.getHistory();
+ if (history) {
+ history.push({ role: 'model', parts: [functionResponse] });
+ }
+ updateFunctionResponseUI(responseInfo, status);
+ if (pendingHistoryItemRef.current) {
+ addItem(pendingHistoryItemRef.current, Date.now());
+ setPendingHistoryItem(null);
+ }
+ setIsResponding(false);
+ }
// --- Stream Event Handlers ---
const handleContentEvent = (
@@ -419,62 +324,6 @@ export const useGeminiStream = (
return newGeminiMessageBuffer;
};
- const handleToolCallRequestEvent = (
- eventValue: ToolCallRequestEvent['value'],
- userMessageTimestamp: number,
- ) => {
- const { callId, name, args } = eventValue;
- const cliTool = toolRegistry.getTool(name);
- if (!cliTool) {
- console.error(`CLI Tool "${name}" not found!`);
- return; // Skip this event if tool is not found
- }
- if (pendingHistoryItemRef.current?.type !== 'tool_group') {
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, userMessageTimestamp);
- }
- setPendingHistoryItem({ type: 'tool_group', tools: [] });
- }
- let description: string;
- try {
- description = cliTool.getDescription(args);
- } catch (e) {
- description = `Error: Unable to get description: ${getErrorMessage(e)}`;
- }
- const toolCallDisplay: IndividualToolCallDisplay = {
- callId,
- name: cliTool.displayName,
- description,
- status: ToolCallStatus.Pending,
- resultDisplay: undefined,
- confirmationDetails: undefined,
- };
- setPendingHistoryItem((pending) =>
- pending?.type === 'tool_group'
- ? { ...pending, tools: [...pending.tools, toolCallDisplay] }
- : null,
- );
- };
-
- const handleToolCallResponseEvent = (
- eventValue: ToolCallResponseEvent['value'],
- ) => {
- const status = eventValue.error
- ? ToolCallStatus.Error
- : ToolCallStatus.Success;
- updateFunctionResponseUI(eventValue, status);
- };
-
- const handleToolCallConfirmationEvent = (
- eventValue: ToolCallConfirmationEvent['value'],
- ) => {
- const confirmationDetails = wireConfirmationSubmission(eventValue);
- updateConfirmingFunctionStatusUI(
- eventValue.request.callId,
- confirmationDetails,
- );
- };
-
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
if (pendingHistoryItemRef.current) {
if (pendingHistoryItemRef.current.type === 'tool_group') {
@@ -500,6 +349,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
+ cancel();
};
const handleErrorEvent = (
@@ -521,7 +371,7 @@ export const useGeminiStream = (
userMessageTimestamp: number,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
-
+ const toolCallRequests: ToolCallRequestInfo[] = [];
for await (const event of stream) {
if (event.type === ServerGeminiEventType.Content) {
geminiMessageBuffer = handleContentEvent(
@@ -530,12 +380,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
- handleToolCallRequestEvent(event.value, userMessageTimestamp);
- } else if (event.type === ServerGeminiEventType.ToolCallResponse) {
- handleToolCallResponseEvent(event.value);
- } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
- handleToolCallConfirmationEvent(event.value);
- return StreamProcessingStatus.PausedForConfirmation;
+ toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled;
@@ -544,9 +389,18 @@ export const useGeminiStream = (
return StreamProcessingStatus.Error;
}
}
+ schedule(toolCallRequests);
return StreamProcessingStatus.Completed;
};
+ const streamingState: StreamingState = isResponding
+ ? StreamingState.Responding
+ : pendingToolCalls?.tools.some(
+ (t) => t.status === ToolCallStatus.Confirming,
+ )
+ ? StreamingState.WaitingForConfirmation
+ : StreamingState.Idle;
+
const submitQuery = useCallback(
async (query: PartListUnion) => {
if (isResponding) return;
@@ -625,20 +479,15 @@ export const useGeminiStream = (
],
);
- const streamingState: StreamingState = isResponding
- ? StreamingState.Responding
- : pendingConfirmations(pendingHistoryItemRef.current)
- ? StreamingState.WaitingForConfirmation
- : StreamingState.Idle;
+ const pendingHistoryItems = [
+ pendingHistoryItemRef.current,
+ pendingToolCalls,
+ ].filter((i) => i !== undefined && i !== null);
return {
streamingState,
submitQuery,
initError,
- pendingHistoryItem: pendingHistoryItemRef.current,
+ pendingHistoryItems,
};
};
-
-const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
- item?.type === 'tool_group' &&
- item.tools.some((t) => t.status === ToolCallStatus.Confirming);