summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/core/src/core/client.ts20
-rw-r--r--packages/core/src/core/contentGenerator.ts27
-rw-r--r--packages/core/src/core/geminiChat.test.ts29
-rw-r--r--packages/core/src/core/geminiChat.ts10
-rw-r--r--packages/core/src/utils/nextSpeakerChecker.test.ts1
5 files changed, 45 insertions, 42 deletions
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index bce2c5e4..d1a59eb1 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -32,10 +32,11 @@ import {
logApiResponse,
logApiError,
} from '../telemetry/index.js';
+import { ContentGenerator } from './contentGenerator.js';
export class GeminiClient {
private chat: Promise<GeminiChat>;
- private client: GoogleGenAI;
+ private contentGenerator: ContentGenerator;
private model: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
@@ -48,7 +49,7 @@ export class GeminiClient {
const apiKeyFromConfig = config.getApiKey();
const vertexaiFlag = config.getVertexAI();
- this.client = new GoogleGenAI({
+ const googleGenAI = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
vertexai: vertexaiFlag,
httpOptions: {
@@ -57,6 +58,7 @@ export class GeminiClient {
},
},
});
+ this.contentGenerator = googleGenAI.models;
this.model = config.getModel();
this.chat = this.startChat();
}
@@ -148,8 +150,7 @@ export class GeminiClient {
const systemInstruction = getCoreSystemPrompt(userMemory);
return new GeminiChat(
- this.client,
- this.client.models,
+ this.contentGenerator,
this.model,
{
systemInstruction,
@@ -285,7 +286,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
- const { totalTokens } = await this.client.models.countTokens({
+ const { totalTokens } = await this.contentGenerator.countTokens({
model,
contents,
});
@@ -300,7 +301,7 @@ export class GeminiClient {
this._logApiRequest(model, inputTokenCount);
const apiCall = () =>
- this.client.models.generateContent({
+ this.contentGenerator.generateContent({
model,
config: {
...requestConfig,
@@ -400,7 +401,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
- const { totalTokens } = await this.client.models.countTokens({
+ const { totalTokens } = await this.contentGenerator.countTokens({
model: modelToUse,
contents,
});
@@ -415,7 +416,7 @@ export class GeminiClient {
this._logApiRequest(modelToUse, inputTokenCount);
const apiCall = () =>
- this.client.models.generateContent({
+ this.contentGenerator.generateContent({
model: modelToUse,
config: requestConfig,
contents,
@@ -453,8 +454,7 @@ export class GeminiClient {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history
- // Count tokens using the models module from the GoogleGenAI client instance
- const { totalTokens } = await this.client.models.countTokens({
+ const { totalTokens } = await this.contentGenerator.countTokens({
model: this.model,
contents: history,
});
diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts
new file mode 100644
index 00000000..32b48c5c
--- /dev/null
+++ b/packages/core/src/core/contentGenerator.ts
@@ -0,0 +1,27 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import {
+ CountTokensResponse,
+ GenerateContentResponse,
+ GenerateContentParameters,
+ CountTokensParameters,
+} from '@google/genai';
+
+/**
+ * Interface abstracting the core functionalities for generating content and counting tokens.
+ */
+export interface ContentGenerator {
+ generateContent(
+ request: GenerateContentParameters,
+ ): Promise<GenerateContentResponse>;
+
+ generateContentStream(
+ request: GenerateContentParameters,
+ ): Promise<AsyncGenerator<GenerateContentResponse>>;
+
+ countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
+}
diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts
index 3a6fb10c..6d18ebd9 100644
--- a/packages/core/src/core/geminiChat.test.ts
+++ b/packages/core/src/core/geminiChat.test.ts
@@ -5,13 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
-import {
- Content,
- GoogleGenAI,
- Models,
- GenerateContentConfig,
- Part,
-} from '@google/genai';
+import { Content, Models, GenerateContentConfig, Part } from '@google/genai';
import { GeminiChat } from './geminiChat.js';
// Mocks
@@ -23,10 +17,6 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(),
} as unknown as Models;
-const mockGoogleGenAI = {
- getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
-} as unknown as GoogleGenAI;
-
describe('GeminiChat', () => {
let chat: GeminiChat;
const model = 'gemini-pro';
@@ -35,7 +25,7 @@ describe('GeminiChat', () => {
beforeEach(() => {
vi.clearAllMocks();
// Reset history for each test by creating a new instance
- chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []);
+ chat = new GeminiChat(mockModelsModule, model, config, []);
});
afterEach(() => {
@@ -129,19 +119,8 @@ describe('GeminiChat', () => {
// @ts-expect-error Accessing private method for testing purposes
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
- // const history = chat.getHistory(); // Removed unused variable to satisfy linter
- // The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
- // However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
- // Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
-
// Reset and set up a more realistic scenario for merging with existing history
- chat = new GeminiChat(
- mockGoogleGenAI,
- mockModelsModule,
- model,
- config,
- [],
- );
+ chat = new GeminiChat(mockModelsModule, model, config, []);
const firstUserInput: Content = {
role: 'user',
parts: [{ text: 'First user input' }],
@@ -184,7 +163,7 @@ describe('GeminiChat', () => {
role: 'model',
parts: [{ text: 'Initial model answer.' }],
};
- chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [
+ chat = new GeminiChat(mockModelsModule, model, config, [
initialUser,
initialModel,
]);
diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts
index b4844499..54f74102 100644
--- a/packages/core/src/core/geminiChat.ts
+++ b/packages/core/src/core/geminiChat.ts
@@ -10,15 +10,14 @@
import {
GenerateContentResponse,
Content,
- Models,
GenerateContentConfig,
SendMessageParameters,
- GoogleGenAI,
createUserContent,
Part,
} from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
+import { ContentGenerator } from './contentGenerator.js';
/**
* Returns true if the response is valid, false otherwise.
@@ -120,8 +119,7 @@ export class GeminiChat {
private sendPromise: Promise<void> = Promise.resolve();
constructor(
- private readonly apiClient: GoogleGenAI,
- private readonly modelsModule: Models,
+ private readonly contentGenerator: ContentGenerator,
private readonly model: string,
private readonly config: GenerateContentConfig = {},
private history: Content[] = [],
@@ -156,7 +154,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const apiCall = () =>
- this.modelsModule.generateContent({
+ this.contentGenerator.generateContent({
model: this.model,
contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config },
@@ -225,7 +223,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const apiCall = () =>
- this.modelsModule.generateContentStream({
+ this.contentGenerator.generateContentStream({
model: this.model,
contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config },
diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts
index 872e00f6..2514c99d 100644
--- a/packages/core/src/utils/nextSpeakerChecker.test.ts
+++ b/packages/core/src/utils/nextSpeakerChecker.test.ts
@@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => {
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat(
- mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name
{},