diff options
| author | Abhi <[email protected]> | 2025-06-09 20:25:37 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-09 20:25:37 -0400 |
| commit | 7f1252d364ec251a4a76becbcb3f101b361f2656 (patch) | |
| tree | 0091370d4b2a2c7cf6766b70243c146f2f463c5a /packages/cli/src/ui/contexts | |
| parent | 6484dc9008448637ebdebd21f83d876aaac127c8 (diff) | |
feat: Display initial token usage metrics in /stats (#879)
Diffstat (limited to 'packages/cli/src/ui/contexts')
| -rw-r--r-- | packages/cli/src/ui/contexts/SessionContext.test.tsx | 185 | ||||
| -rw-r--r-- | packages/cli/src/ui/contexts/SessionContext.tsx | 131 |
2 files changed, 288 insertions, 28 deletions
diff --git a/packages/cli/src/ui/contexts/SessionContext.test.tsx b/packages/cli/src/ui/contexts/SessionContext.test.tsx index 3b5454cf..fedf3d74 100644 --- a/packages/cli/src/ui/contexts/SessionContext.test.tsx +++ b/packages/cli/src/ui/contexts/SessionContext.test.tsx @@ -4,26 +4,181 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { type MutableRefObject } from 'react'; import { render } from 'ink-testing-library'; -import { Text } from 'ink'; -import { SessionProvider, useSession } from './SessionContext.js'; -import { describe, it, expect } from 'vitest'; +import { act } from 'react-dom/test-utils'; +import { SessionStatsProvider, useSessionStats } from './SessionContext.js'; +import { describe, it, expect, vi } from 'vitest'; +import { GenerateContentResponseUsageMetadata } from '@google/genai'; -const TestComponent = () => { - const { startTime } = useSession(); - return <Text>{startTime.toISOString()}</Text>; +// Mock data that simulates what the Gemini API would return. +const mockMetadata1: GenerateContentResponseUsageMetadata = { + promptTokenCount: 100, + candidatesTokenCount: 200, + totalTokenCount: 300, + cachedContentTokenCount: 50, + toolUsePromptTokenCount: 10, + thoughtsTokenCount: 20, }; -describe('SessionContext', () => { - it('should provide a start time', () => { - const { lastFrame } = render( - <SessionProvider> - <TestComponent /> - </SessionProvider>, +const mockMetadata2: GenerateContentResponseUsageMetadata = { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30, + cachedContentTokenCount: 5, + toolUsePromptTokenCount: 1, + thoughtsTokenCount: 2, +}; + +/** + * A test harness component that uses the hook and exposes the context value + * via a mutable ref. This allows us to interact with the context's functions + * and assert against its state directly in our tests. + */ +const TestHarness = ({ + contextRef, +}: { + contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>; +}) => { + contextRef.current = useSessionStats(); + return null; +}; + +describe('SessionStatsContext', () => { + it('should provide the correct initial state', () => { + const contextRef: MutableRefObject< + ReturnType<typeof useSessionStats> | undefined + > = { current: undefined }; + + render( + <SessionStatsProvider> + <TestHarness contextRef={contextRef} /> + </SessionStatsProvider>, + ); + + const stats = contextRef.current?.stats; + + expect(stats?.sessionStartTime).toBeInstanceOf(Date); + expect(stats?.lastTurn).toBeNull(); + expect(stats?.cumulative.turnCount).toBe(0); + expect(stats?.cumulative.totalTokenCount).toBe(0); + expect(stats?.cumulative.promptTokenCount).toBe(0); + }); + + it('should increment turnCount when startNewTurn is called', () => { + const contextRef: MutableRefObject< + ReturnType<typeof useSessionStats> | undefined + > = { current: undefined }; + + render( + <SessionStatsProvider> + <TestHarness contextRef={contextRef} /> + </SessionStatsProvider>, + ); + + act(() => { + contextRef.current?.startNewTurn(); + }); + + const stats = contextRef.current?.stats; + expect(stats?.cumulative.turnCount).toBe(1); + // Ensure token counts are unaffected + expect(stats?.cumulative.totalTokenCount).toBe(0); + }); + + it('should aggregate token usage correctly when addUsage is called', () => { + const contextRef: MutableRefObject< + ReturnType<typeof useSessionStats> | undefined + > = { current: undefined }; + + render( + <SessionStatsProvider> + <TestHarness contextRef={contextRef} /> + </SessionStatsProvider>, + ); + + act(() => { + contextRef.current?.addUsage(mockMetadata1); + }); + + const stats = contextRef.current?.stats; + + // Check that token counts are updated + expect(stats?.cumulative.totalTokenCount).toBe( + mockMetadata1.totalTokenCount ?? 0, + ); + expect(stats?.cumulative.promptTokenCount).toBe( + mockMetadata1.promptTokenCount ?? 0, + ); + + // Check that turn count is NOT incremented + expect(stats?.cumulative.turnCount).toBe(0); + + // Check that lastTurn is updated + expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1); + }); + + it('should correctly track a full logical turn with multiple API calls', () => { + const contextRef: MutableRefObject< + ReturnType<typeof useSessionStats> | undefined + > = { current: undefined }; + + render( + <SessionStatsProvider> + <TestHarness contextRef={contextRef} /> + </SessionStatsProvider>, + ); + + // 1. User starts a new turn + act(() => { + contextRef.current?.startNewTurn(); + }); + + // 2. First API call (e.g., prompt with a tool request) + act(() => { + contextRef.current?.addUsage(mockMetadata1); + }); + + // 3. Second API call (e.g., sending tool response back) + act(() => { + contextRef.current?.addUsage(mockMetadata2); + }); + + const stats = contextRef.current?.stats; + + // Turn count should only be 1 + expect(stats?.cumulative.turnCount).toBe(1); + + // These fields should be the SUM of both calls + expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30 + expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20 + expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2 + + // These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true + expect(stats?.cumulative.promptTokenCount).toBe(100); + expect(stats?.cumulative.cachedContentTokenCount).toBe(50); + expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10); + + // Last turn should hold the metadata from the most recent call + expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2); + }); + + it('should throw an error when useSessionStats is used outside of a provider', () => { + // Suppress the expected console error during this test. + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + const contextRef = { current: undefined }; + + // We expect rendering to fail, which React will catch and log as an error. + render(<TestHarness contextRef={contextRef} />); + + // Assert that the first argument of the first call to console.error + // contains the expected message. This is more robust than checking + // the exact arguments, which can be affected by React/JSDOM internals. + expect(errorSpy.mock.calls[0][0]).toContain( + 'useSessionStats must be used within a SessionStatsProvider', ); - const frameText = lastFrame(); - // Check if the output is a valid ISO string, which confirms it's a Date object. - expect(new Date(frameText!).toString()).not.toBe('Invalid Date'); + errorSpy.mockRestore(); }); }); diff --git a/packages/cli/src/ui/contexts/SessionContext.tsx b/packages/cli/src/ui/contexts/SessionContext.tsx index c511aa46..0549e3e1 100644 --- a/packages/cli/src/ui/contexts/SessionContext.tsx +++ b/packages/cli/src/ui/contexts/SessionContext.tsx @@ -4,35 +4,140 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { createContext, useContext, useState, useMemo } from 'react'; +import React, { + createContext, + useContext, + useState, + useMemo, + useCallback, +} from 'react'; -interface SessionContextType { - startTime: Date; +import { type GenerateContentResponseUsageMetadata } from '@google/genai'; + +// --- Interface Definitions --- + +interface CumulativeStats { + turnCount: number; + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + cachedContentTokenCount: number; + toolUsePromptTokenCount: number; + thoughtsTokenCount: number; } -const SessionContext = createContext<SessionContextType | null>(null); +interface LastTurnStats { + metadata: GenerateContentResponseUsageMetadata; + // TODO(abhipatel12): Add apiTime, etc. here in a future step. +} + +interface SessionStatsState { + sessionStartTime: Date; + cumulative: CumulativeStats; + lastTurn: LastTurnStats | null; + isNewTurnForAggregation: boolean; +} + +// Defines the final "value" of our context, including the state +// and the functions to update it. +interface SessionStatsContextValue { + stats: SessionStatsState; + startNewTurn: () => void; + addUsage: (metadata: GenerateContentResponseUsageMetadata) => void; +} + +// --- Context Definition --- + +const SessionStatsContext = createContext<SessionStatsContextValue | undefined>( + undefined, +); -export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({ +// --- Provider Component --- + +export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({ children, }) => { - const [startTime] = useState(new Date()); + const [stats, setStats] = useState<SessionStatsState>({ + sessionStartTime: new Date(), + cumulative: { + turnCount: 0, + promptTokenCount: 0, + candidatesTokenCount: 0, + totalTokenCount: 0, + cachedContentTokenCount: 0, + toolUsePromptTokenCount: 0, + thoughtsTokenCount: 0, + }, + lastTurn: null, + isNewTurnForAggregation: true, + }); + + // A single, internal worker function to handle all metadata aggregation. + const aggregateTokens = useCallback( + (metadata: GenerateContentResponseUsageMetadata) => { + setStats((prevState) => { + const { isNewTurnForAggregation } = prevState; + const newCumulative = { ...prevState.cumulative }; + + newCumulative.candidatesTokenCount += + metadata.candidatesTokenCount ?? 0; + newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0; + newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0; + + if (isNewTurnForAggregation) { + newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0; + newCumulative.cachedContentTokenCount += + metadata.cachedContentTokenCount ?? 0; + newCumulative.toolUsePromptTokenCount += + metadata.toolUsePromptTokenCount ?? 0; + } + + return { + ...prevState, + cumulative: newCumulative, + lastTurn: { metadata }, + isNewTurnForAggregation: false, + }; + }); + }, + [], + ); + + const startNewTurn = useCallback(() => { + setStats((prevState) => ({ + ...prevState, + cumulative: { + ...prevState.cumulative, + turnCount: prevState.cumulative.turnCount + 1, + }, + isNewTurnForAggregation: true, + })); + }, []); const value = useMemo( () => ({ - startTime, + stats, + startNewTurn, + addUsage: aggregateTokens, }), - [startTime], + [stats, startNewTurn, aggregateTokens], ); return ( - <SessionContext.Provider value={value}>{children}</SessionContext.Provider> + <SessionStatsContext.Provider value={value}> + {children} + </SessionStatsContext.Provider> ); }; -export const useSession = () => { - const context = useContext(SessionContext); - if (!context) { - throw new Error('useSession must be used within a SessionProvider'); +// --- Consumer Hook --- + +export const useSessionStats = () => { + const context = useContext(SessionStatsContext); + if (context === undefined) { + throw new Error( + 'useSessionStats must be used within a SessionStatsProvider', + ); } return context; }; |
