summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useGeminiStream.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts115
1 files changed, 79 insertions, 36 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index fcfa1c57..09b14666 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -89,6 +89,7 @@ export const useGeminiStream = (
shellModeActive: boolean,
getPreferredEditor: () => EditorType | undefined,
onAuthError: () => void,
+ performMemoryRefresh: () => Promise<void>,
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
@@ -97,6 +98,7 @@ export const useGeminiStream = (
const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
+ const processedMemoryToolsRef = useRef<Set<string>>(new Set());
const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats();
const gitService = useMemo(() => {
@@ -234,6 +236,7 @@ export const useGeminiStream = (
callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`,
name: toolName,
args: toolArgs,
+ isClientInitiated: true,
};
scheduleToolCalls([toolCallRequest], abortSignal);
}
@@ -566,38 +569,77 @@ export const useGeminiStream = (
* is not already generating a response.
*/
useEffect(() => {
- if (isResponding) {
- return;
- }
+ const run = async () => {
+ if (isResponding) {
+ return;
+ }
- const completedAndReadyToSubmitTools = toolCalls.filter(
- (
- tc: TrackedToolCall,
- ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
- const isTerminalState =
- tc.status === 'success' ||
- tc.status === 'error' ||
- tc.status === 'cancelled';
+ const completedAndReadyToSubmitTools = toolCalls.filter(
+ (
+ tc: TrackedToolCall,
+ ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
+ const isTerminalState =
+ tc.status === 'success' ||
+ tc.status === 'error' ||
+ tc.status === 'cancelled';
- if (isTerminalState) {
- const completedOrCancelledCall = tc as
- | TrackedCompletedToolCall
- | TrackedCancelledToolCall;
- return (
- !completedOrCancelledCall.responseSubmittedToGemini &&
- completedOrCancelledCall.response?.responseParts !== undefined
- );
- }
- return false;
- },
- );
+ if (isTerminalState) {
+ const completedOrCancelledCall = tc as
+ | TrackedCompletedToolCall
+ | TrackedCancelledToolCall;
+ return (
+ !completedOrCancelledCall.responseSubmittedToGemini &&
+ completedOrCancelledCall.response?.responseParts !== undefined
+ );
+ }
+ return false;
+ },
+ );
+
+ // Finalize any client-initiated tools as soon as they are done.
+ const clientTools = completedAndReadyToSubmitTools.filter(
+ (t) => t.request.isClientInitiated,
+ );
+ if (clientTools.length > 0) {
+ markToolsAsSubmitted(clientTools.map((t) => t.request.callId));
+ }
+
+ // Identify new, successful save_memory calls that we haven't processed yet.
+ const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter(
+ (t) =>
+ t.request.name === 'save_memory' &&
+ t.status === 'success' &&
+ !processedMemoryToolsRef.current.has(t.request.callId),
+ );
+
+ if (newSuccessfulMemorySaves.length > 0) {
+ // Perform the refresh only if there are new ones.
+ void performMemoryRefresh();
+ // Mark them as processed so we don't do this again on the next render.
+ newSuccessfulMemorySaves.forEach((t) =>
+ processedMemoryToolsRef.current.add(t.request.callId),
+ );
+ }
+
+ // Only proceed with submitting to Gemini if ALL tools are complete.
+ const allToolsAreComplete =
+ toolCalls.length > 0 &&
+ toolCalls.length === completedAndReadyToSubmitTools.length;
+
+ if (!allToolsAreComplete) {
+ return;
+ }
+
+ const geminiTools = completedAndReadyToSubmitTools.filter(
+ (t) => !t.request.isClientInitiated,
+ );
+
+ if (geminiTools.length === 0) {
+ return;
+ }
- if (
- completedAndReadyToSubmitTools.length > 0 &&
- completedAndReadyToSubmitTools.length === toolCalls.length
- ) {
// If all the tools were cancelled, don't submit a response to Gemini.
- const allToolsCancelled = completedAndReadyToSubmitTools.every(
+ const allToolsCancelled = geminiTools.every(
(tc) => tc.status === 'cancelled',
);
@@ -605,7 +647,7 @@ export const useGeminiStream = (
if (geminiClient) {
// We need to manually add the function responses to the history
// so the model knows the tools were cancelled.
- const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
+ const responsesToAdd = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
for (const response of responsesToAdd) {
@@ -624,18 +666,17 @@ export const useGeminiStream = (
}
}
- const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
+ const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId,
);
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
return;
}
- const responsesToSend: PartListUnion[] =
- completedAndReadyToSubmitTools.map(
- (toolCall) => toolCall.response.responseParts,
- );
- const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
+ const responsesToSend: PartListUnion[] = geminiTools.map(
+ (toolCall) => toolCall.response.responseParts,
+ );
+ const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId,
);
@@ -643,7 +684,8 @@ export const useGeminiStream = (
submitQuery(mergePartListUnions(responsesToSend), {
isContinuation: true,
});
- }
+ };
+ void run();
}, [
toolCalls,
isResponding,
@@ -651,6 +693,7 @@ export const useGeminiStream = (
markToolsAsSubmitted,
addItem,
geminiClient,
+ performMemoryRefresh,
]);
const pendingHistoryItems = [