summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks/useGeminiStream.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts43
1 files changed, 37 insertions, 6 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index d32c9ffa..b82b0cb2 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -53,6 +53,7 @@ import {
TrackedCompletedToolCall,
TrackedCancelledToolCall,
} from './useReactToolScheduler.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = [];
@@ -101,6 +102,7 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const processedMemoryToolsRef = useRef<Set<string>>(new Set());
+ const { startNewPrompt, getPromptCount } = useSessionStats();
const logger = useLogger();
const gitService = useMemo(() => {
if (!config.getProjectRoot()) {
@@ -203,6 +205,7 @@ export const useGeminiStream = (
query: PartListUnion,
userMessageTimestamp: number,
abortSignal: AbortSignal,
+ prompt_id: string,
): Promise<{
queryToSend: PartListUnion | null;
shouldProceed: boolean;
@@ -220,7 +223,7 @@ export const useGeminiStream = (
const trimmedQuery = query.trim();
logUserPrompt(
config,
- new UserPromptEvent(trimmedQuery.length, trimmedQuery),
+ new UserPromptEvent(trimmedQuery.length, prompt_id, trimmedQuery),
);
onDebugMessage(`User query: '${trimmedQuery}'`);
await logger?.logMessage(MessageSenderType.USER, trimmedQuery);
@@ -236,6 +239,7 @@ export const useGeminiStream = (
name: toolName,
args: toolArgs,
isClientInitiated: true,
+ prompt_id,
};
scheduleToolCalls([toolCallRequest], abortSignal);
}
@@ -485,7 +489,11 @@ export const useGeminiStream = (
);
const submitQuery = useCallback(
- async (query: PartListUnion, options?: { isContinuation: boolean }) => {
+ async (
+ query: PartListUnion,
+ options?: { isContinuation: boolean },
+ prompt_id?: string,
+ ) => {
if (
(streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation) &&
@@ -506,21 +514,34 @@ export const useGeminiStream = (
const abortSignal = abortControllerRef.current.signal;
turnCancelledRef.current = false;
+ if (!prompt_id) {
+ prompt_id = config.getSessionId() + '########' + getPromptCount();
+ }
+
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
query,
userMessageTimestamp,
abortSignal,
+ prompt_id!,
);
if (!shouldProceed || queryToSend === null) {
return;
}
+ if (!options?.isContinuation) {
+ startNewPrompt();
+ }
+
setIsResponding(true);
setInitError(null);
try {
- const stream = geminiClient.sendMessageStream(queryToSend, abortSignal);
+ const stream = geminiClient.sendMessageStream(
+ queryToSend,
+ abortSignal,
+ prompt_id!,
+ );
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
@@ -570,6 +591,8 @@ export const useGeminiStream = (
geminiClient,
onAuthError,
config,
+ startNewPrompt,
+ getPromptCount,
],
);
@@ -676,6 +699,10 @@ export const useGeminiStream = (
(toolCall) => toolCall.request.callId,
);
+ const prompt_ids = geminiTools.map(
+ (toolCall) => toolCall.request.prompt_id,
+ );
+
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
// Don't continue if model was switched due to quota error
@@ -683,9 +710,13 @@ export const useGeminiStream = (
return;
}
- submitQuery(mergePartListUnions(responsesToSend), {
- isContinuation: true,
- });
+ submitQuery(
+ mergePartListUnions(responsesToSend),
+ {
+ isContinuation: true,
+ },
+ prompt_ids[0],
+ );
},
[
isResponding,