summaryrefslogtreecommitdiff
path: root/packages/cli/src/nonInteractiveCli.test.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src/nonInteractiveCli.test.ts')
-rw-r--r--packages/cli/src/nonInteractiveCli.test.ts224
1 files changed, 224 insertions, 0 deletions
diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts
new file mode 100644
index 00000000..dca3b855
--- /dev/null
+++ b/packages/cli/src/nonInteractiveCli.test.ts
@@ -0,0 +1,224 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+/* eslint-disable @typescript-eslint/no-explicit-any */
+import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import { runNonInteractive } from './nonInteractiveCli.js';
+import { Config, GeminiClient, ToolRegistry } from '@gemini-code/core';
+import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
+
+// Mock dependencies
+vi.mock('@gemini-code/core', async () => {
+ const actualCore =
+ await vi.importActual<typeof import('@gemini-code/core')>(
+ '@gemini-code/core',
+ );
+ return {
+ ...actualCore,
+ GeminiClient: vi.fn(),
+ ToolRegistry: vi.fn(),
+ executeToolCall: vi.fn(),
+ };
+});
+
+describe('runNonInteractive', () => {
+ let mockConfig: Config;
+ let mockGeminiClient: GeminiClient;
+ let mockToolRegistry: ToolRegistry;
+ let mockChat: {
+ sendMessageStream: ReturnType<typeof vi.fn>;
+ };
+ let mockProcessStdoutWrite: ReturnType<typeof vi.fn>;
+ let mockProcessExit: ReturnType<typeof vi.fn>;
+
+ beforeEach(() => {
+ mockChat = {
+ sendMessageStream: vi.fn(),
+ };
+ mockGeminiClient = {
+ startChat: vi.fn().mockResolvedValue(mockChat),
+ } as unknown as GeminiClient;
+ mockToolRegistry = {
+ discoverTools: vi.fn().mockResolvedValue(undefined),
+ getFunctionDeclarations: vi.fn().mockReturnValue([]),
+ getTool: vi.fn(),
+ } as unknown as ToolRegistry;
+
+ vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient);
+ vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry);
+
+ mockConfig = {
+ getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
+ } as unknown as Config;
+
+ mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);
+ process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock
+ mockProcessExit = vi
+ .fn()
+ .mockImplementation((_code?: number) => undefined as never);
+ process.exit = mockProcessExit as any; // Use any for process.exit mock
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ // Restore original process methods if they were globally patched
+ // This might require storing the original methods before patching them in beforeEach
+ });
+
+ it('should process input and write text output', async () => {
+ const inputStream = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
+ } as GenerateContentResponse;
+ yield {
+ candidates: [{ content: { parts: [{ text: ' World' }] } }],
+ } as GenerateContentResponse;
+ })();
+ mockChat.sendMessageStream.mockResolvedValue(inputStream);
+
+ await runNonInteractive(mockConfig, 'Test input');
+
+ expect(mockGeminiClient.startChat).toHaveBeenCalled();
+ expect(mockToolRegistry.discoverTools).toHaveBeenCalled();
+ expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
+ message: [{ text: 'Test input' }],
+ config: {
+ abortSignal: expect.any(AbortSignal),
+ tools: [{ functionDeclarations: [] }],
+ },
+ });
+ expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello');
+ expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World');
+ expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n');
+ });
+
+ it('should handle a single tool call and respond', async () => {
+ const functionCall: FunctionCall = {
+ id: 'fc1',
+ name: 'testTool',
+ args: { p: 'v' },
+ };
+ const toolResponsePart: Part = {
+ functionResponse: {
+ name: 'testTool',
+ id: 'fc1',
+ response: { result: 'tool success' },
+ },
+ };
+
+ const { executeToolCall: mockCoreExecuteToolCall } = await import(
+ '@gemini-code/core'
+ );
+ vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
+ callId: 'fc1',
+ responseParts: [toolResponsePart],
+ resultDisplay: 'Tool success display',
+ error: undefined,
+ });
+
+ const stream1 = (async function* () {
+ yield { functionCalls: [functionCall] } as GenerateContentResponse;
+ })();
+ const stream2 = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'Final answer' }] } }],
+ } as GenerateContentResponse;
+ })();
+ mockChat.sendMessageStream
+ .mockResolvedValueOnce(stream1)
+ .mockResolvedValueOnce(stream2);
+
+ await runNonInteractive(mockConfig, 'Use a tool');
+
+ expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
+ expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
+ expect.objectContaining({ callId: 'fc1', name: 'testTool' }),
+ mockToolRegistry,
+ expect.any(AbortSignal),
+ );
+ expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
+ expect.objectContaining({
+ message: [toolResponsePart],
+ }),
+ );
+ expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer');
+ });
+
+ it('should handle error during tool execution', async () => {
+ const functionCall: FunctionCall = {
+ id: 'fcError',
+ name: 'errorTool',
+ args: {},
+ };
+ const errorResponsePart: Part = {
+ functionResponse: {
+ name: 'errorTool',
+ id: 'fcError',
+ response: { error: 'Tool failed' },
+ },
+ };
+
+ const { executeToolCall: mockCoreExecuteToolCall } = await import(
+ '@gemini-code/core'
+ );
+ vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
+ callId: 'fcError',
+ responseParts: [errorResponsePart],
+ resultDisplay: 'Tool execution failed badly',
+ error: new Error('Tool failed'),
+ });
+
+ const stream1 = (async function* () {
+ yield { functionCalls: [functionCall] } as GenerateContentResponse;
+ })();
+
+ const stream2 = (async function* () {
+ yield {
+ candidates: [
+ { content: { parts: [{ text: 'Could not complete request.' }] } },
+ ],
+ } as GenerateContentResponse;
+ })();
+ mockChat.sendMessageStream
+ .mockResolvedValueOnce(stream1)
+ .mockResolvedValueOnce(stream2);
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {});
+
+ await runNonInteractive(mockConfig, 'Trigger tool error');
+
+ expect(mockCoreExecuteToolCall).toHaveBeenCalled();
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Error executing tool errorTool: Tool execution failed badly',
+ );
+ expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
+ expect.objectContaining({
+ message: [errorResponsePart],
+ }),
+ );
+ expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
+ 'Could not complete request.',
+ );
+ consoleErrorSpy.mockRestore();
+ });
+
+ it('should exit with error if sendMessageStream throws initially', async () => {
+ const apiError = new Error('API connection failed');
+ mockChat.sendMessageStream.mockRejectedValue(apiError);
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {});
+
+ await runNonInteractive(mockConfig, 'Initial fail');
+
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Error processing input:',
+ apiError,
+ );
+ consoleErrorSpy.mockRestore();
+ });
+});