diff options
| author | Tommaso Sciortino <[email protected]> | 2025-05-30 18:25:47 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-05-30 18:25:47 -0700 |
| commit | 21fba832d1b4ea7af43fb887d9b2b38fcf8210d0 (patch) | |
| tree | 7200d2fac3a55c385e0a2dac34b5282c942364bc /packages/core/src/utils | |
| parent | c81148a0cc8489f657901c2cc7247c0834075e1a (diff) | |
Rename server->core (#638)
Diffstat (limited to 'packages/core/src/utils')
20 files changed, 4623 insertions, 0 deletions
diff --git a/packages/core/src/utils/LruCache.ts b/packages/core/src/utils/LruCache.ts new file mode 100644 index 00000000..076828c4 --- /dev/null +++ b/packages/core/src/utils/LruCache.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export class LruCache<K, V> { + private cache: Map<K, V>; + private maxSize: number; + + constructor(maxSize: number) { + this.cache = new Map<K, V>(); + this.maxSize = maxSize; + } + + get(key: K): V | undefined { + const value = this.cache.get(key); + if (value) { + // Move to end to mark as recently used + this.cache.delete(key); + this.cache.set(key, value); + } + return value; + } + + set(key: K, value: V): void { + if (this.cache.has(key)) { + this.cache.delete(key); + } else if (this.cache.size >= this.maxSize) { + const firstKey = this.cache.keys().next().value; + if (firstKey !== undefined) { + this.cache.delete(firstKey); + } + } + this.cache.set(key, value); + } + + clear(): void { + this.cache.clear(); + } +} diff --git a/packages/core/src/utils/editCorrector.test.ts b/packages/core/src/utils/editCorrector.test.ts new file mode 100644 index 00000000..7d6f5a53 --- /dev/null +++ b/packages/core/src/utils/editCorrector.test.ts @@ -0,0 +1,503 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest'; + +// MOCKS +let callCount = 0; +const mockResponses: any[] = []; + +let mockGenerateJson: any; +let mockStartChat: any; +let mockSendMessageStream: any; + +vi.mock('../core/client.js', () => ({ + GeminiClient: vi.fn().mockImplementation(function ( + this: any, + _config: Config, + ) { + this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson + this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat + this.sendMessageStream = (...params: any[]) => + mockSendMessageStream(...params); // Corrected: use mockSendMessageStream + return this; + }), +})); +// END MOCKS + +import { + countOccurrences, + ensureCorrectEdit, + unescapeStringForGeminiBug, + resetEditCorrectorCaches_TEST_ONLY, +} from './editCorrector.js'; +import { GeminiClient } from '../core/client.js'; +import type { Config } from '../config/config.js'; +import { ToolRegistry } from '../tools/tool-registry.js'; + +vi.mock('../tools/tool-registry.js'); + +describe('editCorrector', () => { + describe('countOccurrences', () => { + it('should return 0 for empty string', () => { + expect(countOccurrences('', 'a')).toBe(0); + }); + it('should return 0 for empty substring', () => { + expect(countOccurrences('abc', '')).toBe(0); + }); + it('should return 0 if substring is not found', () => { + expect(countOccurrences('abc', 'd')).toBe(0); + }); + it('should return 1 if substring is found once', () => { + expect(countOccurrences('abc', 'b')).toBe(1); + }); + it('should return correct count for multiple occurrences', () => { + expect(countOccurrences('ababa', 'a')).toBe(3); + expect(countOccurrences('ababab', 'ab')).toBe(3); + }); + it('should count non-overlapping occurrences', () => { + expect(countOccurrences('aaaaa', 'aa')).toBe(2); + expect(countOccurrences('ababab', 'aba')).toBe(1); + }); + it('should correctly count occurrences when substring is longer', () => { + expect(countOccurrences('abc', 'abcdef')).toBe(0); + }); + it('should be case sensitive', () => { + expect(countOccurrences('abcABC', 'a')).toBe(1); + expect(countOccurrences('abcABC', 'A')).toBe(1); + }); + }); + + describe('unescapeStringForGeminiBug', () => { + it('should unescape common sequences', () => { + expect(unescapeStringForGeminiBug('\\n')).toBe('\n'); + expect(unescapeStringForGeminiBug('\\t')).toBe('\t'); + expect(unescapeStringForGeminiBug("\\'")).toBe("'"); + expect(unescapeStringForGeminiBug('\\"')).toBe('"'); + expect(unescapeStringForGeminiBug('\\`')).toBe('`'); + }); + it('should handle multiple escaped sequences', () => { + expect(unescapeStringForGeminiBug('Hello\\nWorld\\tTest')).toBe( + 'Hello\nWorld\tTest', + ); + }); + it('should not alter already correct sequences', () => { + expect(unescapeStringForGeminiBug('\n')).toBe('\n'); + expect(unescapeStringForGeminiBug('Correct string')).toBe( + 'Correct string', + ); + }); + it('should handle mixed correct and incorrect sequences', () => { + expect(unescapeStringForGeminiBug('\\nCorrect\t\\`')).toBe( + '\nCorrect\t`', + ); + }); + it('should handle backslash followed by actual newline character', () => { + expect(unescapeStringForGeminiBug('\\\n')).toBe('\n'); + expect(unescapeStringForGeminiBug('First line\\\nSecond line')).toBe( + 'First line\nSecond line', + ); + }); + it('should handle multiple backslashes before an escapable character', () => { + expect(unescapeStringForGeminiBug('\\\\n')).toBe('\n'); + expect(unescapeStringForGeminiBug('\\\\\\t')).toBe('\t'); + expect(unescapeStringForGeminiBug('\\\\\\\\`')).toBe('`'); + }); + it('should return empty string for empty input', () => { + expect(unescapeStringForGeminiBug('')).toBe(''); + }); + it('should not alter strings with no targeted escape sequences', () => { + expect(unescapeStringForGeminiBug('abc def')).toBe('abc def'); + expect(unescapeStringForGeminiBug('C:\\Folder\\File')).toBe( + 'C:\\Folder\\File', + ); + }); + it('should correctly process strings with some targeted escapes', () => { + expect(unescapeStringForGeminiBug('C:\\Users\\name')).toBe( + 'C:\\Users\name', + ); + }); + it('should handle complex cases with mixed slashes and characters', () => { + expect( + unescapeStringForGeminiBug('\\\\\\\nLine1\\\nLine2\\tTab\\\\`Tick\\"'), + ).toBe('\nLine1\nLine2\tTab`Tick"'); + }); + }); + + describe('ensureCorrectEdit', () => { + let mockGeminiClientInstance: Mocked<GeminiClient>; + let mockToolRegistry: Mocked<ToolRegistry>; + let mockConfigInstance: Config; + const abortSignal = new AbortController().signal; + + beforeEach(() => { + mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>; + const configParams = { + apiKey: 'test-api-key', + model: 'test-model', + sandbox: false as boolean | string, + targetDir: '/test', + debugMode: false, + question: undefined as string | undefined, + fullContext: false, + coreTools: undefined as string[] | undefined, + toolDiscoveryCommand: undefined as string | undefined, + toolCallCommand: undefined as string | undefined, + mcpServerCommand: undefined as string | undefined, + mcpServers: undefined as Record<string, any> | undefined, + userAgent: 'test-agent', + userMemory: '', + geminiMdFileCount: 0, + alwaysSkipModificationConfirmation: false, + }; + mockConfigInstance = { + ...configParams, + getApiKey: vi.fn(() => configParams.apiKey), + getModel: vi.fn(() => configParams.model), + getSandbox: vi.fn(() => configParams.sandbox), + getTargetDir: vi.fn(() => configParams.targetDir), + getToolRegistry: vi.fn(() => mockToolRegistry), + getDebugMode: vi.fn(() => configParams.debugMode), + getQuestion: vi.fn(() => configParams.question), + getFullContext: vi.fn(() => configParams.fullContext), + getCoreTools: vi.fn(() => configParams.coreTools), + getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand), + getToolCallCommand: vi.fn(() => configParams.toolCallCommand), + getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand), + getMcpServers: vi.fn(() => configParams.mcpServers), + getUserAgent: vi.fn(() => configParams.userAgent), + getUserMemory: vi.fn(() => configParams.userMemory), + setUserMemory: vi.fn((mem: string) => { + configParams.userMemory = mem; + }), + getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount), + setGeminiMdFileCount: vi.fn((count: number) => { + configParams.geminiMdFileCount = count; + }), + getAlwaysSkipModificationConfirmation: vi.fn( + () => configParams.alwaysSkipModificationConfirmation, + ), + setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => { + configParams.alwaysSkipModificationConfirmation = skip; + }), + } as unknown as Config; + + callCount = 0; + mockResponses.length = 0; + mockGenerateJson = vi + .fn() + .mockImplementation((_contents, _schema, signal) => { + // Check if the signal is aborted. If so, throw an error or return a specific response. + if (signal && signal.aborted) { + return Promise.reject(new Error('Aborted')); // Or some other specific error/response + } + const response = mockResponses[callCount]; + callCount++; + if (response === undefined) return Promise.resolve({}); + return Promise.resolve(response); + }); + mockStartChat = vi.fn(); + mockSendMessageStream = vi.fn(); + + mockGeminiClientInstance = new GeminiClient( + mockConfigInstance, + ) as Mocked<GeminiClient>; + resetEditCorrectorCaches_TEST_ONLY(); + }); + + describe('Scenario Group 1: originalParams.old_string matches currentContent directly', () => { + it('Test 1.1: old_string (no literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => { + const currentContent = 'This is a test string to find me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find me', + new_string: 'replace with \\"this\\"', + }; + mockResponses.push({ + corrected_new_string_escaping: 'replace with "this"', + }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe('replace with "this"'); + expect(result.params.old_string).toBe('find me'); + expect(result.occurrences).toBe(1); + }); + it('Test 1.2: old_string (no literal \\), new_string (correctly formatted) -> new_string unchanged', async () => { + const currentContent = 'This is a test string to find me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find me', + new_string: 'replace with this', + }; + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(0); + expect(result.params.new_string).toBe('replace with this'); + expect(result.params.old_string).toBe('find me'); + expect(result.occurrences).toBe(1); + }); + it('Test 1.3: old_string (with literal \\), new_string (escaped by Gemini) -> new_string unchanged (still escaped)', async () => { + const currentContent = 'This is a test string to find\\me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find\\me', + new_string: 'replace with \\"this\\"', + }; + mockResponses.push({ + corrected_new_string_escaping: 'replace with "this"', + }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe('replace with "this"'); + expect(result.params.old_string).toBe('find\\me'); + expect(result.occurrences).toBe(1); + }); + it('Test 1.4: old_string (with literal \\), new_string (correctly formatted) -> new_string unchanged', async () => { + const currentContent = 'This is a test string to find\\me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find\\me', + new_string: 'replace with this', + }; + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(0); + expect(result.params.new_string).toBe('replace with this'); + expect(result.params.old_string).toBe('find\\me'); + expect(result.occurrences).toBe(1); + }); + }); + + describe('Scenario Group 2: originalParams.old_string does NOT match, but unescapeStringForGeminiBug(originalParams.old_string) DOES match', () => { + it('Test 2.1: old_string (over-escaped, no intended literal \\), new_string (escaped by Gemini) -> new_string unescaped', async () => { + const currentContent = 'This is a test string to find "me".'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find \\"me\\"', + new_string: 'replace with \\"this\\"', + }; + mockResponses.push({ corrected_new_string: 'replace with "this"' }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe('replace with "this"'); + expect(result.params.old_string).toBe('find "me"'); + expect(result.occurrences).toBe(1); + }); + it('Test 2.2: old_string (over-escaped, no intended literal \\), new_string (correctly formatted) -> new_string unescaped (harmlessly)', async () => { + const currentContent = 'This is a test string to find "me".'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find \\"me\\"', + new_string: 'replace with this', + }; + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(0); + expect(result.params.new_string).toBe('replace with this'); + expect(result.params.old_string).toBe('find "me"'); + expect(result.occurrences).toBe(1); + }); + it('Test 2.3: old_string (over-escaped, with intended literal \\), new_string (simple) -> new_string corrected', async () => { + const currentContent = 'This is a test string to find \\me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find \\\\me', + new_string: 'replace with foobar', + }; + mockResponses.push({ + corrected_target_snippet: 'find \\me', + }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe('replace with foobar'); + expect(result.params.old_string).toBe('find \\me'); + expect(result.occurrences).toBe(1); + }); + }); + + describe('Scenario Group 3: LLM Correction Path', () => { + it('Test 3.1: old_string (no literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is double unescaped', async () => { + const currentContent = 'This is a test string to corrected find me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find me', + new_string: 'replace with \\\\"this\\\\"', + }; + const llmNewString = 'LLM says replace with "that"'; + mockResponses.push({ corrected_new_string_escaping: llmNewString }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe(llmNewString); + expect(result.params.old_string).toBe('find me'); + expect(result.occurrences).toBe(1); + }); + it('Test 3.2: old_string (with literal \\), new_string (escaped by Gemini), LLM re-escapes new_string -> final new_string is unescaped once', async () => { + const currentContent = 'This is a test string to corrected find me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find\\me', + new_string: 'replace with \\\\"this\\\\"', + }; + const llmCorrectedOldString = 'corrected find me'; + const llmNewString = 'LLM says replace with "that"'; + mockResponses.push({ corrected_target_snippet: llmCorrectedOldString }); + mockResponses.push({ corrected_new_string: llmNewString }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(2); + expect(result.params.new_string).toBe(llmNewString); + expect(result.params.old_string).toBe(llmCorrectedOldString); + expect(result.occurrences).toBe(1); + }); + it('Test 3.3: old_string needs LLM, new_string is fine -> old_string corrected, new_string original', async () => { + const currentContent = 'This is a test string to be corrected.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'fiiind me', + new_string: 'replace with "this"', + }; + const llmCorrectedOldString = 'to be corrected'; + mockResponses.push({ corrected_target_snippet: llmCorrectedOldString }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe('replace with "this"'); + expect(result.params.old_string).toBe(llmCorrectedOldString); + expect(result.occurrences).toBe(1); + }); + it('Test 3.4: LLM correction path, correctNewString returns the originalNewString it was passed (which was unescaped) -> final new_string is unescaped', async () => { + const currentContent = 'This is a test string to corrected find me.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find me', + new_string: 'replace with \\\\"this\\\\"', + }; + const newStringForLLMAndReturnedByLLM = 'replace with "this"'; + mockResponses.push({ + corrected_new_string_escaping: newStringForLLMAndReturnedByLLM, + }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM); + expect(result.occurrences).toBe(1); + }); + }); + + describe('Scenario Group 4: No Match Found / Multiple Matches', () => { + it('Test 4.1: No version of old_string (original, unescaped, LLM-corrected) matches -> returns original params, 0 occurrences', async () => { + const currentContent = 'This content has nothing to find.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'nonexistent string', + new_string: 'some new string', + }; + mockResponses.push({ corrected_target_snippet: 'still nonexistent' }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(result.params).toEqual(originalParams); + expect(result.occurrences).toBe(0); + }); + it('Test 4.2: unescapedOldStringAttempt results in >1 occurrences -> returns original params, count occurrences', async () => { + const currentContent = + 'This content has find "me" and also find "me" again.'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'find "me"', + new_string: 'some new string', + }; + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(0); + expect(result.params).toEqual(originalParams); + expect(result.occurrences).toBe(2); + }); + }); + + describe('Scenario Group 5: Specific unescapeStringForGeminiBug checks (integrated into ensureCorrectEdit)', () => { + it('Test 5.1: old_string needs LLM to become currentContent, new_string also needs correction', async () => { + const currentContent = 'const x = "a\\nbc\\\\"def\\\\"'; + const originalParams = { + file_path: '/test/file.txt', + old_string: 'const x = \\\\"a\\\\nbc\\\\\\\\"def\\\\\\\\"', + new_string: 'const y = \\\\"new\\\\nval\\\\\\\\"content\\\\\\\\"', + }; + const expectedFinalNewString = 'const y = "new\\nval\\\\"content\\\\"'; + mockResponses.push({ corrected_target_snippet: currentContent }); + mockResponses.push({ corrected_new_string: expectedFinalNewString }); + const result = await ensureCorrectEdit( + currentContent, + originalParams, + mockGeminiClientInstance, + abortSignal, + ); + expect(mockGenerateJson).toHaveBeenCalledTimes(2); + expect(result.params.old_string).toBe(currentContent); + expect(result.params.new_string).toBe(expectedFinalNewString); + expect(result.occurrences).toBe(1); + }); + }); + }); +}); diff --git a/packages/core/src/utils/editCorrector.ts b/packages/core/src/utils/editCorrector.ts new file mode 100644 index 00000000..78663954 --- /dev/null +++ b/packages/core/src/utils/editCorrector.ts @@ -0,0 +1,593 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Content, + GenerateContentConfig, + SchemaUnion, + Type, +} from '@google/genai'; +import { GeminiClient } from '../core/client.js'; +import { EditToolParams } from '../tools/edit.js'; +import { LruCache } from './LruCache.js'; + +const EditModel = 'gemini-2.5-flash-preview-04-17'; +const EditConfig: GenerateContentConfig = { + thinkingConfig: { + thinkingBudget: 0, + }, +}; + +const MAX_CACHE_SIZE = 50; + +// Cache for ensureCorrectEdit results +const editCorrectionCache = new LruCache<string, CorrectedEditResult>( + MAX_CACHE_SIZE, +); + +// Cache for ensureCorrectFileContent results +const fileContentCorrectionCache = new LruCache<string, string>(MAX_CACHE_SIZE); + +/** + * Defines the structure of the parameters within CorrectedEditResult + */ +interface CorrectedEditParams { + file_path: string; + old_string: string; + new_string: string; +} + +/** + * Defines the result structure for ensureCorrectEdit. + */ +export interface CorrectedEditResult { + params: CorrectedEditParams; + occurrences: number; +} + +/** + * Attempts to correct edit parameters if the original old_string is not found. + * It tries unescaping, and then LLM-based correction. + * Results are cached to avoid redundant processing. + * + * @param currentContent The current content of the file. + * @param originalParams The original EditToolParams + * @param client The GeminiClient for LLM calls. + * @returns A promise resolving to an object containing the (potentially corrected) + * EditToolParams (as CorrectedEditParams) and the final occurrences count. + */ +export async function ensureCorrectEdit( + currentContent: string, + originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\' + client: GeminiClient, + abortSignal: AbortSignal, +): Promise<CorrectedEditResult> { + const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`; + const cachedResult = editCorrectionCache.get(cacheKey); + if (cachedResult) { + return cachedResult; + } + + let finalNewString = originalParams.new_string; + const newStringPotentiallyEscaped = + unescapeStringForGeminiBug(originalParams.new_string) !== + originalParams.new_string; + + let finalOldString = originalParams.old_string; + let occurrences = countOccurrences(currentContent, finalOldString); + + if (occurrences === 1) { + if (newStringPotentiallyEscaped) { + finalNewString = await correctNewStringEscaping( + client, + finalOldString, + originalParams.new_string, + abortSignal, + ); + } + } else if (occurrences > 1) { + const result: CorrectedEditResult = { + params: { ...originalParams }, + occurrences, + }; + editCorrectionCache.set(cacheKey, result); + return result; + } else { + // occurrences is 0 or some other unexpected state initially + const unescapedOldStringAttempt = unescapeStringForGeminiBug( + originalParams.old_string, + ); + occurrences = countOccurrences(currentContent, unescapedOldStringAttempt); + + if (occurrences === 1) { + finalOldString = unescapedOldStringAttempt; + if (newStringPotentiallyEscaped) { + finalNewString = await correctNewString( + client, + originalParams.old_string, // original old + unescapedOldStringAttempt, // corrected old + originalParams.new_string, // original new (which is potentially escaped) + abortSignal, + ); + } + } else if (occurrences === 0) { + const llmCorrectedOldString = await correctOldStringMismatch( + client, + currentContent, + unescapedOldStringAttempt, + abortSignal, + ); + const llmOldOccurrences = countOccurrences( + currentContent, + llmCorrectedOldString, + ); + + if (llmOldOccurrences === 1) { + finalOldString = llmCorrectedOldString; + occurrences = llmOldOccurrences; + + if (newStringPotentiallyEscaped) { + const baseNewStringForLLMCorrection = unescapeStringForGeminiBug( + originalParams.new_string, + ); + finalNewString = await correctNewString( + client, + originalParams.old_string, // original old + llmCorrectedOldString, // corrected old + baseNewStringForLLMCorrection, // base new for correction + abortSignal, + ); + } + } else { + // LLM correction also failed for old_string + const result: CorrectedEditResult = { + params: { ...originalParams }, + occurrences: 0, // Explicitly 0 as LLM failed + }; + editCorrectionCache.set(cacheKey, result); + return result; + } + } else { + // Unescaping old_string resulted in > 1 occurrences + const result: CorrectedEditResult = { + params: { ...originalParams }, + occurrences, // This will be > 1 + }; + editCorrectionCache.set(cacheKey, result); + return result; + } + } + + const { targetString, pair } = trimPairIfPossible( + finalOldString, + finalNewString, + currentContent, + ); + finalOldString = targetString; + finalNewString = pair; + + // Final result construction + const result: CorrectedEditResult = { + params: { + file_path: originalParams.file_path, + old_string: finalOldString, + new_string: finalNewString, + }, + occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string + }; + editCorrectionCache.set(cacheKey, result); + return result; +} + +export async function ensureCorrectFileContent( + content: string, + client: GeminiClient, + abortSignal: AbortSignal, +): Promise<string> { + const cachedResult = fileContentCorrectionCache.get(content); + if (cachedResult) { + return cachedResult; + } + + const contentPotentiallyEscaped = + unescapeStringForGeminiBug(content) !== content; + if (!contentPotentiallyEscaped) { + fileContentCorrectionCache.set(content, content); + return content; + } + + const correctedContent = await correctStringEscaping( + content, + client, + abortSignal, + ); + fileContentCorrectionCache.set(content, correctedContent); + return correctedContent; +} + +// Define the expected JSON schema for the LLM response for old_string correction +const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = { + type: Type.OBJECT, + properties: { + corrected_target_snippet: { + type: Type.STRING, + description: + 'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.', + }, + }, + required: ['corrected_target_snippet'], +}; + +export async function correctOldStringMismatch( + geminiClient: GeminiClient, + fileContent: string, + problematicSnippet: string, + abortSignal: AbortSignal, +): Promise<string> { + const prompt = ` +Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped. + +Task: Analyze the provided file content and the problematic target snippet. Identify the segment in the file content that the snippet was *most likely* intended to match. Output the *exact*, literal text of that segment from the file content. Focus *only* on removing extra escape characters and correcting formatting, whitespace, or minor differences to achieve a PERFECT literal match. The output must be the exact literal text as it appears in the file. + +Problematic target snippet: +\`\`\` +${problematicSnippet} +\`\`\` + +File Content: +\`\`\` +${fileContent} +\`\`\` + +For example, if the problematic target snippet was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and the file content had content that looked like "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", then corrected_target_snippet should likely be "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;" to fix the incorrect escaping to match the original file content. +If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_target_snippet. + +Return ONLY the corrected target snippet in the specified JSON format with the key 'corrected_target_snippet'. If no clear, unique match can be found, return an empty string for 'corrected_target_snippet'. +`.trim(); + + const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }]; + + try { + const result = await geminiClient.generateJson( + contents, + OLD_STRING_CORRECTION_SCHEMA, + abortSignal, + EditModel, + EditConfig, + ); + + if ( + result && + typeof result.corrected_target_snippet === 'string' && + result.corrected_target_snippet.length > 0 + ) { + return result.corrected_target_snippet; + } else { + return problematicSnippet; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + console.error( + 'Error during LLM call for old string snippet correction:', + error, + ); + + return problematicSnippet; + } +} + +// Define the expected JSON schema for the new_string correction LLM response +const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = { + type: Type.OBJECT, + properties: { + corrected_new_string: { + type: Type.STRING, + description: + 'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.', + }, + }, + required: ['corrected_new_string'], +}; + +/** + * Adjusts the new_string to align with a corrected old_string, maintaining the original intent. + */ +export async function correctNewString( + geminiClient: GeminiClient, + originalOldString: string, + correctedOldString: string, + originalNewString: string, + abortSignal: AbortSignal, +): Promise<string> { + if (originalOldString === correctedOldString) { + return originalNewString; + } + + const prompt = ` +Context: A text replacement operation was planned. The original text to be replaced (original_old_string) was slightly different from the actual text in the file (corrected_old_string). The original_old_string has now been corrected to match the file content. +We now need to adjust the replacement text (original_new_string) so that it makes sense as a replacement for the corrected_old_string, while preserving the original intent of the change. + +original_old_string (what was initially intended to be found): +\`\`\` +${originalOldString} +\`\`\` + +corrected_old_string (what was actually found in the file and will be replaced): +\`\`\` +${correctedOldString} +\`\`\` + +original_new_string (what was intended to replace original_old_string): +\`\`\` +${originalNewString} +\`\`\` + +Task: Based on the differences between original_old_string and corrected_old_string, and the content of original_new_string, generate a corrected_new_string. This corrected_new_string should be what original_new_string would have been if it was designed to replace corrected_old_string directly, while maintaining the spirit of the original transformation. + +For example, if original_old_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name}\\\\\`\`;" and corrected_old_string is "\nconst greeting = \`Hello ${'\\`'}\${name}${'\\`'}\`;", and original_new_string was "\\\\\\nconst greeting = \`Hello \\\\\`\${name} \${lastName}\\\\\`\`;", then corrected_new_string should likely be "\nconst greeting = \`Hello ${'\\`'}\${name} \${lastName}${'\\`'}\`;" to fix the incorrect escaping. +If the differences are only in whitespace or formatting, apply similar whitespace/formatting changes to the corrected_new_string. + +Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string'. If no adjustment is deemed necessary or possible, return the original_new_string. + `.trim(); + + const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }]; + + try { + const result = await geminiClient.generateJson( + contents, + NEW_STRING_CORRECTION_SCHEMA, + abortSignal, + EditModel, + EditConfig, + ); + + if ( + result && + typeof result.corrected_new_string === 'string' && + result.corrected_new_string.length > 0 + ) { + return result.corrected_new_string; + } else { + return originalNewString; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + console.error('Error during LLM call for new_string correction:', error); + return originalNewString; + } +} + +const CORRECT_NEW_STRING_ESCAPING_SCHEMA: SchemaUnion = { + type: Type.OBJECT, + properties: { + corrected_new_string_escaping: { + type: Type.STRING, + description: + 'The new_string with corrected escaping, ensuring it is a proper replacement for the old_string, especially considering potential over-escaping issues from previous LLM generations.', + }, + }, + required: ['corrected_new_string_escaping'], +}; + +export async function correctNewStringEscaping( + geminiClient: GeminiClient, + oldString: string, + potentiallyProblematicNewString: string, + abortSignal: AbortSignal, +): Promise<string> { + const prompt = ` +Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). + +old_string (this is the exact text that will be replaced): +\`\`\` +${oldString} +\`\`\` + +potentially_problematic_new_string (this is the text that should replace old_string, but MIGHT have bad escaping, or might be entirely correct): +\`\`\` +${potentiallyProblematicNewString} +\`\`\` + +Task: Analyze the potentially_problematic_new_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the new_string, when inserted into the code, will be a valid and correctly interpreted. + +For example, if old_string is "foo" and potentially_problematic_new_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz". +If potentially_problematic_new_string is console.log(\\"Hello World\\"), it should be console.log("Hello World"). + +Return ONLY the corrected string in the specified JSON format with the key 'corrected_new_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_new_string. + `.trim(); + + const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }]; + + try { + const result = await geminiClient.generateJson( + contents, + CORRECT_NEW_STRING_ESCAPING_SCHEMA, + abortSignal, + EditModel, + EditConfig, + ); + + if ( + result && + typeof result.corrected_new_string_escaping === 'string' && + result.corrected_new_string_escaping.length > 0 + ) { + return result.corrected_new_string_escaping; + } else { + return potentiallyProblematicNewString; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + console.error( + 'Error during LLM call for new_string escaping correction:', + error, + ); + return potentiallyProblematicNewString; + } +} + +const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = { + type: Type.OBJECT, + properties: { + corrected_string_escaping: { + type: Type.STRING, + description: + 'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.', + }, + }, + required: ['corrected_string_escaping'], +}; + +export async function correctStringEscaping( + potentiallyProblematicString: string, + client: GeminiClient, + abortSignal: AbortSignal, +): Promise<string> { + const prompt = ` +Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). + +potentially_problematic_string (this text MIGHT have bad escaping, or might be entirely correct): +\`\`\` +${potentiallyProblematicString} +\`\`\` + +Task: Analyze the potentially_problematic_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the text will be a valid and correctly interpreted. + +For example, if potentially_problematic_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz". +If potentially_problematic_string is console.log(\\"Hello World\\"), it should be console.log("Hello World"). + +Return ONLY the corrected string in the specified JSON format with the key 'corrected_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_string. + `.trim(); + + const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }]; + + try { + const result = await client.generateJson( + contents, + CORRECT_STRING_ESCAPING_SCHEMA, + abortSignal, + EditModel, + EditConfig, + ); + + if ( + result && + typeof result.corrected_new_string_escaping === 'string' && + result.corrected_new_string_escaping.length > 0 + ) { + return result.corrected_new_string_escaping; + } else { + return potentiallyProblematicString; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + console.error( + 'Error during LLM call for string escaping correction:', + error, + ); + return potentiallyProblematicString; + } +} + +function trimPairIfPossible( + target: string, + trimIfTargetTrims: string, + currentContent: string, +) { + const trimmedTargetString = target.trim(); + if (target.length !== trimmedTargetString.length) { + const trimmedTargetOccurrences = countOccurrences( + currentContent, + trimmedTargetString, + ); + + if (trimmedTargetOccurrences === 1) { + const trimmedReactiveString = trimIfTargetTrims.trim(); + return { + targetString: trimmedTargetString, + pair: trimmedReactiveString, + }; + } + } + + return { + targetString: target, + pair: trimIfTargetTrims, + }; +} + +/** + * Unescapes a string that might have been overly escaped by an LLM. + */ +export function unescapeStringForGeminiBug(inputString: string): string { + // Regex explanation: + // \\+ : Matches one or more literal backslash characters. + // (n|t|r|'|"|`|\n) : This is a capturing group. It matches one of the following: + // n, t, r, ', ", ` : These match the literal characters 'n', 't', 'r', single quote, double quote, or backtick. + // This handles cases like "\\n", "\\\\`", etc. + // \n : This matches an actual newline character. This handles cases where the input + // string might have something like "\\\n" (a literal backslash followed by a newline). + // g : Global flag, to replace all occurrences. + + return inputString.replace(/\\+(n|t|r|'|"|`|\n)/g, (match, capturedChar) => { + // 'match' is the entire erroneous sequence, e.g., if the input (in memory) was "\\\\`", match is "\\\\`". + // 'capturedChar' is the character that determines the true meaning, e.g., '`'. + + switch (capturedChar) { + case 'n': + return '\n'; // Correctly escaped: \n (newline character) + case 't': + return '\t'; // Correctly escaped: \t (tab character) + case 'r': + return '\r'; // Correctly escaped: \r (carriage return character) + case "'": + return "'"; // Correctly escaped: ' (apostrophe character) + case '"': + return '"'; // Correctly escaped: " (quotation mark character) + case '`': + return '`'; // Correctly escaped: ` (backtick character) + case '\n': // This handles when 'capturedChar' is an actual newline + return '\n'; // Replace the whole erroneous sequence (e.g., "\\\n" in memory) with a clean newline + default: + // This fallback should ideally not be reached if the regex captures correctly. + // It would return the original matched sequence if an unexpected character was captured. + return match; + } + }); +} + +/** + * Counts occurrences of a substring in a string + */ +export function countOccurrences(str: string, substr: string): number { + if (substr === '') { + return 0; + } + let count = 0; + let pos = str.indexOf(substr); + while (pos !== -1) { + count++; + pos = str.indexOf(substr, pos + substr.length); // Start search after the current match + } + return count; +} + +export function resetEditCorrectorCaches_TEST_ONLY() { + editCorrectionCache.clear(); + fileContentCorrectionCache.clear(); +} diff --git a/packages/core/src/utils/errorReporting.test.ts b/packages/core/src/utils/errorReporting.test.ts new file mode 100644 index 00000000..1faba5f6 --- /dev/null +++ b/packages/core/src/utils/errorReporting.test.ts @@ -0,0 +1,220 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest'; + +// Use a type alias for SpyInstance as it's not directly exported +type SpyInstance = ReturnType<typeof vi.spyOn>; +import { reportError } from './errorReporting.js'; +import fs from 'node:fs/promises'; +import os from 'node:os'; + +// Mock dependencies +vi.mock('node:fs/promises'); +vi.mock('node:os'); + +describe('reportError', () => { + let consoleErrorSpy: SpyInstance; + const MOCK_TMP_DIR = '/tmp'; + const MOCK_TIMESTAMP = '2025-01-01T00-00-00-000Z'; + + beforeEach(() => { + vi.resetAllMocks(); + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + (os.tmpdir as Mock).mockReturnValue(MOCK_TMP_DIR); + vi.spyOn(Date.prototype, 'toISOString').mockReturnValue(MOCK_TIMESTAMP); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + vi.restoreAllMocks(); + }); + + const getExpectedReportPath = (type: string) => + `${MOCK_TMP_DIR}/gemini-client-error-${type}-${MOCK_TIMESTAMP}.json`; + + it('should generate a report and log the path', async () => { + const error = new Error('Test error'); + error.stack = 'Test stack'; + const baseMessage = 'An error occurred.'; + const context = { data: 'test context' }; + const type = 'test-type'; + const expectedReportPath = getExpectedReportPath(type); + + (fs.writeFile as Mock).mockResolvedValue(undefined); + + await reportError(error, baseMessage, context, type); + + expect(os.tmpdir).toHaveBeenCalledTimes(1); + expect(fs.writeFile).toHaveBeenCalledWith( + expectedReportPath, + JSON.stringify( + { + error: { message: 'Test error', stack: error.stack }, + context, + }, + null, + 2, + ), + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Full report available at: ${expectedReportPath}`, + ); + }); + + it('should handle errors that are plain objects with a message property', async () => { + const error = { message: 'Test plain object error' }; + const baseMessage = 'Another error.'; + const type = 'general'; + const expectedReportPath = getExpectedReportPath(type); + + (fs.writeFile as Mock).mockResolvedValue(undefined); + await reportError(error, baseMessage); + + expect(fs.writeFile).toHaveBeenCalledWith( + expectedReportPath, + JSON.stringify( + { + error: { message: 'Test plain object error' }, + }, + null, + 2, + ), + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Full report available at: ${expectedReportPath}`, + ); + }); + + it('should handle string errors', async () => { + const error = 'Just a string error'; + const baseMessage = 'String error occurred.'; + const type = 'general'; + const expectedReportPath = getExpectedReportPath(type); + + (fs.writeFile as Mock).mockResolvedValue(undefined); + await reportError(error, baseMessage); + + expect(fs.writeFile).toHaveBeenCalledWith( + expectedReportPath, + JSON.stringify( + { + error: { message: 'Just a string error' }, + }, + null, + 2, + ), + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Full report available at: ${expectedReportPath}`, + ); + }); + + it('should log fallback message if writing report fails', async () => { + const error = new Error('Main error'); + const baseMessage = 'Failed operation.'; + const writeError = new Error('Failed to write file'); + const context = ['some context']; + const type = 'general'; + const expectedReportPath = getExpectedReportPath(type); + + (fs.writeFile as Mock).mockRejectedValue(writeError); + + await reportError(error, baseMessage, context, type); + + expect(fs.writeFile).toHaveBeenCalledWith( + expectedReportPath, + expect.any(String), + ); // It still tries to write + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Additionally, failed to write detailed error report:`, + writeError, + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Original error that triggered report generation:', + error, + ); + expect(consoleErrorSpy).toHaveBeenCalledWith('Original context:', context); + }); + + it('should handle stringification failure of report content (e.g. BigInt in context)', async () => { + const error = new Error('Main error'); + error.stack = 'Main stack'; + const baseMessage = 'Failed operation with BigInt.'; + const context = { a: BigInt(1) }; // BigInt cannot be stringified by JSON.stringify + const type = 'bigint-fail'; + const stringifyError = new TypeError( + 'Do not know how to serialize a BigInt', + ); + const expectedMinimalReportPath = getExpectedReportPath(type); + + // Simulate JSON.stringify throwing an error for the full report + const originalJsonStringify = JSON.stringify; + let callCount = 0; + vi.spyOn(JSON, 'stringify').mockImplementation((value, replacer, space) => { + callCount++; + if (callCount === 1) { + // First call is for the full report content + throw stringifyError; + } + // Subsequent calls (for minimal report) should succeed + return originalJsonStringify(value, replacer, space); + }); + + (fs.writeFile as Mock).mockResolvedValue(undefined); // Mock for the minimal report write + + await reportError(error, baseMessage, context, type); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Could not stringify report content (likely due to context):`, + stringifyError, + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Original error that triggered report generation:', + error, + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Original context could not be stringified or included in report.', + ); + // Check that it attempts to write a minimal report + expect(fs.writeFile).toHaveBeenCalledWith( + expectedMinimalReportPath, + originalJsonStringify( + { error: { message: error.message, stack: error.stack } }, + null, + 2, + ), + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Partial report (excluding context) available at: ${expectedMinimalReportPath}`, + ); + }); + + it('should generate a report without context if context is not provided', async () => { + const error = new Error('Error without context'); + error.stack = 'No context stack'; + const baseMessage = 'Simple error.'; + const type = 'general'; + const expectedReportPath = getExpectedReportPath(type); + + (fs.writeFile as Mock).mockResolvedValue(undefined); + await reportError(error, baseMessage, undefined, type); + + expect(fs.writeFile).toHaveBeenCalledWith( + expectedReportPath, + JSON.stringify( + { + error: { message: 'Error without context', stack: error.stack }, + }, + null, + 2, + ), + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `${baseMessage} Full report available at: ${expectedReportPath}`, + ); + }); +}); diff --git a/packages/core/src/utils/errorReporting.ts b/packages/core/src/utils/errorReporting.ts new file mode 100644 index 00000000..41ce3468 --- /dev/null +++ b/packages/core/src/utils/errorReporting.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import fs from 'node:fs/promises'; +import os from 'node:os'; +import path from 'node:path'; +import { Content } from '@google/genai'; + +interface ErrorReportData { + error: { message: string; stack?: string } | { message: string }; + context?: unknown; + additionalInfo?: Record<string, unknown>; +} + +/** + * Generates an error report, writes it to a temporary file, and logs information to console.error. + * @param error The error object. + * @param context The relevant context (e.g., chat history, request contents). + * @param type A string to identify the type of error (e.g., 'startChat', 'generateJson-api'). + * @param baseMessage The initial message to log to console.error before the report path. + */ +export async function reportError( + error: Error | unknown, + baseMessage: string, + context?: Content[] | Record<string, unknown> | unknown[], + type = 'general', +): Promise<void> { + const timestamp = new Date().toISOString().replace(/[:.]/g, '-'); + const reportFileName = `gemini-client-error-${type}-${timestamp}.json`; + const reportPath = path.join(os.tmpdir(), reportFileName); + + let errorToReport: { message: string; stack?: string }; + if (error instanceof Error) { + errorToReport = { message: error.message, stack: error.stack }; + } else if ( + typeof error === 'object' && + error !== null && + 'message' in error + ) { + errorToReport = { + message: String((error as { message: unknown }).message), + }; + } else { + errorToReport = { message: String(error) }; + } + + const reportContent: ErrorReportData = { error: errorToReport }; + + if (context) { + reportContent.context = context; + } + + let stringifiedReportContent: string; + try { + stringifiedReportContent = JSON.stringify(reportContent, null, 2); + } catch (stringifyError) { + // This can happen if context contains something like BigInt + console.error( + `${baseMessage} Could not stringify report content (likely due to context):`, + stringifyError, + ); + console.error('Original error that triggered report generation:', error); + if (context) { + console.error( + 'Original context could not be stringified or included in report.', + ); + } + // Fallback: try to report only the error if context was the issue + try { + const minimalReportContent = { error: errorToReport }; + stringifiedReportContent = JSON.stringify(minimalReportContent, null, 2); + // Still try to write the minimal report + await fs.writeFile(reportPath, stringifiedReportContent); + console.error( + `${baseMessage} Partial report (excluding context) available at: ${reportPath}`, + ); + } catch (minimalWriteError) { + console.error( + `${baseMessage} Failed to write even a minimal error report:`, + minimalWriteError, + ); + } + return; + } + + try { + await fs.writeFile(reportPath, stringifiedReportContent); + console.error(`${baseMessage} Full report available at: ${reportPath}`); + } catch (writeError) { + console.error( + `${baseMessage} Additionally, failed to write detailed error report:`, + writeError, + ); + // Log the original error as a fallback if report writing fails + console.error('Original error that triggered report generation:', error); + if (context) { + // Context was stringifiable, but writing the file failed. + // We already have stringifiedReportContent, but it might be too large for console. + // So, we try to log the original context object, and if that fails, its stringified version (truncated). + try { + console.error('Original context:', context); + } catch { + try { + console.error( + 'Original context (stringified, truncated):', + JSON.stringify(context).substring(0, 1000), + ); + } catch { + console.error('Original context could not be logged or stringified.'); + } + } + } + } +} diff --git a/packages/core/src/utils/errors.ts b/packages/core/src/utils/errors.ts new file mode 100644 index 00000000..32139c1a --- /dev/null +++ b/packages/core/src/utils/errors.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export function isNodeError(error: unknown): error is NodeJS.ErrnoException { + return error instanceof Error && 'code' in error; +} + +export function getErrorMessage(error: unknown): string { + if (error instanceof Error) { + return error.message; + } else { + try { + const errorMessage = String(error); + return errorMessage; + } catch { + return 'Failed to get error details'; + } + } +} diff --git a/packages/core/src/utils/fileUtils.test.ts b/packages/core/src/utils/fileUtils.test.ts new file mode 100644 index 00000000..4f4c7c1e --- /dev/null +++ b/packages/core/src/utils/fileUtils.test.ts @@ -0,0 +1,431 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; + +import * as actualNodeFs from 'node:fs'; // For setup/teardown +import fsPromises from 'node:fs/promises'; +import path from 'node:path'; +import os from 'node:os'; +import mime from 'mime-types'; + +import { + isWithinRoot, + isBinaryFile, + detectFileType, + processSingleFileContent, +} from './fileUtils.js'; + +vi.mock('mime-types', () => ({ + default: { lookup: vi.fn() }, + lookup: vi.fn(), +})); + +const mockMimeLookup = mime.lookup as Mock; + +describe('fileUtils', () => { + let tempRootDir: string; + const originalProcessCwd = process.cwd; + + let testTextFilePath: string; + let testImageFilePath: string; + let testPdfFilePath: string; + let testBinaryFilePath: string; + let nonExistentFilePath: string; + let directoryPath: string; + + beforeEach(() => { + vi.resetAllMocks(); // Reset all mocks, including mime.lookup + + tempRootDir = actualNodeFs.mkdtempSync( + path.join(os.tmpdir(), 'fileUtils-test-'), + ); + process.cwd = vi.fn(() => tempRootDir); // Mock cwd if necessary for relative path logic within tests + + testTextFilePath = path.join(tempRootDir, 'test.txt'); + testImageFilePath = path.join(tempRootDir, 'image.png'); + testPdfFilePath = path.join(tempRootDir, 'document.pdf'); + testBinaryFilePath = path.join(tempRootDir, 'app.exe'); + nonExistentFilePath = path.join(tempRootDir, 'notfound.txt'); + directoryPath = path.join(tempRootDir, 'subdir'); + + actualNodeFs.mkdirSync(directoryPath, { recursive: true }); // Ensure subdir exists + }); + + afterEach(() => { + if (actualNodeFs.existsSync(tempRootDir)) { + actualNodeFs.rmSync(tempRootDir, { recursive: true, force: true }); + } + process.cwd = originalProcessCwd; + vi.restoreAllMocks(); // Restore any spies + }); + + describe('isWithinRoot', () => { + const root = path.resolve('/project/root'); + + it('should return true for paths directly within the root', () => { + expect(isWithinRoot(path.join(root, 'file.txt'), root)).toBe(true); + expect(isWithinRoot(path.join(root, 'subdir', 'file.txt'), root)).toBe( + true, + ); + }); + + it('should return true for the root path itself', () => { + expect(isWithinRoot(root, root)).toBe(true); + }); + + it('should return false for paths outside the root', () => { + expect( + isWithinRoot(path.resolve('/project/other', 'file.txt'), root), + ).toBe(false); + expect(isWithinRoot(path.resolve('/unrelated', 'file.txt'), root)).toBe( + false, + ); + }); + + it('should return false for paths that only partially match the root prefix', () => { + expect( + isWithinRoot( + path.resolve('/project/root-but-actually-different'), + root, + ), + ).toBe(false); + }); + + it('should handle paths with trailing slashes correctly', () => { + expect(isWithinRoot(path.join(root, 'file.txt') + path.sep, root)).toBe( + true, + ); + expect(isWithinRoot(root + path.sep, root)).toBe(true); + }); + + it('should handle different path separators (POSIX vs Windows)', () => { + const posixRoot = '/project/root'; + const posixPathInside = '/project/root/file.txt'; + const posixPathOutside = '/project/other/file.txt'; + expect(isWithinRoot(posixPathInside, posixRoot)).toBe(true); + expect(isWithinRoot(posixPathOutside, posixRoot)).toBe(false); + }); + + it('should return false for a root path that is a sub-path of the path to check', () => { + const pathToCheck = path.resolve('/project/root/sub'); + const rootSub = path.resolve('/project/root'); + expect(isWithinRoot(pathToCheck, rootSub)).toBe(true); + + const pathToCheckSuper = path.resolve('/project/root'); + const rootSuper = path.resolve('/project/root/sub'); + expect(isWithinRoot(pathToCheckSuper, rootSuper)).toBe(false); + }); + }); + + describe('isBinaryFile', () => { + let filePathForBinaryTest: string; + + beforeEach(() => { + filePathForBinaryTest = path.join(tempRootDir, 'binaryCheck.tmp'); + }); + + afterEach(() => { + if (actualNodeFs.existsSync(filePathForBinaryTest)) { + actualNodeFs.unlinkSync(filePathForBinaryTest); + } + }); + + it('should return false for an empty file', () => { + actualNodeFs.writeFileSync(filePathForBinaryTest, ''); + expect(isBinaryFile(filePathForBinaryTest)).toBe(false); + }); + + it('should return false for a typical text file', () => { + actualNodeFs.writeFileSync( + filePathForBinaryTest, + 'Hello, world!\nThis is a test file with normal text content.', + ); + expect(isBinaryFile(filePathForBinaryTest)).toBe(false); + }); + + it('should return true for a file with many null bytes', () => { + const binaryContent = Buffer.from([ + 0x48, 0x65, 0x00, 0x6c, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, + ]); // "He\0llo\0\0\0\0\0" + actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent); + expect(isBinaryFile(filePathForBinaryTest)).toBe(true); + }); + + it('should return true for a file with high percentage of non-printable ASCII', () => { + const binaryContent = Buffer.from([ + 0x41, 0x42, 0x01, 0x02, 0x03, 0x04, 0x05, 0x43, 0x44, 0x06, + ]); // AB\x01\x02\x03\x04\x05CD\x06 + actualNodeFs.writeFileSync(filePathForBinaryTest, binaryContent); + expect(isBinaryFile(filePathForBinaryTest)).toBe(true); + }); + + it('should return false if file access fails (e.g., ENOENT)', () => { + // Ensure the file does not exist + if (actualNodeFs.existsSync(filePathForBinaryTest)) { + actualNodeFs.unlinkSync(filePathForBinaryTest); + } + expect(isBinaryFile(filePathForBinaryTest)).toBe(false); + }); + }); + + describe('detectFileType', () => { + let filePathForDetectTest: string; + + beforeEach(() => { + filePathForDetectTest = path.join(tempRootDir, 'detectType.tmp'); + // Default: create as a text file for isBinaryFile fallback + actualNodeFs.writeFileSync(filePathForDetectTest, 'Plain text content'); + }); + + afterEach(() => { + if (actualNodeFs.existsSync(filePathForDetectTest)) { + actualNodeFs.unlinkSync(filePathForDetectTest); + } + vi.restoreAllMocks(); // Restore spies on actualNodeFs + }); + + it('should detect image type by extension (png)', () => { + mockMimeLookup.mockReturnValueOnce('image/png'); + expect(detectFileType('file.png')).toBe('image'); + }); + + it('should detect image type by extension (jpeg)', () => { + mockMimeLookup.mockReturnValueOnce('image/jpeg'); + expect(detectFileType('file.jpg')).toBe('image'); + }); + + it('should detect pdf type by extension', () => { + mockMimeLookup.mockReturnValueOnce('application/pdf'); + expect(detectFileType('file.pdf')).toBe('pdf'); + }); + + it('should detect known binary extensions as binary (e.g. .zip)', () => { + mockMimeLookup.mockReturnValueOnce('application/zip'); + expect(detectFileType('archive.zip')).toBe('binary'); + }); + it('should detect known binary extensions as binary (e.g. .exe)', () => { + mockMimeLookup.mockReturnValueOnce('application/octet-stream'); // Common for .exe + expect(detectFileType('app.exe')).toBe('binary'); + }); + + it('should use isBinaryFile for unknown extensions and detect as binary', () => { + mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type + // Create a file that isBinaryFile will identify as binary + const binaryContent = Buffer.from([ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, + ]); + actualNodeFs.writeFileSync(filePathForDetectTest, binaryContent); + expect(detectFileType(filePathForDetectTest)).toBe('binary'); + }); + + it('should default to text if mime type is unknown and content is not binary', () => { + mockMimeLookup.mockReturnValueOnce(false); // Unknown mime type + // filePathForDetectTest is already a text file by default from beforeEach + expect(detectFileType(filePathForDetectTest)).toBe('text'); + }); + }); + + describe('processSingleFileContent', () => { + beforeEach(() => { + // Ensure files exist for statSync checks before readFile might be mocked + if (actualNodeFs.existsSync(testTextFilePath)) + actualNodeFs.unlinkSync(testTextFilePath); + if (actualNodeFs.existsSync(testImageFilePath)) + actualNodeFs.unlinkSync(testImageFilePath); + if (actualNodeFs.existsSync(testPdfFilePath)) + actualNodeFs.unlinkSync(testPdfFilePath); + if (actualNodeFs.existsSync(testBinaryFilePath)) + actualNodeFs.unlinkSync(testBinaryFilePath); + }); + + it('should read a text file successfully', async () => { + const content = 'Line 1\\nLine 2\\nLine 3'; + actualNodeFs.writeFileSync(testTextFilePath, content); + const result = await processSingleFileContent( + testTextFilePath, + tempRootDir, + ); + expect(result.llmContent).toBe(content); + expect(result.returnDisplay).toBe(''); + expect(result.error).toBeUndefined(); + }); + + it('should handle file not found', async () => { + const result = await processSingleFileContent( + nonExistentFilePath, + tempRootDir, + ); + expect(result.error).toContain('File not found'); + expect(result.returnDisplay).toContain('File not found'); + }); + + it('should handle read errors for text files', async () => { + actualNodeFs.writeFileSync(testTextFilePath, 'content'); // File must exist for initial statSync + const readError = new Error('Simulated read error'); + vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError); + + const result = await processSingleFileContent( + testTextFilePath, + tempRootDir, + ); + expect(result.error).toContain('Simulated read error'); + expect(result.returnDisplay).toContain('Simulated read error'); + }); + + it('should handle read errors for image/pdf files', async () => { + actualNodeFs.writeFileSync(testImageFilePath, 'content'); // File must exist + mockMimeLookup.mockReturnValue('image/png'); + const readError = new Error('Simulated image read error'); + vi.spyOn(fsPromises, 'readFile').mockRejectedValueOnce(readError); + + const result = await processSingleFileContent( + testImageFilePath, + tempRootDir, + ); + expect(result.error).toContain('Simulated image read error'); + expect(result.returnDisplay).toContain('Simulated image read error'); + }); + + it('should process an image file', async () => { + const fakePngData = Buffer.from('fake png data'); + actualNodeFs.writeFileSync(testImageFilePath, fakePngData); + mockMimeLookup.mockReturnValue('image/png'); + const result = await processSingleFileContent( + testImageFilePath, + tempRootDir, + ); + expect( + (result.llmContent as { inlineData: unknown }).inlineData, + ).toBeDefined(); + expect( + (result.llmContent as { inlineData: { mimeType: string } }).inlineData + .mimeType, + ).toBe('image/png'); + expect( + (result.llmContent as { inlineData: { data: string } }).inlineData.data, + ).toBe(fakePngData.toString('base64')); + expect(result.returnDisplay).toContain('Read image file: image.png'); + }); + + it('should process a PDF file', async () => { + const fakePdfData = Buffer.from('fake pdf data'); + actualNodeFs.writeFileSync(testPdfFilePath, fakePdfData); + mockMimeLookup.mockReturnValue('application/pdf'); + const result = await processSingleFileContent( + testPdfFilePath, + tempRootDir, + ); + expect( + (result.llmContent as { inlineData: unknown }).inlineData, + ).toBeDefined(); + expect( + (result.llmContent as { inlineData: { mimeType: string } }).inlineData + .mimeType, + ).toBe('application/pdf'); + expect( + (result.llmContent as { inlineData: { data: string } }).inlineData.data, + ).toBe(fakePdfData.toString('base64')); + expect(result.returnDisplay).toContain('Read pdf file: document.pdf'); + }); + + it('should skip binary files', async () => { + actualNodeFs.writeFileSync( + testBinaryFilePath, + Buffer.from([0x00, 0x01, 0x02]), + ); + mockMimeLookup.mockReturnValueOnce('application/octet-stream'); + // isBinaryFile will operate on the real file. + + const result = await processSingleFileContent( + testBinaryFilePath, + tempRootDir, + ); + expect(result.llmContent).toContain( + 'Cannot display content of binary file', + ); + expect(result.returnDisplay).toContain('Skipped binary file: app.exe'); + }); + + it('should handle path being a directory', async () => { + const result = await processSingleFileContent(directoryPath, tempRootDir); + expect(result.error).toContain('Path is a directory'); + expect(result.returnDisplay).toContain('Path is a directory'); + }); + + it('should paginate text files correctly (offset and limit)', async () => { + const lines = Array.from({ length: 20 }, (_, i) => `Line ${i + 1}`); + actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n')); + + const result = await processSingleFileContent( + testTextFilePath, + tempRootDir, + 5, + 5, + ); // Read lines 6-10 + const expectedContent = lines.slice(5, 10).join('\n'); + + expect(result.llmContent).toContain(expectedContent); + expect(result.llmContent).toContain( + '[File content truncated: showing lines 6-10 of 20 total lines. Use offset/limit parameters to view more.]', + ); + expect(result.returnDisplay).toBe('(truncated)'); + expect(result.isTruncated).toBe(true); + expect(result.originalLineCount).toBe(20); + expect(result.linesShown).toEqual([6, 10]); + }); + + it('should handle limit exceeding file length', async () => { + const lines = ['Line 1', 'Line 2']; + actualNodeFs.writeFileSync(testTextFilePath, lines.join('\n')); + + const result = await processSingleFileContent( + testTextFilePath, + tempRootDir, + 0, + 10, + ); + const expectedContent = lines.join('\n'); + + expect(result.llmContent).toBe(expectedContent); + expect(result.returnDisplay).toBe(''); + expect(result.isTruncated).toBe(false); + expect(result.originalLineCount).toBe(2); + expect(result.linesShown).toEqual([1, 2]); + }); + + it('should truncate long lines in text files', async () => { + const longLine = 'a'.repeat(2500); + actualNodeFs.writeFileSync( + testTextFilePath, + `Short line\n${longLine}\nAnother short line`, + ); + + const result = await processSingleFileContent( + testTextFilePath, + tempRootDir, + ); + + expect(result.llmContent).toContain('Short line'); + expect(result.llmContent).toContain( + longLine.substring(0, 2000) + '... [truncated]', + ); + expect(result.llmContent).toContain('Another short line'); + expect(result.llmContent).toContain( + '[File content partially truncated: some lines exceeded maximum length of 2000 characters.]', + ); + expect(result.isTruncated).toBe(true); + }); + }); +}); diff --git a/packages/core/src/utils/fileUtils.ts b/packages/core/src/utils/fileUtils.ts new file mode 100644 index 00000000..d726c053 --- /dev/null +++ b/packages/core/src/utils/fileUtils.ts @@ -0,0 +1,280 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import fs from 'fs'; +import path from 'path'; +import { PartUnion } from '@google/genai'; +import mime from 'mime-types'; + +// Constants for text file processing +const DEFAULT_MAX_LINES_TEXT_FILE = 2000; +const MAX_LINE_LENGTH_TEXT_FILE = 2000; + +// Default values for encoding and separator format +export const DEFAULT_ENCODING: BufferEncoding = 'utf-8'; + +/** + * Checks if a path is within a given root directory. + * @param pathToCheck The absolute path to check. + * @param rootDirectory The absolute root directory. + * @returns True if the path is within the root directory, false otherwise. + */ +export function isWithinRoot( + pathToCheck: string, + rootDirectory: string, +): boolean { + const normalizedPathToCheck = path.normalize(pathToCheck); + const normalizedRootDirectory = path.normalize(rootDirectory); + + // Ensure the rootDirectory path ends with a separator for correct startsWith comparison, + // unless it's the root path itself (e.g., '/' or 'C:\'). + const rootWithSeparator = + normalizedRootDirectory === path.sep || + normalizedRootDirectory.endsWith(path.sep) + ? normalizedRootDirectory + : normalizedRootDirectory + path.sep; + + return ( + normalizedPathToCheck === normalizedRootDirectory || + normalizedPathToCheck.startsWith(rootWithSeparator) + ); +} + +/** + * Determines if a file is likely binary based on content sampling. + * @param filePath Path to the file. + * @returns True if the file appears to be binary. + */ +export function isBinaryFile(filePath: string): boolean { + try { + const fd = fs.openSync(filePath, 'r'); + // Read up to 4KB or file size, whichever is smaller + const fileSize = fs.fstatSync(fd).size; + if (fileSize === 0) { + // Empty file is not considered binary for content checking + fs.closeSync(fd); + return false; + } + const bufferSize = Math.min(4096, fileSize); + const buffer = Buffer.alloc(bufferSize); + const bytesRead = fs.readSync(fd, buffer, 0, buffer.length, 0); + fs.closeSync(fd); + + if (bytesRead === 0) return false; + + let nonPrintableCount = 0; + for (let i = 0; i < bytesRead; i++) { + if (buffer[i] === 0) return true; // Null byte is a strong indicator + if (buffer[i] < 9 || (buffer[i] > 13 && buffer[i] < 32)) { + nonPrintableCount++; + } + } + // If >30% non-printable characters, consider it binary + return nonPrintableCount / bytesRead > 0.3; + } catch { + // If any error occurs (e.g. file not found, permissions), + // treat as not binary here; let higher-level functions handle existence/access errors. + return false; + } +} + +/** + * Detects the type of file based on extension and content. + * @param filePath Path to the file. + * @returns 'text', 'image', 'pdf', or 'binary'. + */ +export function detectFileType( + filePath: string, +): 'text' | 'image' | 'pdf' | 'binary' { + const ext = path.extname(filePath).toLowerCase(); + const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string + + if (lookedUpMimeType && lookedUpMimeType.startsWith('image/')) { + return 'image'; + } + if (lookedUpMimeType && lookedUpMimeType === 'application/pdf') { + return 'pdf'; + } + + // Stricter binary check for common non-text extensions before content check + // These are often not well-covered by mime-types or might be misidentified. + if ( + [ + '.zip', + '.tar', + '.gz', + '.exe', + '.dll', + '.so', + '.class', + '.jar', + '.war', + '.7z', + '.doc', + '.docx', + '.xls', + '.xlsx', + '.ppt', + '.pptx', + '.odt', + '.ods', + '.odp', + '.bin', + '.dat', + '.obj', + '.o', + '.a', + '.lib', + '.wasm', + '.pyc', + '.pyo', + ].includes(ext) + ) { + return 'binary'; + } + + // Fallback to content-based check if mime type wasn't conclusive for image/pdf + // and it's not a known binary extension. + if (isBinaryFile(filePath)) { + return 'binary'; + } + + return 'text'; +} + +export interface ProcessedFileReadResult { + llmContent: PartUnion; // string for text, Part for image/pdf/unreadable binary + returnDisplay: string; + error?: string; // Optional error message for the LLM if file processing failed + isTruncated?: boolean; // For text files, indicates if content was truncated + originalLineCount?: number; // For text files + linesShown?: [number, number]; // For text files [startLine, endLine] (1-based for display) +} + +/** + * Reads and processes a single file, handling text, images, and PDFs. + * @param filePath Absolute path to the file. + * @param rootDirectory Absolute path to the project root for relative path display. + * @param offset Optional offset for text files (0-based line number). + * @param limit Optional limit for text files (number of lines to read). + * @returns ProcessedFileReadResult object. + */ +export async function processSingleFileContent( + filePath: string, + rootDirectory: string, + offset?: number, + limit?: number, +): Promise<ProcessedFileReadResult> { + try { + if (!fs.existsSync(filePath)) { + // Sync check is acceptable before async read + return { + llmContent: '', + returnDisplay: 'File not found.', + error: `File not found: ${filePath}`, + }; + } + const stats = fs.statSync(filePath); // Sync check + if (stats.isDirectory()) { + return { + llmContent: '', + returnDisplay: 'Path is a directory.', + error: `Path is a directory, not a file: ${filePath}`, + }; + } + + const fileType = detectFileType(filePath); + const relativePathForDisplay = path + .relative(rootDirectory, filePath) + .replace(/\\/g, '/'); + + switch (fileType) { + case 'binary': { + return { + llmContent: `Cannot display content of binary file: ${relativePathForDisplay}`, + returnDisplay: `Skipped binary file: ${relativePathForDisplay}`, + }; + } + case 'text': { + const content = await fs.promises.readFile(filePath, 'utf8'); + const lines = content.split('\n'); + const originalLineCount = lines.length; + + const startLine = offset || 0; + const effectiveLimit = + limit === undefined ? DEFAULT_MAX_LINES_TEXT_FILE : limit; + // Ensure endLine does not exceed originalLineCount + const endLine = Math.min(startLine + effectiveLimit, originalLineCount); + // Ensure selectedLines doesn't try to slice beyond array bounds if startLine is too high + const actualStartLine = Math.min(startLine, originalLineCount); + const selectedLines = lines.slice(actualStartLine, endLine); + + let linesWereTruncatedInLength = false; + const formattedLines = selectedLines.map((line) => { + if (line.length > MAX_LINE_LENGTH_TEXT_FILE) { + linesWereTruncatedInLength = true; + return ( + line.substring(0, MAX_LINE_LENGTH_TEXT_FILE) + '... [truncated]' + ); + } + return line; + }); + + const contentRangeTruncated = endLine < originalLineCount; + const isTruncated = contentRangeTruncated || linesWereTruncatedInLength; + + let llmTextContent = ''; + if (contentRangeTruncated) { + llmTextContent += `[File content truncated: showing lines ${actualStartLine + 1}-${endLine} of ${originalLineCount} total lines. Use offset/limit parameters to view more.]\n`; + } else if (linesWereTruncatedInLength) { + llmTextContent += `[File content partially truncated: some lines exceeded maximum length of ${MAX_LINE_LENGTH_TEXT_FILE} characters.]\n`; + } + llmTextContent += formattedLines.join('\n'); + + return { + llmContent: llmTextContent, + returnDisplay: isTruncated ? '(truncated)' : '', + isTruncated, + originalLineCount, + linesShown: [actualStartLine + 1, endLine], + }; + } + case 'image': + case 'pdf': { + const contentBuffer = await fs.promises.readFile(filePath); + const base64Data = contentBuffer.toString('base64'); + return { + llmContent: { + inlineData: { + data: base64Data, + mimeType: mime.lookup(filePath) || 'application/octet-stream', + }, + }, + returnDisplay: `Read ${fileType} file: ${relativePathForDisplay}`, + }; + } + default: { + // Should not happen with current detectFileType logic + const exhaustiveCheck: never = fileType; + return { + llmContent: `Unhandled file type: ${exhaustiveCheck}`, + returnDisplay: `Skipped unhandled file type: ${relativePathForDisplay}`, + error: `Unhandled file type for ${filePath}`, + }; + } + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + const displayPath = path + .relative(rootDirectory, filePath) + .replace(/\\/g, '/'); + return { + llmContent: `Error reading file ${displayPath}: ${errorMessage}`, + returnDisplay: `Error reading file ${displayPath}: ${errorMessage}`, + error: `Error reading file ${filePath}: ${errorMessage}`, + }; + } +} diff --git a/packages/core/src/utils/generateContentResponseUtilities.ts b/packages/core/src/utils/generateContentResponseUtilities.ts new file mode 100644 index 00000000..a1d62124 --- /dev/null +++ b/packages/core/src/utils/generateContentResponseUtilities.ts @@ -0,0 +1,17 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { GenerateContentResponse } from '@google/genai'; + +export function getResponseText( + response: GenerateContentResponse, +): string | undefined { + return ( + response.candidates?.[0]?.content?.parts + ?.map((part) => part.text) + .join('') || undefined + ); +} diff --git a/packages/core/src/utils/getFolderStructure.test.ts b/packages/core/src/utils/getFolderStructure.test.ts new file mode 100644 index 00000000..aecd35c5 --- /dev/null +++ b/packages/core/src/utils/getFolderStructure.test.ts @@ -0,0 +1,278 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest'; +import fsPromises from 'fs/promises'; +import { Dirent as FSDirent } from 'fs'; +import * as nodePath from 'path'; +import { getFolderStructure } from './getFolderStructure.js'; + +vi.mock('path', async (importOriginal) => { + const original = (await importOriginal()) as typeof nodePath; + return { + ...original, + resolve: vi.fn((str) => str), + // Other path functions (basename, join, normalize, etc.) will use original implementation + }; +}); + +vi.mock('fs/promises'); + +// Import 'path' again here, it will be the mocked version +import * as path from 'path'; + +// Helper to create Dirent-like objects for mocking fs.readdir +const createDirent = (name: string, type: 'file' | 'dir'): FSDirent => ({ + name, + isFile: () => type === 'file', + isDirectory: () => type === 'dir', + isBlockDevice: () => false, + isCharacterDevice: () => false, + isSymbolicLink: () => false, + isFIFO: () => false, + isSocket: () => false, + parentPath: '', + path: '', +}); + +describe('getFolderStructure', () => { + beforeEach(() => { + vi.resetAllMocks(); + + // path.resolve is now a vi.fn() due to the top-level vi.mock. + // We ensure its implementation is set for each test (or rely on the one from vi.mock). + // vi.resetAllMocks() clears call history but not the implementation set by vi.fn() in vi.mock. + // If we needed to change it per test, we would do it here: + (path.resolve as Mock).mockImplementation((str: string) => str); + + // Re-apply/define the mock implementation for fsPromises.readdir for each test + (fsPromises.readdir as Mock).mockImplementation( + async (dirPath: string | Buffer | URL) => { + // path.normalize here will use the mocked path module. + // Since normalize is spread from original, it should be the real one. + const normalizedPath = path.normalize(dirPath.toString()); + if (mockFsStructure[normalizedPath]) { + return mockFsStructure[normalizedPath]; + } + throw Object.assign( + new Error( + `ENOENT: no such file or directory, scandir '${normalizedPath}'`, + ), + { code: 'ENOENT' }, + ); + }, + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); // Restores spies (like fsPromises.readdir) and resets vi.fn mocks (like path.resolve) + }); + + const mockFsStructure: Record<string, FSDirent[]> = { + '/testroot': [ + createDirent('file1.txt', 'file'), + createDirent('subfolderA', 'dir'), + createDirent('emptyFolder', 'dir'), + createDirent('.hiddenfile', 'file'), + createDirent('node_modules', 'dir'), + ], + '/testroot/subfolderA': [ + createDirent('fileA1.ts', 'file'), + createDirent('fileA2.js', 'file'), + createDirent('subfolderB', 'dir'), + ], + '/testroot/subfolderA/subfolderB': [createDirent('fileB1.md', 'file')], + '/testroot/emptyFolder': [], + '/testroot/node_modules': [createDirent('somepackage', 'dir')], + '/testroot/manyFilesFolder': Array.from({ length: 10 }, (_, i) => + createDirent(`file-${i}.txt`, 'file'), + ), + '/testroot/manyFolders': Array.from({ length: 5 }, (_, i) => + createDirent(`folder-${i}`, 'dir'), + ), + ...Array.from({ length: 5 }, (_, i) => ({ + [`/testroot/manyFolders/folder-${i}`]: [ + createDirent('child.txt', 'file'), + ], + })).reduce((acc, val) => ({ ...acc, ...val }), {}), + '/testroot/deepFolders': [createDirent('level1', 'dir')], + '/testroot/deepFolders/level1': [createDirent('level2', 'dir')], + '/testroot/deepFolders/level1/level2': [createDirent('level3', 'dir')], + '/testroot/deepFolders/level1/level2/level3': [ + createDirent('file.txt', 'file'), + ], + }; + + it('should return basic folder structure', async () => { + const structure = await getFolderStructure('/testroot/subfolderA'); + const expected = ` +Showing up to 200 items (files + folders). + +/testroot/subfolderA/ +├───fileA1.ts +├───fileA2.js +└───subfolderB/ + └───fileB1.md +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should handle an empty folder', async () => { + const structure = await getFolderStructure('/testroot/emptyFolder'); + const expected = ` +Showing up to 200 items (files + folders). + +/testroot/emptyFolder/ +`.trim(); + expect(structure.trim()).toBe(expected.trim()); + }); + + it('should ignore folders specified in ignoredFolders (default)', async () => { + const structure = await getFolderStructure('/testroot'); + const expected = ` +Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached. + +/testroot/ +├───.hiddenfile +├───file1.txt +├───emptyFolder/ +├───node_modules/... +└───subfolderA/ + ├───fileA1.ts + ├───fileA2.js + └───subfolderB/ + └───fileB1.md +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should ignore folders specified in custom ignoredFolders', async () => { + const structure = await getFolderStructure('/testroot', { + ignoredFolders: new Set(['subfolderA', 'node_modules']), + }); + const expected = ` +Showing up to 200 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (200 items) was reached. + +/testroot/ +├───.hiddenfile +├───file1.txt +├───emptyFolder/ +├───node_modules/... +└───subfolderA/... +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should filter files by fileIncludePattern', async () => { + const structure = await getFolderStructure('/testroot/subfolderA', { + fileIncludePattern: /\.ts$/, + }); + const expected = ` +Showing up to 200 items (files + folders). + +/testroot/subfolderA/ +├───fileA1.ts +└───subfolderB/ +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should handle maxItems truncation for files within a folder', async () => { + const structure = await getFolderStructure('/testroot/subfolderA', { + maxItems: 3, + }); + const expected = ` +Showing up to 3 items (files + folders). + +/testroot/subfolderA/ +├───fileA1.ts +├───fileA2.js +└───subfolderB/ +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should handle maxItems truncation for subfolders', async () => { + const structure = await getFolderStructure('/testroot/manyFolders', { + maxItems: 4, + }); + const expectedRevised = ` +Showing up to 4 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (4 items) was reached. + +/testroot/manyFolders/ +├───folder-0/ +├───folder-1/ +├───folder-2/ +├───folder-3/ +└───... +`.trim(); + expect(structure.trim()).toBe(expectedRevised); + }); + + it('should handle maxItems that only allows the root folder itself', async () => { + const structure = await getFolderStructure('/testroot/subfolderA', { + maxItems: 1, + }); + const expectedRevisedMax1 = ` +Showing up to 1 items (files + folders). Folders or files indicated with ... contain more items not shown, were ignored, or the display limit (1 items) was reached. + +/testroot/subfolderA/ +├───fileA1.ts +├───... +└───... +`.trim(); + expect(structure.trim()).toBe(expectedRevisedMax1); + }); + + it('should handle non-existent directory', async () => { + // Temporarily make fsPromises.readdir throw ENOENT for this specific path + const originalReaddir = fsPromises.readdir; + (fsPromises.readdir as Mock).mockImplementation( + async (p: string | Buffer | URL) => { + if (p === '/nonexistent') { + throw Object.assign(new Error('ENOENT'), { code: 'ENOENT' }); + } + return originalReaddir(p); + }, + ); + + const structure = await getFolderStructure('/nonexistent'); + expect(structure).toContain( + 'Error: Could not read directory "/nonexistent"', + ); + }); + + it('should handle deep folder structure within limits', async () => { + const structure = await getFolderStructure('/testroot/deepFolders', { + maxItems: 10, + }); + const expected = ` +Showing up to 10 items (files + folders). + +/testroot/deepFolders/ +└───level1/ + └───level2/ + └───level3/ + └───file.txt +`.trim(); + expect(structure.trim()).toBe(expected); + }); + + it('should truncate deep folder structure if maxItems is small', async () => { + const structure = await getFolderStructure('/testroot/deepFolders', { + maxItems: 3, + }); + const expected = ` +Showing up to 3 items (files + folders). + +/testroot/deepFolders/ +└───level1/ + └───level2/ + └───level3/ +`.trim(); + expect(structure.trim()).toBe(expected); + }); +}); diff --git a/packages/core/src/utils/getFolderStructure.ts b/packages/core/src/utils/getFolderStructure.ts new file mode 100644 index 00000000..6d921811 --- /dev/null +++ b/packages/core/src/utils/getFolderStructure.ts @@ -0,0 +1,325 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'fs/promises'; +import { Dirent } from 'fs'; +import * as path from 'path'; +import { getErrorMessage, isNodeError } from './errors.js'; + +const MAX_ITEMS = 200; +const TRUNCATION_INDICATOR = '...'; +const DEFAULT_IGNORED_FOLDERS = new Set(['node_modules', '.git', 'dist']); + +// --- Interfaces --- + +/** Options for customizing folder structure retrieval. */ +interface FolderStructureOptions { + /** Maximum number of files and folders combined to display. Defaults to 200. */ + maxItems?: number; + /** Set of folder names to ignore completely. Case-sensitive. */ + ignoredFolders?: Set<string>; + /** Optional regex to filter included files by name. */ + fileIncludePattern?: RegExp; +} + +// Define a type for the merged options where fileIncludePattern remains optional +type MergedFolderStructureOptions = Required< + Omit<FolderStructureOptions, 'fileIncludePattern'> +> & { + fileIncludePattern?: RegExp; +}; + +/** Represents the full, unfiltered information about a folder and its contents. */ +interface FullFolderInfo { + name: string; + path: string; + files: string[]; + subFolders: FullFolderInfo[]; + totalChildren: number; // Number of files and subfolders included from this folder during BFS scan + totalFiles: number; // Number of files included from this folder during BFS scan + isIgnored?: boolean; // Flag to easily identify ignored folders later + hasMoreFiles?: boolean; // Indicates if files were truncated for this specific folder + hasMoreSubfolders?: boolean; // Indicates if subfolders were truncated for this specific folder +} + +// --- Interfaces --- + +// --- Helper Functions --- + +async function readFullStructure( + rootPath: string, + options: MergedFolderStructureOptions, +): Promise<FullFolderInfo | null> { + const rootName = path.basename(rootPath); + const rootNode: FullFolderInfo = { + name: rootName, + path: rootPath, + files: [], + subFolders: [], + totalChildren: 0, + totalFiles: 0, + }; + + const queue: Array<{ folderInfo: FullFolderInfo; currentPath: string }> = [ + { folderInfo: rootNode, currentPath: rootPath }, + ]; + let currentItemCount = 0; + // Count the root node itself as one item if we are not just listing its content + + const processedPaths = new Set<string>(); // To avoid processing same path if symlinks create loops + + while (queue.length > 0) { + const { folderInfo, currentPath } = queue.shift()!; + + if (processedPaths.has(currentPath)) { + continue; + } + processedPaths.add(currentPath); + + if (currentItemCount >= options.maxItems) { + // If the root itself caused us to exceed, we can't really show anything. + // Otherwise, this folder won't be processed further. + // The parent that queued this would have set its own hasMoreSubfolders flag. + continue; + } + + let entries: Dirent[]; + try { + const rawEntries = await fs.readdir(currentPath, { withFileTypes: true }); + // Sort entries alphabetically by name for consistent processing order + entries = rawEntries.sort((a, b) => a.name.localeCompare(b.name)); + } catch (error: unknown) { + if ( + isNodeError(error) && + (error.code === 'EACCES' || error.code === 'ENOENT') + ) { + console.warn( + `Warning: Could not read directory ${currentPath}: ${error.message}`, + ); + if (currentPath === rootPath && error.code === 'ENOENT') { + return null; // Root directory itself not found + } + // For other EACCES/ENOENT on subdirectories, just skip them. + continue; + } + throw error; + } + + const filesInCurrentDir: string[] = []; + const subFoldersInCurrentDir: FullFolderInfo[] = []; + + // Process files first in the current directory + for (const entry of entries) { + if (entry.isFile()) { + if (currentItemCount >= options.maxItems) { + folderInfo.hasMoreFiles = true; + break; + } + const fileName = entry.name; + if ( + !options.fileIncludePattern || + options.fileIncludePattern.test(fileName) + ) { + filesInCurrentDir.push(fileName); + currentItemCount++; + folderInfo.totalFiles++; + folderInfo.totalChildren++; + } + } + } + folderInfo.files = filesInCurrentDir; + + // Then process directories and queue them + for (const entry of entries) { + if (entry.isDirectory()) { + // Check if adding this directory ITSELF would meet or exceed maxItems + // (currentItemCount refers to items *already* added before this one) + if (currentItemCount >= options.maxItems) { + folderInfo.hasMoreSubfolders = true; + break; // Already at limit, cannot add this folder or any more + } + // If adding THIS folder makes us hit the limit exactly, and it might have children, + // it's better to show '...' for the parent, unless this is the very last item slot. + // This logic is tricky. Let's try a simpler: if we can't add this item, mark and break. + + const subFolderName = entry.name; + const subFolderPath = path.join(currentPath, subFolderName); + + if (options.ignoredFolders.has(subFolderName)) { + const ignoredSubFolder: FullFolderInfo = { + name: subFolderName, + path: subFolderPath, + files: [], + subFolders: [], + totalChildren: 0, + totalFiles: 0, + isIgnored: true, + }; + subFoldersInCurrentDir.push(ignoredSubFolder); + currentItemCount++; // Count the ignored folder itself + folderInfo.totalChildren++; // Also counts towards parent's children + continue; + } + + const subFolderNode: FullFolderInfo = { + name: subFolderName, + path: subFolderPath, + files: [], + subFolders: [], + totalChildren: 0, + totalFiles: 0, + }; + subFoldersInCurrentDir.push(subFolderNode); + currentItemCount++; + folderInfo.totalChildren++; // Counts towards parent's children + + // Add to queue for processing its children later + queue.push({ folderInfo: subFolderNode, currentPath: subFolderPath }); + } + } + folderInfo.subFolders = subFoldersInCurrentDir; + } + + return rootNode; +} + +/** + * Reads the directory structure using BFS, respecting maxItems. + * @param node The current node in the reduced structure. + * @param indent The current indentation string. + * @param isLast Sibling indicator. + * @param builder Array to build the string lines. + */ +function formatStructure( + node: FullFolderInfo, + currentIndent: string, + isLastChildOfParent: boolean, + isProcessingRootNode: boolean, + builder: string[], +): void { + const connector = isLastChildOfParent ? '└───' : '├───'; + + // The root node of the structure (the one passed initially to getFolderStructure) + // is not printed with a connector line itself, only its name as a header. + // Its children are printed relative to that conceptual root. + // Ignored root nodes ARE printed with a connector. + if (!isProcessingRootNode || node.isIgnored) { + builder.push( + `${currentIndent}${connector}${node.name}/${node.isIgnored ? TRUNCATION_INDICATOR : ''}`, + ); + } + + // Determine the indent for the children of *this* node. + // If *this* node was the root of the whole structure, its children start with no indent before their connectors. + // Otherwise, children's indent extends from the current node's indent. + const indentForChildren = isProcessingRootNode + ? '' + : currentIndent + (isLastChildOfParent ? ' ' : '│ '); + + // Render files of the current node + const fileCount = node.files.length; + for (let i = 0; i < fileCount; i++) { + const isLastFileAmongSiblings = + i === fileCount - 1 && + node.subFolders.length === 0 && + !node.hasMoreSubfolders; + const fileConnector = isLastFileAmongSiblings ? '└───' : '├───'; + builder.push(`${indentForChildren}${fileConnector}${node.files[i]}`); + } + if (node.hasMoreFiles) { + const isLastIndicatorAmongSiblings = + node.subFolders.length === 0 && !node.hasMoreSubfolders; + const fileConnector = isLastIndicatorAmongSiblings ? '└───' : '├───'; + builder.push(`${indentForChildren}${fileConnector}${TRUNCATION_INDICATOR}`); + } + + // Render subfolders of the current node + const subFolderCount = node.subFolders.length; + for (let i = 0; i < subFolderCount; i++) { + const isLastSubfolderAmongSiblings = + i === subFolderCount - 1 && !node.hasMoreSubfolders; + // Children are never the root node being processed initially. + formatStructure( + node.subFolders[i], + indentForChildren, + isLastSubfolderAmongSiblings, + false, + builder, + ); + } + if (node.hasMoreSubfolders) { + builder.push(`${indentForChildren}└───${TRUNCATION_INDICATOR}`); + } +} + +// --- Main Exported Function --- + +/** + * Generates a string representation of a directory's structure, + * limiting the number of items displayed. Ignored folders are shown + * followed by '...' instead of their contents. + * + * @param directory The absolute or relative path to the directory. + * @param options Optional configuration settings. + * @returns A promise resolving to the formatted folder structure string. + */ +export async function getFolderStructure( + directory: string, + options?: FolderStructureOptions, +): Promise<string> { + const resolvedPath = path.resolve(directory); + const mergedOptions: MergedFolderStructureOptions = { + maxItems: options?.maxItems ?? MAX_ITEMS, + ignoredFolders: options?.ignoredFolders ?? DEFAULT_IGNORED_FOLDERS, + fileIncludePattern: options?.fileIncludePattern, + }; + + try { + // 1. Read the structure using BFS, respecting maxItems + const structureRoot = await readFullStructure(resolvedPath, mergedOptions); + + if (!structureRoot) { + return `Error: Could not read directory "${resolvedPath}". Check path and permissions.`; + } + + // 2. Format the structure into a string + const structureLines: string[] = []; + // Pass true for isRoot for the initial call + formatStructure(structureRoot, '', true, true, structureLines); + + // 3. Build the final output string + const displayPath = resolvedPath.replace(/\\/g, '/'); + + let disclaimer = ''; + // Check if truncation occurred anywhere or if ignored folders are present. + // A simple check: if any node indicates more files/subfolders, or is ignored. + let truncationOccurred = false; + function checkForTruncation(node: FullFolderInfo) { + if (node.hasMoreFiles || node.hasMoreSubfolders || node.isIgnored) { + truncationOccurred = true; + } + if (!truncationOccurred) { + for (const sub of node.subFolders) { + checkForTruncation(sub); + if (truncationOccurred) break; + } + } + } + checkForTruncation(structureRoot); + + if (truncationOccurred) { + disclaimer = `Folders or files indicated with ${TRUNCATION_INDICATOR} contain more items not shown, were ignored, or the display limit (${mergedOptions.maxItems} items) was reached.`; + } + + const summary = + `Showing up to ${mergedOptions.maxItems} items (files + folders). ${disclaimer}`.trim(); + + return `${summary}\n\n${displayPath}/\n${structureLines.join('\n')}`; + } catch (error: unknown) { + console.error(`Error getting folder structure for ${resolvedPath}:`, error); + return `Error processing directory "${resolvedPath}": ${getErrorMessage(error)}`; + } +} diff --git a/packages/core/src/utils/memoryDiscovery.test.ts b/packages/core/src/utils/memoryDiscovery.test.ts new file mode 100644 index 00000000..229f51e5 --- /dev/null +++ b/packages/core/src/utils/memoryDiscovery.test.ts @@ -0,0 +1,382 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + // afterEach, // Removed unused import + Mocked, +} from 'vitest'; +import * as fsPromises from 'fs/promises'; +import * as fsSync from 'fs'; // For constants +import { Stats, Dirent } from 'fs'; // Import types directly from 'fs' +import * as os from 'os'; +import * as path from 'path'; +import { loadServerHierarchicalMemory } from './memoryDiscovery.js'; +import { GEMINI_CONFIG_DIR, GEMINI_MD_FILENAME } from '../tools/memoryTool.js'; + +// Mock the entire fs/promises module +vi.mock('fs/promises'); +// Mock the parts of fsSync we might use (like constants or existsSync if needed) +vi.mock('fs', async (importOriginal) => { + const actual = await importOriginal<typeof fsSync>(); + return { + ...actual, // Spread actual to get all exports, including Stats and Dirent if they are classes/constructors + constants: { ...actual.constants }, // Preserve constants + // Mock other fsSync functions if directly used by memoryDiscovery, e.g., existsSync + // existsSync: vi.fn(), + }; +}); +vi.mock('os'); + +describe('loadServerHierarchicalMemory', () => { + const mockFs = fsPromises as Mocked<typeof fsPromises>; + const mockOs = os as Mocked<typeof os>; + + const CWD = '/test/project/src'; + const PROJECT_ROOT = '/test/project'; + const USER_HOME = '/test/userhome'; + const GLOBAL_GEMINI_DIR = path.join(USER_HOME, GEMINI_CONFIG_DIR); + const GLOBAL_GEMINI_FILE = path.join(GLOBAL_GEMINI_DIR, GEMINI_MD_FILENAME); + + beforeEach(() => { + vi.resetAllMocks(); + + mockOs.homedir.mockReturnValue(USER_HOME); + mockFs.stat.mockRejectedValue(new Error('File not found')); + mockFs.readdir.mockResolvedValue([]); + mockFs.readFile.mockRejectedValue(new Error('File not found')); + mockFs.access.mockRejectedValue(new Error('File not found')); + }); + + it('should return empty memory and count if no GEMINI.md files are found', async () => { + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + expect(memoryContent).toBe(''); + expect(fileCount).toBe(0); + }); + + it('should load only the global GEMINI.md if present and others are not', async () => { + mockFs.access.mockImplementation(async (p) => { + if (p === GLOBAL_GEMINI_FILE) { + return undefined; + } + throw new Error('File not found'); + }); + mockFs.readFile.mockImplementation(async (p) => { + if (p === GLOBAL_GEMINI_FILE) { + return 'Global memory content'; + } + throw new Error('File not found'); + }); + + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + + expect(memoryContent).toBe( + `--- Context from: ${path.relative(CWD, GLOBAL_GEMINI_FILE)} ---\nGlobal memory content\n--- End of Context from: ${path.relative(CWD, GLOBAL_GEMINI_FILE)} ---`, + ); + expect(fileCount).toBe(1); + expect(mockFs.readFile).toHaveBeenCalledWith(GLOBAL_GEMINI_FILE, 'utf-8'); + }); + + it('should load GEMINI.md files by upward traversal from CWD to project root', async () => { + const projectRootGeminiFile = path.join(PROJECT_ROOT, GEMINI_MD_FILENAME); + const srcGeminiFile = path.join(CWD, GEMINI_MD_FILENAME); + + mockFs.stat.mockImplementation(async (p) => { + if (p === path.join(PROJECT_ROOT, '.git')) { + return { isDirectory: () => true } as Stats; + } + throw new Error('File not found'); + }); + + mockFs.access.mockImplementation(async (p) => { + if (p === projectRootGeminiFile || p === srcGeminiFile) { + return undefined; + } + throw new Error('File not found'); + }); + + mockFs.readFile.mockImplementation(async (p) => { + if (p === projectRootGeminiFile) { + return 'Project root memory'; + } + if (p === srcGeminiFile) { + return 'Src directory memory'; + } + throw new Error('File not found'); + }); + + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + const expectedContent = + `--- Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\nProject root memory\n--- End of Context from: ${path.relative(CWD, projectRootGeminiFile)} ---\n\n` + + `--- Context from: ${GEMINI_MD_FILENAME} ---\nSrc directory memory\n--- End of Context from: ${GEMINI_MD_FILENAME} ---`; + + expect(memoryContent).toBe(expectedContent); + expect(fileCount).toBe(2); + expect(mockFs.readFile).toHaveBeenCalledWith( + projectRootGeminiFile, + 'utf-8', + ); + expect(mockFs.readFile).toHaveBeenCalledWith(srcGeminiFile, 'utf-8'); + }); + + it('should load GEMINI.md files by downward traversal from CWD', async () => { + const subDir = path.join(CWD, 'subdir'); + const subDirGeminiFile = path.join(subDir, GEMINI_MD_FILENAME); + const cwdGeminiFile = path.join(CWD, GEMINI_MD_FILENAME); + + mockFs.access.mockImplementation(async (p) => { + if (p === cwdGeminiFile || p === subDirGeminiFile) return undefined; + throw new Error('File not found'); + }); + + mockFs.readFile.mockImplementation(async (p) => { + if (p === cwdGeminiFile) return 'CWD memory'; + if (p === subDirGeminiFile) return 'Subdir memory'; + throw new Error('File not found'); + }); + + mockFs.readdir.mockImplementation((async ( + p: fsSync.PathLike, + ): Promise<Dirent[]> => { + if (p === CWD) { + return [ + { + name: GEMINI_MD_FILENAME, + isFile: () => true, + isDirectory: () => false, + }, + { name: 'subdir', isFile: () => false, isDirectory: () => true }, + ] as Dirent[]; + } + if (p === subDir) { + return [ + { + name: GEMINI_MD_FILENAME, + isFile: () => true, + isDirectory: () => false, + }, + ] as Dirent[]; + } + return [] as Dirent[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + const expectedContent = + `--- Context from: ${GEMINI_MD_FILENAME} ---\nCWD memory\n--- End of Context from: ${GEMINI_MD_FILENAME} ---\n\n` + + `--- Context from: ${path.join('subdir', GEMINI_MD_FILENAME)} ---\nSubdir memory\n--- End of Context from: ${path.join('subdir', GEMINI_MD_FILENAME)} ---`; + + expect(memoryContent).toBe(expectedContent); + expect(fileCount).toBe(2); + }); + + it('should load and correctly order global, upward, and downward GEMINI.md files', async () => { + const projectParentDir = path.dirname(PROJECT_ROOT); + const projectParentGeminiFile = path.join( + projectParentDir, + GEMINI_MD_FILENAME, + ); + const projectRootGeminiFile = path.join(PROJECT_ROOT, GEMINI_MD_FILENAME); + const cwdGeminiFile = path.join(CWD, GEMINI_MD_FILENAME); + const subDir = path.join(CWD, 'sub'); + const subDirGeminiFile = path.join(subDir, GEMINI_MD_FILENAME); + + mockFs.stat.mockImplementation(async (p) => { + if (p === path.join(PROJECT_ROOT, '.git')) { + return { isDirectory: () => true } as Stats; + } + throw new Error('File not found'); + }); + + mockFs.access.mockImplementation(async (p) => { + if ( + p === GLOBAL_GEMINI_FILE || + p === projectParentGeminiFile || + p === projectRootGeminiFile || + p === cwdGeminiFile || + p === subDirGeminiFile + ) { + return undefined; + } + throw new Error('File not found'); + }); + + mockFs.readFile.mockImplementation(async (p) => { + if (p === GLOBAL_GEMINI_FILE) return 'Global memory'; + if (p === projectParentGeminiFile) return 'Project parent memory'; + if (p === projectRootGeminiFile) return 'Project root memory'; + if (p === cwdGeminiFile) return 'CWD memory'; + if (p === subDirGeminiFile) return 'Subdir memory'; + throw new Error('File not found'); + }); + + mockFs.readdir.mockImplementation((async ( + p: fsSync.PathLike, + ): Promise<Dirent[]> => { + if (p === CWD) { + return [ + { name: 'sub', isFile: () => false, isDirectory: () => true }, + ] as Dirent[]; + } + if (p === subDir) { + return [ + { + name: GEMINI_MD_FILENAME, + isFile: () => true, + isDirectory: () => false, + }, + ] as Dirent[]; + } + return [] as Dirent[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + + const relPathGlobal = path.relative(CWD, GLOBAL_GEMINI_FILE); + const relPathProjectParent = path.relative(CWD, projectParentGeminiFile); + const relPathProjectRoot = path.relative(CWD, projectRootGeminiFile); + const relPathCwd = GEMINI_MD_FILENAME; + const relPathSubDir = path.join('sub', GEMINI_MD_FILENAME); + + const expectedContent = [ + `--- Context from: ${relPathGlobal} ---\nGlobal memory\n--- End of Context from: ${relPathGlobal} ---`, + `--- Context from: ${relPathProjectParent} ---\nProject parent memory\n--- End of Context from: ${relPathProjectParent} ---`, + `--- Context from: ${relPathProjectRoot} ---\nProject root memory\n--- End of Context from: ${relPathProjectRoot} ---`, + `--- Context from: ${relPathCwd} ---\nCWD memory\n--- End of Context from: ${relPathCwd} ---`, + `--- Context from: ${relPathSubDir} ---\nSubdir memory\n--- End of Context from: ${relPathSubDir} ---`, + ].join('\n\n'); + + expect(memoryContent).toBe(expectedContent); + expect(fileCount).toBe(5); + }); + + it('should ignore specified directories during downward scan', async () => { + const ignoredDir = path.join(CWD, 'node_modules'); + const ignoredDirGeminiFile = path.join(ignoredDir, GEMINI_MD_FILENAME); + const regularSubDir = path.join(CWD, 'my_code'); + const regularSubDirGeminiFile = path.join( + regularSubDir, + GEMINI_MD_FILENAME, + ); + + mockFs.access.mockImplementation(async (p) => { + if (p === regularSubDirGeminiFile) return undefined; + if (p === ignoredDirGeminiFile) + throw new Error('Should not access ignored file'); + throw new Error('File not found'); + }); + + mockFs.readFile.mockImplementation(async (p) => { + if (p === regularSubDirGeminiFile) return 'My code memory'; + throw new Error('File not found'); + }); + + mockFs.readdir.mockImplementation((async ( + p: fsSync.PathLike, + ): Promise<Dirent[]> => { + if (p === CWD) { + return [ + { + name: 'node_modules', + isFile: () => false, + isDirectory: () => true, + }, + { name: 'my_code', isFile: () => false, isDirectory: () => true }, + ] as Dirent[]; + } + if (p === regularSubDir) { + return [ + { + name: GEMINI_MD_FILENAME, + isFile: () => true, + isDirectory: () => false, + }, + ] as Dirent[]; + } + if (p === ignoredDir) { + return [ + { + name: GEMINI_MD_FILENAME, + isFile: () => true, + isDirectory: () => false, + }, + ] as Dirent[]; + } + return [] as Dirent[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + const { memoryContent, fileCount } = await loadServerHierarchicalMemory( + CWD, + false, + ); + + const expectedContent = `--- Context from: ${path.join('my_code', GEMINI_MD_FILENAME)} ---\nMy code memory\n--- End of Context from: ${path.join('my_code', GEMINI_MD_FILENAME)} ---`; + + expect(memoryContent).toBe(expectedContent); + expect(fileCount).toBe(1); + expect(mockFs.readFile).not.toHaveBeenCalledWith( + ignoredDirGeminiFile, + 'utf-8', + ); + }); + + it('should respect MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY during downward scan', async () => { + const consoleDebugSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + const dirNames: Dirent[] = []; + for (let i = 0; i < 250; i++) { + dirNames.push({ + name: `deep_dir_${i}`, + isFile: () => false, + isDirectory: () => true, + } as Dirent); + } + + mockFs.readdir.mockImplementation((async ( + p: fsSync.PathLike, + ): Promise<Dirent[]> => { + if (p === CWD) return dirNames; + if (p.toString().startsWith(path.join(CWD, 'deep_dir_'))) + return [] as Dirent[]; + return [] as Dirent[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + mockFs.access.mockRejectedValue(new Error('not found')); + + await loadServerHierarchicalMemory(CWD, true); + + expect(consoleDebugSpy).toHaveBeenCalledWith( + expect.stringContaining('[DEBUG] [MemoryDiscovery]'), + expect.stringContaining( + 'Max directory scan limit (200) reached. Stopping downward scan at:', + ), + ); + consoleDebugSpy.mockRestore(); + }); +}); diff --git a/packages/core/src/utils/memoryDiscovery.ts b/packages/core/src/utils/memoryDiscovery.ts new file mode 100644 index 00000000..362134d8 --- /dev/null +++ b/packages/core/src/utils/memoryDiscovery.ts @@ -0,0 +1,351 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'fs/promises'; +import * as fsSync from 'fs'; +import * as path from 'path'; +import { homedir } from 'os'; +import { GEMINI_CONFIG_DIR, GEMINI_MD_FILENAME } from '../tools/memoryTool.js'; + +// Simple console logger, similar to the one previously in CLI's config.ts +// TODO: Integrate with a more robust server-side logger if available/appropriate. +const logger = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + debug: (...args: any[]) => + console.debug('[DEBUG] [MemoryDiscovery]', ...args), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + warn: (...args: any[]) => console.warn('[WARN] [MemoryDiscovery]', ...args), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + error: (...args: any[]) => + console.error('[ERROR] [MemoryDiscovery]', ...args), +}; + +// TODO(adh): Refactor to use a shared ignore list with other tools like glob and read-many-files. +const DEFAULT_IGNORE_DIRECTORIES = [ + 'node_modules', + '.git', + 'dist', + 'build', + 'out', + 'coverage', + '.vscode', + '.idea', + '.DS_Store', +]; + +const MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY = 200; + +interface GeminiFileContent { + filePath: string; + content: string | null; +} + +async function findProjectRoot(startDir: string): Promise<string | null> { + let currentDir = path.resolve(startDir); + while (true) { + const gitPath = path.join(currentDir, '.git'); + try { + const stats = await fs.stat(gitPath); + if (stats.isDirectory()) { + return currentDir; + } + } catch (error: unknown) { + if (typeof error === 'object' && error !== null && 'code' in error) { + const fsError = error as { code: string; message: string }; + if (fsError.code !== 'ENOENT') { + logger.warn( + `Error checking for .git directory at ${gitPath}: ${fsError.message}`, + ); + } + } else { + logger.warn( + `Non-standard error checking for .git directory at ${gitPath}: ${String(error)}`, + ); + } + } + const parentDir = path.dirname(currentDir); + if (parentDir === currentDir) { + return null; + } + currentDir = parentDir; + } +} + +async function collectDownwardGeminiFiles( + directory: string, + debugMode: boolean, + ignoreDirs: string[], + scannedDirCount: { count: number }, + maxScanDirs: number, +): Promise<string[]> { + if (scannedDirCount.count >= maxScanDirs) { + if (debugMode) + logger.debug( + `Max directory scan limit (${maxScanDirs}) reached. Stopping downward scan at: ${directory}`, + ); + return []; + } + scannedDirCount.count++; + + if (debugMode) + logger.debug( + `Scanning downward for ${GEMINI_MD_FILENAME} files in: ${directory} (scanned: ${scannedDirCount.count}/${maxScanDirs})`, + ); + const collectedPaths: string[] = []; + try { + const entries = await fs.readdir(directory, { withFileTypes: true }); + for (const entry of entries) { + const fullPath = path.join(directory, entry.name); + if (entry.isDirectory()) { + if (ignoreDirs.includes(entry.name)) { + if (debugMode) + logger.debug(`Skipping ignored directory: ${fullPath}`); + continue; + } + const subDirPaths = await collectDownwardGeminiFiles( + fullPath, + debugMode, + ignoreDirs, + scannedDirCount, + maxScanDirs, + ); + collectedPaths.push(...subDirPaths); + } else if (entry.isFile() && entry.name === GEMINI_MD_FILENAME) { + try { + await fs.access(fullPath, fsSync.constants.R_OK); + collectedPaths.push(fullPath); + if (debugMode) + logger.debug( + `Found readable downward ${GEMINI_MD_FILENAME}: ${fullPath}`, + ); + } catch { + if (debugMode) + logger.debug( + `Downward ${GEMINI_MD_FILENAME} not readable, skipping: ${fullPath}`, + ); + } + } + } + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + logger.warn(`Error scanning directory ${directory}: ${message}`); + if (debugMode) logger.debug(`Failed to scan directory: ${directory}`); + } + return collectedPaths; +} + +async function getGeminiMdFilePathsInternal( + currentWorkingDirectory: string, + userHomePath: string, // Keep userHomePath as a parameter for clarity + debugMode: boolean, +): Promise<string[]> { + const resolvedCwd = path.resolve(currentWorkingDirectory); + const resolvedHome = path.resolve(userHomePath); + const globalMemoryPath = path.join( + resolvedHome, + GEMINI_CONFIG_DIR, + GEMINI_MD_FILENAME, + ); + const paths: string[] = []; + + if (debugMode) + logger.debug( + `Searching for ${GEMINI_MD_FILENAME} starting from CWD: ${resolvedCwd}`, + ); + if (debugMode) logger.debug(`User home directory: ${resolvedHome}`); + + try { + await fs.access(globalMemoryPath, fsSync.constants.R_OK); + paths.push(globalMemoryPath); + if (debugMode) + logger.debug( + `Found readable global ${GEMINI_MD_FILENAME}: ${globalMemoryPath}`, + ); + } catch { + if (debugMode) + logger.debug( + `Global ${GEMINI_MD_FILENAME} not found or not readable: ${globalMemoryPath}`, + ); + } + + const projectRoot = await findProjectRoot(resolvedCwd); + if (debugMode) + logger.debug(`Determined project root: ${projectRoot ?? 'None'}`); + + const upwardPaths: string[] = []; + let currentDir = resolvedCwd; + // Determine the directory that signifies the top of the project or user-specific space. + const ultimateStopDir = projectRoot + ? path.dirname(projectRoot) + : path.dirname(resolvedHome); + + while (currentDir && currentDir !== path.dirname(currentDir)) { + // Loop until filesystem root or currentDir is empty + if (debugMode) { + logger.debug( + `Checking for ${GEMINI_MD_FILENAME} in (upward scan): ${currentDir}`, + ); + } + + // Skip the global .gemini directory itself during upward scan from CWD, + // as global is handled separately and explicitly first. + if (currentDir === path.join(resolvedHome, GEMINI_CONFIG_DIR)) { + if (debugMode) { + logger.debug( + `Upward scan reached global config dir path, stopping upward search here: ${currentDir}`, + ); + } + break; + } + + const potentialPath = path.join(currentDir, GEMINI_MD_FILENAME); + try { + await fs.access(potentialPath, fsSync.constants.R_OK); + // Add to upwardPaths only if it's not the already added globalMemoryPath + if (potentialPath !== globalMemoryPath) { + upwardPaths.unshift(potentialPath); + if (debugMode) { + logger.debug( + `Found readable upward ${GEMINI_MD_FILENAME}: ${potentialPath}`, + ); + } + } + } catch { + if (debugMode) { + logger.debug( + `Upward ${GEMINI_MD_FILENAME} not found or not readable in: ${currentDir}`, + ); + } + } + + // Stop condition: if currentDir is the ultimateStopDir, break after this iteration. + if (currentDir === ultimateStopDir) { + if (debugMode) + logger.debug( + `Reached ultimate stop directory for upward scan: ${currentDir}`, + ); + break; + } + + currentDir = path.dirname(currentDir); + } + paths.push(...upwardPaths); + + if (debugMode) + logger.debug(`Starting downward scan from CWD: ${resolvedCwd}`); + const scannedDirCount = { count: 0 }; + const downwardPaths = await collectDownwardGeminiFiles( + resolvedCwd, + debugMode, + DEFAULT_IGNORE_DIRECTORIES, + scannedDirCount, + MAX_DIRECTORIES_TO_SCAN_FOR_MEMORY, + ); + downwardPaths.sort(); // Sort for consistent ordering, though hierarchy might be more complex + if (debugMode && downwardPaths.length > 0) + logger.debug( + `Found downward ${GEMINI_MD_FILENAME} files (sorted): ${JSON.stringify(downwardPaths)}`, + ); + // Add downward paths only if they haven't been included already (e.g. from upward scan) + for (const dPath of downwardPaths) { + if (!paths.includes(dPath)) { + paths.push(dPath); + } + } + + if (debugMode) + logger.debug( + `Final ordered ${GEMINI_MD_FILENAME} paths to read: ${JSON.stringify(paths)}`, + ); + return paths; +} + +async function readGeminiMdFiles( + filePaths: string[], + debugMode: boolean, +): Promise<GeminiFileContent[]> { + const results: GeminiFileContent[] = []; + for (const filePath of filePaths) { + try { + const content = await fs.readFile(filePath, 'utf-8'); + results.push({ filePath, content }); + if (debugMode) + logger.debug( + `Successfully read: ${filePath} (Length: ${content.length})`, + ); + } catch (error: unknown) { + const message = error instanceof Error ? error.message : String(error); + logger.warn( + `Warning: Could not read ${GEMINI_MD_FILENAME} file at ${filePath}. Error: ${message}`, + ); + results.push({ filePath, content: null }); // Still include it with null content + if (debugMode) logger.debug(`Failed to read: ${filePath}`); + } + } + return results; +} + +function concatenateInstructions( + instructionContents: GeminiFileContent[], + // CWD is needed to resolve relative paths for display markers + currentWorkingDirectoryForDisplay: string, +): string { + return instructionContents + .filter((item) => typeof item.content === 'string') + .map((item) => { + const trimmedContent = (item.content as string).trim(); + if (trimmedContent.length === 0) { + return null; + } + const displayPath = path.isAbsolute(item.filePath) + ? path.relative(currentWorkingDirectoryForDisplay, item.filePath) + : item.filePath; + return `--- Context from: ${displayPath} ---\n${trimmedContent}\n--- End of Context from: ${displayPath} ---`; + }) + .filter((block): block is string => block !== null) + .join('\n\n'); +} + +/** + * Loads hierarchical GEMINI.md files and concatenates their content. + * This function is intended for use by the server. + */ +export async function loadServerHierarchicalMemory( + currentWorkingDirectory: string, + debugMode: boolean, +): Promise<{ memoryContent: string; fileCount: number }> { + if (debugMode) + logger.debug( + `Loading server hierarchical memory for CWD: ${currentWorkingDirectory}`, + ); + // For the server, homedir() refers to the server process's home. + // This is consistent with how MemoryTool already finds the global path. + const userHomePath = homedir(); + const filePaths = await getGeminiMdFilePathsInternal( + currentWorkingDirectory, + userHomePath, + debugMode, + ); + if (filePaths.length === 0) { + if (debugMode) logger.debug('No GEMINI.md files found in hierarchy.'); + return { memoryContent: '', fileCount: 0 }; + } + const contentsWithPaths = await readGeminiMdFiles(filePaths, debugMode); + // Pass CWD for relative path display in concatenated content + const combinedInstructions = concatenateInstructions( + contentsWithPaths, + currentWorkingDirectory, + ); + if (debugMode) + logger.debug( + `Combined instructions length: ${combinedInstructions.length}`, + ); + if (debugMode && combinedInstructions.length > 0) + logger.debug( + `Combined instructions (snippet): ${combinedInstructions.substring(0, 500)}...`, + ); + return { memoryContent: combinedInstructions, fileCount: filePaths.length }; +} diff --git a/packages/core/src/utils/messageInspectors.ts b/packages/core/src/utils/messageInspectors.ts new file mode 100644 index 00000000..b2c3cdce --- /dev/null +++ b/packages/core/src/utils/messageInspectors.ts @@ -0,0 +1,15 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Content } from '@google/genai'; + +export function isFunctionResponse(content: Content): boolean { + return ( + content.role === 'user' && + !!content.parts && + content.parts.every((part) => !!part.functionResponse) + ); +} diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts new file mode 100644 index 00000000..872e00f6 --- /dev/null +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest'; +import { Content, GoogleGenAI, Models } from '@google/genai'; +import { GeminiClient } from '../core/client.js'; +import { Config } from '../config/config.js'; +import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js'; +import { GeminiChat } from '../core/geminiChat.js'; + +// Mock GeminiClient and Config constructor +vi.mock('../core/client.js'); +vi.mock('../config/config.js'); + +// Define mocks for GoogleGenAI and Models instances that will be used across tests +const mockModelsInstance = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), +} as unknown as Models; + +const mockGoogleGenAIInstance = { + getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance), + // Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods +} as unknown as GoogleGenAI; + +vi.mock('@google/genai', async () => { + const actualGenAI = + await vi.importActual<typeof import('@google/genai')>('@google/genai'); + return { + ...actualGenAI, + GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance + // If Models is instantiated directly in GeminiChat, mock its constructor too + // For now, assuming Models instance is obtained via getGenerativeModel + }; +}); + +describe('checkNextSpeaker', () => { + let chatInstance: GeminiChat; + let mockGeminiClient: GeminiClient; + let MockConfig: Mock; + const abortSignal = new AbortController().signal; + + beforeEach(() => { + MockConfig = vi.mocked(Config); + const mockConfigInstance = new MockConfig( + 'test-api-key', + 'gemini-pro', + false, + '.', + false, + undefined, + false, + undefined, + undefined, + undefined, + ); + + mockGeminiClient = new GeminiClient(mockConfigInstance); + + // Reset mocks before each test to ensure test isolation + vi.mocked(mockModelsInstance.generateContent).mockReset(); + vi.mocked(mockModelsInstance.generateContentStream).mockReset(); + + // GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor + chatInstance = new GeminiChat( + mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor + mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel + 'gemini-pro', // model name + {}, + [], // initial history + ); + + // Spy on getHistory for chatInstance + vi.spyOn(chatInstance, 'getHistory'); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should return null if history is empty', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([]); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); + }); + + it('should return null if the last speaker was the user', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'user', parts: [{ text: 'Hello' }] }, + ] as Content[]); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); + }); + + it("should return { next_speaker: 'model' } when model intends to continue", async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'I will now do something.' }] }, + ] as Content[]); + const mockApiResponse: NextSpeakerResponse = { + reasoning: 'Model stated it will do something.', + next_speaker: 'model', + }; + (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toEqual(mockApiResponse); + expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1); + }); + + it("should return { next_speaker: 'user' } when model asks a question", async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'What would you like to do?' }] }, + ] as Content[]); + const mockApiResponse: NextSpeakerResponse = { + reasoning: 'Model asked a question.', + next_speaker: 'user', + }; + (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toEqual(mockApiResponse); + }); + + it("should return { next_speaker: 'user' } when model makes a statement", async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'This is a statement.' }] }, + ] as Content[]); + const mockApiResponse: NextSpeakerResponse = { + reasoning: 'Model made a statement, awaiting user input.', + next_speaker: 'user', + }; + (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toEqual(mockApiResponse); + }); + + it('should return null if geminiClient.generateJson throws an error', async () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'Some model output.' }] }, + ] as Content[]); + (mockGeminiClient.generateJson as Mock).mockRejectedValue( + new Error('API Error'), + ); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + consoleWarnSpy.mockRestore(); + }); + + it('should return null if geminiClient.generateJson returns invalid JSON (missing next_speaker)', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'Some model output.' }] }, + ] as Content[]); + (mockGeminiClient.generateJson as Mock).mockResolvedValue({ + reasoning: 'This is incomplete.', + } as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + }); + + it('should return null if geminiClient.generateJson returns a non-string next_speaker', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'Some model output.' }] }, + ] as Content[]); + (mockGeminiClient.generateJson as Mock).mockResolvedValue({ + reasoning: 'Model made a statement, awaiting user input.', + next_speaker: 123, // Invalid type + } as unknown as NextSpeakerResponse); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + }); + + it('should return null if geminiClient.generateJson returns an invalid next_speaker string value', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'Some model output.' }] }, + ] as Content[]); + (mockGeminiClient.generateJson as Mock).mockResolvedValue({ + reasoning: 'Model made a statement, awaiting user input.', + next_speaker: 'neither', // Invalid enum value + } as unknown as NextSpeakerResponse); + + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); + expect(result).toBeNull(); + }); +}); diff --git a/packages/core/src/utils/nextSpeakerChecker.ts b/packages/core/src/utils/nextSpeakerChecker.ts new file mode 100644 index 00000000..66fa4395 --- /dev/null +++ b/packages/core/src/utils/nextSpeakerChecker.ts @@ -0,0 +1,151 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Content, SchemaUnion, Type } from '@google/genai'; +import { GeminiClient } from '../core/client.js'; +import { GeminiChat } from '../core/geminiChat.js'; +import { isFunctionResponse } from './messageInspectors.js'; + +const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). +**Decision Rules (apply in order):** +1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next. +2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next. +3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next. +**Output Format:** +Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure. +\`\`\`json +{ + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn." + }, + "next_speaker": { + "type": "string", + "enum": ["user", "model"], + "description": "Who should speak next based *only* on the preceding turn and the decision rules." + } + }, + "required": ["next_speaker", "reasoning"] +} +\`\`\` +`; + +const RESPONSE_SCHEMA: SchemaUnion = { + type: Type.OBJECT, + properties: { + reasoning: { + type: Type.STRING, + description: + "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.", + }, + next_speaker: { + type: Type.STRING, + enum: ['user', 'model'], + description: + 'Who should speak next based *only* on the preceding turn and the decision rules', + }, + }, + required: ['reasoning', 'next_speaker'], +}; + +export interface NextSpeakerResponse { + reasoning: string; + next_speaker: 'user' | 'model'; +} + +export async function checkNextSpeaker( + chat: GeminiChat, + geminiClient: GeminiClient, + abortSignal: AbortSignal, +): Promise<NextSpeakerResponse | null> { + // We need to capture the curated history because there are many moments when the model will return invalid turns + // that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides + // to respond with an empty part collection if you were to send that message back to the server it will respond with + // a 400 indicating that model part collections MUST have content. + const curatedHistory = chat.getHistory(/* curated */ true); + + // Ensure there's a model response to analyze + if (curatedHistory.length === 0) { + // Cannot determine next speaker if history is empty. + return null; + } + + const comprehensiveHistory = chat.getHistory(); + // If comprehensiveHistory is empty, there is no last message to check. + // This case should ideally be caught by the curatedHistory.length check earlier, + // but as a safeguard: + if (comprehensiveHistory.length === 0) { + return null; + } + const lastComprehensiveMessage = + comprehensiveHistory[comprehensiveHistory.length - 1]; + + // If the last message is a user message containing only function_responses, + // then the model should speak next. + if ( + lastComprehensiveMessage && + isFunctionResponse(lastComprehensiveMessage) + ) { + return { + reasoning: + 'The last message was a function response, so the model should speak next.', + next_speaker: 'model', + }; + } + + if ( + lastComprehensiveMessage && + lastComprehensiveMessage.role === 'model' && + lastComprehensiveMessage.parts && + lastComprehensiveMessage.parts.length === 0 + ) { + lastComprehensiveMessage.parts.push({ text: '' }); + return { + reasoning: + 'The last message was a filler model message with no content (nothing for user to act on), model should speak next.', + next_speaker: 'model', + }; + } + + // Things checked out. Lets proceed to potentially making an LLM request. + + const lastMessage = curatedHistory[curatedHistory.length - 1]; + if (!lastMessage || lastMessage.role !== 'model') { + // Cannot determine next speaker if the last turn wasn't from the model + // or if history is empty. + return null; + } + + const contents: Content[] = [ + ...curatedHistory, + { role: 'user', parts: [{ text: CHECK_PROMPT }] }, + ]; + + try { + const parsedResponse = (await geminiClient.generateJson( + contents, + RESPONSE_SCHEMA, + abortSignal, + )) as unknown as NextSpeakerResponse; + + if ( + parsedResponse && + parsedResponse.next_speaker && + ['user', 'model'].includes(parsedResponse.next_speaker) + ) { + return parsedResponse; + } + return null; + } catch (error) { + console.warn( + 'Failed to talk to Gemini endpoint when seeing if conversation should continue.', + error, + ); + return null; + } +} diff --git a/packages/core/src/utils/paths.ts b/packages/core/src/utils/paths.ts new file mode 100644 index 00000000..bbd479fd --- /dev/null +++ b/packages/core/src/utils/paths.ts @@ -0,0 +1,139 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import path from 'node:path'; +import os from 'os'; + +/** + * Replaces the home directory with a tilde. + * @param path - The path to tildeify. + * @returns The tildeified path. + */ +export function tildeifyPath(path: string): string { + const homeDir = os.homedir(); + if (path.startsWith(homeDir)) { + return path.replace(homeDir, '~'); + } + return path; +} + +/** + * Shortens a path string if it exceeds maxLen, prioritizing the start and end segments. + * Example: /path/to/a/very/long/file.txt -> /path/.../long/file.txt + */ +export function shortenPath(filePath: string, maxLen: number = 35): string { + if (filePath.length <= maxLen) { + return filePath; + } + + const parsedPath = path.parse(filePath); + const root = parsedPath.root; + const separator = path.sep; + + // Get segments of the path *after* the root + const relativePath = filePath.substring(root.length); + const segments = relativePath.split(separator).filter((s) => s !== ''); // Filter out empty segments + + // Handle cases with no segments after root (e.g., "/", "C:\") or only one segment + if (segments.length <= 1) { + // Fallback to simple start/end truncation for very short paths or single segments + const keepLen = Math.floor((maxLen - 3) / 2); + // Ensure keepLen is not negative if maxLen is very small + if (keepLen <= 0) { + return filePath.substring(0, maxLen - 3) + '...'; + } + const start = filePath.substring(0, keepLen); + const end = filePath.substring(filePath.length - keepLen); + return `${start}...${end}`; + } + + const firstDir = segments[0]; + const startComponent = root + firstDir; + + const endPartSegments: string[] = []; + // Base length: startComponent + separator + "..." + let currentLength = startComponent.length + separator.length + 3; + + // Iterate backwards through segments (excluding the first one) + for (let i = segments.length - 1; i >= 1; i--) { + const segment = segments[i]; + // Length needed if we add this segment: current + separator + segment + const lengthWithSegment = currentLength + separator.length + segment.length; + + if (lengthWithSegment <= maxLen) { + endPartSegments.unshift(segment); // Add to the beginning of the end part + currentLength = lengthWithSegment; + } else { + // Adding this segment would exceed maxLen + break; + } + } + + // Construct the final path + let result = startComponent + separator + '...'; + if (endPartSegments.length > 0) { + result += separator + endPartSegments.join(separator); + } + + // As a final check, if the result is somehow still too long (e.g., startComponent + ... is too long) + // fallback to simple truncation of the original path + if (result.length > maxLen) { + const keepLen = Math.floor((maxLen - 3) / 2); + if (keepLen <= 0) { + return filePath.substring(0, maxLen - 3) + '...'; + } + const start = filePath.substring(0, keepLen); + const end = filePath.substring(filePath.length - keepLen); + return `${start}...${end}`; + } + + return result; +} + +/** + * Calculates the relative path from a root directory to a target path. + * Ensures both paths are resolved before calculating. + * Returns '.' if the target path is the same as the root directory. + * + * @param targetPath The absolute or relative path to make relative. + * @param rootDirectory The absolute path of the directory to make the target path relative to. + * @returns The relative path from rootDirectory to targetPath. + */ +export function makeRelative( + targetPath: string, + rootDirectory: string, +): string { + const resolvedTargetPath = path.resolve(targetPath); + const resolvedRootDirectory = path.resolve(rootDirectory); + + const relativePath = path.relative(resolvedRootDirectory, resolvedTargetPath); + + // If the paths are the same, path.relative returns '', return '.' instead + return relativePath || '.'; +} + +/** + * Escapes spaces in a file path. + */ +export function escapePath(filePath: string): string { + let result = ''; + for (let i = 0; i < filePath.length; i++) { + // Only escape spaces that are not already escaped. + if (filePath[i] === ' ' && (i === 0 || filePath[i - 1] !== '\\')) { + result += '\\ '; + } else { + result += filePath[i]; + } + } + return result; +} + +/** + * Unescapes spaces in a file path. + */ +export function unescapePath(filePath: string): string { + return filePath.replace(/\\ /g, ' '); +} diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts new file mode 100644 index 00000000..ea344d60 --- /dev/null +++ b/packages/core/src/utils/retry.test.ts @@ -0,0 +1,238 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { retryWithBackoff } from './retry.js'; + +// Define an interface for the error with a status property +interface HttpError extends Error { + status?: number; +} + +// Helper to create a mock function that fails a certain number of times +const createFailingFunction = ( + failures: number, + successValue: string = 'success', +) => { + let attempts = 0; + return vi.fn(async () => { + attempts++; + if (attempts <= failures) { + // Simulate a retryable error + const error: HttpError = new Error(`Simulated error attempt ${attempts}`); + error.status = 500; // Simulate a server error + throw error; + } + return successValue; + }); +}; + +// Custom error for testing non-retryable conditions +class NonRetryableError extends Error { + constructor(message: string) { + super(message); + this.name = 'NonRetryableError'; + } +} + +describe('retryWithBackoff', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should return the result on the first attempt if successful', async () => { + const mockFn = createFailingFunction(0); + const result = await retryWithBackoff(mockFn); + expect(result).toBe('success'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should retry and succeed if failures are within maxAttempts', async () => { + const mockFn = createFailingFunction(2); + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 10, + }); + + await vi.runAllTimersAsync(); // Ensure all delays and retries complete + + const result = await promise; + expect(result).toBe('success'); + expect(mockFn).toHaveBeenCalledTimes(3); + }); + + it('should throw an error if all attempts fail', async () => { + const mockFn = createFailingFunction(3); + + // 1. Start the retryable operation, which returns a promise. + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 10, + }); + + // 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*. + // This ensures a 'catch' handler is present before the promise can reject. + // The result is a new promise that resolves when the assertion is met. + const assertionPromise = expect(promise).rejects.toThrow( + 'Simulated error attempt 3', + ); + + // 3. Now, advance the timers. This will trigger the retries and the + // eventual rejection. The handler attached in step 2 will catch it. + await vi.runAllTimersAsync(); + + // 4. Await the assertion promise itself to ensure the test was successful. + await assertionPromise; + + // 5. Finally, assert the number of calls. + expect(mockFn).toHaveBeenCalledTimes(3); + }); + + it('should not retry if shouldRetry returns false', async () => { + const mockFn = vi.fn(async () => { + throw new NonRetryableError('Non-retryable error'); + }); + const shouldRetry = (error: Error) => !(error instanceof NonRetryableError); + + const promise = retryWithBackoff(mockFn, { + shouldRetry, + initialDelayMs: 10, + }); + + await expect(promise).rejects.toThrow('Non-retryable error'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should use default shouldRetry if not provided, retrying on 429', async () => { + const mockFn = vi.fn(async () => { + const error = new Error('Too Many Requests') as any; + error.status = 429; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + + // Attach the rejection expectation *before* running timers + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + + // Run timers to trigger retries and eventual rejection + await vi.runAllTimersAsync(); + + // Await the assertion + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + }); + + it('should use default shouldRetry if not provided, not retrying on 400', async () => { + const mockFn = vi.fn(async () => { + const error = new Error('Bad Request') as any; + error.status = 400; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + await expect(promise).rejects.toThrow('Bad Request'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should respect maxDelayMs', async () => { + const mockFn = createFailingFunction(3); + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 4, + initialDelayMs: 100, + maxDelayMs: 250, // Max delay is less than 100 * 2 * 2 = 400 + }); + + await vi.advanceTimersByTimeAsync(1000); // Advance well past all delays + await promise; + + const delays = setTimeoutSpy.mock.calls.map((call) => call[1] as number); + + // Delays should be around initial, initial*2, maxDelay (due to cap) + // Jitter makes exact assertion hard, so we check ranges / caps + expect(delays.length).toBe(3); + expect(delays[0]).toBeGreaterThanOrEqual(100 * 0.7); + expect(delays[0]).toBeLessThanOrEqual(100 * 1.3); + expect(delays[1]).toBeGreaterThanOrEqual(200 * 0.7); + expect(delays[1]).toBeLessThanOrEqual(200 * 1.3); + // The third delay should be capped by maxDelayMs (250ms), accounting for jitter + expect(delays[2]).toBeGreaterThanOrEqual(250 * 0.7); + expect(delays[2]).toBeLessThanOrEqual(250 * 1.3); + + setTimeoutSpy.mockRestore(); + }); + + it('should handle jitter correctly, ensuring varied delays', async () => { + let mockFn = createFailingFunction(5); + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + // Run retryWithBackoff multiple times to observe jitter + const runRetry = () => + retryWithBackoff(mockFn, { + maxAttempts: 2, // Only one retry, so one delay + initialDelayMs: 100, + maxDelayMs: 1000, + }); + + // We expect rejections as mockFn fails 5 times + const promise1 = runRetry(); + // Attach the rejection expectation *before* running timers + const assertionPromise1 = expect(promise1).rejects.toThrow(); + await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry + await assertionPromise1; + + const firstDelaySet = setTimeoutSpy.mock.calls.map( + (call) => call[1] as number, + ); + setTimeoutSpy.mockClear(); // Clear calls for the next run + + // Reset mockFn to reset its internal attempt counter for the next run + mockFn = createFailingFunction(5); // Re-initialize with 5 failures + + const promise2 = runRetry(); + // Attach the rejection expectation *before* running timers + const assertionPromise2 = expect(promise2).rejects.toThrow(); + await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry + await assertionPromise2; + + const secondDelaySet = setTimeoutSpy.mock.calls.map( + (call) => call[1] as number, + ); + + // Check that the delays are not exactly the same due to jitter + // This is a probabilistic test, but with +/-30% jitter, it's highly likely they differ. + if (firstDelaySet.length > 0 && secondDelaySet.length > 0) { + // Check the first delay of each set + expect(firstDelaySet[0]).not.toBe(secondDelaySet[0]); + } else { + // If somehow no delays were captured (e.g. test setup issue), fail explicitly + throw new Error('Delays were not captured for jitter test'); + } + + // Ensure delays are within the expected jitter range [70, 130] for initialDelayMs = 100 + [...firstDelaySet, ...secondDelaySet].forEach((d) => { + expect(d).toBeGreaterThanOrEqual(100 * 0.7); + expect(d).toBeLessThanOrEqual(100 * 1.3); + }); + + setTimeoutSpy.mockRestore(); + }); +}); diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts new file mode 100644 index 00000000..1e7d5bcb --- /dev/null +++ b/packages/core/src/utils/retry.ts @@ -0,0 +1,227 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface RetryOptions { + maxAttempts: number; + initialDelayMs: number; + maxDelayMs: number; + shouldRetry: (error: Error) => boolean; +} + +const DEFAULT_RETRY_OPTIONS: RetryOptions = { + maxAttempts: 5, + initialDelayMs: 5000, + maxDelayMs: 30000, // 30 seconds + shouldRetry: defaultShouldRetry, +}; + +/** + * Default predicate function to determine if a retry should be attempted. + * Retries on 429 (Too Many Requests) and 5xx server errors. + * @param error The error object. + * @returns True if the error is a transient error, false otherwise. + */ +function defaultShouldRetry(error: Error | unknown): boolean { + // Check for common transient error status codes either in message or a status property + if (error && typeof (error as { status?: number }).status === 'number') { + const status = (error as { status: number }).status; + if (status === 429 || (status >= 500 && status < 600)) { + return true; + } + } + if (error instanceof Error && error.message) { + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; +} + +/** + * Delays execution for a specified number of milliseconds. + * @param ms The number of milliseconds to delay. + * @returns A promise that resolves after the delay. + */ +function delay(ms: number): Promise<void> { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** + * Retries a function with exponential backoff and jitter. + * @param fn The asynchronous function to retry. + * @param options Optional retry configuration. + * @returns A promise that resolves with the result of the function if successful. + * @throws The last error encountered if all attempts fail. + */ +export async function retryWithBackoff<T>( + fn: () => Promise<T>, + options?: Partial<RetryOptions>, +): Promise<T> { + const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = { + ...DEFAULT_RETRY_OPTIONS, + ...options, + }; + + let attempt = 0; + let currentDelay = initialDelayMs; + + while (attempt < maxAttempts) { + attempt++; + try { + return await fn(); + } catch (error) { + if (attempt >= maxAttempts || !shouldRetry(error as Error)) { + throw error; + } + + const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error); + + if (delayDurationMs > 0) { + // Respect Retry-After header if present and parsed + console.warn( + `Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, + error, + ); + await delay(delayDurationMs); + // Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time + currentDelay = initialDelayMs; + } else { + // Fallback to exponential backoff with jitter + logRetryAttempt(attempt, error, errorStatus); + // Add jitter: +/- 30% of currentDelay + const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1); + const delayWithJitter = Math.max(0, currentDelay + jitter); + await delay(delayWithJitter); + currentDelay = Math.min(maxDelayMs, currentDelay * 2); + } + } + } + // This line should theoretically be unreachable due to the throw in the catch block. + // Added for type safety and to satisfy the compiler that a promise is always returned. + throw new Error('Retry attempts exhausted'); +} + +/** + * Extracts the HTTP status code from an error object. + * @param error The error object. + * @returns The HTTP status code, or undefined if not found. + */ +function getErrorStatus(error: unknown): number | undefined { + if (typeof error === 'object' && error !== null) { + if ('status' in error && typeof error.status === 'number') { + return error.status; + } + // Check for error.response.status (common in axios errors) + if ( + 'response' in error && + typeof (error as { response?: unknown }).response === 'object' && + (error as { response?: unknown }).response !== null + ) { + const response = ( + error as { response: { status?: unknown; headers?: unknown } } + ).response; + if ('status' in response && typeof response.status === 'number') { + return response.status; + } + } + } + return undefined; +} + +/** + * Extracts the Retry-After delay from an error object's headers. + * @param error The error object. + * @returns The delay in milliseconds, or 0 if not found or invalid. + */ +function getRetryAfterDelayMs(error: unknown): number { + if (typeof error === 'object' && error !== null) { + // Check for error.response.headers (common in axios errors) + if ( + 'response' in error && + typeof (error as { response?: unknown }).response === 'object' && + (error as { response?: unknown }).response !== null + ) { + const response = (error as { response: { headers?: unknown } }).response; + if ( + 'headers' in response && + typeof response.headers === 'object' && + response.headers !== null + ) { + const headers = response.headers as { 'retry-after'?: unknown }; + const retryAfterHeader = headers['retry-after']; + if (typeof retryAfterHeader === 'string') { + const retryAfterSeconds = parseInt(retryAfterHeader, 10); + if (!isNaN(retryAfterSeconds)) { + return retryAfterSeconds * 1000; + } + // It might be an HTTP date + const retryAfterDate = new Date(retryAfterHeader); + if (!isNaN(retryAfterDate.getTime())) { + return Math.max(0, retryAfterDate.getTime() - Date.now()); + } + } + } + } + } + return 0; +} + +/** + * Determines the delay duration based on the error, prioritizing Retry-After header. + * @param error The error object. + * @returns An object containing the delay duration in milliseconds and the error status. + */ +function getDelayDurationAndStatus(error: unknown): { + delayDurationMs: number; + errorStatus: number | undefined; +} { + const errorStatus = getErrorStatus(error); + let delayDurationMs = 0; + + if (errorStatus === 429) { + delayDurationMs = getRetryAfterDelayMs(error); + } + return { delayDurationMs, errorStatus }; +} + +/** + * Logs a message for a retry attempt when using exponential backoff. + * @param attempt The current attempt number. + * @param error The error that caused the retry. + * @param errorStatus The HTTP status code of the error, if available. + */ +function logRetryAttempt( + attempt: number, + error: unknown, + errorStatus?: number, +): void { + let message = `Attempt ${attempt} failed. Retrying with backoff...`; + if (errorStatus) { + message = `Attempt ${attempt} failed with status ${errorStatus}. Retrying with backoff...`; + } + + if (errorStatus === 429) { + console.warn(message, error); + } else if (errorStatus && errorStatus >= 500 && errorStatus < 600) { + console.error(message, error); + } else if (error instanceof Error) { + // Fallback for errors that might not have a status but have a message + if (error.message.includes('429')) { + console.warn( + `Attempt ${attempt} failed with 429 error (no Retry-After header). Retrying with backoff...`, + error, + ); + } else if (error.message.match(/5\d{2}/)) { + console.error( + `Attempt ${attempt} failed with 5xx error. Retrying with backoff...`, + error, + ); + } else { + console.warn(message, error); // Default to warn for other errors + } + } else { + console.warn(message, error); // Default to warn if error type is unknown + } +} diff --git a/packages/core/src/utils/schemaValidator.ts b/packages/core/src/utils/schemaValidator.ts new file mode 100644 index 00000000..34ed5843 --- /dev/null +++ b/packages/core/src/utils/schemaValidator.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Simple utility to validate objects against JSON Schemas + */ +export class SchemaValidator { + /** + * Validates data against a JSON schema + * @param schema JSON Schema to validate against + * @param data Data to validate + * @returns True if valid, false otherwise + */ + static validate(schema: Record<string, unknown>, data: unknown): boolean { + // This is a simplified implementation + // In a real application, you would use a library like Ajv for proper validation + + // Check for required fields + if (schema.required && Array.isArray(schema.required)) { + const required = schema.required as string[]; + const dataObj = data as Record<string, unknown>; + + for (const field of required) { + if (dataObj[field] === undefined) { + console.error(`Missing required field: ${field}`); + return false; + } + } + } + + // Check property types if properties are defined + if (schema.properties && typeof schema.properties === 'object') { + const properties = schema.properties as Record<string, { type?: string }>; + const dataObj = data as Record<string, unknown>; + + for (const [key, prop] of Object.entries(properties)) { + if (dataObj[key] !== undefined && prop.type) { + const expectedType = prop.type; + const actualType = Array.isArray(dataObj[key]) + ? 'array' + : typeof dataObj[key]; + + if (expectedType !== actualType) { + console.error( + `Type mismatch for property "${key}": expected ${expectedType}, got ${actualType}`, + ); + return false; + } + } + } + } + + return true; + } +} |
