diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 115 |
1 files changed, 79 insertions, 36 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index fcfa1c57..09b14666 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,6 +89,7 @@ export const useGeminiStream = ( shellModeActive: boolean, getPreferredEditor: () => EditorType | undefined, onAuthError: () => void, + performMemoryRefresh: () => Promise<void>, ) => { const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); @@ -97,6 +98,7 @@ export const useGeminiStream = ( const [thought, setThought] = useState<ThoughtSummary | null>(null); const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); + const processedMemoryToolsRef = useRef<Set<string>>(new Set()); const logger = useLogger(); const { startNewTurn, addUsage } = useSessionStats(); const gitService = useMemo(() => { @@ -234,6 +236,7 @@ export const useGeminiStream = ( callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`, name: toolName, args: toolArgs, + isClientInitiated: true, }; scheduleToolCalls([toolCallRequest], abortSignal); } @@ -566,38 +569,77 @@ export const useGeminiStream = ( * is not already generating a response. */ useEffect(() => { - if (isResponding) { - return; - } + const run = async () => { + if (isResponding) { + return; + } - const completedAndReadyToSubmitTools = toolCalls.filter( - ( - tc: TrackedToolCall, - ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { - const isTerminalState = - tc.status === 'success' || - tc.status === 'error' || - tc.status === 'cancelled'; + const completedAndReadyToSubmitTools = toolCalls.filter( + ( + tc: TrackedToolCall, + ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { + const isTerminalState = + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled'; - if (isTerminalState) { - const completedOrCancelledCall = tc as - | TrackedCompletedToolCall - | TrackedCancelledToolCall; - return ( - !completedOrCancelledCall.responseSubmittedToGemini && - completedOrCancelledCall.response?.responseParts !== undefined - ); - } - return false; - }, - ); + if (isTerminalState) { + const completedOrCancelledCall = tc as + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + return ( + !completedOrCancelledCall.responseSubmittedToGemini && + completedOrCancelledCall.response?.responseParts !== undefined + ); + } + return false; + }, + ); + + // Finalize any client-initiated tools as soon as they are done. + const clientTools = completedAndReadyToSubmitTools.filter( + (t) => t.request.isClientInitiated, + ); + if (clientTools.length > 0) { + markToolsAsSubmitted(clientTools.map((t) => t.request.callId)); + } + + // Identify new, successful save_memory calls that we haven't processed yet. + const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter( + (t) => + t.request.name === 'save_memory' && + t.status === 'success' && + !processedMemoryToolsRef.current.has(t.request.callId), + ); + + if (newSuccessfulMemorySaves.length > 0) { + // Perform the refresh only if there are new ones. + void performMemoryRefresh(); + // Mark them as processed so we don't do this again on the next render. + newSuccessfulMemorySaves.forEach((t) => + processedMemoryToolsRef.current.add(t.request.callId), + ); + } + + // Only proceed with submitting to Gemini if ALL tools are complete. + const allToolsAreComplete = + toolCalls.length > 0 && + toolCalls.length === completedAndReadyToSubmitTools.length; + + if (!allToolsAreComplete) { + return; + } + + const geminiTools = completedAndReadyToSubmitTools.filter( + (t) => !t.request.isClientInitiated, + ); + + if (geminiTools.length === 0) { + return; + } - if ( - completedAndReadyToSubmitTools.length > 0 && - completedAndReadyToSubmitTools.length === toolCalls.length - ) { // If all the tools were cancelled, don't submit a response to Gemini. - const allToolsCancelled = completedAndReadyToSubmitTools.every( + const allToolsCancelled = geminiTools.every( (tc) => tc.status === 'cancelled', ); @@ -605,7 +647,7 @@ export const useGeminiStream = ( if (geminiClient) { // We need to manually add the function responses to the history // so the model knows the tools were cancelled. - const responsesToAdd = completedAndReadyToSubmitTools.flatMap( + const responsesToAdd = geminiTools.flatMap( (toolCall) => toolCall.response.responseParts, ); for (const response of responsesToAdd) { @@ -624,18 +666,17 @@ export const useGeminiStream = ( } } - const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + const callIdsToMarkAsSubmitted = geminiTools.map( (toolCall) => toolCall.request.callId, ); markToolsAsSubmitted(callIdsToMarkAsSubmitted); return; } - const responsesToSend: PartListUnion[] = - completedAndReadyToSubmitTools.map( - (toolCall) => toolCall.response.responseParts, - ); - const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + const responsesToSend: PartListUnion[] = geminiTools.map( + (toolCall) => toolCall.response.responseParts, + ); + const callIdsToMarkAsSubmitted = geminiTools.map( (toolCall) => toolCall.request.callId, ); @@ -643,7 +684,8 @@ export const useGeminiStream = ( submitQuery(mergePartListUnions(responsesToSend), { isContinuation: true, }); - } + }; + void run(); }, [ toolCalls, isResponding, @@ -651,6 +693,7 @@ export const useGeminiStream = ( markToolsAsSubmitted, addItem, geminiClient, + performMemoryRefresh, ]); const pendingHistoryItems = [ |
