summaryrefslogtreecommitdiff
path: root/packages/server/src
diff options
context:
space:
mode:
authorTaylor Mullen <[email protected]>2025-05-29 23:53:35 -0700
committerN. Taylor Mullen <[email protected]>2025-05-30 10:43:48 -0700
commit9537ff476219486574fb6a50e54389a78beefe8e (patch)
tree95fe901641dec941f81b3a9f8d3dff610bd2e3d9 /packages/server/src
parent7c4a5464f68db32bb0069927f65deb2a61bd094f (diff)
feat(server): consolidate adjacent model content in chat history
- Consolidates consecutive model messages into a single message in the chat history. - This prevents multiple model messages from being displayed in a row, improving readability. - This may also address some instances of 500 errors that could have been caused by multiple, rapidly succeeding model messages. - Adds tests for the new consolidation logic. Fixes https://b.corp.google.com/issues/421010429
Diffstat (limited to 'packages/server/src')
-rw-r--r--packages/server/src/core/geminiChat.test.ts282
-rw-r--r--packages/server/src/core/geminiChat.ts40
2 files changed, 321 insertions, 1 deletions
diff --git a/packages/server/src/core/geminiChat.test.ts b/packages/server/src/core/geminiChat.test.ts
new file mode 100644
index 00000000..11e222c9
--- /dev/null
+++ b/packages/server/src/core/geminiChat.test.ts
@@ -0,0 +1,282 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import {
+ Content,
+ GoogleGenAI,
+ Models,
+ GenerateContentConfig,
+ Part,
+} from '@google/genai';
+import { GeminiChat } from './geminiChat.js';
+
+// Mocks
+const mockModelsModule = {
+ generateContent: vi.fn(),
+ generateContentStream: vi.fn(),
+ countTokens: vi.fn(),
+ embedContent: vi.fn(),
+ batchEmbedContents: vi.fn(),
+} as unknown as Models;
+
+const mockGoogleGenAI = {
+ getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
+} as unknown as GoogleGenAI;
+
+describe('GeminiChat', () => {
+ let chat: GeminiChat;
+ const model = 'gemini-pro';
+ const config: GenerateContentConfig = {};
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+ // Reset history for each test by creating a new instance
+ chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []);
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ describe('recordHistory', () => {
+ const userInput: Content = {
+ role: 'user',
+ parts: [{ text: 'User input' }],
+ };
+
+ it('should add user input and a single model output to history', () => {
+ const modelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'Model output' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutput);
+ const history = chat.getHistory();
+ expect(history).toEqual([userInput, modelOutput[0]]);
+ });
+
+ it('should consolidate adjacent model outputs', () => {
+ const modelOutputParts: Content[] = [
+ { role: 'model', parts: [{ text: 'Model part 1' }] },
+ { role: 'model', parts: [{ text: 'Model part 2' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutputParts);
+ const history = chat.getHistory();
+ expect(history.length).toBe(2);
+ expect(history[0]).toEqual(userInput);
+ expect(history[1].role).toBe('model');
+ expect(history[1].parts).toEqual([
+ { text: 'Model part 1' },
+ { text: 'Model part 2' },
+ ]);
+ });
+
+ it('should handle a mix of user and model roles in outputContents (though unusual)', () => {
+ const mixedOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'Model 1' }] },
+ { role: 'user', parts: [{ text: 'Unexpected User' }] }, // This should be pushed as is
+ { role: 'model', parts: [{ text: 'Model 2' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, mixedOutput);
+ const history = chat.getHistory();
+ expect(history.length).toBe(4); // user, model1, user_unexpected, model2
+ expect(history[0]).toEqual(userInput);
+ expect(history[1]).toEqual(mixedOutput[0]);
+ expect(history[2]).toEqual(mixedOutput[1]);
+ expect(history[3]).toEqual(mixedOutput[2]);
+ });
+
+ it('should consolidate multiple adjacent model outputs correctly', () => {
+ const modelOutputParts: Content[] = [
+ { role: 'model', parts: [{ text: 'M1' }] },
+ { role: 'model', parts: [{ text: 'M2' }] },
+ { role: 'model', parts: [{ text: 'M3' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutputParts);
+ const history = chat.getHistory();
+ expect(history.length).toBe(2);
+ expect(history[1].parts).toEqual([
+ { text: 'M1' },
+ { text: 'M2' },
+ { text: 'M3' },
+ ]);
+ });
+
+ it('should not consolidate if roles are different between model outputs', () => {
+ const modelOutputParts: Content[] = [
+ { role: 'model', parts: [{ text: 'M1' }] },
+ { role: 'user', parts: [{ text: 'Interjecting User' }] },
+ { role: 'model', parts: [{ text: 'M2' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutputParts);
+ const history = chat.getHistory();
+ expect(history.length).toBe(4); // user, M1, Interjecting User, M2
+ expect(history[1].parts).toEqual([{ text: 'M1' }]);
+ expect(history[3].parts).toEqual([{ text: 'M2' }]);
+ });
+
+ it('should merge with last history entry if it is also a model output', () => {
+ // @ts-expect-error Accessing private property for test setup
+ chat.history = [
+ userInput,
+ { role: 'model', parts: [{ text: 'Initial Model Output' }] },
+ ]; // Prime the history
+
+ const newModelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'New Model Part 1' }] },
+ { role: 'model', parts: [{ text: 'New Model Part 2' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
+
+ // const history = chat.getHistory(); // Removed unused variable to satisfy linter
+ // The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
+ // However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
+ // Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
+
+ // Reset and set up a more realistic scenario for merging with existing history
+ chat = new GeminiChat(
+ mockGoogleGenAI,
+ mockModelsModule,
+ model,
+ config,
+ [],
+ );
+ const firstUserInput: Content = {
+ role: 'user',
+ parts: [{ text: 'First user input' }],
+ };
+ const firstModelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'First model response' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(firstUserInput, firstModelOutput);
+
+ const secondUserInput: Content = {
+ role: 'user',
+ parts: [{ text: 'Second user input' }],
+ };
+ const secondModelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'Second model response part 1' }] },
+ { role: 'model', parts: [{ text: 'Second model response part 2' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(secondUserInput, secondModelOutput);
+
+ const finalHistory = chat.getHistory();
+ expect(finalHistory.length).toBe(4); // user1, model1, user2, model2(consolidated)
+ expect(finalHistory[0]).toEqual(firstUserInput);
+ expect(finalHistory[1]).toEqual(firstModelOutput[0]);
+ expect(finalHistory[2]).toEqual(secondUserInput);
+ expect(finalHistory[3].role).toBe('model');
+ expect(finalHistory[3].parts).toEqual([
+ { text: 'Second model response part 1' },
+ { text: 'Second model response part 2' },
+ ]);
+ });
+
+ it('should correctly merge consolidated new output with existing model history', () => {
+ // Setup: history ends with a model turn
+ const initialUser: Content = {
+ role: 'user',
+ parts: [{ text: 'Initial user query' }],
+ };
+ const initialModel: Content = {
+ role: 'model',
+ parts: [{ text: 'Initial model answer.' }],
+ };
+ chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [
+ initialUser,
+ initialModel,
+ ]);
+
+ // New interaction
+ const currentUserInput: Content = {
+ role: 'user',
+ parts: [{ text: 'Follow-up question' }],
+ };
+ const newModelParts: Content[] = [
+ { role: 'model', parts: [{ text: 'Part A of new answer.' }] },
+ { role: 'model', parts: [{ text: 'Part B of new answer.' }] },
+ ];
+
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(currentUserInput, newModelParts);
+ const history = chat.getHistory();
+
+ // Expected: initialUser, initialModel, currentUserInput, consolidatedNewModelParts
+ expect(history.length).toBe(4);
+ expect(history[0]).toEqual(initialUser);
+ expect(history[1]).toEqual(initialModel);
+ expect(history[2]).toEqual(currentUserInput);
+ expect(history[3].role).toBe('model');
+ expect(history[3].parts).toEqual([
+ { text: 'Part A of new answer.' },
+ { text: 'Part B of new answer.' },
+ ]);
+ });
+
+ it('should handle empty modelOutput array', () => {
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, []);
+ const history = chat.getHistory();
+ // If modelOutput is empty, it might push a default empty model part depending on isFunctionResponse
+ // Assuming isFunctionResponse(userInput) is false for this simple text input
+ expect(history.length).toBe(2);
+ expect(history[0]).toEqual(userInput);
+ expect(history[1].role).toBe('model');
+ expect(history[1].parts).toEqual([]);
+ });
+
+ it('should handle modelOutput with parts being undefined or empty (if they pass initial every check)', () => {
+ const modelOutputUndefinedParts: Content[] = [
+ { role: 'model', parts: [{ text: 'Text part' }] },
+ { role: 'model', parts: undefined as unknown as Part[] }, // Test undefined parts
+ { role: 'model', parts: [] }, // Test empty parts array
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutputUndefinedParts);
+ const history = chat.getHistory();
+ expect(history.length).toBe(2);
+ expect(history[1].role).toBe('model');
+ // The consolidation logic should handle undefined/empty parts by spreading `|| []`
+ expect(history[1].parts).toEqual([{ text: 'Text part' }]);
+ });
+
+ it('should correctly handle automaticFunctionCallingHistory', () => {
+ const afcHistory: Content[] = [
+ { role: 'user', parts: [{ text: 'AFC User' }] },
+ { role: 'model', parts: [{ text: 'AFC Model' }] },
+ ];
+ const modelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'Regular Model Output' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutput, afcHistory);
+ const history = chat.getHistory();
+ expect(history.length).toBe(3);
+ expect(history[0]).toEqual(afcHistory[0]);
+ expect(history[1]).toEqual(afcHistory[1]);
+ expect(history[2]).toEqual(modelOutput[0]);
+ });
+
+ it('should add userInput if AFC history is present but empty', () => {
+ const modelOutput: Content[] = [
+ { role: 'model', parts: [{ text: 'Model Output' }] },
+ ];
+ // @ts-expect-error Accessing private method for testing purposes
+ chat.recordHistory(userInput, modelOutput, []); // Empty AFC history
+ const history = chat.getHistory();
+ expect(history.length).toBe(2);
+ expect(history[0]).toEqual(userInput);
+ expect(history[1]).toEqual(modelOutput[0]);
+ });
+ });
+});
diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts
index 5ba8ce2d..877d0825 100644
--- a/packages/server/src/core/geminiChat.ts
+++ b/packages/server/src/core/geminiChat.ts
@@ -313,6 +313,44 @@ export class GeminiChat {
} else {
this.history.push(userInput);
}
- this.history.push(...outputContents);
+
+ // Consolidate adjacent model roles in outputContents
+ const consolidatedOutputContents: Content[] = [];
+ for (const content of outputContents) {
+ const lastContent =
+ consolidatedOutputContents[consolidatedOutputContents.length - 1];
+ if (
+ lastContent &&
+ lastContent.role === 'model' &&
+ content.role === 'model' &&
+ lastContent.parts
+ ) {
+ lastContent.parts.push(...(content.parts || []));
+ } else {
+ consolidatedOutputContents.push(content);
+ }
+ }
+
+ if (consolidatedOutputContents.length > 0) {
+ const lastHistoryEntry = this.history[this.history.length - 1];
+ // Only merge if AFC history was NOT just added, to prevent merging with last AFC model turn.
+ const canMergeWithLastHistory =
+ !automaticFunctionCallingHistory ||
+ automaticFunctionCallingHistory.length === 0;
+
+ if (
+ canMergeWithLastHistory &&
+ lastHistoryEntry &&
+ lastHistoryEntry.role === 'model' &&
+ lastHistoryEntry.parts &&
+ consolidatedOutputContents[0].role === 'model'
+ ) {
+ lastHistoryEntry.parts.push(
+ ...(consolidatedOutputContents[0].parts || []),
+ );
+ consolidatedOutputContents.shift(); // Remove the first element as it's merged
+ }
+ this.history.push(...consolidatedOutputContents);
+ }
}
}