summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorTommaso Sciortino <[email protected]>2025-07-14 10:09:11 -0700
committerGitHub <[email protected]>2025-07-14 17:09:11 +0000
commit2f1d6234def2c8c77c2afebd9f83a2dcf3d6aacd (patch)
treeb07162bc9016c400d2ed01bec5555daf5f775e2d /packages/core/src
parentc313c3dee1872a0edc943ad096eab68a03a3dda5 (diff)
Don't start uncompressed history with a function response (#4141)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/client.test.ts86
-rw-r--r--packages/core/src/core/client.ts4
2 files changed, 77 insertions, 13 deletions
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index bbcb549b..03793bda 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -470,34 +470,31 @@ describe('Gemini Client (client.ts)', () => {
describe('tryCompressChat', () => {
const mockCountTokens = vi.fn();
const mockSendMessage = vi.fn();
+ const mockGetHistory = vi.fn();
beforeEach(() => {
vi.mock('./tokenLimits', () => ({
tokenLimit: vi.fn(),
}));
- const mockGenerator: Partial<ContentGenerator> = {
+ client['contentGenerator'] = {
countTokens: mockCountTokens,
- };
- client['contentGenerator'] = mockGenerator as ContentGenerator;
+ } as unknown as ContentGenerator;
- // Mock the chat's sendMessage method
- const mockChat: Partial<GeminiChat> = {
- getHistory: vi
- .fn()
- .mockReturnValue([
- { role: 'user', parts: [{ text: '...history...' }] },
- ]),
+ client['chat'] = {
+ getHistory: mockGetHistory,
addHistory: vi.fn(),
setHistory: vi.fn(),
sendMessage: mockSendMessage,
- };
- client['chat'] = mockChat as GeminiChat;
+ } as unknown as GeminiChat;
});
it('should not trigger summarization if token count is below threshold', async () => {
const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
+ mockGetHistory.mockReturnValue([
+ { role: 'user', parts: [{ text: '...history...' }] },
+ ]);
mockCountTokens.mockResolvedValue({
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
@@ -515,6 +512,9 @@ describe('Gemini Client (client.ts)', () => {
it('should trigger summarization if token count is at threshold', async () => {
const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
+ mockGetHistory.mockReturnValue([
+ { role: 'user', parts: [{ text: '...history...' }] },
+ ]);
const originalTokenCount = 1000 * 0.7;
const newTokenCount = 100;
@@ -546,7 +546,69 @@ describe('Gemini Client (client.ts)', () => {
expect(newChat).not.toBe(initialChat);
});
+ it('should not compress across a function call response', async () => {
+ const MOCKED_TOKEN_LIMIT = 1000;
+ vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
+ mockGetHistory.mockReturnValue([
+ { role: 'user', parts: [{ text: '...history 1...' }] },
+ { role: 'model', parts: [{ text: '...history 2...' }] },
+ { role: 'user', parts: [{ text: '...history 3...' }] },
+ { role: 'model', parts: [{ text: '...history 4...' }] },
+ { role: 'user', parts: [{ text: '...history 5...' }] },
+ { role: 'model', parts: [{ text: '...history 6...' }] },
+ { role: 'user', parts: [{ text: '...history 7...' }] },
+ { role: 'model', parts: [{ text: '...history 8...' }] },
+ // Normally we would break here, but we have a function response.
+ {
+ role: 'user',
+ parts: [{ functionResponse: { name: '...history 8...' } }],
+ },
+ { role: 'model', parts: [{ text: '...history 10...' }] },
+ // Instead we will break here.
+ { role: 'user', parts: [{ text: '...history 10...' }] },
+ ]);
+
+ const originalTokenCount = 1000 * 0.7;
+ const newTokenCount = 100;
+
+ mockCountTokens
+ .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
+ .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
+
+ // Mock the summary response from the chat
+ mockSendMessage.mockResolvedValue({
+ role: 'model',
+ parts: [{ text: 'This is a summary.' }],
+ });
+
+ const initialChat = client.getChat();
+ const result = await client.tryCompressChat('prompt-id-3');
+ const newChat = client.getChat();
+
+ expect(tokenLimit).toHaveBeenCalled();
+ expect(mockSendMessage).toHaveBeenCalled();
+
+ // Assert that summarization happened and returned the correct stats
+ expect(result).toEqual({
+ originalTokenCount,
+ newTokenCount,
+ });
+ // Assert that the chat was reset
+ expect(newChat).not.toBe(initialChat);
+
+ // 1. standard start context message
+ // 2. standard canned user start message
+ // 3. compressed summary message
+ // 4. standard canned user summary message
+ // 5. The last user message (not the last 3 because that would start with a function response)
+ expect(newChat.getHistory().length).toEqual(5);
+ });
+
it('should always trigger summarization when force is true, regardless of token count', async () => {
+ mockGetHistory.mockReturnValue([
+ { role: 'user', parts: [{ text: '...history...' }] },
+ ]);
+
const originalTokenCount = 10; // Well below threshold
const newTokenCount = 5;
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index ed903788..d8143d05 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -30,6 +30,7 @@ import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js';
import { getErrorMessage } from '../utils/errors.js';
+import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js';
import {
AuthType,
@@ -547,7 +548,8 @@ export class GeminiClient {
// Find the first user message after the index. This is the start of the next turn.
while (
compressBeforeIndex < curatedHistory.length &&
- curatedHistory[compressBeforeIndex]?.role !== 'user'
+ (curatedHistory[compressBeforeIndex]?.role === 'model' ||
+ isFunctionResponse(curatedHistory[compressBeforeIndex]))
) {
compressBeforeIndex++;
}