summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/ui/App.tsx17
-rw-r--r--packages/cli/src/ui/components/messages/ToolGroupMessage.tsx12
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts311
3 files changed, 97 insertions, 243 deletions
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx
index 42613530..74c1ea5d 100644
--- a/packages/cli/src/ui/App.tsx
+++ b/packages/cli/src/ui/App.tsx
@@ -134,7 +134,7 @@ export const App = ({
cliVersion,
);
- const { streamingState, submitQuery, initError, pendingHistoryItem } =
+ const { streamingState, submitQuery, initError, pendingHistoryItems } =
useGeminiStream(
addItem,
refreshStatic,
@@ -209,7 +209,7 @@ export const App = ({
}, [terminalHeight, footerHeight]);
useEffect(() => {
- if (!pendingHistoryItem) {
+ if (!pendingHistoryItems.length) {
return;
}
@@ -223,7 +223,7 @@ export const App = ({
if (pendingItemDimensions.height > availableTerminalHeight) {
setStaticNeedsRefresh(true);
}
- }, [pendingHistoryItem, availableTerminalHeight, streamingState]);
+ }, [pendingHistoryItems.length, availableTerminalHeight, streamingState]);
useEffect(() => {
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
@@ -264,17 +264,18 @@ export const App = ({
>
{(item) => item}
</Static>
- {pendingHistoryItem && (
- <Box ref={pendingHistoryItemRef}>
+ <Box ref={pendingHistoryItemRef}>
+ {pendingHistoryItems.map((item, i) => (
<HistoryItemDisplay
+ key={i}
availableTerminalHeight={availableTerminalHeight}
// TODO(taehykim): It seems like references to ids aren't necessary in
// HistoryItemDisplay. Refactor later. Use a fake id for now.
- item={{ ...pendingHistoryItem, id: 0 }}
+ item={{ ...item, id: 0 }}
isPending={true}
/>
- </Box>
- )}
+ ))}
+ </Box>
{showHelp && <Help commands={slashCommands} />}
<Box flexDirection="column" ref={mainControlsRef}>
diff --git a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx
index d0ad1c5f..4b2c7dfe 100644
--- a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx
+++ b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import React from 'react';
+import React, { useMemo } from 'react';
import { Box } from 'ink';
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
import { ToolMessage } from './ToolMessage.js';
@@ -19,7 +19,6 @@ interface ToolGroupMessageProps {
// Main component renders the border and maps the tools using ToolMessage
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
- groupId,
toolCalls,
availableTerminalHeight,
}) => {
@@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
+ const toolAwaitingApproval = useMemo(
+ () => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming),
+ [toolCalls],
+ );
+
return (
<Box
- key={groupId}
flexDirection="column"
borderStyle="round"
/*
@@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
marginBottom={1}
>
{toolCalls.map((tool) => (
- <Box key={groupId + '-' + tool.callId} flexDirection="column">
+ <Box key={tool.callId} flexDirection="column">
<ToolMessage
key={tool.callId}
callId={tool.callId}
@@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
availableTerminalHeight={availableTerminalHeight - staticHeight}
/>
{tool.status === ToolCallStatus.Confirming &&
+ tool.callId === toolAwaitingApproval?.callId &&
tool.confirmationDetails && (
<ToolConfirmationMessage
confirmationDetails={tool.confirmationDetails}
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index de7980d5..324a4ffa 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -4,34 +4,28 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { useState, useRef, useCallback, useEffect } from 'react';
+import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
ServerGeminiContentEvent as ContentEvent,
- ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
- ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
- ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
Config,
MessageSenderType,
ServerToolCallConfirmationDetails,
- ToolCallConfirmationDetails,
ToolCallResponseInfo,
- ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolResultDisplay,
- partListUnionToString,
+ ToolCallRequestInfo,
} from '@gemini-code/server';
import { type Chat, type PartListUnion, type Part } from '@google/genai';
import {
StreamingState,
- IndividualToolCallDisplay,
ToolCallStatus,
HistoryItemWithoutId,
HistoryItemToolGroup,
@@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
+import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
enum StreamProcessingStatus {
Completed,
@@ -65,7 +60,6 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean,
shellModeActive: boolean,
) => {
- const toolRegistry = config.getToolRegistry();
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null);
@@ -74,6 +68,25 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
+ const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
+ if (tools.length) {
+ addItem(mapToDisplay(tools), Date.now());
+ submitQuery(
+ tools
+ .filter(
+ (t) =>
+ t.status === 'error' ||
+ t.status === 'cancelled' ||
+ t.status === 'success',
+ )
+ .map((t) => t.response.responsePart),
+ );
+ }
+ }, config);
+ const pendingToolCalls = useMemo(
+ () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
+ [toolCalls],
+ );
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
@@ -104,6 +117,7 @@ export const useGeminiStream = (
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
+ cancel();
}
});
@@ -215,157 +229,48 @@ export const useGeminiStream = (
);
};
- const updateConfirmingFunctionStatusUI = (
- callId: string,
- confirmationDetails: ToolCallConfirmationDetails | undefined,
- ) => {
- setPendingHistoryItem((item) =>
- item?.type === 'tool_group'
- ? {
- ...item,
- tools: item.tools.map((tool) =>
- tool.callId === callId
- ? {
- ...tool,
- status: ToolCallStatus.Confirming,
- confirmationDetails,
- }
- : tool,
- ),
- }
- : item,
- );
- };
-
- const wireConfirmationSubmission = (
- confirmationDetails: ServerToolCallConfirmationDetails,
- ): ToolCallConfirmationDetails => {
- const originalConfirmationDetails = confirmationDetails.details;
- const request = confirmationDetails.request;
- const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
- originalConfirmationDetails.onConfirm(outcome);
- if (pendingHistoryItemRef?.current?.type === 'tool_group') {
- setPendingHistoryItem((item) =>
- item?.type === 'tool_group'
- ? {
- ...item,
- tools: item.tools.map((tool) =>
- tool.callId === request.callId
- ? {
- ...tool,
- confirmationDetails: undefined,
- status: ToolCallStatus.Executing,
- }
- : tool,
- ),
- }
- : item,
- );
- refreshStatic();
- }
-
- if (outcome === ToolConfirmationOutcome.Cancel) {
- declineToolExecution(
- 'User rejected function call.',
- ToolCallStatus.Error,
- request,
- originalConfirmationDetails,
- );
- } else {
- const tool = toolRegistry.getTool(request.name);
- if (!tool) {
- throw new Error(
- `Tool "${request.name}" not found or is not registered.`,
- );
- }
- try {
- abortControllerRef.current = new AbortController();
- const result = await tool.execute(
- request.args,
- abortControllerRef.current.signal,
- );
- if (abortControllerRef.current.signal.aborted) {
- declineToolExecution(
- partListUnionToString(result.llmContent),
- ToolCallStatus.Canceled,
- request,
- originalConfirmationDetails,
- );
- return;
- }
-
- const functionResponse: Part = {
- functionResponse: {
- name: request.name,
- id: request.callId,
- response: { output: result.llmContent },
- },
- };
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay: result.returnDisplay,
- error: undefined,
- };
- updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, Date.now());
- setPendingHistoryItem(null);
- }
- setIsResponding(false);
- await submitQuery(functionResponse); // Recursive call
- } finally {
- if (streamingState !== StreamingState.WaitingForConfirmation) {
- abortControllerRef.current = null;
- }
- }
- }
- };
-
- // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
- // or could be a standalone helper if more params are passed.
- function declineToolExecution(
- declineMessage: string,
- status: ToolCallStatus,
- request: ServerToolCallConfirmationDetails['request'],
- originalDetails: ServerToolCallConfirmationDetails['details'],
- ) {
- let resultDisplay: ToolResultDisplay | undefined;
- if ('fileDiff' in originalDetails) {
- resultDisplay = {
- fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
- fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
- };
- } else {
- resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
- }
- const functionResponse: Part = {
- functionResponse: {
- id: request.callId,
- name: request.name,
- response: { error: declineMessage },
- },
- };
- const responseInfo: ToolCallResponseInfo = {
- callId: request.callId,
- responsePart: functionResponse,
- resultDisplay,
- error: new Error(declineMessage),
+ // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
+ // or could be a standalone helper if more params are passed.
+ // TODO: handle file diff result display stuff
+ function _declineToolExecution(
+ declineMessage: string,
+ status: ToolCallStatus,
+ request: ServerToolCallConfirmationDetails['request'],
+ originalDetails: ServerToolCallConfirmationDetails['details'],
+ ) {
+ let resultDisplay: ToolResultDisplay | undefined;
+ if ('fileDiff' in originalDetails) {
+ resultDisplay = {
+ fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
+ fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
- const history = chatSessionRef.current?.getHistory();
- if (history) {
- history.push({ role: 'model', parts: [functionResponse] });
- }
- updateFunctionResponseUI(responseInfo, status);
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, Date.now());
- setPendingHistoryItem(null);
- }
- setIsResponding(false);
+ } else {
+ resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
-
- return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
- };
+ const functionResponse: Part = {
+ functionResponse: {
+ id: request.callId,
+ name: request.name,
+ response: { error: declineMessage },
+ },
+ };
+ const responseInfo: ToolCallResponseInfo = {
+ callId: request.callId,
+ responsePart: functionResponse,
+ resultDisplay,
+ error: new Error(declineMessage),
+ };
+ const history = chatSessionRef.current?.getHistory();
+ if (history) {
+ history.push({ role: 'model', parts: [functionResponse] });
+ }
+ updateFunctionResponseUI(responseInfo, status);
+ if (pendingHistoryItemRef.current) {
+ addItem(pendingHistoryItemRef.current, Date.now());
+ setPendingHistoryItem(null);
+ }
+ setIsResponding(false);
+ }
// --- Stream Event Handlers ---
const handleContentEvent = (
@@ -419,62 +324,6 @@ export const useGeminiStream = (
return newGeminiMessageBuffer;
};
- const handleToolCallRequestEvent = (
- eventValue: ToolCallRequestEvent['value'],
- userMessageTimestamp: number,
- ) => {
- const { callId, name, args } = eventValue;
- const cliTool = toolRegistry.getTool(name);
- if (!cliTool) {
- console.error(`CLI Tool "${name}" not found!`);
- return; // Skip this event if tool is not found
- }
- if (pendingHistoryItemRef.current?.type !== 'tool_group') {
- if (pendingHistoryItemRef.current) {
- addItem(pendingHistoryItemRef.current, userMessageTimestamp);
- }
- setPendingHistoryItem({ type: 'tool_group', tools: [] });
- }
- let description: string;
- try {
- description = cliTool.getDescription(args);
- } catch (e) {
- description = `Error: Unable to get description: ${getErrorMessage(e)}`;
- }
- const toolCallDisplay: IndividualToolCallDisplay = {
- callId,
- name: cliTool.displayName,
- description,
- status: ToolCallStatus.Pending,
- resultDisplay: undefined,
- confirmationDetails: undefined,
- };
- setPendingHistoryItem((pending) =>
- pending?.type === 'tool_group'
- ? { ...pending, tools: [...pending.tools, toolCallDisplay] }
- : null,
- );
- };
-
- const handleToolCallResponseEvent = (
- eventValue: ToolCallResponseEvent['value'],
- ) => {
- const status = eventValue.error
- ? ToolCallStatus.Error
- : ToolCallStatus.Success;
- updateFunctionResponseUI(eventValue, status);
- };
-
- const handleToolCallConfirmationEvent = (
- eventValue: ToolCallConfirmationEvent['value'],
- ) => {
- const confirmationDetails = wireConfirmationSubmission(eventValue);
- updateConfirmingFunctionStatusUI(
- eventValue.request.callId,
- confirmationDetails,
- );
- };
-
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
if (pendingHistoryItemRef.current) {
if (pendingHistoryItemRef.current.type === 'tool_group') {
@@ -500,6 +349,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
+ cancel();
};
const handleErrorEvent = (
@@ -521,7 +371,7 @@ export const useGeminiStream = (
userMessageTimestamp: number,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
-
+ const toolCallRequests: ToolCallRequestInfo[] = [];
for await (const event of stream) {
if (event.type === ServerGeminiEventType.Content) {
geminiMessageBuffer = handleContentEvent(
@@ -530,12 +380,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
- handleToolCallRequestEvent(event.value, userMessageTimestamp);
- } else if (event.type === ServerGeminiEventType.ToolCallResponse) {
- handleToolCallResponseEvent(event.value);
- } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
- handleToolCallConfirmationEvent(event.value);
- return StreamProcessingStatus.PausedForConfirmation;
+ toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled;
@@ -544,9 +389,18 @@ export const useGeminiStream = (
return StreamProcessingStatus.Error;
}
}
+ schedule(toolCallRequests);
return StreamProcessingStatus.Completed;
};
+ const streamingState: StreamingState = isResponding
+ ? StreamingState.Responding
+ : pendingToolCalls?.tools.some(
+ (t) => t.status === ToolCallStatus.Confirming,
+ )
+ ? StreamingState.WaitingForConfirmation
+ : StreamingState.Idle;
+
const submitQuery = useCallback(
async (query: PartListUnion) => {
if (isResponding) return;
@@ -625,20 +479,15 @@ export const useGeminiStream = (
],
);
- const streamingState: StreamingState = isResponding
- ? StreamingState.Responding
- : pendingConfirmations(pendingHistoryItemRef.current)
- ? StreamingState.WaitingForConfirmation
- : StreamingState.Idle;
+ const pendingHistoryItems = [
+ pendingHistoryItemRef.current,
+ pendingToolCalls,
+ ].filter((i) => i !== undefined && i !== null);
return {
streamingState,
submitQuery,
initError,
- pendingHistoryItem: pendingHistoryItemRef.current,
+ pendingHistoryItems,
};
};
-
-const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
- item?.type === 'tool_group' &&
- item.tools.some((t) => t.status === ToolCallStatus.Confirming);