diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 43 |
1 files changed, 37 insertions, 6 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index d32c9ffa..b82b0cb2 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -53,6 +53,7 @@ import { TrackedCompletedToolCall, TrackedCancelledToolCall, } from './useReactToolScheduler.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { const resultParts: PartListUnion = []; @@ -101,6 +102,7 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); const processedMemoryToolsRef = useRef<Set<string>>(new Set()); + const { startNewPrompt, getPromptCount } = useSessionStats(); const logger = useLogger(); const gitService = useMemo(() => { if (!config.getProjectRoot()) { @@ -203,6 +205,7 @@ export const useGeminiStream = ( query: PartListUnion, userMessageTimestamp: number, abortSignal: AbortSignal, + prompt_id: string, ): Promise<{ queryToSend: PartListUnion | null; shouldProceed: boolean; @@ -220,7 +223,7 @@ export const useGeminiStream = ( const trimmedQuery = query.trim(); logUserPrompt( config, - new UserPromptEvent(trimmedQuery.length, trimmedQuery), + new UserPromptEvent(trimmedQuery.length, prompt_id, trimmedQuery), ); onDebugMessage(`User query: '${trimmedQuery}'`); await logger?.logMessage(MessageSenderType.USER, trimmedQuery); @@ -236,6 +239,7 @@ export const useGeminiStream = ( name: toolName, args: toolArgs, isClientInitiated: true, + prompt_id, }; scheduleToolCalls([toolCallRequest], abortSignal); } @@ -485,7 +489,11 @@ export const useGeminiStream = ( ); const submitQuery = useCallback( - async (query: PartListUnion, options?: { isContinuation: boolean }) => { + async ( + query: PartListUnion, + options?: { isContinuation: boolean }, + prompt_id?: string, + ) => { if ( (streamingState === StreamingState.Responding || streamingState === StreamingState.WaitingForConfirmation) && @@ -506,21 +514,34 @@ export const useGeminiStream = ( const abortSignal = abortControllerRef.current.signal; turnCancelledRef.current = false; + if (!prompt_id) { + prompt_id = config.getSessionId() + '########' + getPromptCount(); + } + const { queryToSend, shouldProceed } = await prepareQueryForGemini( query, userMessageTimestamp, abortSignal, + prompt_id!, ); if (!shouldProceed || queryToSend === null) { return; } + if (!options?.isContinuation) { + startNewPrompt(); + } + setIsResponding(true); setInitError(null); try { - const stream = geminiClient.sendMessageStream(queryToSend, abortSignal); + const stream = geminiClient.sendMessageStream( + queryToSend, + abortSignal, + prompt_id!, + ); const processingStatus = await processGeminiStreamEvents( stream, userMessageTimestamp, @@ -570,6 +591,8 @@ export const useGeminiStream = ( geminiClient, onAuthError, config, + startNewPrompt, + getPromptCount, ], ); @@ -676,6 +699,10 @@ export const useGeminiStream = ( (toolCall) => toolCall.request.callId, ); + const prompt_ids = geminiTools.map( + (toolCall) => toolCall.request.prompt_id, + ); + markToolsAsSubmitted(callIdsToMarkAsSubmitted); // Don't continue if model was switched due to quota error @@ -683,9 +710,13 @@ export const useGeminiStream = ( return; } - submitQuery(mergePartListUnions(responsesToSend), { - isContinuation: true, - }); + submitQuery( + mergePartListUnions(responsesToSend), + { + isContinuation: true, + }, + prompt_ids[0], + ); }, [ isResponding, |
