diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/gemini.tsx | 66 | ||||
| -rw-r--r-- | packages/cli/src/nonInteractiveCli.test.ts | 224 | ||||
| -rw-r--r-- | packages/cli/src/nonInteractiveCli.ts | 114 |
3 files changed, 380 insertions, 24 deletions
diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 0ed27a99..07551813 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -9,7 +9,6 @@ import { render } from 'ink'; import { App } from './ui/App.js'; import { loadCliConfig } from './config/config.js'; import { readStdin } from './utils/readStdin.js'; -import { GeminiClient } from '@gemini-code/core'; import { readPackageUp } from 'read-package-up'; import { fileURLToPath } from 'node:url'; import { dirname } from 'node:path'; @@ -17,14 +16,25 @@ import { sandbox_command, start_sandbox } from './utils/sandbox.js'; import { loadSettings } from './config/settings.js'; import { themeManager } from './ui/themes/theme-manager.js'; import { getStartupWarnings } from './utils/startupWarnings.js'; +import { runNonInteractive } from './nonInteractiveCli.js'; +import { + EditTool, + GlobTool, + GrepTool, + LSTool, + MemoryTool, + ReadFileTool, + ReadManyFilesTool, + ShellTool, + WebFetchTool, + WebSearchTool, + WriteFileTool, +} from '@gemini-code/core'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); async function main() { - const settings = loadSettings(process.cwd()); - const config = await loadCliConfig(settings.merged); - // warn about deprecated environment variables if (process.env.GEMINI_CODE_MODEL) { console.warn('GEMINI_CODE_MODEL is deprecated. Use GEMINI_MODEL instead.'); @@ -43,6 +53,9 @@ async function main() { process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE; } + const settings = loadSettings(process.cwd()); + const config = await loadCliConfig(settings.merged); + if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { // If the theme is not found during initial load, log a warning and continue. @@ -92,26 +105,31 @@ async function main() { process.exit(1); } - // If not a TTY and we have initial input, process it directly - const geminiClient = new GeminiClient(config); - const chat = await geminiClient.startChat(); - try { - for await (const event of geminiClient.sendMessageStream( - chat, - [{ text: input }], - new AbortController().signal, - )) { - if (event.type === 'content') { - process.stdout.write(event.value); - } - // We might need to handle other event types later, but for now, just content. - } - process.stdout.write('\n'); // Add a newline at the end - process.exit(0); - } catch (error) { - console.error('Error processing piped input:', error); - process.exit(1); - } + // Non-interactive mode handled by runNonInteractive + let existingCoreTools = config.getCoreTools(); + existingCoreTools = existingCoreTools || [ + ReadFileTool.Name, + LSTool.Name, + GrepTool.Name, + GlobTool.Name, + EditTool.Name, + WriteFileTool.Name, + WebFetchTool.Name, + WebSearchTool.Name, + ReadManyFilesTool.Name, + ShellTool.Name, + MemoryTool.Name, + ]; + const interactiveTools = [ShellTool.Name, EditTool.Name, WriteFileTool.Name]; + const nonInteractiveTools = existingCoreTools.filter( + (tool) => !interactiveTools.includes(tool), + ); + const nonInteractiveSettings = { + ...settings.merged, + coreTools: nonInteractiveTools, + }; + const nonInteractiveConfig = await loadCliConfig(nonInteractiveSettings); + await runNonInteractive(nonInteractiveConfig, input); } // --- Global Unhandled Rejection Handler --- diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts new file mode 100644 index 00000000..dca3b855 --- /dev/null +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -0,0 +1,224 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { runNonInteractive } from './nonInteractiveCli.js'; +import { Config, GeminiClient, ToolRegistry } from '@gemini-code/core'; +import { GenerateContentResponse, Part, FunctionCall } from '@google/genai'; + +// Mock dependencies +vi.mock('@gemini-code/core', async () => { + const actualCore = + await vi.importActual<typeof import('@gemini-code/core')>( + '@gemini-code/core', + ); + return { + ...actualCore, + GeminiClient: vi.fn(), + ToolRegistry: vi.fn(), + executeToolCall: vi.fn(), + }; +}); + +describe('runNonInteractive', () => { + let mockConfig: Config; + let mockGeminiClient: GeminiClient; + let mockToolRegistry: ToolRegistry; + let mockChat: { + sendMessageStream: ReturnType<typeof vi.fn>; + }; + let mockProcessStdoutWrite: ReturnType<typeof vi.fn>; + let mockProcessExit: ReturnType<typeof vi.fn>; + + beforeEach(() => { + mockChat = { + sendMessageStream: vi.fn(), + }; + mockGeminiClient = { + startChat: vi.fn().mockResolvedValue(mockChat), + } as unknown as GeminiClient; + mockToolRegistry = { + discoverTools: vi.fn().mockResolvedValue(undefined), + getFunctionDeclarations: vi.fn().mockReturnValue([]), + getTool: vi.fn(), + } as unknown as ToolRegistry; + + vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient); + vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry); + + mockConfig = { + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + } as unknown as Config; + + mockProcessStdoutWrite = vi.fn().mockImplementation(() => true); + process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock + mockProcessExit = vi + .fn() + .mockImplementation((_code?: number) => undefined as never); + process.exit = mockProcessExit as any; // Use any for process.exit mock + }); + + afterEach(() => { + vi.restoreAllMocks(); + // Restore original process methods if they were globally patched + // This might require storing the original methods before patching them in beforeEach + }); + + it('should process input and write text output', async () => { + const inputStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + } as GenerateContentResponse; + yield { + candidates: [{ content: { parts: [{ text: ' World' }] } }], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream.mockResolvedValue(inputStream); + + await runNonInteractive(mockConfig, 'Test input'); + + expect(mockGeminiClient.startChat).toHaveBeenCalled(); + expect(mockToolRegistry.discoverTools).toHaveBeenCalled(); + expect(mockChat.sendMessageStream).toHaveBeenCalledWith({ + message: [{ text: 'Test input' }], + config: { + abortSignal: expect.any(AbortSignal), + tools: [{ functionDeclarations: [] }], + }, + }); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello'); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World'); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n'); + }); + + it('should handle a single tool call and respond', async () => { + const functionCall: FunctionCall = { + id: 'fc1', + name: 'testTool', + args: { p: 'v' }, + }; + const toolResponsePart: Part = { + functionResponse: { + name: 'testTool', + id: 'fc1', + response: { result: 'tool success' }, + }, + }; + + const { executeToolCall: mockCoreExecuteToolCall } = await import( + '@gemini-code/core' + ); + vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ + callId: 'fc1', + responseParts: [toolResponsePart], + resultDisplay: 'Tool success display', + error: undefined, + }); + + const stream1 = (async function* () { + yield { functionCalls: [functionCall] } as GenerateContentResponse; + })(); + const stream2 = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Final answer' }] } }], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await runNonInteractive(mockConfig, 'Use a tool'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( + expect.objectContaining({ callId: 'fc1', name: 'testTool' }), + mockToolRegistry, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( + expect.objectContaining({ + message: [toolResponsePart], + }), + ); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer'); + }); + + it('should handle error during tool execution', async () => { + const functionCall: FunctionCall = { + id: 'fcError', + name: 'errorTool', + args: {}, + }; + const errorResponsePart: Part = { + functionResponse: { + name: 'errorTool', + id: 'fcError', + response: { error: 'Tool failed' }, + }, + }; + + const { executeToolCall: mockCoreExecuteToolCall } = await import( + '@gemini-code/core' + ); + vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ + callId: 'fcError', + responseParts: [errorResponsePart], + resultDisplay: 'Tool execution failed badly', + error: new Error('Tool failed'), + }); + + const stream1 = (async function* () { + yield { functionCalls: [functionCall] } as GenerateContentResponse; + })(); + + const stream2 = (async function* () { + yield { + candidates: [ + { content: { parts: [{ text: 'Could not complete request.' }] } }, + ], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await runNonInteractive(mockConfig, 'Trigger tool error'); + + expect(mockCoreExecuteToolCall).toHaveBeenCalled(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool errorTool: Tool execution failed badly', + ); + expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( + expect.objectContaining({ + message: [errorResponsePart], + }), + ); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith( + 'Could not complete request.', + ); + consoleErrorSpy.mockRestore(); + }); + + it('should exit with error if sendMessageStream throws initially', async () => { + const apiError = new Error('API connection failed'); + mockChat.sendMessageStream.mockRejectedValue(apiError); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await runNonInteractive(mockConfig, 'Initial fail'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error processing input:', + apiError, + ); + consoleErrorSpy.mockRestore(); + }); +}); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts new file mode 100644 index 00000000..9077ecbf --- /dev/null +++ b/packages/cli/src/nonInteractiveCli.ts @@ -0,0 +1,114 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + GeminiClient, + ToolCallRequestInfo, + executeToolCall, + ToolRegistry, +} from '@gemini-code/core'; +import { + Content, + Part, + FunctionCall, + GenerateContentResponse, +} from '@google/genai'; + +function getResponseText(response: GenerateContentResponse): string | null { + if (response.candidates && response.candidates.length > 0) { + const candidate = response.candidates[0]; + if ( + candidate.content && + candidate.content.parts && + candidate.content.parts.length > 0 + ) { + return candidate.content.parts + .filter((part) => part.text) + .map((part) => part.text) + .join(''); + } + } + return null; +} + +export async function runNonInteractive( + config: Config, + input: string, +): Promise<void> { + const geminiClient = new GeminiClient(config); + const toolRegistry: ToolRegistry = config.getToolRegistry(); + await toolRegistry.discoverTools(); + + const chat = await geminiClient.startChat(); + const abortController = new AbortController(); + let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; + + try { + while (true) { + const functionCalls: FunctionCall[] = []; + + const responseStream = await chat.sendMessageStream({ + message: currentMessages[0]?.parts || [], // Ensure parts are always provided + config: { + abortSignal: abortController.signal, + tools: [ + { functionDeclarations: toolRegistry.getFunctionDeclarations() }, + ], + }, + }); + + for await (const resp of responseStream) { + if (abortController.signal.aborted) { + console.error('Operation cancelled.'); + return; + } + const textPart = getResponseText(resp); + if (textPart) { + process.stdout.write(textPart); + } + if (resp.functionCalls) { + functionCalls.push(...resp.functionCalls); + } + } + + if (functionCalls.length > 0) { + const toolResponseParts: Part[] = []; + + for (const fc of functionCalls) { + const callId = fc.id ?? `${fc.name}-${Date.now()}`; + const requestInfo: ToolCallRequestInfo = { + callId, + name: fc.name as string, + args: (fc.args ?? {}) as Record<string, unknown>, + }; + + const toolResponse = await executeToolCall( + requestInfo, + toolRegistry, + abortController.signal, + ); + + if (toolResponse.error) { + console.error( + `Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + ); + toolResponseParts.push(...(toolResponse.responseParts as Part[])); + } else { + toolResponseParts.push(...(toolResponse.responseParts as Part[])); + } + } + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + process.stdout.write('\n'); // Ensure a final newline + return; + } + } + } catch (error) { + console.error('Error processing input:', error); + process.exit(1); + } +} |
