summaryrefslogtreecommitdiff
path: root/packages/server/src/core/turn.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/server/src/core/turn.test.ts')
-rw-r--r--packages/server/src/core/turn.test.ts285
1 files changed, 0 insertions, 285 deletions
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts
deleted file mode 100644
index 8fb3a4c1..00000000
--- a/packages/server/src/core/turn.test.ts
+++ /dev/null
@@ -1,285 +0,0 @@
-/**
- * @license
- * Copyright 2025 Google LLC
- * SPDX-License-Identifier: Apache-2.0
- */
-
-import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
-import {
- Turn,
- GeminiEventType,
- ServerGeminiToolCallRequestEvent,
- ServerGeminiErrorEvent,
-} from './turn.js';
-import { GenerateContentResponse, Part, Content } from '@google/genai';
-import { reportError } from '../utils/errorReporting.js';
-import { GeminiChat } from './geminiChat.js';
-
-const mockSendMessageStream = vi.fn();
-const mockGetHistory = vi.fn();
-
-vi.mock('@google/genai', async (importOriginal) => {
- const actual = await importOriginal<typeof import('@google/genai')>();
- const MockChat = vi.fn().mockImplementation(() => ({
- sendMessageStream: mockSendMessageStream,
- getHistory: mockGetHistory,
- }));
- return {
- ...actual,
- Chat: MockChat,
- };
-});
-
-vi.mock('../utils/errorReporting', () => ({
- reportError: vi.fn(),
-}));
-
-vi.mock('../utils/generateContentResponseUtilities', () => ({
- getResponseText: (resp: GenerateContentResponse) =>
- resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
- undefined,
-}));
-
-describe('Turn', () => {
- let turn: Turn;
- // Define a type for the mocked Chat instance for clarity
- type MockedChatInstance = {
- sendMessageStream: typeof mockSendMessageStream;
- getHistory: typeof mockGetHistory;
- };
- let mockChatInstance: MockedChatInstance;
-
- beforeEach(() => {
- vi.resetAllMocks();
- mockChatInstance = {
- sendMessageStream: mockSendMessageStream,
- getHistory: mockGetHistory,
- };
- turn = new Turn(mockChatInstance as unknown as GeminiChat);
- mockGetHistory.mockReturnValue([]);
- mockSendMessageStream.mockResolvedValue((async function* () {})());
- });
-
- afterEach(() => {
- vi.restoreAllMocks();
- });
-
- describe('constructor', () => {
- it('should initialize pendingToolCalls and debugResponses', () => {
- expect(turn.pendingToolCalls).toEqual([]);
- expect(turn.getDebugResponses()).toEqual([]);
- });
- });
-
- describe('run', () => {
- it('should yield content events for text parts', async () => {
- const mockResponseStream = (async function* () {
- yield {
- candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
- } as unknown as GenerateContentResponse;
- yield {
- candidates: [{ content: { parts: [{ text: ' world' }] } }],
- } as unknown as GenerateContentResponse;
- })();
- mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
- const events = [];
- const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const event of turn.run(
- reqParts,
- new AbortController().signal,
- )) {
- events.push(event);
- }
-
- expect(mockSendMessageStream).toHaveBeenCalledWith({
- message: reqParts,
- config: { abortSignal: expect.any(AbortSignal) },
- });
- expect(events).toEqual([
- { type: GeminiEventType.Content, value: 'Hello' },
- { type: GeminiEventType.Content, value: ' world' },
- ]);
- expect(turn.getDebugResponses().length).toBe(2);
- });
-
- it('should yield tool_call_request events for function calls', async () => {
- const mockResponseStream = (async function* () {
- yield {
- functionCalls: [
- { id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
- { name: 'tool2', args: { arg2: 'val2' } }, // No ID
- ],
- } as unknown as GenerateContentResponse;
- })();
- mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
- const events = [];
- const reqParts: Part[] = [{ text: 'Use tools' }];
- for await (const event of turn.run(
- reqParts,
- new AbortController().signal,
- )) {
- events.push(event);
- }
-
- expect(events.length).toBe(2);
- const event1 = events[0] as ServerGeminiToolCallRequestEvent;
- expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
- expect(event1.value).toEqual(
- expect.objectContaining({
- callId: 'fc1',
- name: 'tool1',
- args: { arg1: 'val1' },
- }),
- );
- expect(turn.pendingToolCalls[0]).toEqual(event1.value);
-
- const event2 = events[1] as ServerGeminiToolCallRequestEvent;
- expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
- expect(event2.value).toEqual(
- expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
- );
- expect(event2.value.callId).toEqual(
- expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
- );
- expect(turn.pendingToolCalls[1]).toEqual(event2.value);
- expect(turn.getDebugResponses().length).toBe(1);
- });
-
- it('should yield UserCancelled event if signal is aborted', async () => {
- const abortController = new AbortController();
- const mockResponseStream = (async function* () {
- yield {
- candidates: [{ content: { parts: [{ text: 'First part' }] } }],
- } as unknown as GenerateContentResponse;
- abortController.abort();
- yield {
- candidates: [
- {
- content: {
- parts: [{ text: 'Second part - should not be processed' }],
- },
- },
- ],
- } as unknown as GenerateContentResponse;
- })();
- mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
- const events = [];
- const reqParts: Part[] = [{ text: 'Test abort' }];
- for await (const event of turn.run(reqParts, abortController.signal)) {
- events.push(event);
- }
- expect(events).toEqual([
- { type: GeminiEventType.Content, value: 'First part' },
- { type: GeminiEventType.UserCancelled },
- ]);
- expect(turn.getDebugResponses().length).toBe(1);
- });
-
- it('should yield Error event and report if sendMessageStream throws', async () => {
- const error = new Error('API Error');
- mockSendMessageStream.mockRejectedValue(error);
- const reqParts: Part[] = [{ text: 'Trigger error' }];
- const historyContent: Content[] = [
- { role: 'model', parts: [{ text: 'Previous history' }] },
- ];
- mockGetHistory.mockReturnValue(historyContent);
-
- const events = [];
- for await (const event of turn.run(
- reqParts,
- new AbortController().signal,
- )) {
- events.push(event);
- }
-
- expect(events.length).toBe(1);
- const errorEvent = events[0] as ServerGeminiErrorEvent;
- expect(errorEvent.type).toBe(GeminiEventType.Error);
- expect(errorEvent.value).toEqual({ message: 'API Error' });
- expect(turn.getDebugResponses().length).toBe(0);
- expect(reportError).toHaveBeenCalledWith(
- error,
- 'Error when talking to Gemini API',
- [...historyContent, reqParts],
- 'Turn.run-sendMessageStream',
- );
- });
-
- it('should handle function calls with undefined name or args', async () => {
- const mockResponseStream = (async function* () {
- yield {
- functionCalls: [
- { id: 'fc1', name: undefined, args: { arg1: 'val1' } },
- { id: 'fc2', name: 'tool2', args: undefined },
- { id: 'fc3', name: undefined, args: undefined },
- ],
- } as unknown as GenerateContentResponse;
- })();
- mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
- const events = [];
- const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
- for await (const event of turn.run(
- reqParts,
- new AbortController().signal,
- )) {
- events.push(event);
- }
-
- expect(events.length).toBe(3);
- const event1 = events[0] as ServerGeminiToolCallRequestEvent;
- expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
- expect(event1.value).toEqual(
- expect.objectContaining({
- callId: 'fc1',
- name: 'undefined_tool_name',
- args: { arg1: 'val1' },
- }),
- );
- expect(turn.pendingToolCalls[0]).toEqual(event1.value);
-
- const event2 = events[1] as ServerGeminiToolCallRequestEvent;
- expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
- expect(event2.value).toEqual(
- expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
- );
- expect(turn.pendingToolCalls[1]).toEqual(event2.value);
-
- const event3 = events[2] as ServerGeminiToolCallRequestEvent;
- expect(event3.type).toBe(GeminiEventType.ToolCallRequest);
- expect(event3.value).toEqual(
- expect.objectContaining({
- callId: 'fc3',
- name: 'undefined_tool_name',
- args: {},
- }),
- );
- expect(turn.pendingToolCalls[2]).toEqual(event3.value);
- expect(turn.getDebugResponses().length).toBe(1);
- });
- });
-
- describe('getDebugResponses', () => {
- it('should return collected debug responses', async () => {
- const resp1 = {
- candidates: [{ content: { parts: [{ text: 'Debug 1' }] } }],
- } as unknown as GenerateContentResponse;
- const resp2 = {
- functionCalls: [{ name: 'debugTool' }],
- } as unknown as GenerateContentResponse;
- const mockResponseStream = (async function* () {
- yield resp1;
- yield resp2;
- })();
- mockSendMessageStream.mockResolvedValue(mockResponseStream);
- const reqParts: Part[] = [{ text: 'Hi' }];
- for await (const _ of turn.run(reqParts, new AbortController().signal)) {
- // consume stream
- }
- expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
- });
- });
-});