diff options
| author | Jacob Richman <[email protected]> | 2025-05-20 13:02:41 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-05-20 13:02:41 -0700 |
| commit | 716f7875a2fe4cec5433f64651a7f50cce58a41e (patch) | |
| tree | b440d482e12bc7efb55a9a813a7c4f6b67e3a117 /packages/cli/src | |
| parent | 4002e980d9e02e973e19226dbc25aeec00a65cf5 (diff) | |
Support Images and PDFs (#447)
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/ui/hooks/atCommandProcessor.test.ts | 323 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/atCommandProcessor.ts | 21 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 4 |
3 files changed, 341 insertions, 7 deletions
diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts new file mode 100644 index 00000000..966fd0fa --- /dev/null +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -0,0 +1,323 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest'; +import { handleAtCommand } from './atCommandProcessor.js'; +import { Config, ToolResult } from '@gemini-code/server'; +import { ToolCallStatus } from '../types.js'; // Adjusted import +import { /* PartListUnion, */ Part } from '@google/genai'; // Removed PartListUnion +import { UseHistoryManagerReturn } from './useHistoryManager.js'; +import * as fsPromises from 'fs/promises'; // Import for mocking stat +import type { Stats } from 'fs'; // Import Stats type for mocking + +// Mock Config and ToolRegistry +const mockGetToolRegistry = vi.fn(); +const mockGetTargetDir = vi.fn(); +const mockConfig = { + getToolRegistry: mockGetToolRegistry, + getTargetDir: mockGetTargetDir, +} as unknown as Config; + +// Mock read_many_files tool +const mockReadManyFilesExecute = vi.fn(); +const mockReadManyFilesTool = { + name: 'read_many_files', + displayName: 'Read Many Files', + description: 'Reads multiple files.', + execute: mockReadManyFilesExecute, + getDescription: vi.fn((params) => `Read files: ${params.paths.join(', ')}`), +}; + +// Mock addItem from useHistoryManager +const mockAddItem: Mock<UseHistoryManagerReturn['addItem']> = vi.fn(); +const mockOnDebugMessage: Mock<(message: string) => void> = vi.fn(); + +vi.mock('fs/promises', async () => { + const actual = await vi.importActual('fs/promises'); + return { + ...actual, + stat: vi.fn(), // Mock stat here + }; +}); + +describe('handleAtCommand', () => { + let abortController: AbortController; + + beforeEach(() => { + vi.resetAllMocks(); + abortController = new AbortController(); + mockGetTargetDir.mockReturnValue('/test/dir'); + mockGetToolRegistry.mockReturnValue({ + getTool: vi.fn((toolName: string) => { + if (toolName === 'read_many_files') { + return mockReadManyFilesTool; + } + return undefined; + }), + }); + // Default mock for fs.stat if not overridden by a specific test + vi.mocked(fsPromises.stat).mockResolvedValue({ + isDirectory: () => false, + } as unknown as Stats); + }); + + afterEach(() => { + abortController.abort(); // Ensure any pending operations are cancelled + }); + + it('should pass through query if no @ command is present', async () => { + const query = 'regular user query'; + const result = await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 123, + signal: abortController.signal, + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 123, + ); + expect(result.processedQuery).toEqual([{ text: query }]); + expect(result.shouldProceed).toBe(true); + expect(mockReadManyFilesExecute).not.toHaveBeenCalled(); + }); + + it('should pass through query if only a lone @ symbol is present', async () => { + const queryWithSpaces = ' @ '; + // const trimmedQuery = queryWithSpaces.trim(); // Not needed for addItem expectation here + const result = await handleAtCommand({ + query: queryWithSpaces, // Pass the version with spaces to the function + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 124, + signal: abortController.signal, + }); + + // For a lone '@', addItem is called with the *original untrimmed* query + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: queryWithSpaces }, + 124, + ); + // processedQuery should also be the original untrimmed version for lone @ + expect(result.processedQuery).toEqual([{ text: queryWithSpaces }]); + expect(result.shouldProceed).toBe(true); + expect(mockOnDebugMessage).toHaveBeenCalledWith( + 'Lone @ detected, passing directly to LLM.', + ); + }); + + it('should process a valid text file path', async () => { + const filePath = 'path/to/file.txt'; + const query = `@${filePath}`; + const fileContent = 'This is the file content.'; + mockReadManyFilesExecute.mockResolvedValue({ + llmContent: fileContent, + returnDisplay: 'Read 1 file.', + } as ToolResult); + // fs.stat will use the default mock (isDirectory: false) + + const result = await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 125, + signal: abortController.signal, + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 125, + ); + expect(mockReadManyFilesExecute).toHaveBeenCalledWith( + { paths: [filePath] }, + abortController.signal, + ); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'tool_group', + tools: expect.arrayContaining([ + expect.objectContaining({ + name: 'Read Many Files', + status: ToolCallStatus.Success, + resultDisplay: 'Read 1 file.', + }), + ]), + }), + 125, + ); + expect(result.processedQuery).toEqual([ + '\n--- Content from: ${contentLabel} ---\n', + fileContent, + '\n--- End of content ---\n', + ]); + expect(result.shouldProceed).toBe(true); + }); + + it('should process a valid directory path and convert to glob', async () => { + const dirPath = 'path/to/dir'; + const query = `@${dirPath}`; + const dirContent = [ + 'Content of file 1.', + 'Content of file 2.', + 'Content of file 3.', + ]; + vi.mocked(fsPromises.stat).mockResolvedValue({ + isDirectory: () => true, + } as unknown as Stats); + + mockReadManyFilesExecute.mockResolvedValue({ + llmContent: dirContent, + returnDisplay: 'Read directory contents.', + } as ToolResult); + + const result = await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 126, + signal: abortController.signal, + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 126, + ); + expect(mockReadManyFilesExecute).toHaveBeenCalledWith( + { paths: [`${dirPath}/**`] }, // Expect glob pattern + abortController.signal, + ); + expect(mockOnDebugMessage).toHaveBeenCalledWith( + `Path resolved to directory, using glob: ${dirPath}/**`, + ); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ type: 'tool_group' }), + 126, + ); + expect(result.processedQuery).toEqual([ + '\n--- Content from: ${contentLabel} ---\n', + ...dirContent, + '\n--- End of content ---\n', + ]); + expect(result.shouldProceed).toBe(true); + }); + + it('should process a valid image file path', async () => { + const imagePath = 'path/to/image.png'; + const query = `@${imagePath}`; + const imageData: Part = { + inlineData: { mimeType: 'image/png', data: 'base64imagedata' }, + }; + mockReadManyFilesExecute.mockResolvedValue({ + llmContent: [imageData], + returnDisplay: 'Read 1 image.', + } as ToolResult); + // fs.stat will use the default mock (isDirectory: false) + + const result = await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 127, + signal: abortController.signal, + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 127, + ); + expect(mockReadManyFilesExecute).toHaveBeenCalledWith( + { paths: [imagePath] }, + abortController.signal, + ); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'tool_group', + tools: expect.arrayContaining([ + expect.objectContaining({ + name: 'Read Many Files', + status: ToolCallStatus.Success, + resultDisplay: 'Read 1 image.', + }), + ]), + }), + 127, + ); + expect(result.processedQuery).toEqual([ + '\n--- Content from: ${contentLabel} ---\n', + imageData, + '\n--- End of content ---\n', + ]); + expect(result.shouldProceed).toBe(true); + }); + + it('should handle query with text before and after @command', async () => { + const textBefore = 'Explain this:'; + const filePath = 'doc.md'; + const textAfter = 'in detail.'; + const query = `${textBefore} @${filePath} ${textAfter}`; + const fileContent = 'Markdown content.'; + mockReadManyFilesExecute.mockResolvedValue({ + llmContent: fileContent, + returnDisplay: 'Read 1 doc.', + } as ToolResult); + // fs.stat will use the default mock (isDirectory: false) + + const result = await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 128, + signal: abortController.signal, + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, // Expect original query for addItem + 128, + ); + expect(result.processedQuery).toEqual([ + { text: textBefore }, + '\n--- Content from: ${contentLabel} ---\n', + fileContent, + '\n--- End of content ---\n', + { text: textAfter }, + ]); + expect(result.shouldProceed).toBe(true); + }); + + it('should correctly unescape paths with escaped spaces', async () => { + const rawPath = 'path/to/my\\ file.txt'; + const unescapedPath = 'path/to/my file.txt'; + const query = `@${rawPath}`; + const fileContent = 'Content of file with space.'; + mockReadManyFilesExecute.mockResolvedValue({ + llmContent: fileContent, + returnDisplay: 'Read 1 file.', + } as ToolResult); + // fs.stat will use the default mock (isDirectory: false) + + await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 129, + signal: abortController.signal, + }); + + expect(mockReadManyFilesExecute).toHaveBeenCalledWith( + { paths: [unescapedPath] }, // Expect unescaped path + abortController.signal, + ); + }); +}); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index e2934840..a5b602ad 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -6,7 +6,7 @@ import * as fs from 'fs/promises'; import * as path from 'path'; -import { PartListUnion } from '@google/genai'; +import { PartListUnion, PartUnion } from '@google/genai'; import { Config, getErrorMessage, @@ -126,6 +126,7 @@ export async function handleAtCommand({ } const contentLabel = pathPart; + const toolRegistry = config.getToolRegistry(); const readManyFilesTool = toolRegistry.getTool('read_many_files'); @@ -168,7 +169,6 @@ export async function handleAtCommand({ try { const result = await readManyFilesTool.execute(toolArgs, signal); - const fileContent = result.llmContent || ''; toolCallDisplay = { callId: `client-read-${userMessageTimestamp}`, @@ -180,13 +180,22 @@ export async function handleAtCommand({ }; // Prepare the query parts for the LLM - const processedQueryParts = []; + const processedQueryParts: PartUnion[] = []; if (textBefore) { processedQueryParts.push({ text: textBefore }); } - processedQueryParts.push({ - text: `\n--- Content from: ${contentLabel} ---\n${fileContent}\n--- End Content ---`, - }); + + // Process the result from the tool + processedQueryParts.push('\n--- Content from: ${contentLabel} ---\n'); + if (Array.isArray(result.llmContent)) { + for (const part of result.llmContent) { + processedQueryParts.push(part); + } + } else { + processedQueryParts.push(result.llmContent); + } + processedQueryParts.push('\n--- End of content ---\n'); + if (textAfter) { processedQueryParts.push({ text: textAfter }); } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 19cb244d..92b055bb 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -25,6 +25,7 @@ import { ToolResultDisplay, ToolEditConfirmationDetails, ToolExecuteConfirmationDetails, + partListUnionToString, } from '@gemini-code/server'; import { type Chat, type PartListUnion, type Part } from '@google/genai'; import { @@ -280,13 +281,14 @@ export const useGeminiStream = ( ); if (abortControllerRef.current.signal.aborted) { declineToolExecution( - result.llmContent, + partListUnionToString(result.llmContent), ToolCallStatus.Canceled, request, originalConfirmationDetails, ); return; } + const functionResponse: Part = { functionResponse: { name: request.name, |
