summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/nonInteractiveCli.test.ts372
-rw-r--r--packages/cli/src/nonInteractiveCli.ts65
2 files changed, 153 insertions, 284 deletions
diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts
index 8b0419f1..a0fc6f9f 100644
--- a/packages/cli/src/nonInteractiveCli.test.ts
+++ b/packages/cli/src/nonInteractiveCli.test.ts
@@ -4,196 +4,167 @@
* SPDX-License-Identifier: Apache-2.0
*/
-/* eslint-disable @typescript-eslint/no-explicit-any */
-import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
+import {
+ Config,
+ executeToolCall,
+ ToolRegistry,
+ shutdownTelemetry,
+ GeminiEventType,
+ ServerGeminiStreamEvent,
+} from '@google/gemini-cli-core';
+import { Part } from '@google/genai';
import { runNonInteractive } from './nonInteractiveCli.js';
-import { Config, GeminiClient, ToolRegistry } from '@google/gemini-cli-core';
-import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
+import { vi } from 'vitest';
-// Mock dependencies
-vi.mock('@google/gemini-cli-core', async () => {
- const actualCore = await vi.importActual<
- typeof import('@google/gemini-cli-core')
- >('@google/gemini-cli-core');
+// Mock core modules
+vi.mock('@google/gemini-cli-core', async (importOriginal) => {
+ const original =
+ await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
- ...actualCore,
- GeminiClient: vi.fn(),
- ToolRegistry: vi.fn(),
+ ...original,
executeToolCall: vi.fn(),
+ shutdownTelemetry: vi.fn(),
+ isTelemetrySdkInitialized: vi.fn().mockReturnValue(true),
};
});
describe('runNonInteractive', () => {
let mockConfig: Config;
- let mockGeminiClient: GeminiClient;
let mockToolRegistry: ToolRegistry;
- let mockChat: {
- sendMessageStream: ReturnType<typeof vi.fn>;
+ let mockCoreExecuteToolCall: vi.Mock;
+ let mockShutdownTelemetry: vi.Mock;
+ let consoleErrorSpy: vi.SpyInstance;
+ let processExitSpy: vi.SpyInstance;
+ let processStdoutSpy: vi.SpyInstance;
+ let mockGeminiClient: {
+ sendMessageStream: vi.Mock;
};
- let mockProcessStdoutWrite: ReturnType<typeof vi.fn>;
- let mockProcessExit: ReturnType<typeof vi.fn>;
beforeEach(() => {
- vi.resetAllMocks();
- mockChat = {
- sendMessageStream: vi.fn(),
- };
- mockGeminiClient = {
- getChat: vi.fn().mockResolvedValue(mockChat),
- } as unknown as GeminiClient;
+ mockCoreExecuteToolCall = vi.mocked(executeToolCall);
+ mockShutdownTelemetry = vi.mocked(shutdownTelemetry);
+
+ consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
+ processExitSpy = vi
+ .spyOn(process, 'exit')
+ .mockImplementation((() => {}) as (code?: number) => never);
+ processStdoutSpy = vi
+ .spyOn(process.stdout, 'write')
+ .mockImplementation(() => true);
+
mockToolRegistry = {
- getFunctionDeclarations: vi.fn().mockReturnValue([]),
getTool: vi.fn(),
+ getFunctionDeclarations: vi.fn().mockReturnValue([]),
} as unknown as ToolRegistry;
- vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient);
- vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry);
+ mockGeminiClient = {
+ sendMessageStream: vi.fn(),
+ };
mockConfig = {
- getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
+ initialize: vi.fn().mockResolvedValue(undefined),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
- getContentGeneratorConfig: vi.fn().mockReturnValue({}),
+ getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getMaxSessionTurns: vi.fn().mockReturnValue(10),
- initialize: vi.fn(),
+ getIdeMode: vi.fn().mockReturnValue(false),
+ getFullContext: vi.fn().mockReturnValue(false),
+ getContentGeneratorConfig: vi.fn().mockReturnValue({}),
} as unknown as Config;
-
- mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);
- process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock
- mockProcessExit = vi
- .fn()
- .mockImplementation((_code?: number) => undefined as never);
- process.exit = mockProcessExit as any; // Use any for process.exit mock
});
afterEach(() => {
vi.restoreAllMocks();
- // Restore original process methods if they were globally patched
- // This might require storing the original methods before patching them in beforeEach
});
+ async function* createStreamFromEvents(
+ events: ServerGeminiStreamEvent[],
+ ): AsyncGenerator<ServerGeminiStreamEvent> {
+ for (const event of events) {
+ yield event;
+ }
+ }
+
it('should process input and write text output', async () => {
- const inputStream = (async function* () {
- yield {
- candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
- } as GenerateContentResponse;
- yield {
- candidates: [{ content: { parts: [{ text: ' World' }] } }],
- } as GenerateContentResponse;
- })();
- mockChat.sendMessageStream.mockResolvedValue(inputStream);
+ const events: ServerGeminiStreamEvent[] = [
+ { type: GeminiEventType.Content, value: 'Hello' },
+ { type: GeminiEventType.Content, value: ' World' },
+ ];
+ mockGeminiClient.sendMessageStream.mockReturnValue(
+ createStreamFromEvents(events),
+ );
await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1');
- expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
- {
- message: [{ text: 'Test input' }],
- config: {
- abortSignal: expect.any(AbortSignal),
- tools: [{ functionDeclarations: [] }],
- },
- },
- expect.any(String),
+ expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith(
+ [{ text: 'Test input' }],
+ expect.any(AbortSignal),
+ 'prompt-id-1',
);
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello');
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World');
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n');
+ expect(processStdoutSpy).toHaveBeenCalledWith('Hello');
+ expect(processStdoutSpy).toHaveBeenCalledWith(' World');
+ expect(processStdoutSpy).toHaveBeenCalledWith('\n');
+ expect(mockShutdownTelemetry).toHaveBeenCalled();
});
it('should handle a single tool call and respond', async () => {
- const functionCall: FunctionCall = {
- id: 'fc1',
- name: 'testTool',
- args: { p: 'v' },
- };
- const toolResponsePart: Part = {
- functionResponse: {
+ const toolCallEvent: ServerGeminiStreamEvent = {
+ type: GeminiEventType.ToolCallRequest,
+ value: {
+ callId: 'tool-1',
name: 'testTool',
- id: 'fc1',
- response: { result: 'tool success' },
+ args: { arg1: 'value1' },
+ isClientInitiated: false,
+ prompt_id: 'prompt-id-2',
},
};
+ const toolResponse: Part[] = [{ text: 'Tool response' }];
+ mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse });
- const { executeToolCall: mockCoreExecuteToolCall } = await import(
- '@google/gemini-cli-core'
- );
- vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
- callId: 'fc1',
- responseParts: [toolResponsePart],
- resultDisplay: 'Tool success display',
- error: undefined,
- });
+ const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
+ const secondCallEvents: ServerGeminiStreamEvent[] = [
+ { type: GeminiEventType.Content, value: 'Final answer' },
+ ];
- const stream1 = (async function* () {
- yield { functionCalls: [functionCall] } as GenerateContentResponse;
- })();
- const stream2 = (async function* () {
- yield {
- candidates: [{ content: { parts: [{ text: 'Final answer' }] } }],
- } as GenerateContentResponse;
- })();
- mockChat.sendMessageStream
- .mockResolvedValueOnce(stream1)
- .mockResolvedValueOnce(stream2);
+ mockGeminiClient.sendMessageStream
+ .mockReturnValueOnce(createStreamFromEvents(firstCallEvents))
+ .mockReturnValueOnce(createStreamFromEvents(secondCallEvents));
await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2');
- expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
+ expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig,
- expect.objectContaining({ callId: 'fc1', name: 'testTool' }),
+ expect.objectContaining({ name: 'testTool' }),
mockToolRegistry,
expect.any(AbortSignal),
);
- expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
- expect.objectContaining({
- message: [toolResponsePart],
- }),
- expect.any(String),
+ expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
+ 2,
+ [{ text: 'Tool response' }],
+ expect.any(AbortSignal),
+ 'prompt-id-2',
);
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer');
+ expect(processStdoutSpy).toHaveBeenCalledWith('Final answer');
+ expect(processStdoutSpy).toHaveBeenCalledWith('\n');
});
it('should handle error during tool execution', async () => {
- const functionCall: FunctionCall = {
- id: 'fcError',
- name: 'errorTool',
- args: {},
- };
- const errorResponsePart: Part = {
- functionResponse: {
+ const toolCallEvent: ServerGeminiStreamEvent = {
+ type: GeminiEventType.ToolCallRequest,
+ value: {
+ callId: 'tool-1',
name: 'errorTool',
- id: 'fcError',
- response: { error: 'Tool failed' },
+ args: {},
+ isClientInitiated: false,
+ prompt_id: 'prompt-id-3',
},
};
-
- const { executeToolCall: mockCoreExecuteToolCall } = await import(
- '@google/gemini-cli-core'
- );
- vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
- callId: 'fcError',
- responseParts: [errorResponsePart],
- resultDisplay: 'Tool execution failed badly',
- error: new Error('Tool failed'),
+ mockCoreExecuteToolCall.mockResolvedValue({
+ error: new Error('Tool execution failed badly'),
});
-
- const stream1 = (async function* () {
- yield { functionCalls: [functionCall] } as GenerateContentResponse;
- })();
-
- const stream2 = (async function* () {
- yield {
- candidates: [
- { content: { parts: [{ text: 'Could not complete request.' }] } },
- ],
- } as GenerateContentResponse;
- })();
- mockChat.sendMessageStream
- .mockResolvedValueOnce(stream1)
- .mockResolvedValueOnce(stream2);
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
+ mockGeminiClient.sendMessageStream.mockReturnValue(
+ createStreamFromEvents([toolCallEvent]),
+ );
await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3');
@@ -201,75 +172,48 @@ describe('runNonInteractive', () => {
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool errorTool: Tool execution failed badly',
);
- expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
- expect.objectContaining({
- message: [errorResponsePart],
- }),
- expect.any(String),
- );
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
- 'Could not complete request.',
- );
+ expect(processExitSpy).toHaveBeenCalledWith(1);
});
it('should exit with error if sendMessageStream throws initially', async () => {
const apiError = new Error('API connection failed');
- mockChat.sendMessageStream.mockRejectedValue(apiError);
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
+ mockGeminiClient.sendMessageStream.mockImplementation(() => {
+ throw apiError;
+ });
await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4');
expect(consoleErrorSpy).toHaveBeenCalledWith(
'[API Error: API connection failed]',
);
+ expect(processExitSpy).toHaveBeenCalledWith(1);
});
it('should not exit if a tool is not found, and should send error back to model', async () => {
- const functionCall: FunctionCall = {
- id: 'fcNotFound',
- name: 'nonexistentTool',
- args: {},
- };
- const errorResponsePart: Part = {
- functionResponse: {
+ const toolCallEvent: ServerGeminiStreamEvent = {
+ type: GeminiEventType.ToolCallRequest,
+ value: {
+ callId: 'tool-1',
name: 'nonexistentTool',
- id: 'fcNotFound',
- response: { error: 'Tool "nonexistentTool" not found in registry.' },
+ args: {},
+ isClientInitiated: false,
+ prompt_id: 'prompt-id-5',
},
};
-
- const { executeToolCall: mockCoreExecuteToolCall } = await import(
- '@google/gemini-cli-core'
- );
- vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
- callId: 'fcNotFound',
- responseParts: [errorResponsePart],
- resultDisplay: 'Tool "nonexistentTool" not found in registry.',
+ mockCoreExecuteToolCall.mockResolvedValue({
error: new Error('Tool "nonexistentTool" not found in registry.'),
+ resultDisplay: 'Tool "nonexistentTool" not found in registry.',
});
+ const finalResponse: ServerGeminiStreamEvent[] = [
+ {
+ type: GeminiEventType.Content,
+ value: "Sorry, I can't find that tool.",
+ },
+ ];
- const stream1 = (async function* () {
- yield { functionCalls: [functionCall] } as GenerateContentResponse;
- })();
- const stream2 = (async function* () {
- yield {
- candidates: [
- {
- content: {
- parts: [{ text: 'Unfortunately the tool does not exist.' }],
- },
- },
- ],
- } as GenerateContentResponse;
- })();
- mockChat.sendMessageStream
- .mockResolvedValueOnce(stream1)
- .mockResolvedValueOnce(stream2);
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
+ mockGeminiClient.sendMessageStream
+ .mockReturnValueOnce(createStreamFromEvents([toolCallEvent]))
+ .mockReturnValueOnce(createStreamFromEvents(finalResponse));
await runNonInteractive(
mockConfig,
@@ -277,68 +221,22 @@ describe('runNonInteractive', () => {
'prompt-id-5',
);
+ expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.',
);
-
- expect(mockProcessExit).not.toHaveBeenCalled();
-
- expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
- expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
- expect.objectContaining({
- message: [errorResponsePart],
- }),
- expect.any(String),
- );
-
- expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
- 'Unfortunately the tool does not exist.',
+ expect(processExitSpy).not.toHaveBeenCalled();
+ expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
+ expect(processStdoutSpy).toHaveBeenCalledWith(
+ "Sorry, I can't find that tool.",
);
});
it('should exit when max session turns are exceeded', async () => {
- const functionCall: FunctionCall = {
- id: 'fcLoop',
- name: 'loopTool',
- args: {},
- };
- const toolResponsePart: Part = {
- functionResponse: {
- name: 'loopTool',
- id: 'fcLoop',
- response: { result: 'still looping' },
- },
- };
-
- // Config with a max turn of 1
- vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1);
-
- const { executeToolCall: mockCoreExecuteToolCall } = await import(
- '@google/gemini-cli-core'
- );
- vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
- callId: 'fcLoop',
- responseParts: [toolResponsePart],
- resultDisplay: 'Still looping',
- error: undefined,
- });
-
- const stream = (async function* () {
- yield { functionCalls: [functionCall] } as GenerateContentResponse;
- })();
-
- mockChat.sendMessageStream.mockResolvedValue(stream);
- const consoleErrorSpy = vi
- .spyOn(console, 'error')
- .mockImplementation(() => {});
-
- await runNonInteractive(mockConfig, 'Trigger loop');
-
- expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1);
+ vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0);
+ await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6');
expect(consoleErrorSpy).toHaveBeenCalledWith(
- `
- Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`,
+ '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
);
- expect(mockProcessExit).not.toHaveBeenCalled();
});
});
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts
index 7bc0f6aa..1d0a7f3d 100644
--- a/packages/cli/src/nonInteractiveCli.ts
+++ b/packages/cli/src/nonInteractiveCli.ts
@@ -11,38 +11,12 @@ import {
ToolRegistry,
shutdownTelemetry,
isTelemetrySdkInitialized,
+ GeminiEventType,
} from '@google/gemini-cli-core';
-import {
- Content,
- Part,
- FunctionCall,
- GenerateContentResponse,
-} from '@google/genai';
+import { Content, Part, FunctionCall } from '@google/genai';
import { parseAndFormatApiError } from './ui/utils/errorParsing.js';
-function getResponseText(response: GenerateContentResponse): string | null {
- if (response.candidates && response.candidates.length > 0) {
- const candidate = response.candidates[0];
- if (
- candidate.content &&
- candidate.content.parts &&
- candidate.content.parts.length > 0
- ) {
- // We are running in headless mode so we don't need to return thoughts to STDOUT.
- const thoughtPart = candidate.content.parts[0];
- if (thoughtPart?.thought) {
- return null;
- }
- return candidate.content.parts
- .filter((part) => part.text)
- .map((part) => part.text)
- .join('');
- }
- }
- return null;
-}
-
export async function runNonInteractive(
config: Config,
input: string,
@@ -60,7 +34,6 @@ export async function runNonInteractive(
const geminiClient = config.getGeminiClient();
const toolRegistry: ToolRegistry = await config.getToolRegistry();
- const chat = await geminiClient.getChat();
const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
let turnCount = 0;
@@ -68,7 +41,7 @@ export async function runNonInteractive(
while (true) {
turnCount++;
if (
- config.getMaxSessionTurns() > 0 &&
+ config.getMaxSessionTurns() >= 0 &&
turnCount > config.getMaxSessionTurns()
) {
console.error(
@@ -78,30 +51,28 @@ export async function runNonInteractive(
}
const functionCalls: FunctionCall[] = [];
- const responseStream = await chat.sendMessageStream(
- {
- message: currentMessages[0]?.parts || [], // Ensure parts are always provided
- config: {
- abortSignal: abortController.signal,
- tools: [
- { functionDeclarations: toolRegistry.getFunctionDeclarations() },
- ],
- },
- },
+ const responseStream = geminiClient.sendMessageStream(
+ currentMessages[0]?.parts || [],
+ abortController.signal,
prompt_id,
);
- for await (const resp of responseStream) {
+ for await (const event of responseStream) {
if (abortController.signal.aborted) {
console.error('Operation cancelled.');
return;
}
- const textPart = getResponseText(resp);
- if (textPart) {
- process.stdout.write(textPart);
- }
- if (resp.functionCalls) {
- functionCalls.push(...resp.functionCalls);
+
+ if (event.type === GeminiEventType.Content) {
+ process.stdout.write(event.value);
+ } else if (event.type === GeminiEventType.ToolCallRequest) {
+ const toolCallRequest = event.value;
+ const fc: FunctionCall = {
+ name: toolCallRequest.name,
+ args: toolCallRequest.args,
+ id: toolCallRequest.callId,
+ };
+ functionCalls.push(fc);
}
}