summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/server/src/core/turn.test.ts269
-rw-r--r--packages/server/src/core/turn.ts3
2 files changed, 271 insertions, 1 deletions
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts
new file mode 100644
index 00000000..90d3407f
--- /dev/null
+++ b/packages/server/src/core/turn.test.ts
@@ -0,0 +1,269 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import {
+ Turn,
+ GeminiEventType,
+ ServerGeminiToolCallRequestEvent,
+ ServerGeminiErrorEvent,
+} from './turn.js';
+import { Chat, GenerateContentResponse, Part, Content } from '@google/genai';
+import { reportError } from '../utils/errorReporting.js';
+
+const mockSendMessageStream = vi.fn();
+const mockGetHistory = vi.fn();
+
+vi.mock('@google/genai', async (importOriginal) => {
+ const actual = await importOriginal<typeof import('@google/genai')>();
+ const MockChat = vi.fn().mockImplementation(() => ({
+ sendMessageStream: mockSendMessageStream,
+ getHistory: mockGetHistory,
+ }));
+ return {
+ ...actual,
+ Chat: MockChat,
+ };
+});
+
+vi.mock('../utils/errorReporting', () => ({
+ reportError: vi.fn(),
+}));
+
+vi.mock('../utils/generateContentResponseUtilities', () => ({
+ getResponseText: (resp: GenerateContentResponse) =>
+ resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
+ undefined,
+}));
+
+describe('Turn', () => {
+ let turn: Turn;
+ // Define a type for the mocked Chat instance for clarity
+ type MockedChatInstance = {
+ sendMessageStream: typeof mockSendMessageStream;
+ getHistory: typeof mockGetHistory;
+ };
+ let mockChatInstance: MockedChatInstance;
+
+ beforeEach(() => {
+ vi.resetAllMocks();
+ mockChatInstance = {
+ sendMessageStream: mockSendMessageStream,
+ getHistory: mockGetHistory,
+ };
+ turn = new Turn(mockChatInstance as unknown as Chat);
+ mockGetHistory.mockReturnValue([]);
+ mockSendMessageStream.mockResolvedValue((async function* () {})());
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ describe('constructor', () => {
+ it('should initialize pendingToolCalls and debugResponses', () => {
+ expect(turn.pendingToolCalls).toEqual([]);
+ expect(turn.getDebugResponses()).toEqual([]);
+ });
+ });
+
+ describe('run', () => {
+ it('should yield content events for text parts', async () => {
+ const mockResponseStream = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
+ } as unknown as GenerateContentResponse;
+ yield {
+ candidates: [{ content: { parts: [{ text: ' world' }] } }],
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Hi' }];
+ for await (const event of turn.run(reqParts)) {
+ events.push(event);
+ }
+
+ expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
+ expect(events).toEqual([
+ { type: GeminiEventType.Content, value: 'Hello' },
+ { type: GeminiEventType.Content, value: ' world' },
+ ]);
+ expect(turn.getDebugResponses().length).toBe(2);
+ });
+
+ it('should yield tool_call_request events for function calls', async () => {
+ const mockResponseStream = (async function* () {
+ yield {
+ functionCalls: [
+ { id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
+ { name: 'tool2', args: { arg2: 'val2' } }, // No ID
+ ],
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Use tools' }];
+ for await (const event of turn.run(reqParts)) {
+ events.push(event);
+ }
+
+ expect(events.length).toBe(2);
+ const event1 = events[0] as ServerGeminiToolCallRequestEvent;
+ expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
+ expect(event1.value).toEqual(
+ expect.objectContaining({
+ callId: 'fc1',
+ name: 'tool1',
+ args: { arg1: 'val1' },
+ }),
+ );
+ expect(turn.pendingToolCalls[0]).toEqual(event1.value);
+
+ const event2 = events[1] as ServerGeminiToolCallRequestEvent;
+ expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
+ expect(event2.value).toEqual(
+ expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
+ );
+ expect(event2.value.callId).toEqual(
+ expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
+ );
+ expect(turn.pendingToolCalls[1]).toEqual(event2.value);
+ expect(turn.getDebugResponses().length).toBe(1);
+ });
+
+ it('should yield UserCancelled event if signal is aborted', async () => {
+ const abortController = new AbortController();
+ const mockResponseStream = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'First part' }] } }],
+ } as unknown as GenerateContentResponse;
+ abortController.abort();
+ yield {
+ candidates: [
+ {
+ content: {
+ parts: [{ text: 'Second part - should not be processed' }],
+ },
+ },
+ ],
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Test abort' }];
+ for await (const event of turn.run(reqParts, abortController.signal)) {
+ events.push(event);
+ }
+ expect(events).toEqual([
+ { type: GeminiEventType.Content, value: 'First part' },
+ { type: GeminiEventType.UserCancelled },
+ ]);
+ expect(turn.getDebugResponses().length).toBe(1);
+ });
+
+ it('should yield Error event and report if sendMessageStream throws', async () => {
+ const error = new Error('API Error');
+ mockSendMessageStream.mockRejectedValue(error);
+ const reqParts: Part[] = [{ text: 'Trigger error' }];
+ const historyContent: Content[] = [
+ { role: 'model', parts: [{ text: 'Previous history' }] },
+ ];
+ mockGetHistory.mockReturnValue(historyContent);
+
+ const events = [];
+ for await (const event of turn.run(reqParts)) {
+ events.push(event);
+ }
+
+ expect(events.length).toBe(1);
+ const errorEvent = events[0] as ServerGeminiErrorEvent;
+ expect(errorEvent.type).toBe(GeminiEventType.Error);
+ expect(errorEvent.value).toEqual({ message: 'API Error' });
+ expect(turn.getDebugResponses().length).toBe(0);
+ expect(reportError).toHaveBeenCalledWith(
+ error,
+ 'Error when talking to Gemini API',
+ [...historyContent, reqParts],
+ 'Turn.run-sendMessageStream',
+ );
+ });
+
+ it('should handle function calls with undefined name or args', async () => {
+ const mockResponseStream = (async function* () {
+ yield {
+ functionCalls: [
+ { id: 'fc1', name: undefined, args: { arg1: 'val1' } },
+ { id: 'fc2', name: 'tool2', args: undefined },
+ { id: 'fc3', name: undefined, args: undefined },
+ ],
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
+ for await (const event of turn.run(reqParts)) {
+ events.push(event);
+ }
+
+ expect(events.length).toBe(3);
+ const event1 = events[0] as ServerGeminiToolCallRequestEvent;
+ expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
+ expect(event1.value).toEqual(
+ expect.objectContaining({
+ callId: 'fc1',
+ name: 'undefined_tool_name',
+ args: { arg1: 'val1' },
+ }),
+ );
+ expect(turn.pendingToolCalls[0]).toEqual(event1.value);
+
+ const event2 = events[1] as ServerGeminiToolCallRequestEvent;
+ expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
+ expect(event2.value).toEqual(
+ expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
+ );
+ expect(turn.pendingToolCalls[1]).toEqual(event2.value);
+
+ const event3 = events[2] as ServerGeminiToolCallRequestEvent;
+ expect(event3.type).toBe(GeminiEventType.ToolCallRequest);
+ expect(event3.value).toEqual(
+ expect.objectContaining({
+ callId: 'fc3',
+ name: 'undefined_tool_name',
+ args: {},
+ }),
+ );
+ expect(turn.pendingToolCalls[2]).toEqual(event3.value);
+ expect(turn.getDebugResponses().length).toBe(1);
+ });
+ });
+
+ describe('getDebugResponses', () => {
+ it('should return collected debug responses', async () => {
+ const resp1 = {
+ candidates: [{ content: { parts: [{ text: 'Debug 1' }] } }],
+ } as unknown as GenerateContentResponse;
+ const resp2 = {
+ functionCalls: [{ name: 'debugTool' }],
+ } as unknown as GenerateContentResponse;
+ const mockResponseStream = (async function* () {
+ yield resp1;
+ yield resp2;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+ const reqParts: Part[] = [{ text: 'Hi' }];
+ for await (const _ of turn.run(reqParts)) {
+ // consume stream
+ }
+ expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
+ });
+ });
+});
diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts
index 38932041..a02b5eb6 100644
--- a/packages/server/src/core/turn.ts
+++ b/packages/server/src/core/turn.ts
@@ -128,11 +128,12 @@ export class Turn {
});
for await (const resp of responseStream) {
- this.debugResponses.push(resp);
if (signal?.aborted) {
yield { type: GeminiEventType.UserCancelled };
+ // Do not add resp to debugResponses if aborted before processing
return;
}
+ this.debugResponses.push(resp);
const text = getResponseText(resp);
if (text) {