diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/core/client.test.ts | 58 | ||||
| -rw-r--r-- | packages/core/src/core/client.ts | 12 | ||||
| -rw-r--r-- | packages/core/src/core/turn.test.ts | 63 | ||||
| -rw-r--r-- | packages/core/src/core/turn.ts | 27 |
4 files changed, 155 insertions, 5 deletions
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index cbbbd113..58ad5dbd 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -13,14 +13,32 @@ import { GoogleGenAI, } from '@google/genai'; import { GeminiClient } from './client.js'; +import { ContentGenerator } from './contentGenerator.js'; +import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; +import { Turn } from './turn.js'; // --- Mocks --- const mockChatCreateFn = vi.fn(); const mockGenerateContentFn = vi.fn(); const mockEmbedContentFn = vi.fn(); +const mockTurnRunFn = vi.fn(); vi.mock('@google/genai'); +vi.mock('./turn', () => { + // Define a mock class that has the same shape as the real Turn + class MockTurn { + pendingToolCalls = []; + // The run method is a property that holds our mock function + run = mockTurnRunFn; + + constructor() { + // The constructor can be empty or do some mock setup + } + } + // Export the mock class as 'Turn' + return { Turn: MockTurn }; +}); vi.mock('../config/config.js'); vi.mock('./prompts'); @@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => { expect(mockChat.addHistory).toHaveBeenCalledWith(newContent); }); }); + + describe('sendMessageStream', () => { + it('should return the turn instance after the stream is complete', async () => { + // Arrange + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial<GeminiChat> = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = Promise.resolve(mockChat as GeminiChat); + + const mockGenerator: Partial<ContentGenerator> = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + ); + + // Consume the stream manually to get the final return value. + let finalResult: Turn | undefined; + while (true) { + const result = await stream.next(); + if (result.done) { + finalResult = result.value; + break; + } + } + + // Assert + expect(finalResult).toBeInstanceOf(Turn); + }); + }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 8b921ab1..1b953d30 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -174,9 +174,10 @@ export class GeminiClient { request: PartListUnion, signal: AbortSignal, turns: number = this.MAX_TURNS, - ): AsyncGenerator<ServerGeminiStreamEvent> { + ): AsyncGenerator<ServerGeminiStreamEvent, Turn> { if (!turns) { - return; + const chat = await this.chat; + return new Turn(chat); } const compressed = await this.tryCompressChat(); @@ -193,9 +194,12 @@ export class GeminiClient { const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; + // This recursive call's events will be yielded out, but the final + // turn object will be from the top-level call. yield* this.sendMessageStream(nextRequest, signal, turns - 1); } } + return turn; } private _logApiRequest(model: string, inputTokenCount: number): void { @@ -423,6 +427,10 @@ export class GeminiClient { }); const result = await retryWithBackoff(apiCall); + console.log( + 'Raw API Response in client.ts:', + JSON.stringify(result, null, 2), + ); const durationMs = Date.now() - startTime; this._logApiResponse(modelToUse, durationMs, attempt, result); return result; diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 8fb3a4c1..2217e5da 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -10,8 +10,14 @@ import { GeminiEventType, ServerGeminiToolCallRequestEvent, ServerGeminiErrorEvent, + ServerGeminiUsageMetadataEvent, } from './turn.js'; -import { GenerateContentResponse, Part, Content } from '@google/genai'; +import { + GenerateContentResponse, + Part, + Content, + GenerateContentResponseUsageMetadata, +} from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import { GeminiChat } from './geminiChat.js'; @@ -49,6 +55,24 @@ describe('Turn', () => { }; let mockChatInstance: MockedChatInstance; + const mockMetadata1: GenerateContentResponseUsageMetadata = { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30, + cachedContentTokenCount: 5, + toolUsePromptTokenCount: 2, + thoughtsTokenCount: 3, + }; + + const mockMetadata2: GenerateContentResponseUsageMetadata = { + promptTokenCount: 100, + candidatesTokenCount: 200, + totalTokenCount: 300, + cachedContentTokenCount: 50, + toolUsePromptTokenCount: 20, + thoughtsTokenCount: 30, + }; + beforeEach(() => { vi.resetAllMocks(); mockChatInstance = { @@ -96,6 +120,7 @@ describe('Turn', () => { message: reqParts, config: { abortSignal: expect.any(AbortSignal) }, }); + expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' world' }, @@ -208,6 +233,41 @@ describe('Turn', () => { ); }); + it('should yield the last UsageMetadata event from the stream', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'First response' }] } }], + usageMetadata: mockMetadata1, + } as unknown as GenerateContentResponse; + yield { + functionCalls: [{ name: 'aTool' }], + usageMetadata: mockMetadata2, + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Test metadata' }]; + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { + events.push(event); + } + + // There should be a content event, a tool call, and our metadata event + expect(events.length).toBe(3); + + const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent; + expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata); + + // The value should be the *last* metadata object received. + expect(metadataEvent.value).toEqual(mockMetadata2); + + // Also check the public getter + expect(turn.getUsageMetadata()).toEqual(mockMetadata2); + }); + it('should handle function calls with undefined name or args', async () => { const mockResponseStream = (async function* () { yield { @@ -219,7 +279,6 @@ describe('Turn', () => { } as unknown as GenerateContentResponse; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); - const events = []; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; for await (const event of turn.run( diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 637fc19d..34e4a494 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -9,6 +9,7 @@ import { GenerateContentResponse, FunctionCall, FunctionDeclaration, + GenerateContentResponseUsageMetadata, } from '@google/genai'; import { ToolCallConfirmationDetails, @@ -43,6 +44,7 @@ export enum GeminiEventType { UserCancelled = 'user_cancelled', Error = 'error', ChatCompressed = 'chat_compressed', + UsageMetadata = 'usage_metadata', } export interface GeminiErrorEventValue { @@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = { type: GeminiEventType.ChatCompressed; }; +export type ServerGeminiUsageMetadataEvent = { + type: GeminiEventType.UsageMetadata; + value: GenerateContentResponseUsageMetadata; +}; + // The original union type, now composed of the individual types export type ServerGeminiStreamEvent = | ServerGeminiContentEvent @@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiToolCallConfirmationEvent | ServerGeminiUserCancelledEvent | ServerGeminiErrorEvent - | ServerGeminiChatCompressedEvent; + | ServerGeminiChatCompressedEvent + | ServerGeminiUsageMetadataEvent; // A turn manages the agentic loop turn within the server context. export class Turn { @@ -118,6 +126,7 @@ export class Turn { args: Record<string, unknown>; }>; private debugResponses: GenerateContentResponse[]; + private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null; constructor(private readonly chat: GeminiChat) { this.pendingToolCalls = []; @@ -157,6 +166,18 @@ export class Turn { yield event; } } + + if (resp.usageMetadata) { + this.lastUsageMetadata = + resp.usageMetadata as GenerateContentResponseUsageMetadata; + } + } + + if (this.lastUsageMetadata) { + yield { + type: GeminiEventType.UsageMetadata, + value: this.lastUsageMetadata, + }; } } catch (error) { if (signal.aborted) { @@ -197,4 +218,8 @@ export class Turn { getDebugResponses(): GenerateContentResponse[] { return this.debugResponses; } + + getUsageMetadata(): GenerateContentResponseUsageMetadata | null { + return this.lastUsageMetadata; + } } |
