summaryrefslogtreecommitdiff
path: root/packages/server/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/server/src')
-rw-r--r--packages/server/src/core/client.ts16
-rw-r--r--packages/server/src/core/geminiChat.ts314
-rw-r--r--packages/server/src/core/turn.test.ts5
-rw-r--r--packages/server/src/core/turn.ts4
-rw-r--r--packages/server/src/utils/nextSpeakerChecker.test.ts7
-rw-r--r--packages/server/src/utils/nextSpeakerChecker.ts5
6 files changed, 335 insertions, 16 deletions
diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts
index 85850da8..3d5927e3 100644
--- a/packages/server/src/core/client.ts
+++ b/packages/server/src/core/client.ts
@@ -8,7 +8,6 @@ import {
GenerateContentConfig,
GoogleGenAI,
Part,
- Chat,
SchemaUnion,
PartListUnion,
Content,
@@ -23,6 +22,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
import { reportError } from '../utils/errorReporting.js';
+import { GeminiChat } from './geminiChat.js';
export class GeminiClient {
private client: GoogleGenAI;
@@ -108,7 +108,7 @@ export class GeminiClient {
return initialParts;
}
- async startChat(): Promise<Chat> {
+ async startChat(): Promise<GeminiChat> {
const envParts = await this.getEnvironment();
const toolDeclarations = this.config
.getToolRegistry()
@@ -128,15 +128,17 @@ export class GeminiClient {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
- return this.client.chats.create({
- model: this.model,
- config: {
+ return new GeminiChat(
+ this.client,
+ this.client.models,
+ this.model,
+ {
systemInstruction,
...this.generateContentConfig,
tools,
},
history,
- });
+ );
} catch (error) {
await reportError(
error,
@@ -150,7 +152,7 @@ export class GeminiClient {
}
async *sendMessageStream(
- chat: Chat,
+ chat: GeminiChat,
request: PartListUnion,
signal?: AbortSignal,
turns: number = this.MAX_TURNS,
diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts
new file mode 100644
index 00000000..dd5f3b7a
--- /dev/null
+++ b/packages/server/src/core/geminiChat.ts
@@ -0,0 +1,314 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug
+// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090
+
+import {
+ GenerateContentResponse,
+ Content,
+ Models,
+ GenerateContentConfig,
+ SendMessageParameters,
+ GoogleGenAI,
+ createUserContent,
+} from '@google/genai';
+
+/**
+ * Returns true if the response is valid, false otherwise.
+ */
+function isValidResponse(response: GenerateContentResponse): boolean {
+ if (response.candidates === undefined || response.candidates.length === 0) {
+ return false;
+ }
+ const content = response.candidates[0]?.content;
+ if (content === undefined) {
+ return false;
+ }
+ return isValidContent(content);
+}
+
+function isValidContent(content: Content): boolean {
+ if (content.parts === undefined || content.parts.length === 0) {
+ return false;
+ }
+ for (const part of content.parts) {
+ if (part === undefined || Object.keys(part).length === 0) {
+ return false;
+ }
+ if (!part.thought && part.text !== undefined && part.text === '') {
+ return false;
+ }
+ }
+ return true;
+}
+
+/**
+ * Validates the history contains the correct roles.
+ *
+ * @throws Error if the history does not start with a user turn.
+ * @throws Error if the history contains an invalid role.
+ */
+function validateHistory(history: Content[]) {
+ // Empty history is valid.
+ if (history.length === 0) {
+ return;
+ }
+ for (const content of history) {
+ if (content.role !== 'user' && content.role !== 'model') {
+ throw new Error(`Role must be user or model, but got ${content.role}.`);
+ }
+ }
+}
+
+/**
+ * Extracts the curated (valid) history from a comprehensive history.
+ *
+ * @remarks
+ * The model may sometimes generate invalid or empty contents(e.g., due to safty
+ * filters or recitation). Extracting valid turns from the history
+ * ensures that subsequent requests could be accpeted by the model.
+ */
+function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] {
+ if (comprehensiveHistory === undefined || comprehensiveHistory.length === 0) {
+ return [];
+ }
+ const curatedHistory: Content[] = [];
+ const length = comprehensiveHistory.length;
+ let i = 0;
+ while (i < length) {
+ if (comprehensiveHistory[i].role === 'user') {
+ curatedHistory.push(comprehensiveHistory[i]);
+ i++;
+ } else {
+ const modelOutput: Content[] = [];
+ let isValid = true;
+ while (i < length && comprehensiveHistory[i].role === 'model') {
+ modelOutput.push(comprehensiveHistory[i]);
+ if (isValid && !isValidContent(comprehensiveHistory[i])) {
+ isValid = false;
+ }
+ i++;
+ }
+ if (isValid) {
+ curatedHistory.push(...modelOutput);
+ } else {
+ // Remove the last user input when model content is invalid.
+ curatedHistory.pop();
+ }
+ }
+ }
+ return curatedHistory;
+}
+
+/**
+ * Chat session that enables sending messages to the model with previous
+ * conversation context.
+ *
+ * @remarks
+ * The session maintains all the turns between user and model.
+ */
+export class GeminiChat {
+ // A promise to represent the current state of the message being sent to the
+ // model.
+ private sendPromise: Promise<void> = Promise.resolve();
+
+ constructor(
+ private readonly apiClient: GoogleGenAI,
+ private readonly modelsModule: Models,
+ private readonly model: string,
+ private readonly config: GenerateContentConfig = {},
+ private history: Content[] = [],
+ ) {
+ validateHistory(history);
+ }
+
+ /**
+ * Sends a message to the model and returns the response.
+ *
+ * @remarks
+ * This method will wait for the previous message to be processed before
+ * sending the next message.
+ *
+ * @see {@link Chat#sendMessageStream} for streaming method.
+ * @param params - parameters for sending messages within a chat session.
+ * @returns The model's response.
+ *
+ * @example
+ * ```ts
+ * const chat = ai.chats.create({model: 'gemini-2.0-flash'});
+ * const response = await chat.sendMessage({
+ * message: 'Why is the sky blue?'
+ * });
+ * console.log(response.text);
+ * ```
+ */
+ async sendMessage(
+ params: SendMessageParameters,
+ ): Promise<GenerateContentResponse> {
+ await this.sendPromise;
+ const userContent = createUserContent(params.message);
+ const responsePromise = this.modelsModule.generateContent({
+ model: this.model,
+ contents: this.getHistory(true).concat(userContent),
+ config: params.config ?? this.config,
+ });
+ this.sendPromise = (async () => {
+ const response = await responsePromise;
+ const outputContent = response.candidates?.[0]?.content;
+
+ // Because the AFC input contains the entire curated chat history in
+ // addition to the new user input, we need to truncate the AFC history
+ // to deduplicate the existing chat history.
+ const fullAutomaticFunctionCallingHistory =
+ response.automaticFunctionCallingHistory;
+ const index = this.getHistory(true).length;
+
+ let automaticFunctionCallingHistory: Content[] = [];
+ if (fullAutomaticFunctionCallingHistory != null) {
+ automaticFunctionCallingHistory =
+ fullAutomaticFunctionCallingHistory.slice(index) ?? [];
+ }
+
+ const modelOutput = outputContent ? [outputContent] : [];
+ this.recordHistory(
+ userContent,
+ modelOutput,
+ automaticFunctionCallingHistory,
+ );
+ return;
+ })();
+ await this.sendPromise.catch(() => {
+ // Resets sendPromise to avoid subsequent calls failing
+ this.sendPromise = Promise.resolve();
+ });
+ return responsePromise;
+ }
+
+ /**
+ * Sends a message to the model and returns the response in chunks.
+ *
+ * @remarks
+ * This method will wait for the previous message to be processed before
+ * sending the next message.
+ *
+ * @see {@link Chat#sendMessage} for non-streaming method.
+ * @param params - parameters for sending the message.
+ * @return The model's response.
+ *
+ * @example
+ * ```ts
+ * const chat = ai.chats.create({model: 'gemini-2.0-flash'});
+ * const response = await chat.sendMessageStream({
+ * message: 'Why is the sky blue?'
+ * });
+ * for await (const chunk of response) {
+ * console.log(chunk.text);
+ * }
+ * ```
+ */
+ async sendMessageStream(
+ params: SendMessageParameters,
+ ): Promise<AsyncGenerator<GenerateContentResponse>> {
+ await this.sendPromise;
+ const userContent = createUserContent(params.message);
+ const streamResponse = this.modelsModule.generateContentStream({
+ model: this.model,
+ contents: this.getHistory(true).concat(userContent),
+ config: params.config ?? this.config,
+ });
+ // Resolve the internal tracking of send completion promise - `sendPromise`
+ // for both success and failure response. The actual failure is still
+ // propagated by the `await streamResponse`.
+ this.sendPromise = streamResponse
+ .then(() => undefined)
+ .catch(() => undefined);
+ const response = await streamResponse;
+ const result = this.processStreamResponse(response, userContent);
+ return result;
+ }
+
+ /**
+ * Returns the chat history.
+ *
+ * @remarks
+ * The history is a list of contents alternating between user and model.
+ *
+ * There are two types of history:
+ * - The `curated history` contains only the valid turns between user and
+ * model, which will be included in the subsequent requests sent to the model.
+ * - The `comprehensive history` contains all turns, including invalid or
+ * empty model outputs, providing a complete record of the history.
+ *
+ * The history is updated after receiving the response from the model,
+ * for streaming response, it means receiving the last chunk of the response.
+ *
+ * The `comprehensive history` is returned by default. To get the `curated
+ * history`, set the `curated` parameter to `true`.
+ *
+ * @param curated - whether to return the curated history or the comprehensive
+ * history.
+ * @return History contents alternating between user and model for the entire
+ * chat session.
+ */
+ getHistory(curated: boolean = false): Content[] {
+ const history = curated
+ ? extractCuratedHistory(this.history)
+ : this.history;
+ // Deep copy the history to avoid mutating the history outside of the
+ // chat session.
+ return structuredClone(history);
+ }
+
+ private async *processStreamResponse(
+ streamResponse: AsyncGenerator<GenerateContentResponse>,
+ inputContent: Content,
+ ) {
+ const outputContent: Content[] = [];
+ for await (const chunk of streamResponse) {
+ if (isValidResponse(chunk)) {
+ const content = chunk.candidates?.[0]?.content;
+ if (content !== undefined) {
+ outputContent.push(content);
+ }
+ }
+ yield chunk;
+ }
+ this.recordHistory(inputContent, outputContent);
+ }
+
+ private recordHistory(
+ userInput: Content,
+ modelOutput: Content[],
+ automaticFunctionCallingHistory?: Content[],
+ ) {
+ let outputContents: Content[] = [];
+ if (
+ modelOutput.length > 0 &&
+ modelOutput.every((content) => content.role !== undefined)
+ ) {
+ outputContents = modelOutput;
+ } else {
+ // Appends an empty content when model returns empty response, so that the
+ // history is always alternating between user and model.
+ outputContents.push({
+ role: 'model',
+ parts: [],
+ } as Content);
+ }
+ if (
+ automaticFunctionCallingHistory &&
+ automaticFunctionCallingHistory.length > 0
+ ) {
+ this.history.push(
+ ...extractCuratedHistory(automaticFunctionCallingHistory!),
+ );
+ } else {
+ this.history.push(userInput);
+ }
+ this.history.push(...outputContents);
+ }
+}
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts
index 90d3407f..44bb983f 100644
--- a/packages/server/src/core/turn.test.ts
+++ b/packages/server/src/core/turn.test.ts
@@ -11,8 +11,9 @@ import {
ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent,
} from './turn.js';
-import { Chat, GenerateContentResponse, Part, Content } from '@google/genai';
+import { GenerateContentResponse, Part, Content } from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
+import { GeminiChat } from './geminiChat.js';
const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn();
@@ -54,7 +55,7 @@ describe('Turn', () => {
sendMessageStream: mockSendMessageStream,
getHistory: mockGetHistory,
};
- turn = new Turn(mockChatInstance as unknown as Chat);
+ turn = new Turn(mockChatInstance as unknown as GeminiChat);
mockGetHistory.mockReturnValue([]);
mockSendMessageStream.mockResolvedValue((async function* () {})());
});
diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts
index a02b5eb6..d5c7eb58 100644
--- a/packages/server/src/core/turn.ts
+++ b/packages/server/src/core/turn.ts
@@ -6,7 +6,6 @@
import {
Part,
- Chat,
PartListUnion,
GenerateContentResponse,
FunctionCall,
@@ -20,6 +19,7 @@ import {
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { reportError } from '../utils/errorReporting.js';
import { getErrorMessage } from '../utils/errors.js';
+import { GeminiChat } from './geminiChat.js';
// Define a structure for tools passed to the server
export interface ServerTool {
@@ -113,7 +113,7 @@ export class Turn {
}>;
private debugResponses: GenerateContentResponse[];
- constructor(private readonly chat: Chat) {
+ constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = [];
this.debugResponses = [];
}
diff --git a/packages/server/src/utils/nextSpeakerChecker.test.ts b/packages/server/src/utils/nextSpeakerChecker.test.ts
index b8d17875..f32227e9 100644
--- a/packages/server/src/utils/nextSpeakerChecker.test.ts
+++ b/packages/server/src/utils/nextSpeakerChecker.test.ts
@@ -5,10 +5,11 @@
*/
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
-import { Chat, Content } from '@google/genai';
+import { Content } from '@google/genai';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js'; // Added Config import
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
+import { GeminiChat } from '../core/geminiChat.js';
// Mock GeminiClient and Config constructor
vi.mock('../core/client.js');
@@ -39,7 +40,7 @@ vi.mock('@google/genai', async () => {
});
describe('checkNextSpeaker', () => {
- let mockChat: Chat;
+ let mockChat: GeminiChat;
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
@@ -64,7 +65,7 @@ describe('checkNextSpeaker', () => {
mockGeminiClient = new GeminiClient(mockConfigInstance);
// Simulate chat creation as done in GeminiClient
- mockChat = { getHistory: mockGetHistory } as unknown as Chat;
+ mockChat = { getHistory: mockGetHistory } as unknown as GeminiChat;
});
afterEach(() => {
diff --git a/packages/server/src/utils/nextSpeakerChecker.ts b/packages/server/src/utils/nextSpeakerChecker.ts
index 5eb0c512..3fe813db 100644
--- a/packages/server/src/utils/nextSpeakerChecker.ts
+++ b/packages/server/src/utils/nextSpeakerChecker.ts
@@ -4,8 +4,9 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import { Chat, Content, SchemaUnion, Type } from '@google/genai';
+import { Content, SchemaUnion, Type } from '@google/genai';
import { GeminiClient } from '../core/client.js';
+import { GeminiChat } from '../core/geminiChat.js';
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
**Decision Rules (apply in order):**
@@ -57,7 +58,7 @@ export interface NextSpeakerResponse {
}
export async function checkNextSpeaker(
- chat: Chat,
+ chat: GeminiChat,
geminiClient: GeminiClient,
): Promise<NextSpeakerResponse | null> {
// We need to capture the curated history because there are many moments when the model will return invalid turns