summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/client.test.ts53
-rw-r--r--packages/core/src/core/client.ts73
2 files changed, 118 insertions, 8 deletions
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index dc3b8455..9d3791fd 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -8,11 +8,12 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
Chat,
+ Content,
EmbedContentResponse,
GenerateContentResponse,
GoogleGenAI,
} from '@google/genai';
-import { GeminiClient } from './client.js';
+import { findIndexAfterFraction, GeminiClient } from './client.js';
import { AuthType, ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
@@ -65,6 +66,54 @@ vi.mock('../telemetry/index.js', () => ({
logApiError: vi.fn(),
}));
+describe('findIndexAfterFraction', () => {
+ const history: Content[] = [
+ { role: 'user', parts: [{ text: 'This is the first message.' }] },
+ { role: 'model', parts: [{ text: 'This is the second message.' }] },
+ { role: 'user', parts: [{ text: 'This is the third message.' }] },
+ { role: 'model', parts: [{ text: 'This is the fourth message.' }] },
+ { role: 'user', parts: [{ text: 'This is the fifth message.' }] },
+ ];
+
+ it('should throw an error for non-positive numbers', () => {
+ expect(() => findIndexAfterFraction(history, 0)).toThrow(
+ 'Fraction must be between 0 and 1',
+ );
+ });
+
+ it('should throw an error for a fraction greater than or equal to 1', () => {
+ expect(() => findIndexAfterFraction(history, 1)).toThrow(
+ 'Fraction must be between 0 and 1',
+ );
+ });
+
+ it('should handle a fraction in the middle', () => {
+ // Total length is 257. 257 * 0.5 = 128.5
+ // 0: 53
+ // 1: 53 + 54 = 107
+ // 2: 107 + 53 = 160
+ // 160 >= 128.5, so index is 2
+ expect(findIndexAfterFraction(history, 0.5)).toBe(2);
+ });
+
+ it('should handle an empty history', () => {
+ expect(findIndexAfterFraction([], 0.5)).toBe(0);
+ });
+
+ it('should handle a history with only one item', () => {
+ expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0);
+ });
+
+ it('should handle history with weird parts', () => {
+ const historyWithEmptyParts: Content[] = [
+ { role: 'user', parts: [{ text: 'Message 1' }] },
+ { role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
+ { role: 'user', parts: [{ text: 'Message 2' }] },
+ ];
+ expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1);
+ });
+});
+
describe('Gemini Client (client.ts)', () => {
let client: GeminiClient;
beforeEach(async () => {
@@ -384,6 +433,7 @@ describe('Gemini Client (client.ts)', () => {
{ role: 'user', parts: [{ text: '...history...' }] },
]),
addHistory: vi.fn(),
+ setHistory: vi.fn(),
sendMessage: mockSendMessage,
};
client['chat'] = mockChat as GeminiChat;
@@ -735,6 +785,7 @@ describe('Gemini Client (client.ts)', () => {
const mockChat: Partial<GeminiChat> = {
getHistory: vi.fn().mockReturnValue(mockChatHistory),
+ setHistory: vi.fn(),
sendMessage: mockSendMessage,
};
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 69ed0dff..6cfcd407 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -45,6 +45,39 @@ function isThinkingSupported(model: string) {
return false;
}
+/**
+ * Returns the index of the content after the fraction of the total characters in the history.
+ *
+ * Exported for testing purposes.
+ */
+export function findIndexAfterFraction(
+ history: Content[],
+ fraction: number,
+): number {
+ if (fraction <= 0 || fraction >= 1) {
+ throw new Error('Fraction must be between 0 and 1');
+ }
+
+ const contentLengths = history.map(
+ (content) => JSON.stringify(content).length,
+ );
+
+ const totalCharacters = contentLengths.reduce(
+ (sum, length) => sum + length,
+ 0,
+ );
+ const targetCharacters = totalCharacters * fraction;
+
+ let charactersSoFar = 0;
+ for (let i = 0; i < contentLengths.length; i++) {
+ charactersSoFar += contentLengths[i];
+ if (charactersSoFar >= targetCharacters) {
+ return i;
+ }
+ }
+ return contentLengths.length;
+}
+
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
@@ -54,7 +87,16 @@ export class GeminiClient {
topP: 1,
};
private readonly MAX_TURNS = 100;
- private readonly TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7;
+ /**
+ * Threshold for compression token count as a fraction of the model's token limit.
+ * If the chat history exceeds this threshold, it will be compressed.
+ */
+ private readonly COMPRESSION_TOKEN_THRESHOLD = 0.7;
+ /**
+ * The fraction of the latest chat history to keep. A value of 0.3
+ * means that only the last 30% of the chat history will be kept after compression.
+ */
+ private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
constructor(private config: Config) {
if (config.getProxy()) {
@@ -90,11 +132,11 @@ export class GeminiClient {
return this.chat;
}
- async getHistory(): Promise<Content[]> {
+ getHistory(): Content[] {
return this.getChat().getHistory();
}
- async setHistory(history: Content[]): Promise<void> {
+ setHistory(history: Content[]) {
this.getChat().setHistory(history);
}
@@ -441,25 +483,41 @@ export class GeminiClient {
const model = this.config.getModel();
- let { totalTokens: originalTokenCount } =
+ const { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
model,
contents: curatedHistory,
});
if (originalTokenCount === undefined) {
console.warn(`Could not determine token count for model ${model}.`);
- originalTokenCount = 0;
+ return null;
}
// Don't compress if not forced and we are under the limit.
if (
!force &&
- originalTokenCount <
- this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
+ originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model)
) {
return null;
}
+ let compressBeforeIndex = findIndexAfterFraction(
+ curatedHistory,
+ 1 - this.COMPRESSION_PRESERVE_THRESHOLD,
+ );
+ // Find the first user message after the index. This is the start of the next turn.
+ while (
+ compressBeforeIndex < curatedHistory.length &&
+ curatedHistory[compressBeforeIndex]?.role !== 'user'
+ ) {
+ compressBeforeIndex++;
+ }
+
+ const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
+ const historyToKeep = curatedHistory.slice(compressBeforeIndex);
+
+ this.getChat().setHistory(historyToCompress);
+
const { text: summary } = await this.getChat().sendMessage({
message: {
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
@@ -477,6 +535,7 @@ export class GeminiClient {
role: 'model',
parts: [{ text: 'Got it. Thanks for the additional context!' }],
},
+ ...historyToKeep,
]);
const { totalTokens: newTokenCount } =