diff options
Diffstat (limited to 'packages/cli/src/ui/hooks')
4 files changed, 66 insertions, 12 deletions
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 137098df..45f52074 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -159,7 +159,7 @@ describe('useSlashCommandProcessor', () => { stats: { sessionStartTime: new Date('2025-01-01T00:00:00.000Z'), cumulative: { - turnCount: 0, + promptCount: 0, promptTokenCount: 0, candidatesTokenCount: 0, totalTokenCount: 0, @@ -1311,7 +1311,10 @@ describe('useSlashCommandProcessor', () => { hook.rerender(); }); expect(hook.result.current.pendingHistoryItems).toEqual([]); - expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledWith(true); + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledWith( + 'Prompt Id not set', + true, + ); expect(mockAddItem).toHaveBeenNthCalledWith( 2, expect.objectContaining({ diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 66cf4e39..f53bdc12 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -880,7 +880,8 @@ export const useSlashCommandProcessor = ( try { const compressed = await config! .getGeminiClient()! - .tryCompressChat(true); + // TODO: Set Prompt id for CompressChat from SlashCommandProcessor. + .tryCompressChat('Prompt Id not set', true); if (compressed) { addMessage({ type: MessageType.COMPRESSION, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 62ade50f..e0e21f55 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -109,12 +109,13 @@ vi.mock('./useLogger.js', () => ({ }), })); -const mockStartNewTurn = vi.fn(); +const mockStartNewPrompt = vi.fn(); const mockAddUsage = vi.fn(); vi.mock('../contexts/SessionContext.js', () => ({ useSessionStats: vi.fn(() => ({ - startNewTurn: mockStartNewTurn, + startNewPrompt: mockStartNewPrompt, addUsage: mockAddUsage, + getPromptCount: vi.fn(() => 5), })), })); @@ -301,6 +302,9 @@ describe('useGeminiStream', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, addHistory: vi.fn(), + getSessionId() { + return 'test-session-id'; + }, setQuotaErrorOccurred: vi.fn(), getQuotaErrorOccurred: vi.fn(() => false), } as unknown as Config; @@ -426,6 +430,7 @@ describe('useGeminiStream', () => { name: 'tool1', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-1', }, status: 'success', responseSubmittedToGemini: false, @@ -444,7 +449,12 @@ describe('useGeminiStream', () => { endTime: Date.now(), } as TrackedCompletedToolCall, { - request: { callId: 'call2', name: 'tool2', args: {} }, + request: { + callId: 'call2', + name: 'tool2', + args: {}, + prompt_id: 'prompt-id-1', + }, status: 'executing', responseSubmittedToGemini: false, tool: { @@ -481,6 +491,7 @@ describe('useGeminiStream', () => { name: 'tool1', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-2', }, status: 'success', responseSubmittedToGemini: false, @@ -492,6 +503,7 @@ describe('useGeminiStream', () => { name: 'tool2', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-2', }, status: 'error', responseSubmittedToGemini: false, @@ -546,6 +558,7 @@ describe('useGeminiStream', () => { expect(mockSendMessageStream).toHaveBeenCalledWith( expectedMergedResponse, expect.any(AbortSignal), + 'prompt-id-2', ); }); @@ -557,6 +570,7 @@ describe('useGeminiStream', () => { name: 'testTool', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-3', }, status: 'cancelled', response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, @@ -618,6 +632,7 @@ describe('useGeminiStream', () => { name: 'toolA', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-7', }, tool: { name: 'toolA', @@ -641,6 +656,7 @@ describe('useGeminiStream', () => { name: 'toolB', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-8', }, tool: { name: 'toolB', @@ -731,6 +747,7 @@ describe('useGeminiStream', () => { name: 'tool1', args: {}, isClientInitiated: false, + prompt_id: 'prompt-id-4', }, status: 'executing', responseSubmittedToGemini: false, @@ -824,6 +841,7 @@ describe('useGeminiStream', () => { expect(mockSendMessageStream).toHaveBeenCalledWith( toolCallResponseParts, expect.any(AbortSignal), + 'prompt-id-4', ); }); @@ -1036,6 +1054,7 @@ describe('useGeminiStream', () => { name: 'save_memory', args: { fact: 'test' }, isClientInitiated: true, + prompt_id: 'prompt-id-6', }, status: 'success', responseSubmittedToGemini: false, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index d32c9ffa..b82b0cb2 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -53,6 +53,7 @@ import { TrackedCompletedToolCall, TrackedCancelledToolCall, } from './useReactToolScheduler.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { const resultParts: PartListUnion = []; @@ -101,6 +102,7 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); const processedMemoryToolsRef = useRef<Set<string>>(new Set()); + const { startNewPrompt, getPromptCount } = useSessionStats(); const logger = useLogger(); const gitService = useMemo(() => { if (!config.getProjectRoot()) { @@ -203,6 +205,7 @@ export const useGeminiStream = ( query: PartListUnion, userMessageTimestamp: number, abortSignal: AbortSignal, + prompt_id: string, ): Promise<{ queryToSend: PartListUnion | null; shouldProceed: boolean; @@ -220,7 +223,7 @@ export const useGeminiStream = ( const trimmedQuery = query.trim(); logUserPrompt( config, - new UserPromptEvent(trimmedQuery.length, trimmedQuery), + new UserPromptEvent(trimmedQuery.length, prompt_id, trimmedQuery), ); onDebugMessage(`User query: '${trimmedQuery}'`); await logger?.logMessage(MessageSenderType.USER, trimmedQuery); @@ -236,6 +239,7 @@ export const useGeminiStream = ( name: toolName, args: toolArgs, isClientInitiated: true, + prompt_id, }; scheduleToolCalls([toolCallRequest], abortSignal); } @@ -485,7 +489,11 @@ export const useGeminiStream = ( ); const submitQuery = useCallback( - async (query: PartListUnion, options?: { isContinuation: boolean }) => { + async ( + query: PartListUnion, + options?: { isContinuation: boolean }, + prompt_id?: string, + ) => { if ( (streamingState === StreamingState.Responding || streamingState === StreamingState.WaitingForConfirmation) && @@ -506,21 +514,34 @@ export const useGeminiStream = ( const abortSignal = abortControllerRef.current.signal; turnCancelledRef.current = false; + if (!prompt_id) { + prompt_id = config.getSessionId() + '########' + getPromptCount(); + } + const { queryToSend, shouldProceed } = await prepareQueryForGemini( query, userMessageTimestamp, abortSignal, + prompt_id!, ); if (!shouldProceed || queryToSend === null) { return; } + if (!options?.isContinuation) { + startNewPrompt(); + } + setIsResponding(true); setInitError(null); try { - const stream = geminiClient.sendMessageStream(queryToSend, abortSignal); + const stream = geminiClient.sendMessageStream( + queryToSend, + abortSignal, + prompt_id!, + ); const processingStatus = await processGeminiStreamEvents( stream, userMessageTimestamp, @@ -570,6 +591,8 @@ export const useGeminiStream = ( geminiClient, onAuthError, config, + startNewPrompt, + getPromptCount, ], ); @@ -676,6 +699,10 @@ export const useGeminiStream = ( (toolCall) => toolCall.request.callId, ); + const prompt_ids = geminiTools.map( + (toolCall) => toolCall.request.prompt_id, + ); + markToolsAsSubmitted(callIdsToMarkAsSubmitted); // Don't continue if model was switched due to quota error @@ -683,9 +710,13 @@ export const useGeminiStream = ( return; } - submitQuery(mergePartListUnions(responsesToSend), { - isContinuation: true, - }); + submitQuery( + mergePartListUnions(responsesToSend), + { + isContinuation: true, + }, + prompt_ids[0], + ); }, [ isResponding, |
