summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
authorAbhi <[email protected]>2025-06-09 20:25:37 -0400
committerGitHub <[email protected]>2025-06-09 20:25:37 -0400
commit7f1252d364ec251a4a76becbcb3f101b361f2656 (patch)
tree0091370d4b2a2c7cf6766b70243c146f2f463c5a /packages/cli/src
parent6484dc9008448637ebdebd21f83d876aaac127c8 (diff)
feat: Display initial token usage metrics in /stats (#879)
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/ui/App.tsx6
-rw-r--r--packages/cli/src/ui/contexts/SessionContext.test.tsx185
-rw-r--r--packages/cli/src/ui/contexts/SessionContext.tsx131
-rw-r--r--packages/cli/src/ui/hooks/slashCommandProcessor.test.ts67
-rw-r--r--packages/cli/src/ui/hooks/slashCommandProcessor.ts29
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.test.tsx68
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts17
7 files changed, 449 insertions, 54 deletions
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx
index b458a822..bf8c2abb 100644
--- a/packages/cli/src/ui/App.tsx
+++ b/packages/cli/src/ui/App.tsx
@@ -48,7 +48,7 @@ import {
} from '@gemini-cli/core';
import { useLogger } from './hooks/useLogger.js';
import { StreamingContext } from './contexts/StreamingContext.js';
-import { SessionProvider } from './contexts/SessionContext.js';
+import { SessionStatsProvider } from './contexts/SessionContext.js';
import { useGitBranchName } from './hooks/useGitBranchName.js';
const CTRL_C_PROMPT_DURATION_MS = 1000;
@@ -60,9 +60,9 @@ interface AppProps {
}
export const AppWrapper = (props: AppProps) => (
- <SessionProvider>
+ <SessionStatsProvider>
<App {...props} />
- </SessionProvider>
+ </SessionStatsProvider>
);
const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
diff --git a/packages/cli/src/ui/contexts/SessionContext.test.tsx b/packages/cli/src/ui/contexts/SessionContext.test.tsx
index 3b5454cf..fedf3d74 100644
--- a/packages/cli/src/ui/contexts/SessionContext.test.tsx
+++ b/packages/cli/src/ui/contexts/SessionContext.test.tsx
@@ -4,26 +4,181 @@
* SPDX-License-Identifier: Apache-2.0
*/
+import { type MutableRefObject } from 'react';
import { render } from 'ink-testing-library';
-import { Text } from 'ink';
-import { SessionProvider, useSession } from './SessionContext.js';
-import { describe, it, expect } from 'vitest';
+import { act } from 'react-dom/test-utils';
+import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
+import { describe, it, expect, vi } from 'vitest';
+import { GenerateContentResponseUsageMetadata } from '@google/genai';
-const TestComponent = () => {
- const { startTime } = useSession();
- return <Text>{startTime.toISOString()}</Text>;
+// Mock data that simulates what the Gemini API would return.
+const mockMetadata1: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 100,
+ candidatesTokenCount: 200,
+ totalTokenCount: 300,
+ cachedContentTokenCount: 50,
+ toolUsePromptTokenCount: 10,
+ thoughtsTokenCount: 20,
};
-describe('SessionContext', () => {
- it('should provide a start time', () => {
- const { lastFrame } = render(
- <SessionProvider>
- <TestComponent />
- </SessionProvider>,
+const mockMetadata2: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 10,
+ candidatesTokenCount: 20,
+ totalTokenCount: 30,
+ cachedContentTokenCount: 5,
+ toolUsePromptTokenCount: 1,
+ thoughtsTokenCount: 2,
+};
+
+/**
+ * A test harness component that uses the hook and exposes the context value
+ * via a mutable ref. This allows us to interact with the context's functions
+ * and assert against its state directly in our tests.
+ */
+const TestHarness = ({
+ contextRef,
+}: {
+ contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>;
+}) => {
+ contextRef.current = useSessionStats();
+ return null;
+};
+
+describe('SessionStatsContext', () => {
+ it('should provide the correct initial state', () => {
+ const contextRef: MutableRefObject<
+ ReturnType<typeof useSessionStats> | undefined
+ > = { current: undefined };
+
+ render(
+ <SessionStatsProvider>
+ <TestHarness contextRef={contextRef} />
+ </SessionStatsProvider>,
+ );
+
+ const stats = contextRef.current?.stats;
+
+ expect(stats?.sessionStartTime).toBeInstanceOf(Date);
+ expect(stats?.lastTurn).toBeNull();
+ expect(stats?.cumulative.turnCount).toBe(0);
+ expect(stats?.cumulative.totalTokenCount).toBe(0);
+ expect(stats?.cumulative.promptTokenCount).toBe(0);
+ });
+
+ it('should increment turnCount when startNewTurn is called', () => {
+ const contextRef: MutableRefObject<
+ ReturnType<typeof useSessionStats> | undefined
+ > = { current: undefined };
+
+ render(
+ <SessionStatsProvider>
+ <TestHarness contextRef={contextRef} />
+ </SessionStatsProvider>,
+ );
+
+ act(() => {
+ contextRef.current?.startNewTurn();
+ });
+
+ const stats = contextRef.current?.stats;
+ expect(stats?.cumulative.turnCount).toBe(1);
+ // Ensure token counts are unaffected
+ expect(stats?.cumulative.totalTokenCount).toBe(0);
+ });
+
+ it('should aggregate token usage correctly when addUsage is called', () => {
+ const contextRef: MutableRefObject<
+ ReturnType<typeof useSessionStats> | undefined
+ > = { current: undefined };
+
+ render(
+ <SessionStatsProvider>
+ <TestHarness contextRef={contextRef} />
+ </SessionStatsProvider>,
+ );
+
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata1);
+ });
+
+ const stats = contextRef.current?.stats;
+
+ // Check that token counts are updated
+ expect(stats?.cumulative.totalTokenCount).toBe(
+ mockMetadata1.totalTokenCount ?? 0,
+ );
+ expect(stats?.cumulative.promptTokenCount).toBe(
+ mockMetadata1.promptTokenCount ?? 0,
+ );
+
+ // Check that turn count is NOT incremented
+ expect(stats?.cumulative.turnCount).toBe(0);
+
+ // Check that lastTurn is updated
+ expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
+ });
+
+ it('should correctly track a full logical turn with multiple API calls', () => {
+ const contextRef: MutableRefObject<
+ ReturnType<typeof useSessionStats> | undefined
+ > = { current: undefined };
+
+ render(
+ <SessionStatsProvider>
+ <TestHarness contextRef={contextRef} />
+ </SessionStatsProvider>,
+ );
+
+ // 1. User starts a new turn
+ act(() => {
+ contextRef.current?.startNewTurn();
+ });
+
+ // 2. First API call (e.g., prompt with a tool request)
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata1);
+ });
+
+ // 3. Second API call (e.g., sending tool response back)
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata2);
+ });
+
+ const stats = contextRef.current?.stats;
+
+ // Turn count should only be 1
+ expect(stats?.cumulative.turnCount).toBe(1);
+
+ // These fields should be the SUM of both calls
+ expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
+ expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
+ expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
+
+ // These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
+ expect(stats?.cumulative.promptTokenCount).toBe(100);
+ expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
+ expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
+
+ // Last turn should hold the metadata from the most recent call
+ expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
+ });
+
+ it('should throw an error when useSessionStats is used outside of a provider', () => {
+ // Suppress the expected console error during this test.
+ const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ const contextRef = { current: undefined };
+
+ // We expect rendering to fail, which React will catch and log as an error.
+ render(<TestHarness contextRef={contextRef} />);
+
+ // Assert that the first argument of the first call to console.error
+ // contains the expected message. This is more robust than checking
+ // the exact arguments, which can be affected by React/JSDOM internals.
+ expect(errorSpy.mock.calls[0][0]).toContain(
+ 'useSessionStats must be used within a SessionStatsProvider',
);
- const frameText = lastFrame();
- // Check if the output is a valid ISO string, which confirms it's a Date object.
- expect(new Date(frameText!).toString()).not.toBe('Invalid Date');
+ errorSpy.mockRestore();
});
});
diff --git a/packages/cli/src/ui/contexts/SessionContext.tsx b/packages/cli/src/ui/contexts/SessionContext.tsx
index c511aa46..0549e3e1 100644
--- a/packages/cli/src/ui/contexts/SessionContext.tsx
+++ b/packages/cli/src/ui/contexts/SessionContext.tsx
@@ -4,35 +4,140 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import React, { createContext, useContext, useState, useMemo } from 'react';
+import React, {
+ createContext,
+ useContext,
+ useState,
+ useMemo,
+ useCallback,
+} from 'react';
-interface SessionContextType {
- startTime: Date;
+import { type GenerateContentResponseUsageMetadata } from '@google/genai';
+
+// --- Interface Definitions ---
+
+interface CumulativeStats {
+ turnCount: number;
+ promptTokenCount: number;
+ candidatesTokenCount: number;
+ totalTokenCount: number;
+ cachedContentTokenCount: number;
+ toolUsePromptTokenCount: number;
+ thoughtsTokenCount: number;
}
-const SessionContext = createContext<SessionContextType | null>(null);
+interface LastTurnStats {
+ metadata: GenerateContentResponseUsageMetadata;
+ // TODO(abhipatel12): Add apiTime, etc. here in a future step.
+}
+
+interface SessionStatsState {
+ sessionStartTime: Date;
+ cumulative: CumulativeStats;
+ lastTurn: LastTurnStats | null;
+ isNewTurnForAggregation: boolean;
+}
+
+// Defines the final "value" of our context, including the state
+// and the functions to update it.
+interface SessionStatsContextValue {
+ stats: SessionStatsState;
+ startNewTurn: () => void;
+ addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
+}
+
+// --- Context Definition ---
+
+const SessionStatsContext = createContext<SessionStatsContextValue | undefined>(
+ undefined,
+);
-export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({
+// --- Provider Component ---
+
+export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
- const [startTime] = useState(new Date());
+ const [stats, setStats] = useState<SessionStatsState>({
+ sessionStartTime: new Date(),
+ cumulative: {
+ turnCount: 0,
+ promptTokenCount: 0,
+ candidatesTokenCount: 0,
+ totalTokenCount: 0,
+ cachedContentTokenCount: 0,
+ toolUsePromptTokenCount: 0,
+ thoughtsTokenCount: 0,
+ },
+ lastTurn: null,
+ isNewTurnForAggregation: true,
+ });
+
+ // A single, internal worker function to handle all metadata aggregation.
+ const aggregateTokens = useCallback(
+ (metadata: GenerateContentResponseUsageMetadata) => {
+ setStats((prevState) => {
+ const { isNewTurnForAggregation } = prevState;
+ const newCumulative = { ...prevState.cumulative };
+
+ newCumulative.candidatesTokenCount +=
+ metadata.candidatesTokenCount ?? 0;
+ newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
+ newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
+
+ if (isNewTurnForAggregation) {
+ newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
+ newCumulative.cachedContentTokenCount +=
+ metadata.cachedContentTokenCount ?? 0;
+ newCumulative.toolUsePromptTokenCount +=
+ metadata.toolUsePromptTokenCount ?? 0;
+ }
+
+ return {
+ ...prevState,
+ cumulative: newCumulative,
+ lastTurn: { metadata },
+ isNewTurnForAggregation: false,
+ };
+ });
+ },
+ [],
+ );
+
+ const startNewTurn = useCallback(() => {
+ setStats((prevState) => ({
+ ...prevState,
+ cumulative: {
+ ...prevState.cumulative,
+ turnCount: prevState.cumulative.turnCount + 1,
+ },
+ isNewTurnForAggregation: true,
+ }));
+ }, []);
const value = useMemo(
() => ({
- startTime,
+ stats,
+ startNewTurn,
+ addUsage: aggregateTokens,
}),
- [startTime],
+ [stats, startNewTurn, aggregateTokens],
);
return (
- <SessionContext.Provider value={value}>{children}</SessionContext.Provider>
+ <SessionStatsContext.Provider value={value}>
+ {children}
+ </SessionStatsContext.Provider>
);
};
-export const useSession = () => {
- const context = useContext(SessionContext);
- if (!context) {
- throw new Error('useSession must be used within a SessionProvider');
+// --- Consumer Hook ---
+
+export const useSessionStats = () => {
+ const context = useContext(SessionStatsContext);
+ if (context === undefined) {
+ throw new Error(
+ 'useSessionStats must be used within a SessionStatsProvider',
+ );
}
return context;
};
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
index aa1e701f..cc6be49e 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
@@ -61,13 +61,13 @@ import {
MCPServerStatus,
getMCPServerStatus,
} from '@gemini-cli/core';
-import { useSession } from '../contexts/SessionContext.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
vi.mock('../contexts/SessionContext.js', () => ({
- useSession: vi.fn(),
+ useSessionStats: vi.fn(),
}));
vi.mock('./useShowMemoryCommand.js', () => ({
@@ -89,7 +89,7 @@ describe('useSlashCommandProcessor', () => {
let mockPerformMemoryRefresh: ReturnType<typeof vi.fn>;
let mockConfig: Config;
let mockCorgiMode: ReturnType<typeof vi.fn>;
- const mockUseSession = useSession as Mock;
+ const mockUseSessionStats = useSessionStats as Mock;
beforeEach(() => {
mockAddItem = vi.fn();
@@ -105,8 +105,19 @@ describe('useSlashCommandProcessor', () => {
getModel: vi.fn(() => 'test-model'),
} as unknown as Config;
mockCorgiMode = vi.fn();
- mockUseSession.mockReturnValue({
- startTime: new Date('2025-01-01T00:00:00.000Z'),
+ mockUseSessionStats.mockReturnValue({
+ stats: {
+ sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
+ cumulative: {
+ turnCount: 0,
+ promptTokenCount: 0,
+ candidatesTokenCount: 0,
+ totalTokenCount: 0,
+ cachedContentTokenCount: 0,
+ toolUsePromptTokenCount: 0,
+ thoughtsTokenCount: 0,
+ },
+ },
});
(open as Mock).mockClear();
@@ -240,29 +251,55 @@ describe('useSlashCommandProcessor', () => {
});
describe('/stats command', () => {
- it('should show the session duration', async () => {
- const { handleSlashCommand } = getProcessor();
- let commandResult: SlashCommandActionReturn | boolean = false;
+ it('should show detailed session statistics', async () => {
+ // Arrange
+ mockUseSessionStats.mockReturnValue({
+ stats: {
+ sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
+ cumulative: {
+ totalTokenCount: 900,
+ promptTokenCount: 200,
+ candidatesTokenCount: 400,
+ cachedContentTokenCount: 100,
+ turnCount: 1,
+ toolUsePromptTokenCount: 50,
+ thoughtsTokenCount: 150,
+ },
+ },
+ });
- // Mock current time
- const mockDate = new Date('2025-01-01T00:01:05.000Z');
+ const { handleSlashCommand } = getProcessor();
+ const mockDate = new Date('2025-01-01T01:02:03.000Z'); // 1h 2m 3s duration
vi.setSystemTime(mockDate);
+ // Act
await act(async () => {
- commandResult = handleSlashCommand('/stats');
+ handleSlashCommand('/stats');
});
+ // Assert
+ const expectedContent = [
+ ` ⎿ Total duration (wall): 1h 2m 3s`,
+ ` Total Token usage:`,
+ ` Turns: 1`,
+ ` Total: 900`,
+ ` ├─ Input: 200`,
+ ` ├─ Output: 400`,
+ ` ├─ Cached: 100`,
+ ` └─ Overhead: 200`,
+ ` ├─ Model thoughts: 150`,
+ ` └─ Tool-use prompts: 50`,
+ ].join('\n');
+
expect(mockAddItem).toHaveBeenNthCalledWith(
- 2,
+ 2, // Called after the user message
expect.objectContaining({
type: MessageType.INFO,
- text: 'Session duration: 1m 5s',
+ text: expectedContent,
}),
expect.any(Number),
);
- expect(commandResult).toBe(true);
- // Restore system time
vi.useRealTimers();
});
});
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
index daec0379..6159fe89 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
@@ -11,7 +11,7 @@ import process from 'node:process';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core';
import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
-import { useSession } from '../contexts/SessionContext.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
import { createShowMemoryAction } from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { formatMemoryUsage } from '../utils/formatters.js';
@@ -50,8 +50,7 @@ export const useSlashCommandProcessor = (
toggleCorgiMode: () => void,
showToolDescriptions: boolean = false,
) => {
- const session = useSession();
-
+ const session = useSessionStats();
const addMessage = useCallback(
(message: Message) => {
// Convert Message to HistoryItemWithoutId
@@ -147,7 +146,9 @@ export const useSlashCommandProcessor = (
description: 'check session stats',
action: (_mainCommand, _subCommand, _args) => {
const now = new Date();
- const duration = now.getTime() - session.startTime.getTime();
+ const { sessionStartTime, cumulative } = session.stats;
+
+ const duration = now.getTime() - sessionStartTime.getTime();
const durationInSeconds = Math.floor(duration / 1000);
const hours = Math.floor(durationInSeconds / 3600);
const minutes = Math.floor((durationInSeconds % 3600) / 60);
@@ -161,9 +162,25 @@ export const useSlashCommandProcessor = (
.filter(Boolean)
.join(' ');
+ const overheadTotal =
+ cumulative.thoughtsTokenCount + cumulative.toolUsePromptTokenCount;
+
+ const statsContent = [
+ ` ⎿ Total duration (wall): ${durationString}`,
+ ` Total Token usage:`,
+ ` Turns: ${cumulative.turnCount.toLocaleString()}`,
+ ` Total: ${cumulative.totalTokenCount.toLocaleString()}`,
+ ` ├─ Input: ${cumulative.promptTokenCount.toLocaleString()}`,
+ ` ├─ Output: ${cumulative.candidatesTokenCount.toLocaleString()}`,
+ ` ├─ Cached: ${cumulative.cachedContentTokenCount.toLocaleString()}`,
+ ` └─ Overhead: ${overheadTotal.toLocaleString()}`,
+ ` ├─ Model thoughts: ${cumulative.thoughtsTokenCount.toLocaleString()}`,
+ ` └─ Tool-use prompts: ${cumulative.toolUsePromptTokenCount.toLocaleString()}`,
+ ].join('\n');
+
addMessage({
type: MessageType.INFO,
- content: `Session duration: ${durationString}`,
+ content: statsContent,
timestamp: new Date(),
});
},
@@ -477,7 +494,7 @@ Add any other context about the problem here.
toggleCorgiMode,
config,
showToolDescriptions,
- session.startTime,
+ session,
],
);
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index f41f7f9c..ed0f2aac 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -96,6 +96,15 @@ vi.mock('./useLogger.js', () => ({
}),
}));
+const mockStartNewTurn = vi.fn();
+const mockAddUsage = vi.fn();
+vi.mock('../contexts/SessionContext.js', () => ({
+ useSessionStats: vi.fn(() => ({
+ startNewTurn: mockStartNewTurn,
+ addUsage: mockAddUsage,
+ })),
+}));
+
vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false),
}));
@@ -531,4 +540,63 @@ describe('useGeminiStream', () => {
});
});
});
+
+ describe('Session Stats Integration', () => {
+ it('should call startNewTurn and addUsage for a simple prompt', async () => {
+ const mockMetadata = { totalTokenCount: 123 };
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Response' };
+ yield { type: 'usage_metadata', value: mockMetadata };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery('Hello, world!');
+ });
+
+ expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
+ });
+
+ it('should only call addUsage for a tool continuation prompt', async () => {
+ const mockMetadata = { totalTokenCount: 456 };
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Final Answer' };
+ yield { type: 'usage_metadata', value: mockMetadata };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery([{ text: 'tool response' }], {
+ isContinuation: true,
+ });
+ });
+
+ expect(mockStartNewTurn).not.toHaveBeenCalled();
+ expect(mockAddUsage).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
+ });
+
+ it('should not call addUsage if the stream contains no usage metadata', async () => {
+ // Arrange: A stream that yields content but never a usage_metadata event
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Some response text' };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery('Query with no usage data');
+ });
+
+ expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).not.toHaveBeenCalled();
+ });
+ });
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 2b47ae6f..bad9f78a 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -42,6 +42,7 @@ import {
TrackedCompletedToolCall,
TrackedCancelledToolCall,
} from './useReactToolScheduler.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = [];
@@ -82,6 +83,7 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
+ const { startNewTurn, addUsage } = useSessionStats();
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
@@ -390,6 +392,9 @@ export const useGeminiStream = (
case ServerGeminiEventType.ChatCompressed:
handleChatCompressionEvent();
break;
+ case ServerGeminiEventType.UsageMetadata:
+ addUsage(event.value);
+ break;
case ServerGeminiEventType.ToolCallConfirmation:
case ServerGeminiEventType.ToolCallResponse:
// do nothing
@@ -412,11 +417,12 @@ export const useGeminiStream = (
handleErrorEvent,
scheduleToolCalls,
handleChatCompressionEvent,
+ addUsage,
],
);
const submitQuery = useCallback(
- async (query: PartListUnion) => {
+ async (query: PartListUnion, options?: { isContinuation: boolean }) => {
if (
streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation
@@ -426,6 +432,10 @@ export const useGeminiStream = (
const userMessageTimestamp = Date.now();
setShowHelp(false);
+ if (!options?.isContinuation) {
+ startNewTurn();
+ }
+
abortControllerRef.current = new AbortController();
const abortSignal = abortControllerRef.current.signal;
@@ -491,6 +501,7 @@ export const useGeminiStream = (
setPendingHistoryItem,
setInitError,
geminiClient,
+ startNewTurn,
],
);
@@ -576,7 +587,9 @@ export const useGeminiStream = (
);
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
- submitQuery(mergePartListUnions(responsesToSend));
+ submitQuery(mergePartListUnions(responsesToSend), {
+ isContinuation: true,
+ });
}
}, [
toolCalls,