summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/client.test.ts58
-rw-r--r--packages/core/src/core/client.ts12
-rw-r--r--packages/core/src/core/turn.test.ts63
-rw-r--r--packages/core/src/core/turn.ts27
4 files changed, 155 insertions, 5 deletions
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index cbbbd113..58ad5dbd 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -13,14 +13,32 @@ import {
GoogleGenAI,
} from '@google/genai';
import { GeminiClient } from './client.js';
+import { ContentGenerator } from './contentGenerator.js';
+import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
+import { Turn } from './turn.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn();
+const mockTurnRunFn = vi.fn();
vi.mock('@google/genai');
+vi.mock('./turn', () => {
+ // Define a mock class that has the same shape as the real Turn
+ class MockTurn {
+ pendingToolCalls = [];
+ // The run method is a property that holds our mock function
+ run = mockTurnRunFn;
+
+ constructor() {
+ // The constructor can be empty or do some mock setup
+ }
+ }
+ // Export the mock class as 'Turn'
+ return { Turn: MockTurn };
+});
vi.mock('../config/config.js');
vi.mock('./prompts');
@@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
});
});
+
+ describe('sendMessageStream', () => {
+ it('should return the turn instance after the stream is complete', async () => {
+ // Arrange
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Hello' };
+ })();
+ mockTurnRunFn.mockReturnValue(mockStream);
+
+ const mockChat: Partial<GeminiChat> = {
+ addHistory: vi.fn(),
+ getHistory: vi.fn().mockReturnValue([]),
+ };
+ client['chat'] = Promise.resolve(mockChat as GeminiChat);
+
+ const mockGenerator: Partial<ContentGenerator> = {
+ countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
+ };
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+
+ // Act
+ const stream = client.sendMessageStream(
+ [{ text: 'Hi' }],
+ new AbortController().signal,
+ );
+
+ // Consume the stream manually to get the final return value.
+ let finalResult: Turn | undefined;
+ while (true) {
+ const result = await stream.next();
+ if (result.done) {
+ finalResult = result.value;
+ break;
+ }
+ }
+
+ // Assert
+ expect(finalResult).toBeInstanceOf(Turn);
+ });
+ });
});
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 8b921ab1..1b953d30 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -174,9 +174,10 @@ export class GeminiClient {
request: PartListUnion,
signal: AbortSignal,
turns: number = this.MAX_TURNS,
- ): AsyncGenerator<ServerGeminiStreamEvent> {
+ ): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) {
- return;
+ const chat = await this.chat;
+ return new Turn(chat);
}
const compressed = await this.tryCompressChat();
@@ -193,9 +194,12 @@ export class GeminiClient {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
+ // This recursive call's events will be yielded out, but the final
+ // turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
}
}
+ return turn;
}
private _logApiRequest(model: string, inputTokenCount: number): void {
@@ -423,6 +427,10 @@ export class GeminiClient {
});
const result = await retryWithBackoff(apiCall);
+ console.log(
+ 'Raw API Response in client.ts:',
+ JSON.stringify(result, null, 2),
+ );
const durationMs = Date.now() - startTime;
this._logApiResponse(modelToUse, durationMs, attempt, result);
return result;
diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts
index 8fb3a4c1..2217e5da 100644
--- a/packages/core/src/core/turn.test.ts
+++ b/packages/core/src/core/turn.test.ts
@@ -10,8 +10,14 @@ import {
GeminiEventType,
ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent,
+ ServerGeminiUsageMetadataEvent,
} from './turn.js';
-import { GenerateContentResponse, Part, Content } from '@google/genai';
+import {
+ GenerateContentResponse,
+ Part,
+ Content,
+ GenerateContentResponseUsageMetadata,
+} from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
@@ -49,6 +55,24 @@ describe('Turn', () => {
};
let mockChatInstance: MockedChatInstance;
+ const mockMetadata1: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 10,
+ candidatesTokenCount: 20,
+ totalTokenCount: 30,
+ cachedContentTokenCount: 5,
+ toolUsePromptTokenCount: 2,
+ thoughtsTokenCount: 3,
+ };
+
+ const mockMetadata2: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 100,
+ candidatesTokenCount: 200,
+ totalTokenCount: 300,
+ cachedContentTokenCount: 50,
+ toolUsePromptTokenCount: 20,
+ thoughtsTokenCount: 30,
+ };
+
beforeEach(() => {
vi.resetAllMocks();
mockChatInstance = {
@@ -96,6 +120,7 @@ describe('Turn', () => {
message: reqParts,
config: { abortSignal: expect.any(AbortSignal) },
});
+
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
@@ -208,6 +233,41 @@ describe('Turn', () => {
);
});
+ it('should yield the last UsageMetadata event from the stream', async () => {
+ const mockResponseStream = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'First response' }] } }],
+ usageMetadata: mockMetadata1,
+ } as unknown as GenerateContentResponse;
+ yield {
+ functionCalls: [{ name: 'aTool' }],
+ usageMetadata: mockMetadata2,
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Test metadata' }];
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
+ events.push(event);
+ }
+
+ // There should be a content event, a tool call, and our metadata event
+ expect(events.length).toBe(3);
+
+ const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
+ expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
+
+ // The value should be the *last* metadata object received.
+ expect(metadataEvent.value).toEqual(mockMetadata2);
+
+ // Also check the public getter
+ expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
+ });
+
it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () {
yield {
@@ -219,7 +279,6 @@ describe('Turn', () => {
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run(
diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts
index 637fc19d..34e4a494 100644
--- a/packages/core/src/core/turn.ts
+++ b/packages/core/src/core/turn.ts
@@ -9,6 +9,7 @@ import {
GenerateContentResponse,
FunctionCall,
FunctionDeclaration,
+ GenerateContentResponseUsageMetadata,
} from '@google/genai';
import {
ToolCallConfirmationDetails,
@@ -43,6 +44,7 @@ export enum GeminiEventType {
UserCancelled = 'user_cancelled',
Error = 'error',
ChatCompressed = 'chat_compressed',
+ UsageMetadata = 'usage_metadata',
}
export interface GeminiErrorEventValue {
@@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
type: GeminiEventType.ChatCompressed;
};
+export type ServerGeminiUsageMetadataEvent = {
+ type: GeminiEventType.UsageMetadata;
+ value: GenerateContentResponseUsageMetadata;
+};
+
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent
@@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
- | ServerGeminiChatCompressedEvent;
+ | ServerGeminiChatCompressedEvent
+ | ServerGeminiUsageMetadataEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
@@ -118,6 +126,7 @@ export class Turn {
args: Record<string, unknown>;
}>;
private debugResponses: GenerateContentResponse[];
+ private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = [];
@@ -157,6 +166,18 @@ export class Turn {
yield event;
}
}
+
+ if (resp.usageMetadata) {
+ this.lastUsageMetadata =
+ resp.usageMetadata as GenerateContentResponseUsageMetadata;
+ }
+ }
+
+ if (this.lastUsageMetadata) {
+ yield {
+ type: GeminiEventType.UsageMetadata,
+ value: this.lastUsageMetadata,
+ };
}
} catch (error) {
if (signal.aborted) {
@@ -197,4 +218,8 @@ export class Turn {
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
+
+ getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
+ return this.lastUsageMetadata;
+ }
}