diff options
| -rw-r--r-- | packages/server/src/core/turn.test.ts | 269 | ||||
| -rw-r--r-- | packages/server/src/core/turn.ts | 3 |
2 files changed, 271 insertions, 1 deletions
diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts new file mode 100644 index 00000000..90d3407f --- /dev/null +++ b/packages/server/src/core/turn.test.ts @@ -0,0 +1,269 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + Turn, + GeminiEventType, + ServerGeminiToolCallRequestEvent, + ServerGeminiErrorEvent, +} from './turn.js'; +import { Chat, GenerateContentResponse, Part, Content } from '@google/genai'; +import { reportError } from '../utils/errorReporting.js'; + +const mockSendMessageStream = vi.fn(); +const mockGetHistory = vi.fn(); + +vi.mock('@google/genai', async (importOriginal) => { + const actual = await importOriginal<typeof import('@google/genai')>(); + const MockChat = vi.fn().mockImplementation(() => ({ + sendMessageStream: mockSendMessageStream, + getHistory: mockGetHistory, + })); + return { + ...actual, + Chat: MockChat, + }; +}); + +vi.mock('../utils/errorReporting', () => ({ + reportError: vi.fn(), +})); + +vi.mock('../utils/generateContentResponseUtilities', () => ({ + getResponseText: (resp: GenerateContentResponse) => + resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || + undefined, +})); + +describe('Turn', () => { + let turn: Turn; + // Define a type for the mocked Chat instance for clarity + type MockedChatInstance = { + sendMessageStream: typeof mockSendMessageStream; + getHistory: typeof mockGetHistory; + }; + let mockChatInstance: MockedChatInstance; + + beforeEach(() => { + vi.resetAllMocks(); + mockChatInstance = { + sendMessageStream: mockSendMessageStream, + getHistory: mockGetHistory, + }; + turn = new Turn(mockChatInstance as unknown as Chat); + mockGetHistory.mockReturnValue([]); + mockSendMessageStream.mockResolvedValue((async function* () {})()); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('constructor', () => { + it('should initialize pendingToolCalls and debugResponses', () => { + expect(turn.pendingToolCalls).toEqual([]); + expect(turn.getDebugResponses()).toEqual([]); + }); + }); + + describe('run', () => { + it('should yield content events for text parts', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + } as unknown as GenerateContentResponse; + yield { + candidates: [{ content: { parts: [{ text: ' world' }] } }], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Hi' }]; + for await (const event of turn.run(reqParts)) { + events.push(event); + } + + expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts }); + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' world' }, + ]); + expect(turn.getDebugResponses().length).toBe(2); + }); + + it('should yield tool_call_request events for function calls', async () => { + const mockResponseStream = (async function* () { + yield { + functionCalls: [ + { id: 'fc1', name: 'tool1', args: { arg1: 'val1' } }, + { name: 'tool2', args: { arg2: 'val2' } }, // No ID + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Use tools' }]; + for await (const event of turn.run(reqParts)) { + events.push(event); + } + + expect(events.length).toBe(2); + const event1 = events[0] as ServerGeminiToolCallRequestEvent; + expect(event1.type).toBe(GeminiEventType.ToolCallRequest); + expect(event1.value).toEqual( + expect.objectContaining({ + callId: 'fc1', + name: 'tool1', + args: { arg1: 'val1' }, + }), + ); + expect(turn.pendingToolCalls[0]).toEqual(event1.value); + + const event2 = events[1] as ServerGeminiToolCallRequestEvent; + expect(event2.type).toBe(GeminiEventType.ToolCallRequest); + expect(event2.value).toEqual( + expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }), + ); + expect(event2.value.callId).toEqual( + expect.stringMatching(/^tool2-\d{13}-\w{10,}$/), + ); + expect(turn.pendingToolCalls[1]).toEqual(event2.value); + expect(turn.getDebugResponses().length).toBe(1); + }); + + it('should yield UserCancelled event if signal is aborted', async () => { + const abortController = new AbortController(); + const mockResponseStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'First part' }] } }], + } as unknown as GenerateContentResponse; + abortController.abort(); + yield { + candidates: [ + { + content: { + parts: [{ text: 'Second part - should not be processed' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Test abort' }]; + for await (const event of turn.run(reqParts, abortController.signal)) { + events.push(event); + } + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'First part' }, + { type: GeminiEventType.UserCancelled }, + ]); + expect(turn.getDebugResponses().length).toBe(1); + }); + + it('should yield Error event and report if sendMessageStream throws', async () => { + const error = new Error('API Error'); + mockSendMessageStream.mockRejectedValue(error); + const reqParts: Part[] = [{ text: 'Trigger error' }]; + const historyContent: Content[] = [ + { role: 'model', parts: [{ text: 'Previous history' }] }, + ]; + mockGetHistory.mockReturnValue(historyContent); + + const events = []; + for await (const event of turn.run(reqParts)) { + events.push(event); + } + + expect(events.length).toBe(1); + const errorEvent = events[0] as ServerGeminiErrorEvent; + expect(errorEvent.type).toBe(GeminiEventType.Error); + expect(errorEvent.value).toEqual({ message: 'API Error' }); + expect(turn.getDebugResponses().length).toBe(0); + expect(reportError).toHaveBeenCalledWith( + error, + 'Error when talking to Gemini API', + [...historyContent, reqParts], + 'Turn.run-sendMessageStream', + ); + }); + + it('should handle function calls with undefined name or args', async () => { + const mockResponseStream = (async function* () { + yield { + functionCalls: [ + { id: 'fc1', name: undefined, args: { arg1: 'val1' } }, + { id: 'fc2', name: 'tool2', args: undefined }, + { id: 'fc3', name: undefined, args: undefined }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; + for await (const event of turn.run(reqParts)) { + events.push(event); + } + + expect(events.length).toBe(3); + const event1 = events[0] as ServerGeminiToolCallRequestEvent; + expect(event1.type).toBe(GeminiEventType.ToolCallRequest); + expect(event1.value).toEqual( + expect.objectContaining({ + callId: 'fc1', + name: 'undefined_tool_name', + args: { arg1: 'val1' }, + }), + ); + expect(turn.pendingToolCalls[0]).toEqual(event1.value); + + const event2 = events[1] as ServerGeminiToolCallRequestEvent; + expect(event2.type).toBe(GeminiEventType.ToolCallRequest); + expect(event2.value).toEqual( + expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }), + ); + expect(turn.pendingToolCalls[1]).toEqual(event2.value); + + const event3 = events[2] as ServerGeminiToolCallRequestEvent; + expect(event3.type).toBe(GeminiEventType.ToolCallRequest); + expect(event3.value).toEqual( + expect.objectContaining({ + callId: 'fc3', + name: 'undefined_tool_name', + args: {}, + }), + ); + expect(turn.pendingToolCalls[2]).toEqual(event3.value); + expect(turn.getDebugResponses().length).toBe(1); + }); + }); + + describe('getDebugResponses', () => { + it('should return collected debug responses', async () => { + const resp1 = { + candidates: [{ content: { parts: [{ text: 'Debug 1' }] } }], + } as unknown as GenerateContentResponse; + const resp2 = { + functionCalls: [{ name: 'debugTool' }], + } as unknown as GenerateContentResponse; + const mockResponseStream = (async function* () { + yield resp1; + yield resp2; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + const reqParts: Part[] = [{ text: 'Hi' }]; + for await (const _ of turn.run(reqParts)) { + // consume stream + } + expect(turn.getDebugResponses()).toEqual([resp1, resp2]); + }); + }); +}); diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 38932041..a02b5eb6 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -128,11 +128,12 @@ export class Turn { }); for await (const resp of responseStream) { - this.debugResponses.push(resp); if (signal?.aborted) { yield { type: GeminiEventType.UserCancelled }; + // Do not add resp to debugResponses if aborted before processing return; } + this.debugResponses.push(resp); const text = getResponseText(resp); if (text) { |
