summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/client.test.ts110
-rw-r--r--packages/core/src/core/client.ts22
2 files changed, 122 insertions, 10 deletions
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index 0adbf986..dc3b8455 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -687,4 +687,114 @@ describe('Gemini Client (client.ts)', () => {
);
});
});
+
+ describe('generateContent', () => {
+ it('should use current model from config for content generation', async () => {
+ const initialModel = client['config'].getModel();
+ const contents = [{ role: 'user', parts: [{ text: 'test' }] }];
+ const currentModel = initialModel + '-changed';
+
+ vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
+
+ const mockGenerator: Partial<ContentGenerator> = {
+ countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
+ generateContent: mockGenerateContentFn,
+ };
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+
+ await client.generateContent(contents, {}, new AbortController().signal);
+
+ expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
+ model: initialModel,
+ config: expect.any(Object),
+ contents,
+ });
+ expect(mockGenerateContentFn).toHaveBeenCalledWith({
+ model: currentModel,
+ config: expect.any(Object),
+ contents,
+ });
+ });
+ });
+
+ describe('tryCompressChat', () => {
+ it('should use current model from config for token counting after sendMessage', async () => {
+ const initialModel = client['config'].getModel();
+
+ const mockCountTokens = vi
+ .fn()
+ .mockResolvedValueOnce({ totalTokens: 100000 })
+ .mockResolvedValueOnce({ totalTokens: 5000 });
+
+ const mockSendMessage = vi.fn().mockResolvedValue({ text: 'Summary' });
+
+ const mockChatHistory = [
+ { role: 'user', parts: [{ text: 'Long conversation' }] },
+ { role: 'model', parts: [{ text: 'Long response' }] },
+ ];
+
+ const mockChat: Partial<GeminiChat> = {
+ getHistory: vi.fn().mockReturnValue(mockChatHistory),
+ sendMessage: mockSendMessage,
+ };
+
+ const mockGenerator: Partial<ContentGenerator> = {
+ countTokens: mockCountTokens,
+ };
+
+ // mock the model has been changed between calls of `countTokens`
+ const firstCurrentModel = initialModel + '-changed-1';
+ const secondCurrentModel = initialModel + '-changed-2';
+ vi.spyOn(client['config'], 'getModel')
+ .mockReturnValueOnce(firstCurrentModel)
+ .mockReturnValueOnce(secondCurrentModel);
+
+ client['chat'] = mockChat as GeminiChat;
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+ client['startChat'] = vi.fn().mockResolvedValue(mockChat);
+
+ const result = await client.tryCompressChat(true);
+
+ expect(mockCountTokens).toHaveBeenCalledTimes(2);
+ expect(mockCountTokens).toHaveBeenNthCalledWith(1, {
+ model: firstCurrentModel,
+ contents: mockChatHistory,
+ });
+ expect(mockCountTokens).toHaveBeenNthCalledWith(2, {
+ model: secondCurrentModel,
+ contents: expect.any(Array),
+ });
+
+ expect(result).toEqual({
+ originalTokenCount: 100000,
+ newTokenCount: 5000,
+ });
+ });
+ });
+
+ describe('handleFlashFallback', () => {
+ it('should use current model from config when checking for fallback', async () => {
+ const initialModel = client['config'].getModel();
+ const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
+
+ // mock config been changed
+ const currentModel = initialModel + '-changed';
+ vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
+
+ const mockFallbackHandler = vi.fn().mockResolvedValue(true);
+ client['config'].flashFallbackHandler = mockFallbackHandler;
+ client['config'].setModel = vi.fn();
+
+ const result = await client['handleFlashFallback'](
+ AuthType.LOGIN_WITH_GOOGLE,
+ );
+
+ expect(result).toBe(fallbackModel);
+
+ expect(mockFallbackHandler).toHaveBeenCalledWith(
+ currentModel,
+ fallbackModel,
+ );
+ });
+ });
});
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index b39b10a0..69ed0dff 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -48,7 +48,6 @@ function isThinkingSupported(model: string) {
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
- private model: string;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
@@ -62,7 +61,6 @@ export class GeminiClient {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
- this.model = config.getModel();
this.embeddingModel = config.getEmbeddingModel();
}
@@ -187,7 +185,9 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
- const generateContentConfigWithThinking = isThinkingSupported(this.model)
+ const generateContentConfigWithThinking = isThinkingSupported(
+ this.config.getModel(),
+ )
? {
...this.generateContentConfig,
thinkingConfig: {
@@ -345,7 +345,7 @@ export class GeminiClient {
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
): Promise<GenerateContentResponse> {
- const modelToUse = this.model;
+ const modelToUse = this.config.getModel();
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
@@ -439,13 +439,15 @@ export class GeminiClient {
return null;
}
+ const model = this.config.getModel();
+
let { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
- model: this.model,
+ model,
contents: curatedHistory,
});
if (originalTokenCount === undefined) {
- console.warn(`Could not determine token count for model ${this.model}.`);
+ console.warn(`Could not determine token count for model ${model}.`);
originalTokenCount = 0;
}
@@ -453,7 +455,7 @@ export class GeminiClient {
if (
!force &&
originalTokenCount <
- this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(this.model)
+ this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
) {
return null;
}
@@ -479,7 +481,8 @@ export class GeminiClient {
const { totalTokens: newTokenCount } =
await this.getContentGenerator().countTokens({
- model: this.model,
+ // model might change after calling `sendMessage`, so we get the newest value from config
+ model: this.config.getModel(),
contents: this.getChat().getHistory(),
});
if (newTokenCount === undefined) {
@@ -503,7 +506,7 @@ export class GeminiClient {
return null;
}
- const currentModel = this.model;
+ const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
@@ -518,7 +521,6 @@ export class GeminiClient {
const accepted = await fallbackHandler(currentModel, fallbackModel);
if (accepted) {
this.config.setModel(fallbackModel);
- this.model = fallbackModel;
return fallbackModel;
}
} catch (error) {