diff options
Diffstat (limited to 'packages/cli/src/nonInteractiveCli.test.ts')
| -rw-r--r-- | packages/cli/src/nonInteractiveCli.test.ts | 372 |
1 files changed, 135 insertions, 237 deletions
diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 8b0419f1..a0fc6f9f 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -4,196 +4,167 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + Config, + executeToolCall, + ToolRegistry, + shutdownTelemetry, + GeminiEventType, + ServerGeminiStreamEvent, +} from '@google/gemini-cli-core'; +import { Part } from '@google/genai'; import { runNonInteractive } from './nonInteractiveCli.js'; -import { Config, GeminiClient, ToolRegistry } from '@google/gemini-cli-core'; -import { GenerateContentResponse, Part, FunctionCall } from '@google/genai'; +import { vi } from 'vitest'; -// Mock dependencies -vi.mock('@google/gemini-cli-core', async () => { - const actualCore = await vi.importActual< - typeof import('@google/gemini-cli-core') - >('@google/gemini-cli-core'); +// Mock core modules +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal<typeof import('@google/gemini-cli-core')>(); return { - ...actualCore, - GeminiClient: vi.fn(), - ToolRegistry: vi.fn(), + ...original, executeToolCall: vi.fn(), + shutdownTelemetry: vi.fn(), + isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), }; }); describe('runNonInteractive', () => { let mockConfig: Config; - let mockGeminiClient: GeminiClient; let mockToolRegistry: ToolRegistry; - let mockChat: { - sendMessageStream: ReturnType<typeof vi.fn>; + let mockCoreExecuteToolCall: vi.Mock; + let mockShutdownTelemetry: vi.Mock; + let consoleErrorSpy: vi.SpyInstance; + let processExitSpy: vi.SpyInstance; + let processStdoutSpy: vi.SpyInstance; + let mockGeminiClient: { + sendMessageStream: vi.Mock; }; - let mockProcessStdoutWrite: ReturnType<typeof vi.fn>; - let mockProcessExit: ReturnType<typeof vi.fn>; beforeEach(() => { - vi.resetAllMocks(); - mockChat = { - sendMessageStream: vi.fn(), - }; - mockGeminiClient = { - getChat: vi.fn().mockResolvedValue(mockChat), - } as unknown as GeminiClient; + mockCoreExecuteToolCall = vi.mocked(executeToolCall); + mockShutdownTelemetry = vi.mocked(shutdownTelemetry); + + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + processExitSpy = vi + .spyOn(process, 'exit') + .mockImplementation((() => {}) as (code?: number) => never); + processStdoutSpy = vi + .spyOn(process.stdout, 'write') + .mockImplementation(() => true); + mockToolRegistry = { - getFunctionDeclarations: vi.fn().mockReturnValue([]), getTool: vi.fn(), + getFunctionDeclarations: vi.fn().mockReturnValue([]), } as unknown as ToolRegistry; - vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient); - vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry); + mockGeminiClient = { + sendMessageStream: vi.fn(), + }; mockConfig = { - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + initialize: vi.fn().mockResolvedValue(undefined), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), getMaxSessionTurns: vi.fn().mockReturnValue(10), - initialize: vi.fn(), + getIdeMode: vi.fn().mockReturnValue(false), + getFullContext: vi.fn().mockReturnValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), } as unknown as Config; - - mockProcessStdoutWrite = vi.fn().mockImplementation(() => true); - process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock - mockProcessExit = vi - .fn() - .mockImplementation((_code?: number) => undefined as never); - process.exit = mockProcessExit as any; // Use any for process.exit mock }); afterEach(() => { vi.restoreAllMocks(); - // Restore original process methods if they were globally patched - // This might require storing the original methods before patching them in beforeEach }); + async function* createStreamFromEvents( + events: ServerGeminiStreamEvent[], + ): AsyncGenerator<ServerGeminiStreamEvent> { + for (const event of events) { + yield event; + } + } + it('should process input and write text output', async () => { - const inputStream = (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - } as GenerateContentResponse; - yield { - candidates: [{ content: { parts: [{ text: ' World' }] } }], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream.mockResolvedValue(inputStream); + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' World' }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - { - message: [{ text: 'Test input' }], - config: { - abortSignal: expect.any(AbortSignal), - tools: [{ functionDeclarations: [] }], - }, - }, - expect.any(String), + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + 'prompt-id-1', ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello'); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World'); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n'); + expect(processStdoutSpy).toHaveBeenCalledWith('Hello'); + expect(processStdoutSpy).toHaveBeenCalledWith(' World'); + expect(processStdoutSpy).toHaveBeenCalledWith('\n'); + expect(mockShutdownTelemetry).toHaveBeenCalled(); }); it('should handle a single tool call and respond', async () => { - const functionCall: FunctionCall = { - id: 'fc1', - name: 'testTool', - args: { p: 'v' }, - }; - const toolResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'testTool', - id: 'fc1', - response: { result: 'tool success' }, + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', }, }; + const toolResponse: Part[] = [{ text: 'Tool response' }]; + mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse }); - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fc1', - responseParts: [toolResponsePart], - resultDisplay: 'Tool success display', - error: undefined, - }); + const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Final answer' }, + ]; - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - const stream2 = (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'Final answer' }] } }], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2'); - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( mockConfig, - expect.objectContaining({ callId: 'fc1', name: 'testTool' }), + expect.objectContaining({ name: 'testTool' }), mockToolRegistry, expect.any(AbortSignal), ); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [toolResponsePart], - }), - expect.any(String), + expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( + 2, + [{ text: 'Tool response' }], + expect.any(AbortSignal), + 'prompt-id-2', ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer'); + expect(processStdoutSpy).toHaveBeenCalledWith('Final answer'); + expect(processStdoutSpy).toHaveBeenCalledWith('\n'); }); it('should handle error during tool execution', async () => { - const functionCall: FunctionCall = { - id: 'fcError', - name: 'errorTool', - args: {}, - }; - const errorResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'errorTool', - id: 'fcError', - response: { error: 'Tool failed' }, + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', }, }; - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcError', - responseParts: [errorResponsePart], - resultDisplay: 'Tool execution failed badly', - error: new Error('Tool failed'), + mockCoreExecuteToolCall.mockResolvedValue({ + error: new Error('Tool execution failed badly'), }); - - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - - const stream2 = (async function* () { - yield { - candidates: [ - { content: { parts: [{ text: 'Could not complete request.' }] } }, - ], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents([toolCallEvent]), + ); await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3'); @@ -201,75 +172,48 @@ describe('runNonInteractive', () => { expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool errorTool: Tool execution failed badly', ); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [errorResponsePart], - }), - expect.any(String), - ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith( - 'Could not complete request.', - ); + expect(processExitSpy).toHaveBeenCalledWith(1); }); it('should exit with error if sendMessageStream throws initially', async () => { const apiError = new Error('API connection failed'); - mockChat.sendMessageStream.mockRejectedValue(apiError); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw apiError; + }); await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4'); expect(consoleErrorSpy).toHaveBeenCalledWith( '[API Error: API connection failed]', ); + expect(processExitSpy).toHaveBeenCalledWith(1); }); it('should not exit if a tool is not found, and should send error back to model', async () => { - const functionCall: FunctionCall = { - id: 'fcNotFound', - name: 'nonexistentTool', - args: {}, - }; - const errorResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'nonexistentTool', - id: 'fcNotFound', - response: { error: 'Tool "nonexistentTool" not found in registry.' }, + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-5', }, }; - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcNotFound', - responseParts: [errorResponsePart], - resultDisplay: 'Tool "nonexistentTool" not found in registry.', + mockCoreExecuteToolCall.mockResolvedValue({ error: new Error('Tool "nonexistentTool" not found in registry.'), + resultDisplay: 'Tool "nonexistentTool" not found in registry.', }); + const finalResponse: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Content, + value: "Sorry, I can't find that tool.", + }, + ]; - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - const stream2 = (async function* () { - yield { - candidates: [ - { - content: { - parts: [{ text: 'Unfortunately the tool does not exist.' }], - }, - }, - ], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) + .mockReturnValueOnce(createStreamFromEvents(finalResponse)); await runNonInteractive( mockConfig, @@ -277,68 +221,22 @@ describe('runNonInteractive', () => { 'prompt-id-5', ); + expect(mockCoreExecuteToolCall).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.', ); - - expect(mockProcessExit).not.toHaveBeenCalled(); - - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [errorResponsePart], - }), - expect.any(String), - ); - - expect(mockProcessStdoutWrite).toHaveBeenCalledWith( - 'Unfortunately the tool does not exist.', + expect(processExitSpy).not.toHaveBeenCalled(); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(processStdoutSpy).toHaveBeenCalledWith( + "Sorry, I can't find that tool.", ); }); it('should exit when max session turns are exceeded', async () => { - const functionCall: FunctionCall = { - id: 'fcLoop', - name: 'loopTool', - args: {}, - }; - const toolResponsePart: Part = { - functionResponse: { - name: 'loopTool', - id: 'fcLoop', - response: { result: 'still looping' }, - }, - }; - - // Config with a max turn of 1 - vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1); - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcLoop', - responseParts: [toolResponsePart], - resultDisplay: 'Still looping', - error: undefined, - }); - - const stream = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - - mockChat.sendMessageStream.mockResolvedValue(stream); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - await runNonInteractive(mockConfig, 'Trigger loop'); - - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0); + await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6'); expect(consoleErrorSpy).toHaveBeenCalledWith( - ` - Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`, + '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', ); - expect(mockProcessExit).not.toHaveBeenCalled(); }); }); |
