diff options
Diffstat (limited to 'packages/server/src/utils')
| -rw-r--r-- | packages/server/src/utils/editCorrector.test.ts | 33 | ||||
| -rw-r--r-- | packages/server/src/utils/editCorrector.ts | 37 | ||||
| -rw-r--r-- | packages/server/src/utils/nextSpeakerChecker.test.ts | 57 | ||||
| -rw-r--r-- | packages/server/src/utils/nextSpeakerChecker.ts | 2 |
4 files changed, 112 insertions, 17 deletions
diff --git a/packages/server/src/utils/editCorrector.test.ts b/packages/server/src/utils/editCorrector.test.ts index 27c9ffe8..7d6f5a53 100644 --- a/packages/server/src/utils/editCorrector.test.ts +++ b/packages/server/src/utils/editCorrector.test.ts @@ -132,6 +132,7 @@ describe('editCorrector', () => { let mockGeminiClientInstance: Mocked<GeminiClient>; let mockToolRegistry: Mocked<ToolRegistry>; let mockConfigInstance: Config; + const abortSignal = new AbortController().signal; beforeEach(() => { mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>; @@ -187,12 +188,18 @@ describe('editCorrector', () => { callCount = 0; mockResponses.length = 0; - mockGenerateJson = vi.fn().mockImplementation(() => { - const response = mockResponses[callCount]; - callCount++; - if (response === undefined) return Promise.resolve({}); - return Promise.resolve(response); - }); + mockGenerateJson = vi + .fn() + .mockImplementation((_contents, _schema, signal) => { + // Check if the signal is aborted. If so, throw an error or return a specific response. + if (signal && signal.aborted) { + return Promise.reject(new Error('Aborted')); // Or some other specific error/response + } + const response = mockResponses[callCount]; + callCount++; + if (response === undefined) return Promise.resolve({}); + return Promise.resolve(response); + }); mockStartChat = vi.fn(); mockSendMessageStream = vi.fn(); @@ -217,6 +224,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -234,6 +242,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -254,6 +263,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -271,6 +281,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -292,6 +303,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -309,6 +321,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -329,6 +342,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with foobar'); @@ -351,6 +365,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe(llmNewString); @@ -372,6 +387,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(result.params.new_string).toBe(llmNewString); @@ -391,6 +407,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -412,6 +429,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM); @@ -432,6 +450,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params).toEqual(originalParams); @@ -449,6 +468,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params).toEqual(originalParams); @@ -471,6 +491,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(result.params.old_string).toBe(currentContent); diff --git a/packages/server/src/utils/editCorrector.ts b/packages/server/src/utils/editCorrector.ts index 92551478..78663954 100644 --- a/packages/server/src/utils/editCorrector.ts +++ b/packages/server/src/utils/editCorrector.ts @@ -63,6 +63,7 @@ export async function ensureCorrectEdit( currentContent: string, originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\' client: GeminiClient, + abortSignal: AbortSignal, ): Promise<CorrectedEditResult> { const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`; const cachedResult = editCorrectionCache.get(cacheKey); @@ -84,6 +85,7 @@ export async function ensureCorrectEdit( client, finalOldString, originalParams.new_string, + abortSignal, ); } } else if (occurrences > 1) { @@ -108,6 +110,7 @@ export async function ensureCorrectEdit( originalParams.old_string, // original old unescapedOldStringAttempt, // corrected old originalParams.new_string, // original new (which is potentially escaped) + abortSignal, ); } } else if (occurrences === 0) { @@ -115,6 +118,7 @@ export async function ensureCorrectEdit( client, currentContent, unescapedOldStringAttempt, + abortSignal, ); const llmOldOccurrences = countOccurrences( currentContent, @@ -134,6 +138,7 @@ export async function ensureCorrectEdit( originalParams.old_string, // original old llmCorrectedOldString, // corrected old baseNewStringForLLMCorrection, // base new for correction + abortSignal, ); } } else { @@ -180,6 +185,7 @@ export async function ensureCorrectEdit( export async function ensureCorrectFileContent( content: string, client: GeminiClient, + abortSignal: AbortSignal, ): Promise<string> { const cachedResult = fileContentCorrectionCache.get(content); if (cachedResult) { @@ -193,7 +199,11 @@ export async function ensureCorrectFileContent( return content; } - const correctedContent = await correctStringEscaping(content, client); + const correctedContent = await correctStringEscaping( + content, + client, + abortSignal, + ); fileContentCorrectionCache.set(content, correctedContent); return correctedContent; } @@ -215,6 +225,7 @@ export async function correctOldStringMismatch( geminiClient: GeminiClient, fileContent: string, problematicSnippet: string, + abortSignal: AbortSignal, ): Promise<string> { const prompt = ` Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped. @@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k const result = await geminiClient.generateJson( contents, OLD_STRING_CORRECTION_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k return problematicSnippet; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for old string snippet correction:', error, ); + return problematicSnippet; } } @@ -286,6 +303,7 @@ export async function correctNewString( originalOldString: string, correctedOldString: string, originalNewString: string, + abortSignal: AbortSignal, ): Promise<string> { if (originalOldString === correctedOldString) { return originalNewString; @@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await geminiClient.generateJson( contents, NEW_STRING_CORRECTION_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return originalNewString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error('Error during LLM call for new_string correction:', error); return originalNewString; } @@ -359,6 +382,7 @@ export async function correctNewStringEscaping( geminiClient: GeminiClient, oldString: string, potentiallyProblematicNewString: string, + abortSignal: AbortSignal, ): Promise<string> { const prompt = ` Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). @@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await geminiClient.generateJson( contents, CORRECT_NEW_STRING_ESCAPING_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return potentiallyProblematicNewString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for new_string escaping correction:', error, @@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = { export async function correctStringEscaping( potentiallyProblematicString: string, client: GeminiClient, + abortSignal: AbortSignal, ): Promise<string> { const prompt = ` Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). @@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await client.generateJson( contents, CORRECT_STRING_ESCAPING_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return potentiallyProblematicString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for string escaping correction:', error, diff --git a/packages/server/src/utils/nextSpeakerChecker.test.ts b/packages/server/src/utils/nextSpeakerChecker.test.ts index 1d87bffb..872e00f6 100644 --- a/packages/server/src/utils/nextSpeakerChecker.test.ts +++ b/packages/server/src/utils/nextSpeakerChecker.test.ts @@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => { let chatInstance: GeminiChat; let mockGeminiClient: GeminiClient; let MockConfig: Mock; + const abortSignal = new AbortController().signal; beforeEach(() => { MockConfig = vi.mocked(Config); @@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => { mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel 'gemini-pro', // model name - {}, // config + {}, [], // initial history ); @@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => { it('should return null if history is empty', async () => { (chatInstance.getHistory as Mock).mockReturnValue([]); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); }); @@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => { (chatInstance.getHistory as Mock).mockReturnValue([ { role: 'user', parts: [{ text: 'Hello' }] }, ] as Content[]); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); }); @@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1); }); @@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); }); @@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); }); @@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => { new Error('API Error'), ); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); consoleWarnSpy.mockRestore(); }); @@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => { reasoning: 'This is incomplete.', } as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); @@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => { next_speaker: 123, // Invalid type } as unknown as NextSpeakerResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); @@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => { next_speaker: 'neither', // Invalid enum value } as unknown as NextSpeakerResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); }); diff --git a/packages/server/src/utils/nextSpeakerChecker.ts b/packages/server/src/utils/nextSpeakerChecker.ts index fb00b39c..66fa4395 100644 --- a/packages/server/src/utils/nextSpeakerChecker.ts +++ b/packages/server/src/utils/nextSpeakerChecker.ts @@ -61,6 +61,7 @@ export interface NextSpeakerResponse { export async function checkNextSpeaker( chat: GeminiChat, geminiClient: GeminiClient, + abortSignal: AbortSignal, ): Promise<NextSpeakerResponse | null> { // We need to capture the curated history because there are many moments when the model will return invalid turns // that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides @@ -129,6 +130,7 @@ export async function checkNextSpeaker( const parsedResponse = (await geminiClient.generateJson( contents, RESPONSE_SCHEMA, + abortSignal, )) as unknown as NextSpeakerResponse; if ( |
